You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.2 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef DECODER_CTC_PREFIX_BEAM_SEARCH_H_
  15. #define DECODER_CTC_PREFIX_BEAM_SEARCH_H_
  16. #include <memory>
  17. #include <unordered_map>
  18. #include <utility>
  19. #include <vector>
  20. #include "decoder/context_graph.h"
  21. #include "decoder/search_interface.h"
  22. #include "../utils/wn_utils.h"
  23. namespace wenet {
  24. struct CtcPrefixBeamSearchOptions {
  25. int blank = 0; // blank id
  26. int first_beam_size = 10;
  27. int second_beam_size = 10;
  28. };
  29. struct PrefixScore {
  30. float s = -kFloatMax; // blank ending score
  31. float ns = -kFloatMax; // none blank ending score
  32. float v_s = -kFloatMax; // viterbi blank ending score
  33. float v_ns = -kFloatMax; // viterbi none blank ending score
  34. float cur_token_prob = -kFloatMax; // prob of current token
  35. std::vector<int> times_s; // times of viterbi blank path
  36. std::vector<int> times_ns; // times of viterbi none blank path
  37. float score() const { return LogAdd(s, ns); }
  38. float viterbi_score() const { return v_s > v_ns ? v_s : v_ns; }
  39. const std::vector<int>& times() const {
  40. return v_s > v_ns ? times_s : times_ns;
  41. }
  42. bool has_context = false;
  43. int context_state = 0;
  44. float context_score = 0;
  45. void CopyContext(const PrefixScore& prefix_score) {
  46. context_state = prefix_score.context_state;
  47. context_score = prefix_score.context_score;
  48. }
  49. void UpdateContext(const std::shared_ptr<ContextGraph>& context_graph,
  50. const PrefixScore& prefix_score, int word_id) {
  51. this->CopyContext(prefix_score);
  52. float score = 0;
  53. context_state = context_graph->GetNextState(prefix_score.context_state,
  54. word_id, &score);
  55. context_score += score;
  56. }
  57. float total_score() const { return score() + context_score; }
  58. };
  59. struct PrefixHash {
  60. size_t operator()(const std::vector<int>& prefix) const {
  61. size_t hash_code = 0;
  62. // here we use KB&DR hash code
  63. for (int id : prefix) {
  64. hash_code = id + 31 * hash_code;
  65. }
  66. return hash_code;
  67. }
  68. };
  69. class CtcPrefixBeamSearch : public SearchInterface {
  70. public:
  71. explicit CtcPrefixBeamSearch(
  72. const CtcPrefixBeamSearchOptions& opts,
  73. const std::shared_ptr<ContextGraph>& context_graph = nullptr);
  74. void Search(const std::vector<std::vector<float>>& logp) override;
  75. void Reset() override;
  76. void FinalizeSearch() override;
  77. SearchType Type() const override { return SearchType::kPrefixBeamSearch; }
  78. void UpdateHypotheses(
  79. const std::vector<std::pair<std::vector<int>, PrefixScore>>& hpys);
  80. const std::vector<float>& viterbi_likelihood() const {
  81. return viterbi_likelihood_;
  82. }
  83. const std::vector<std::vector<int>>& Inputs() const override {
  84. return hypotheses_;
  85. }
  86. const std::vector<std::vector<int>>& Outputs() const override {
  87. return outputs_;
  88. }
  89. const std::vector<float>& Likelihood() const override { return likelihood_; }
  90. const std::vector<std::vector<int>>& Times() const override { return times_; }
  91. private:
  92. int abs_time_step_ = 0;
  93. // N-best list and corresponding likelihood_, in sorted order
  94. std::vector<std::vector<int>> hypotheses_;
  95. std::vector<float> likelihood_;
  96. std::vector<float> viterbi_likelihood_;
  97. std::vector<std::vector<int>> times_;
  98. std::unordered_map<std::vector<int>, PrefixScore, PrefixHash> cur_hyps_;
  99. std::shared_ptr<ContextGraph> context_graph_ = nullptr;
  100. // Outputs contain the hypotheses_ and tags like: <context> and </context>
  101. std::vector<std::vector<int>> outputs_;
  102. const CtcPrefixBeamSearchOptions& opts_;
  103. public:
  104. WENET_DISALLOW_COPY_AND_ASSIGN(CtcPrefixBeamSearch);
  105. };
  106. } // namespace wenet
  107. #endif // DECODER_CTC_PREFIX_BEAM_SEARCH_H_