You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
7.9 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "decoder/ctc_prefix_beam_search.h"
  15. #include <algorithm>
  16. #include <tuple>
  17. #include <unordered_map>
  18. #include <utility>
  19. #include "utils/log.h"
  20. #include "../utils/wn_utils.h"
  21. namespace wenet {
  22. CtcPrefixBeamSearch::CtcPrefixBeamSearch(
  23. const CtcPrefixBeamSearchOptions& opts,
  24. const std::shared_ptr<ContextGraph>& context_graph)
  25. : opts_(opts), context_graph_(context_graph) {
  26. Reset();
  27. }
  28. void CtcPrefixBeamSearch::Reset() {
  29. hypotheses_.clear();
  30. likelihood_.clear();
  31. cur_hyps_.clear();
  32. viterbi_likelihood_.clear();
  33. times_.clear();
  34. outputs_.clear();
  35. abs_time_step_ = 0;
  36. PrefixScore prefix_score;
  37. prefix_score.s = 0.0;
  38. prefix_score.ns = -kFloatMax;
  39. prefix_score.v_s = 0.0;
  40. prefix_score.v_ns = 0.0;
  41. std::vector<int> empty;
  42. cur_hyps_[empty] = prefix_score;
  43. outputs_.emplace_back(empty);
  44. hypotheses_.emplace_back(empty);
  45. likelihood_.emplace_back(prefix_score.total_score());
  46. times_.emplace_back(empty);
  47. }
  48. static bool PrefixScoreCompare(
  49. const std::pair<std::vector<int>, PrefixScore>& a,
  50. const std::pair<std::vector<int>, PrefixScore>& b) {
  51. return a.second.total_score() > b.second.total_score();
  52. }
  53. void CtcPrefixBeamSearch::UpdateHypotheses(
  54. const std::vector<std::pair<std::vector<int>, PrefixScore>>& hpys) {
  55. cur_hyps_.clear();
  56. outputs_.clear();
  57. hypotheses_.clear();
  58. likelihood_.clear();
  59. viterbi_likelihood_.clear();
  60. times_.clear();
  61. for (auto& item : hpys) {
  62. cur_hyps_[item.first] = item.second;
  63. hypotheses_.emplace_back(item.first);
  64. outputs_.emplace_back(std::move(item.first));
  65. likelihood_.emplace_back(item.second.total_score());
  66. viterbi_likelihood_.emplace_back(item.second.viterbi_score());
  67. times_.emplace_back(item.second.times());
  68. }
  69. }
  70. // Please refer https://robin1001.github.io/2020/12/11/ctc-search
  71. // for how CTC prefix beam search works, and there is a simple graph demo in
  72. // it.
  73. void CtcPrefixBeamSearch::Search(const std::vector<std::vector<float>>& logp) {
  74. if (logp.size() == 0) return;
  75. int first_beam_size =
  76. std::min(static_cast<int>(logp[0].size()), opts_.first_beam_size);
  77. for (int t = 0; t < logp.size(); ++t, ++abs_time_step_) {
  78. const std::vector<float>& logp_t = logp[t];
  79. std::unordered_map<std::vector<int>, PrefixScore, PrefixHash> next_hyps;
  80. // 1. First beam prune, only select topk candidates
  81. std::vector<float> topk_score;
  82. std::vector<int32_t> topk_index;
  83. TopK(logp_t, first_beam_size, &topk_score, &topk_index);
  84. // 2. Token passing
  85. for (int i = 0; i < topk_index.size(); ++i) {
  86. int id = topk_index[i];
  87. auto prob = topk_score[i];
  88. for (const auto& it : cur_hyps_) {
  89. const std::vector<int>& prefix = it.first;
  90. const PrefixScore& prefix_score = it.second;
  91. // If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert
  92. // PrefixScore(-inf, -inf) by default, since the default constructor
  93. // of PrefixScore will set fields s(blank ending score) and
  94. // ns(none blank ending score) to -inf, respectively.
  95. if (id == opts_.blank) {
  96. // Case 0: *a + ε => *a
  97. PrefixScore& next_score = next_hyps[prefix];
  98. next_score.s = LogAdd(next_score.s, prefix_score.score() + prob);
  99. next_score.v_s = prefix_score.viterbi_score() + prob;
  100. next_score.times_s = prefix_score.times();
  101. // Prefix not changed, copy the context from prefix.
  102. if (context_graph_ && !next_score.has_context) {
  103. next_score.CopyContext(prefix_score);
  104. next_score.has_context = true;
  105. }
  106. } else if (!prefix.empty() && id == prefix.back()) {
  107. // Case 1: *a + a => *a
  108. PrefixScore& next_score1 = next_hyps[prefix];
  109. next_score1.ns = LogAdd(next_score1.ns, prefix_score.ns + prob);
  110. if (next_score1.v_ns < prefix_score.v_ns + prob) {
  111. next_score1.v_ns = prefix_score.v_ns + prob;
  112. if (next_score1.cur_token_prob < prob) {
  113. next_score1.cur_token_prob = prob;
  114. next_score1.times_ns = prefix_score.times_ns;
  115. CHECK_GT(next_score1.times_ns.size(), 0);
  116. next_score1.times_ns.back() = abs_time_step_;
  117. }
  118. }
  119. if (context_graph_ && !next_score1.has_context) {
  120. next_score1.CopyContext(prefix_score);
  121. next_score1.has_context = true;
  122. }
  123. // Case 2: *aε + a => *aa
  124. std::vector<int> new_prefix(prefix);
  125. new_prefix.emplace_back(id);
  126. PrefixScore& next_score2 = next_hyps[new_prefix];
  127. next_score2.ns = LogAdd(next_score2.ns, prefix_score.s + prob);
  128. if (next_score2.v_ns < prefix_score.v_s + prob) {
  129. next_score2.v_ns = prefix_score.v_s + prob;
  130. next_score2.cur_token_prob = prob;
  131. next_score2.times_ns = prefix_score.times_s;
  132. next_score2.times_ns.emplace_back(abs_time_step_);
  133. }
  134. if (context_graph_ && !next_score2.has_context) {
  135. // Prefix changed, calculate the context score.
  136. next_score2.UpdateContext(context_graph_, prefix_score, id);
  137. next_score2.has_context = true;
  138. }
  139. } else {
  140. // Case 3: *a + b => *ab, *aε + b => *ab
  141. std::vector<int> new_prefix(prefix);
  142. new_prefix.emplace_back(id);
  143. PrefixScore& next_score = next_hyps[new_prefix];
  144. next_score.ns = LogAdd(next_score.ns, prefix_score.score() + prob);
  145. if (next_score.v_ns < prefix_score.viterbi_score() + prob) {
  146. next_score.v_ns = prefix_score.viterbi_score() + prob;
  147. next_score.cur_token_prob = prob;
  148. next_score.times_ns = prefix_score.times();
  149. next_score.times_ns.emplace_back(abs_time_step_);
  150. }
  151. if (context_graph_ && !next_score.has_context) {
  152. // Calculate the context score.
  153. next_score.UpdateContext(context_graph_, prefix_score, id);
  154. next_score.has_context = true;
  155. }
  156. }
  157. }
  158. }
  159. // 3. Second beam prune, only keep top n best paths
  160. std::vector<std::pair<std::vector<int>, PrefixScore>> arr(next_hyps.begin(),
  161. next_hyps.end());
  162. int second_beam_size =
  163. std::min(static_cast<int>(arr.size()), opts_.second_beam_size);
  164. std::nth_element(arr.begin(), arr.begin() + second_beam_size, arr.end(),
  165. PrefixScoreCompare);
  166. arr.resize(second_beam_size);
  167. std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
  168. // 4. Update cur_hyps_ and get new result
  169. UpdateHypotheses(arr);
  170. }
  171. }
  172. void CtcPrefixBeamSearch::FinalizeSearch() {
  173. if (context_graph_ == nullptr) return;
  174. CHECK_EQ(hypotheses_.size(), cur_hyps_.size());
  175. CHECK_EQ(hypotheses_.size(), likelihood_.size());
  176. // We should backoff the context score/state when the context is
  177. // not fully matched at the last time.
  178. for (const auto& prefix : hypotheses_) {
  179. PrefixScore& prefix_score = cur_hyps_[prefix];
  180. if (prefix_score.context_state != 0) {
  181. prefix_score.UpdateContext(context_graph_, prefix_score, -1);
  182. }
  183. }
  184. std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
  185. cur_hyps_.end());
  186. std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
  187. // Update cur_hyps_ and get new result
  188. UpdateHypotheses(arr);
  189. }
  190. } // namespace wenet