// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef DECODER_CTC_WFST_BEAM_SEARCH_H_ #define DECODER_CTC_WFST_BEAM_SEARCH_H_ #include #include #include "decoder/context_graph.h" #include "decoder/search_interface.h" #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "../utils/wn_utils.h" namespace wenet { class DecodableTensorScaled : public kaldi::DecodableInterface { public: explicit DecodableTensorScaled(float scale = 1.0) : scale_(scale) { Reset(); } void Reset(); int32 NumFramesReady() const override { return num_frames_ready_; } bool IsLastFrame(int32 frame) const override; float LogLikelihood(int32 frame, int32 index) override; int32 NumIndices() const override; void AcceptLoglikes(const std::vector& logp); void SetFinish() { done_ = true; } private: int num_frames_ready_ = 0; float scale_ = 1.0; bool done_ = false; std::vector logp_; }; // LatticeFasterDecoderConfig has the following key members // beam: decoding beam // max_active: Decoder max active states // lattice_beam: Lattice generation beam struct CtcWfstBeamSearchOptions : public kaldi::LatticeFasterDecoderConfig { float acoustic_scale = 1.0; float nbest = 10; // When blank score is greater than this thresh, skip the frame in viterbi // search float blank_skip_thresh = 0.98; float blank_scale = 1.0; int blank = 0; }; class CtcWfstBeamSearch : public SearchInterface { public: explicit CtcWfstBeamSearch( const fst::Fst& fst, const CtcWfstBeamSearchOptions& opts, const std::shared_ptr& context_graph); void Search(const std::vector>& logp) override; void Reset() override; void FinalizeSearch() override; SearchType Type() const override { return SearchType::kWfstBeamSearch; } // For CTC prefix beam search, both inputs and outputs are hypotheses_ const std::vector>& Inputs() const override { return inputs_; } const std::vector>& Outputs() const override { return outputs_; } const std::vector& Likelihood() const override { return likelihood_; } const std::vector>& Times() const override { return times_; } private: // Sub one and remove void ConvertToInputs(const std::vector& alignment, std::vector* input, std::vector* time = nullptr); int num_frames_ = 0; std::vector decoded_frames_mapping_; int last_best_ = 0; // last none blank best id std::vector last_frame_prob_; bool is_last_frame_blank_ = false; std::vector> inputs_, outputs_; std::vector likelihood_; std::vector> times_; DecodableTensorScaled decodable_; kaldi::LatticeFasterOnlineDecoder decoder_; std::shared_ptr context_graph_; const CtcWfstBeamSearchOptions& opts_; }; } // namespace wenet #endif // DECODER_CTC_WFST_BEAM_SEARCH_H_