// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef DECODER_CTC_PREFIX_BEAM_SEARCH_H_ #define DECODER_CTC_PREFIX_BEAM_SEARCH_H_ #include #include #include #include #include "decoder/context_graph.h" #include "decoder/search_interface.h" #include "../utils/wn_utils.h" namespace wenet { struct CtcPrefixBeamSearchOptions { int blank = 0; // blank id int first_beam_size = 10; int second_beam_size = 10; }; struct PrefixScore { float s = -kFloatMax; // blank ending score float ns = -kFloatMax; // none blank ending score float v_s = -kFloatMax; // viterbi blank ending score float v_ns = -kFloatMax; // viterbi none blank ending score float cur_token_prob = -kFloatMax; // prob of current token std::vector times_s; // times of viterbi blank path std::vector times_ns; // times of viterbi none blank path float score() const { return LogAdd(s, ns); } float viterbi_score() const { return v_s > v_ns ? v_s : v_ns; } const std::vector& times() const { return v_s > v_ns ? times_s : times_ns; } bool has_context = false; int context_state = 0; float context_score = 0; void CopyContext(const PrefixScore& prefix_score) { context_state = prefix_score.context_state; context_score = prefix_score.context_score; } void UpdateContext(const std::shared_ptr& context_graph, const PrefixScore& prefix_score, int word_id) { this->CopyContext(prefix_score); float score = 0; context_state = context_graph->GetNextState(prefix_score.context_state, word_id, &score); context_score += score; } float total_score() const { return score() + context_score; } }; struct PrefixHash { size_t operator()(const std::vector& prefix) const { size_t hash_code = 0; // here we use KB&DR hash code for (int id : prefix) { hash_code = id + 31 * hash_code; } return hash_code; } }; class CtcPrefixBeamSearch : public SearchInterface { public: explicit CtcPrefixBeamSearch( const CtcPrefixBeamSearchOptions& opts, const std::shared_ptr& context_graph = nullptr); void Search(const std::vector>& logp) override; void Reset() override; void FinalizeSearch() override; SearchType Type() const override { return SearchType::kPrefixBeamSearch; } void UpdateHypotheses( const std::vector, PrefixScore>>& hpys); const std::vector& viterbi_likelihood() const { return viterbi_likelihood_; } const std::vector>& Inputs() const override { return hypotheses_; } const std::vector>& Outputs() const override { return outputs_; } const std::vector& Likelihood() const override { return likelihood_; } const std::vector>& Times() const override { return times_; } private: int abs_time_step_ = 0; // N-best list and corresponding likelihood_, in sorted order std::vector> hypotheses_; std::vector likelihood_; std::vector viterbi_likelihood_; std::vector> times_; std::unordered_map, PrefixScore, PrefixHash> cur_hyps_; std::shared_ptr context_graph_ = nullptr; // Outputs contain the hypotheses_ and tags like: and std::vector> outputs_; const CtcPrefixBeamSearchOptions& opts_; public: WENET_DISALLOW_COPY_AND_ASSIGN(CtcPrefixBeamSearch); }; } // namespace wenet #endif // DECODER_CTC_PREFIX_BEAM_SEARCH_H_