You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
5.8 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "decoder/ctc_wfst_beam_search.h"
  15. #include <utility>
  16. namespace wenet {
  17. void DecodableTensorScaled::Reset() {
  18. num_frames_ready_ = 0;
  19. done_ = false;
  20. // Give an empty initialization, will throw error when
  21. // AcceptLoglikes is not called
  22. logp_.clear();
  23. }
  24. void DecodableTensorScaled::AcceptLoglikes(const std::vector<float>& logp) {
  25. ++num_frames_ready_;
  26. // TODO(Binbin Zhang): Avoid copy here
  27. logp_ = logp;
  28. }
  29. float DecodableTensorScaled::LogLikelihood(int32 frame, int32 index) {
  30. CHECK_GT(index, 0);
  31. CHECK_LT(frame, num_frames_ready_);
  32. return scale_ * logp_[index - 1];
  33. }
  34. bool DecodableTensorScaled::IsLastFrame(int32 frame) const {
  35. CHECK_LT(frame, num_frames_ready_);
  36. return done_ && (frame == num_frames_ready_ - 1);
  37. }
  38. int32 DecodableTensorScaled::NumIndices() const {
  39. LOG(FATAL) << "Not implement";
  40. return 0;
  41. }
  42. CtcWfstBeamSearch::CtcWfstBeamSearch(
  43. const fst::Fst<fst::StdArc>& fst, const CtcWfstBeamSearchOptions& opts,
  44. const std::shared_ptr<ContextGraph>& context_graph)
  45. : decodable_(opts.acoustic_scale),
  46. decoder_(fst, opts, context_graph),
  47. context_graph_(context_graph),
  48. opts_(opts) {
  49. Reset();
  50. }
  51. void CtcWfstBeamSearch::Reset() {
  52. num_frames_ = 0;
  53. decoded_frames_mapping_.clear();
  54. is_last_frame_blank_ = false;
  55. last_best_ = 0;
  56. inputs_.clear();
  57. outputs_.clear();
  58. likelihood_.clear();
  59. times_.clear();
  60. decodable_.Reset();
  61. decoder_.InitDecoding();
  62. }
  63. void CtcWfstBeamSearch::Search(const std::vector<std::vector<float>>& logp) {
  64. if (0 == logp.size()) {
  65. return;
  66. }
  67. // Every time we get the log posterior, we decode it all before return
  68. for (int i = 0; i < logp.size(); i++) {
  69. float blank_score = std::exp(logp[i][opts_.blank]);
  70. if (blank_score > opts_.blank_skip_thresh * opts_.blank_scale) {
  71. VLOG(3) << "skipping frame " << num_frames_ << " score " << blank_score;
  72. is_last_frame_blank_ = true;
  73. last_frame_prob_ = logp[i];
  74. } else {
  75. // Get the best symbol
  76. int cur_best =
  77. std::max_element(logp[i].begin(), logp[i].end()) - logp[i].begin();
  78. // Optional, adding one blank frame if we has skipped it in two same
  79. // symbols
  80. if (cur_best != opts_.blank && is_last_frame_blank_ &&
  81. cur_best == last_best_) {
  82. decodable_.AcceptLoglikes(last_frame_prob_);
  83. decoder_.AdvanceDecoding(&decodable_, 1);
  84. decoded_frames_mapping_.push_back(num_frames_ - 1);
  85. VLOG(2) << "Adding blank frame at symbol " << cur_best;
  86. }
  87. last_best_ = cur_best;
  88. decodable_.AcceptLoglikes(logp[i]);
  89. decoder_.AdvanceDecoding(&decodable_, 1);
  90. decoded_frames_mapping_.push_back(num_frames_);
  91. is_last_frame_blank_ = false;
  92. }
  93. num_frames_++;
  94. }
  95. // Get the best path
  96. inputs_.clear();
  97. outputs_.clear();
  98. likelihood_.clear();
  99. if (decoded_frames_mapping_.size() > 0) {
  100. inputs_.resize(1);
  101. outputs_.resize(1);
  102. likelihood_.resize(1);
  103. kaldi::Lattice lat;
  104. decoder_.GetBestPath(&lat, true);
  105. std::vector<int> alignment;
  106. kaldi::LatticeWeight weight;
  107. fst::GetLinearSymbolSequence(lat, &alignment, &outputs_[0], &weight);
  108. ConvertToInputs(alignment, &inputs_[0]);
  109. VLOG(3) << weight.Value1() << " " << weight.Value2();
  110. likelihood_[0] = -(weight.Value1() + weight.Value2());
  111. }
  112. }
  113. void CtcWfstBeamSearch::FinalizeSearch() {
  114. decodable_.SetFinish();
  115. decoder_.FinalizeDecoding();
  116. inputs_.clear();
  117. outputs_.clear();
  118. likelihood_.clear();
  119. times_.clear();
  120. if (decoded_frames_mapping_.size() > 0) {
  121. std::vector<kaldi::Lattice> nbest_lats;
  122. if (opts_.nbest == 1) {
  123. kaldi::Lattice lat;
  124. decoder_.GetBestPath(&lat, true);
  125. nbest_lats.push_back(std::move(lat));
  126. } else {
  127. // Get N-best path by lattice(CompactLattice)
  128. kaldi::CompactLattice clat;
  129. decoder_.GetLattice(&clat, true);
  130. kaldi::Lattice lat, nbest_lat;
  131. fst::ConvertLattice(clat, &lat);
  132. // TODO(Binbin Zhang): it's n-best word lists here, not character n-best
  133. fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
  134. fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
  135. }
  136. int nbest = nbest_lats.size();
  137. inputs_.resize(nbest);
  138. outputs_.resize(nbest);
  139. likelihood_.resize(nbest);
  140. times_.resize(nbest);
  141. for (int i = 0; i < nbest; i++) {
  142. kaldi::LatticeWeight weight;
  143. std::vector<int> alignment;
  144. fst::GetLinearSymbolSequence(nbest_lats[i], &alignment, &outputs_[i],
  145. &weight);
  146. ConvertToInputs(alignment, &inputs_[i], &times_[i]);
  147. likelihood_[i] = -(weight.Value1() + weight.Value2());
  148. }
  149. }
  150. }
  151. void CtcWfstBeamSearch::ConvertToInputs(const std::vector<int>& alignment,
  152. std::vector<int>* input,
  153. std::vector<int>* time) {
  154. input->clear();
  155. if (time != nullptr) time->clear();
  156. for (int cur = 0; cur < alignment.size(); ++cur) {
  157. // ignore blank
  158. if (alignment[cur] - 1 == opts_.blank) continue;
  159. // merge continuous same label
  160. if (cur > 0 && alignment[cur] == alignment[cur - 1]) continue;
  161. input->push_back(alignment[cur] - 1);
  162. if (time != nullptr) {
  163. time->push_back(decoded_frames_mapping_[cur]);
  164. }
  165. }
  166. }
  167. } // namespace wenet