// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) // 2022 Binbin Zhang (binbzha@qq.com) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "decoder/asr_decoder.h" #include #include #include "utils/timer.h" namespace wenet { AsrDecoder::AsrDecoder(std::shared_ptr feature_pipeline, std::shared_ptr resource, const DecodeOptions& opts) : feature_pipeline_(std::move(feature_pipeline)), // Make a copy of the model ASR model since we will change the inner // status of the model model_(resource->model->Copy()), post_processor_(resource->post_processor), context_graph_(resource->context_graph), symbol_table_(resource->symbol_table), fst_(resource->fst), unit_table_(resource->unit_table), opts_(opts), ctc_endpointer_(new CtcEndpoint(opts.ctc_endpoint_config)) { if (opts_.reverse_weight > 0) { // Check if model has a right to left decoder CHECK(model_->is_bidirectional_decoder()); } if (nullptr == fst_) { searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts, resource->context_graph)); } else { searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts, resource->context_graph)); } ctc_endpointer_->frame_shift_in_ms(frame_shift_in_ms()); } void AsrDecoder::Reset() { start_ = false; result_.clear(); num_frames_ = 0; global_frame_offset_ = 0; model_->Reset(); searcher_->Reset(); feature_pipeline_->Reset(); ctc_endpointer_->Reset(); } void AsrDecoder::ResetContinuousDecoding() { global_frame_offset_ = num_frames_; start_ = false; result_.clear(); model_->Reset(); searcher_->Reset(); ctc_endpointer_->Reset(); } DecodeState AsrDecoder::Decode(bool block) { return this->AdvanceDecoding(block); } void AsrDecoder::Rescoring() { // Do attention rescoring Timer timer; AttentionRescoring(); VLOG(2) << "Rescoring cost latency: " << timer.Elapsed() << "ms."; } DecodeState AsrDecoder::AdvanceDecoding(bool block) { DecodeState state = DecodeState::kEndBatch; model_->set_chunk_size(opts_.chunk_size); model_->set_num_left_chunks(opts_.num_left_chunks); int num_required_frames = model_->num_frames_for_chunk(start_); std::vector> chunk_feats; // Return immediately if we do not want to block if (!block && !feature_pipeline_->input_finished() && feature_pipeline_->NumQueuedFrames() < num_required_frames) { return DecodeState::kWaitFeats; } // If not okay, that means we reach the end of the input if (!feature_pipeline_->Read(num_required_frames, &chunk_feats)) { state = DecodeState::kEndFeats; } num_frames_ += chunk_feats.size(); VLOG(2) << "Required " << num_required_frames << " get " << chunk_feats.size(); Timer timer; std::vector> ctc_log_probs; model_->ForwardEncoder(chunk_feats, &ctc_log_probs); int forward_time = timer.Elapsed(); if (opts_.ctc_wfst_search_opts.blank_scale != 1.0) { for (int i = 0; i < ctc_log_probs.size(); i++) { ctc_log_probs[i][0] = ctc_log_probs[i][0] + std::log(opts_.ctc_wfst_search_opts.blank_scale); } } timer.Reset(); searcher_->Search(ctc_log_probs); int search_time = timer.Elapsed(); VLOG(3) << "forward takes " << forward_time << " ms, search takes " << search_time << " ms"; UpdateResult(); if (state != DecodeState::kEndFeats) { if (ctc_endpointer_->IsEndpoint(ctc_log_probs, DecodedSomething())) { VLOG(1) << "Endpoint is detected at " << num_frames_; state = DecodeState::kEndpoint; } } start_ = true; return state; } void AsrDecoder::UpdateResult(bool finish) { const auto& hypotheses = searcher_->Outputs(); const auto& inputs = searcher_->Inputs(); const auto& likelihood = searcher_->Likelihood(); const auto& times = searcher_->Times(); result_.clear(); CHECK_EQ(hypotheses.size(), likelihood.size()); for (size_t i = 0; i < hypotheses.size(); i++) { const std::vector& hypothesis = hypotheses[i]; DecodeResult path; path.score = likelihood[i]; int offset = global_frame_offset_ * feature_frame_shift_in_ms(); for (size_t j = 0; j < hypothesis.size(); j++) { std::string word = symbol_table_->Find(hypothesis[j]); // A detailed explanation of this if-else branch can be found in // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 if (searcher_->Type() == kWfstBeamSearch) { path.sentence += (' ' + word); } else { path.sentence += (word); } } // TimeStamp is only supported in final result // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to // various FST operations when building the decoding graph. So here we use // time stamp of the input(e2e model unit), which is more accurate, and it // requires the symbol table of the e2e model used in training. if (unit_table_ != nullptr && finish) { const std::vector& input = inputs[i]; const std::vector& time_stamp = times[i]; CHECK_EQ(input.size(), time_stamp.size()); for (size_t j = 0; j < input.size(); j++) { std::string word = unit_table_->Find(input[j]); int start = time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ > 0 ? time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ : 0; if (j > 0) { start = (time_stamp[j] - time_stamp[j - 1]) * frame_shift_in_ms() < time_stamp_gap_ ? (time_stamp[j - 1] + time_stamp[j]) / 2 * frame_shift_in_ms() : start; } int end = time_stamp[j] * frame_shift_in_ms(); if (j < input.size() - 1) { end = (time_stamp[j + 1] - time_stamp[j]) * frame_shift_in_ms() < time_stamp_gap_ ? (time_stamp[j + 1] + time_stamp[j]) / 2 * frame_shift_in_ms() : end; } WordPiece word_piece(word, offset + start, offset + end); path.word_pieces.emplace_back(word_piece); } } if (post_processor_ != nullptr) { path.sentence = post_processor_->Process(path.sentence, finish); } result_.emplace_back(path); } if (DecodedSomething()) { VLOG(1) << "Partial CTC result " << result_[0].sentence; if (context_graph_ != nullptr) { int cur_state = 0; float score = 0; for (int ilabel : inputs[0]) { cur_state = context_graph_->GetNextState(cur_state, ilabel, &score, &(result_[0].contexts)); } std::string contexts; for (const auto& context : result_[0].contexts) { contexts += context + ", "; } VLOG(1) << "Contexts: " << contexts; } } } void AsrDecoder::AttentionRescoring() { searcher_->FinalizeSearch(); UpdateResult(true); // No need to do rescoring if (0.0 == opts_.rescoring_weight) { return; } // Inputs() returns N-best input ids, which is the basic unit for rescoring // In CtcPrefixBeamSearch, inputs are the same to outputs const auto& hypotheses = searcher_->Inputs(); int num_hyps = hypotheses.size(); if (num_hyps <= 0) { return; } // TODO(zhendong.peng): Do we need rescoring while context matching? std::vector rescoring_score; model_->AttentionRescoring(hypotheses, opts_.reverse_weight, &rescoring_score); // Combine ctc score and rescoring score for (size_t i = 0; i < num_hyps; ++i) { result_[i].score = opts_.rescoring_weight * rescoring_score[i] + opts_.ctc_weight * result_[i].score; } std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); } } // namespace wenet