You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

242 lines
8.4 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 Binbin Zhang (binbzha@qq.com)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #include "decoder/asr_decoder.h"
  16. #include <algorithm>
  17. #include <utility>
  18. #include "utils/timer.h"
  19. namespace wenet {
  20. AsrDecoder::AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
  21. std::shared_ptr<DecodeResource> resource,
  22. const DecodeOptions& opts)
  23. : feature_pipeline_(std::move(feature_pipeline)),
  24. // Make a copy of the model ASR model since we will change the inner
  25. // status of the model
  26. model_(resource->model->Copy()),
  27. post_processor_(resource->post_processor),
  28. context_graph_(resource->context_graph),
  29. symbol_table_(resource->symbol_table),
  30. fst_(resource->fst),
  31. unit_table_(resource->unit_table),
  32. opts_(opts),
  33. ctc_endpointer_(new CtcEndpoint(opts.ctc_endpoint_config)) {
  34. if (opts_.reverse_weight > 0) {
  35. // Check if model has a right to left decoder
  36. CHECK(model_->is_bidirectional_decoder());
  37. }
  38. if (nullptr == fst_) {
  39. searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts,
  40. resource->context_graph));
  41. } else {
  42. searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts,
  43. resource->context_graph));
  44. }
  45. ctc_endpointer_->frame_shift_in_ms(frame_shift_in_ms());
  46. }
  47. void AsrDecoder::Reset() {
  48. start_ = false;
  49. result_.clear();
  50. num_frames_ = 0;
  51. global_frame_offset_ = 0;
  52. model_->Reset();
  53. searcher_->Reset();
  54. feature_pipeline_->Reset();
  55. ctc_endpointer_->Reset();
  56. }
  57. void AsrDecoder::ResetContinuousDecoding() {
  58. global_frame_offset_ = num_frames_;
  59. start_ = false;
  60. result_.clear();
  61. model_->Reset();
  62. searcher_->Reset();
  63. ctc_endpointer_->Reset();
  64. }
  65. DecodeState AsrDecoder::Decode(bool block) {
  66. return this->AdvanceDecoding(block);
  67. }
  68. void AsrDecoder::Rescoring() {
  69. // Do attention rescoring
  70. Timer timer;
  71. AttentionRescoring();
  72. VLOG(2) << "Rescoring cost latency: " << timer.Elapsed() << "ms.";
  73. }
  74. DecodeState AsrDecoder::AdvanceDecoding(bool block) {
  75. DecodeState state = DecodeState::kEndBatch;
  76. model_->set_chunk_size(opts_.chunk_size);
  77. model_->set_num_left_chunks(opts_.num_left_chunks);
  78. int num_required_frames = model_->num_frames_for_chunk(start_);
  79. std::vector<std::vector<float>> chunk_feats;
  80. // Return immediately if we do not want to block
  81. if (!block && !feature_pipeline_->input_finished() &&
  82. feature_pipeline_->NumQueuedFrames() < num_required_frames) {
  83. return DecodeState::kWaitFeats;
  84. }
  85. // If not okay, that means we reach the end of the input
  86. if (!feature_pipeline_->Read(num_required_frames, &chunk_feats)) {
  87. state = DecodeState::kEndFeats;
  88. }
  89. num_frames_ += chunk_feats.size();
  90. VLOG(2) << "Required " << num_required_frames << " get "
  91. << chunk_feats.size();
  92. Timer timer;
  93. std::vector<std::vector<float>> ctc_log_probs;
  94. model_->ForwardEncoder(chunk_feats, &ctc_log_probs);
  95. int forward_time = timer.Elapsed();
  96. if (opts_.ctc_wfst_search_opts.blank_scale != 1.0) {
  97. for (int i = 0; i < ctc_log_probs.size(); i++) {
  98. ctc_log_probs[i][0] = ctc_log_probs[i][0] +
  99. std::log(opts_.ctc_wfst_search_opts.blank_scale);
  100. }
  101. }
  102. timer.Reset();
  103. searcher_->Search(ctc_log_probs);
  104. int search_time = timer.Elapsed();
  105. VLOG(3) << "forward takes " << forward_time << " ms, search takes "
  106. << search_time << " ms";
  107. UpdateResult();
  108. if (state != DecodeState::kEndFeats) {
  109. if (ctc_endpointer_->IsEndpoint(ctc_log_probs, DecodedSomething())) {
  110. VLOG(1) << "Endpoint is detected at " << num_frames_;
  111. state = DecodeState::kEndpoint;
  112. }
  113. }
  114. start_ = true;
  115. return state;
  116. }
  117. void AsrDecoder::UpdateResult(bool finish) {
  118. const auto& hypotheses = searcher_->Outputs();
  119. const auto& inputs = searcher_->Inputs();
  120. const auto& likelihood = searcher_->Likelihood();
  121. const auto& times = searcher_->Times();
  122. result_.clear();
  123. CHECK_EQ(hypotheses.size(), likelihood.size());
  124. for (size_t i = 0; i < hypotheses.size(); i++) {
  125. const std::vector<int>& hypothesis = hypotheses[i];
  126. DecodeResult path;
  127. path.score = likelihood[i];
  128. int offset = global_frame_offset_ * feature_frame_shift_in_ms();
  129. for (size_t j = 0; j < hypothesis.size(); j++) {
  130. std::string word = symbol_table_->Find(hypothesis[j]);
  131. // A detailed explanation of this if-else branch can be found in
  132. // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
  133. if (searcher_->Type() == kWfstBeamSearch) {
  134. path.sentence += (' ' + word);
  135. } else {
  136. path.sentence += (word);
  137. }
  138. }
  139. // TimeStamp is only supported in final result
  140. // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
  141. // various FST operations when building the decoding graph. So here we use
  142. // time stamp of the input(e2e model unit), which is more accurate, and it
  143. // requires the symbol table of the e2e model used in training.
  144. if (unit_table_ != nullptr && finish) {
  145. const std::vector<int>& input = inputs[i];
  146. const std::vector<int>& time_stamp = times[i];
  147. CHECK_EQ(input.size(), time_stamp.size());
  148. for (size_t j = 0; j < input.size(); j++) {
  149. std::string word = unit_table_->Find(input[j]);
  150. int start = time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ > 0
  151. ? time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_
  152. : 0;
  153. if (j > 0) {
  154. start = (time_stamp[j] - time_stamp[j - 1]) * frame_shift_in_ms() <
  155. time_stamp_gap_
  156. ? (time_stamp[j - 1] + time_stamp[j]) / 2 *
  157. frame_shift_in_ms()
  158. : start;
  159. }
  160. int end = time_stamp[j] * frame_shift_in_ms();
  161. if (j < input.size() - 1) {
  162. end = (time_stamp[j + 1] - time_stamp[j]) * frame_shift_in_ms() <
  163. time_stamp_gap_
  164. ? (time_stamp[j + 1] + time_stamp[j]) / 2 *
  165. frame_shift_in_ms()
  166. : end;
  167. }
  168. WordPiece word_piece(word, offset + start, offset + end);
  169. path.word_pieces.emplace_back(word_piece);
  170. }
  171. }
  172. if (post_processor_ != nullptr) {
  173. path.sentence = post_processor_->Process(path.sentence, finish);
  174. }
  175. result_.emplace_back(path);
  176. }
  177. if (DecodedSomething()) {
  178. VLOG(1) << "Partial CTC result " << result_[0].sentence;
  179. if (context_graph_ != nullptr) {
  180. int cur_state = 0;
  181. float score = 0;
  182. for (int ilabel : inputs[0]) {
  183. cur_state = context_graph_->GetNextState(cur_state, ilabel, &score,
  184. &(result_[0].contexts));
  185. }
  186. std::string contexts;
  187. for (const auto& context : result_[0].contexts) {
  188. contexts += context + ", ";
  189. }
  190. VLOG(1) << "Contexts: " << contexts;
  191. }
  192. }
  193. }
  194. void AsrDecoder::AttentionRescoring() {
  195. searcher_->FinalizeSearch();
  196. UpdateResult(true);
  197. // No need to do rescoring
  198. if (0.0 == opts_.rescoring_weight) {
  199. return;
  200. }
  201. // Inputs() returns N-best input ids, which is the basic unit for rescoring
  202. // In CtcPrefixBeamSearch, inputs are the same to outputs
  203. const auto& hypotheses = searcher_->Inputs();
  204. int num_hyps = hypotheses.size();
  205. if (num_hyps <= 0) {
  206. return;
  207. }
  208. // TODO(zhendong.peng): Do we need rescoring while context matching?
  209. std::vector<float> rescoring_score;
  210. model_->AttentionRescoring(hypotheses, opts_.reverse_weight,
  211. &rescoring_score);
  212. // Combine ctc score and rescoring score
  213. for (size_t i = 0; i < num_hyps; ++i) {
  214. result_[i].score = opts_.rescoring_weight * rescoring_score[i] +
  215. opts_.ctc_weight * result_[i].score;
  216. }
  217. std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
  218. }
  219. } // namespace wenet