You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

245 lines
8.4 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 Binbin Zhang (binbzha@qq.com)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #include "decoder/asr_decoder.h"
  16. #include <ctype.h>
  17. #include <algorithm>
  18. #include <limits>
  19. #include <utility>
  20. #include "utils/timer.h"
  21. namespace wenet {
  22. AsrDecoder::AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
  23. std::shared_ptr<DecodeResource> resource,
  24. const DecodeOptions& opts)
  25. : feature_pipeline_(std::move(feature_pipeline)),
  26. // Make a copy of the model ASR model since we will change the inner
  27. // status of the model
  28. model_(resource->model->Copy()),
  29. post_processor_(resource->post_processor),
  30. context_graph_(resource->context_graph),
  31. symbol_table_(resource->symbol_table),
  32. fst_(resource->fst),
  33. unit_table_(resource->unit_table),
  34. opts_(opts),
  35. ctc_endpointer_(new CtcEndpoint(opts.ctc_endpoint_config)) {
  36. if (opts_.reverse_weight > 0) {
  37. // Check if model has a right to left decoder
  38. CHECK(model_->is_bidirectional_decoder());
  39. }
  40. if (nullptr == fst_) {
  41. searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts,
  42. resource->context_graph));
  43. } else {
  44. searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts,
  45. resource->context_graph));
  46. }
  47. ctc_endpointer_->frame_shift_in_ms(frame_shift_in_ms());
  48. }
  49. void AsrDecoder::Reset() {
  50. start_ = false;
  51. result_.clear();
  52. num_frames_ = 0;
  53. global_frame_offset_ = 0;
  54. model_->Reset();
  55. searcher_->Reset();
  56. feature_pipeline_->Reset();
  57. ctc_endpointer_->Reset();
  58. }
  59. void AsrDecoder::ResetContinuousDecoding() {
  60. global_frame_offset_ = num_frames_;
  61. start_ = false;
  62. result_.clear();
  63. model_->Reset();
  64. searcher_->Reset();
  65. ctc_endpointer_->Reset();
  66. }
  67. DecodeState AsrDecoder::Decode(bool block) {
  68. return this->AdvanceDecoding(block);
  69. }
  70. void AsrDecoder::Rescoring() {
  71. // Do attention rescoring
  72. Timer timer;
  73. AttentionRescoring();
  74. VLOG(2) << "Rescoring cost latency: " << timer.Elapsed() << "ms.";
  75. }
  76. DecodeState AsrDecoder::AdvanceDecoding(bool block) {
  77. DecodeState state = DecodeState::kEndBatch;
  78. model_->set_chunk_size(opts_.chunk_size);
  79. model_->set_num_left_chunks(opts_.num_left_chunks);
  80. int num_required_frames = model_->num_frames_for_chunk(start_);
  81. std::vector<std::vector<float>> chunk_feats;
  82. // Return immediately if we do not want to block
  83. if (!block && !feature_pipeline_->input_finished() &&
  84. feature_pipeline_->NumQueuedFrames() < num_required_frames) {
  85. return DecodeState::kWaitFeats;
  86. }
  87. // If not okay, that means we reach the end of the input
  88. if (!feature_pipeline_->Read(num_required_frames, &chunk_feats)) {
  89. state = DecodeState::kEndFeats;
  90. }
  91. num_frames_ += chunk_feats.size();
  92. VLOG(2) << "Required " << num_required_frames << " get "
  93. << chunk_feats.size();
  94. Timer timer;
  95. std::vector<std::vector<float>> ctc_log_probs;
  96. model_->ForwardEncoder(chunk_feats, &ctc_log_probs);
  97. int forward_time = timer.Elapsed();
  98. if (opts_.ctc_wfst_search_opts.blank_scale != 1.0) {
  99. for (int i = 0; i < ctc_log_probs.size(); i++) {
  100. ctc_log_probs[i][0] = ctc_log_probs[i][0] +
  101. std::log(opts_.ctc_wfst_search_opts.blank_scale);
  102. }
  103. }
  104. timer.Reset();
  105. searcher_->Search(ctc_log_probs);
  106. int search_time = timer.Elapsed();
  107. VLOG(3) << "forward takes " << forward_time << " ms, search takes "
  108. << search_time << " ms";
  109. UpdateResult();
  110. if (state != DecodeState::kEndFeats) {
  111. if (ctc_endpointer_->IsEndpoint(ctc_log_probs, DecodedSomething())) {
  112. VLOG(1) << "Endpoint is detected at " << num_frames_;
  113. state = DecodeState::kEndpoint;
  114. }
  115. }
  116. start_ = true;
  117. return state;
  118. }
  119. void AsrDecoder::UpdateResult(bool finish) {
  120. const auto& hypotheses = searcher_->Outputs();
  121. const auto& inputs = searcher_->Inputs();
  122. const auto& likelihood = searcher_->Likelihood();
  123. const auto& times = searcher_->Times();
  124. result_.clear();
  125. CHECK_EQ(hypotheses.size(), likelihood.size());
  126. for (size_t i = 0; i < hypotheses.size(); i++) {
  127. const std::vector<int>& hypothesis = hypotheses[i];
  128. DecodeResult path;
  129. path.score = likelihood[i];
  130. int offset = global_frame_offset_ * feature_frame_shift_in_ms();
  131. for (size_t j = 0; j < hypothesis.size(); j++) {
  132. std::string word = symbol_table_->Find(hypothesis[j]);
  133. // A detailed explanation of this if-else branch can be found in
  134. // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
  135. if (searcher_->Type() == kWfstBeamSearch) {
  136. path.sentence += (' ' + word);
  137. } else {
  138. path.sentence += (word);
  139. }
  140. }
  141. // TimeStamp is only supported in final result
  142. // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
  143. // various FST operations when building the decoding graph. So here we use
  144. // time stamp of the input(e2e model unit), which is more accurate, and it
  145. // requires the symbol table of the e2e model used in training.
  146. if (unit_table_ != nullptr && finish) {
  147. const std::vector<int>& input = inputs[i];
  148. const std::vector<int>& time_stamp = times[i];
  149. CHECK_EQ(input.size(), time_stamp.size());
  150. for (size_t j = 0; j < input.size(); j++) {
  151. std::string word = unit_table_->Find(input[j]);
  152. int start = time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ > 0
  153. ? time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_
  154. : 0;
  155. if (j > 0) {
  156. start = (time_stamp[j] - time_stamp[j - 1]) * frame_shift_in_ms() <
  157. time_stamp_gap_
  158. ? (time_stamp[j - 1] + time_stamp[j]) / 2 *
  159. frame_shift_in_ms()
  160. : start;
  161. }
  162. int end = time_stamp[j] * frame_shift_in_ms();
  163. if (j < input.size() - 1) {
  164. end = (time_stamp[j + 1] - time_stamp[j]) * frame_shift_in_ms() <
  165. time_stamp_gap_
  166. ? (time_stamp[j + 1] + time_stamp[j]) / 2 *
  167. frame_shift_in_ms()
  168. : end;
  169. }
  170. WordPiece word_piece(word, offset + start, offset + end);
  171. path.word_pieces.emplace_back(word_piece);
  172. }
  173. }
  174. if (post_processor_ != nullptr) {
  175. path.sentence = post_processor_->Process(path.sentence, finish);
  176. }
  177. result_.emplace_back(path);
  178. }
  179. if (DecodedSomething()) {
  180. VLOG(1) << "Partial CTC result " << result_[0].sentence;
  181. if (context_graph_ != nullptr) {
  182. int cur_state = 0;
  183. float score = 0;
  184. for (int ilabel : inputs[0]) {
  185. cur_state = context_graph_->GetNextState(cur_state, ilabel, &score,
  186. &(result_[0].contexts));
  187. }
  188. std::string contexts;
  189. for (const auto& context : result_[0].contexts) {
  190. contexts += context + ", ";
  191. }
  192. VLOG(1) << "Contexts: " << contexts;
  193. }
  194. }
  195. }
  196. void AsrDecoder::AttentionRescoring() {
  197. searcher_->FinalizeSearch();
  198. UpdateResult(true);
  199. // No need to do rescoring
  200. if (0.0 == opts_.rescoring_weight) {
  201. return;
  202. }
  203. // Inputs() returns N-best input ids, which is the basic unit for rescoring
  204. // In CtcPrefixBeamSearch, inputs are the same to outputs
  205. const auto& hypotheses = searcher_->Inputs();
  206. int num_hyps = hypotheses.size();
  207. if (num_hyps <= 0) {
  208. return;
  209. }
  210. // TODO(zhendong.peng): Do we need rescoring while context matching?
  211. std::vector<float> rescoring_score;
  212. model_->AttentionRescoring(hypotheses, opts_.reverse_weight,
  213. &rescoring_score);
  214. // Combine ctc score and rescoring score
  215. for (size_t i = 0; i < num_hyps; ++i) {
  216. result_[i].score = opts_.rescoring_weight * rescoring_score[i] +
  217. opts_.ctc_weight * result_[i].score;
  218. }
  219. std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
  220. }
  221. } // namespace wenet