You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

245 lines
8.4 KiB

// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/asr_decoder.h"
#include <ctype.h>
#include <algorithm>
#include <limits>
#include <utility>
#include "utils/timer.h"
namespace wenet {
AsrDecoder::AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
std::shared_ptr<DecodeResource> resource,
const DecodeOptions& opts)
: feature_pipeline_(std::move(feature_pipeline)),
// Make a copy of the model ASR model since we will change the inner
// status of the model
model_(resource->model->Copy()),
post_processor_(resource->post_processor),
context_graph_(resource->context_graph),
symbol_table_(resource->symbol_table),
fst_(resource->fst),
unit_table_(resource->unit_table),
opts_(opts),
ctc_endpointer_(new CtcEndpoint(opts.ctc_endpoint_config)) {
if (opts_.reverse_weight > 0) {
// Check if model has a right to left decoder
CHECK(model_->is_bidirectional_decoder());
}
if (nullptr == fst_) {
searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts,
resource->context_graph));
} else {
searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts,
resource->context_graph));
}
ctc_endpointer_->frame_shift_in_ms(frame_shift_in_ms());
}
void AsrDecoder::Reset() {
start_ = false;
result_.clear();
num_frames_ = 0;
global_frame_offset_ = 0;
model_->Reset();
searcher_->Reset();
feature_pipeline_->Reset();
ctc_endpointer_->Reset();
}
void AsrDecoder::ResetContinuousDecoding() {
global_frame_offset_ = num_frames_;
start_ = false;
result_.clear();
model_->Reset();
searcher_->Reset();
ctc_endpointer_->Reset();
}
DecodeState AsrDecoder::Decode(bool block) {
return this->AdvanceDecoding(block);
}
void AsrDecoder::Rescoring() {
// Do attention rescoring
Timer timer;
AttentionRescoring();
VLOG(2) << "Rescoring cost latency: " << timer.Elapsed() << "ms.";
}
DecodeState AsrDecoder::AdvanceDecoding(bool block) {
DecodeState state = DecodeState::kEndBatch;
model_->set_chunk_size(opts_.chunk_size);
model_->set_num_left_chunks(opts_.num_left_chunks);
int num_required_frames = model_->num_frames_for_chunk(start_);
std::vector<std::vector<float>> chunk_feats;
// Return immediately if we do not want to block
if (!block && !feature_pipeline_->input_finished() &&
feature_pipeline_->NumQueuedFrames() < num_required_frames) {
return DecodeState::kWaitFeats;
}
// If not okay, that means we reach the end of the input
if (!feature_pipeline_->Read(num_required_frames, &chunk_feats)) {
state = DecodeState::kEndFeats;
}
num_frames_ += chunk_feats.size();
VLOG(2) << "Required " << num_required_frames << " get "
<< chunk_feats.size();
Timer timer;
std::vector<std::vector<float>> ctc_log_probs;
model_->ForwardEncoder(chunk_feats, &ctc_log_probs);
int forward_time = timer.Elapsed();
if (opts_.ctc_wfst_search_opts.blank_scale != 1.0) {
for (int i = 0; i < ctc_log_probs.size(); i++) {
ctc_log_probs[i][0] = ctc_log_probs[i][0] +
std::log(opts_.ctc_wfst_search_opts.blank_scale);
}
}
timer.Reset();
searcher_->Search(ctc_log_probs);
int search_time = timer.Elapsed();
VLOG(3) << "forward takes " << forward_time << " ms, search takes "
<< search_time << " ms";
UpdateResult();
if (state != DecodeState::kEndFeats) {
if (ctc_endpointer_->IsEndpoint(ctc_log_probs, DecodedSomething())) {
VLOG(1) << "Endpoint is detected at " << num_frames_;
state = DecodeState::kEndpoint;
}
}
start_ = true;
return state;
}
void AsrDecoder::UpdateResult(bool finish) {
const auto& hypotheses = searcher_->Outputs();
const auto& inputs = searcher_->Inputs();
const auto& likelihood = searcher_->Likelihood();
const auto& times = searcher_->Times();
result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];
DecodeResult path;
path.score = likelihood[i];
int offset = global_frame_offset_ * feature_frame_shift_in_ms();
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (searcher_->Type() == kWfstBeamSearch) {
path.sentence += (' ' + word);
} else {
path.sentence += (word);
}
}
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we use
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
const std::vector<int>& input = inputs[i];
const std::vector<int>& time_stamp = times[i];
CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
int start = time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ > 0
? time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_
: 0;
if (j > 0) {
start = (time_stamp[j] - time_stamp[j - 1]) * frame_shift_in_ms() <
time_stamp_gap_
? (time_stamp[j - 1] + time_stamp[j]) / 2 *
frame_shift_in_ms()
: start;
}
int end = time_stamp[j] * frame_shift_in_ms();
if (j < input.size() - 1) {
end = (time_stamp[j + 1] - time_stamp[j]) * frame_shift_in_ms() <
time_stamp_gap_
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
frame_shift_in_ms()
: end;
}
WordPiece word_piece(word, offset + start, offset + end);
path.word_pieces.emplace_back(word_piece);
}
}
if (post_processor_ != nullptr) {
path.sentence = post_processor_->Process(path.sentence, finish);
}
result_.emplace_back(path);
}
if (DecodedSomething()) {
VLOG(1) << "Partial CTC result " << result_[0].sentence;
if (context_graph_ != nullptr) {
int cur_state = 0;
float score = 0;
for (int ilabel : inputs[0]) {
cur_state = context_graph_->GetNextState(cur_state, ilabel, &score,
&(result_[0].contexts));
}
std::string contexts;
for (const auto& context : result_[0].contexts) {
contexts += context + ", ";
}
VLOG(1) << "Contexts: " << contexts;
}
}
}
void AsrDecoder::AttentionRescoring() {
searcher_->FinalizeSearch();
UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.rescoring_weight) {
return;
}
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = searcher_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
// TODO(zhendong.peng): Do we need rescoring while context matching?
std::vector<float> rescoring_score;
model_->AttentionRescoring(hypotheses, opts_.reverse_weight,
&rescoring_score);
// Combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; ++i) {
result_[i].score = opts_.rescoring_weight * rescoring_score[i] +
opts_.ctc_weight * result_[i].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
}
} // namespace wenet