// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "decoder/ctc_wfst_beam_search.h" #include namespace wenet { void DecodableTensorScaled::Reset() { num_frames_ready_ = 0; done_ = false; // Give an empty initialization, will throw error when // AcceptLoglikes is not called logp_.clear(); } void DecodableTensorScaled::AcceptLoglikes(const std::vector& logp) { ++num_frames_ready_; // TODO(Binbin Zhang): Avoid copy here logp_ = logp; } float DecodableTensorScaled::LogLikelihood(int32 frame, int32 index) { CHECK_GT(index, 0); CHECK_LT(frame, num_frames_ready_); return scale_ * logp_[index - 1]; } bool DecodableTensorScaled::IsLastFrame(int32 frame) const { CHECK_LT(frame, num_frames_ready_); return done_ && (frame == num_frames_ready_ - 1); } int32 DecodableTensorScaled::NumIndices() const { LOG(FATAL) << "Not implement"; return 0; } CtcWfstBeamSearch::CtcWfstBeamSearch( const fst::Fst& fst, const CtcWfstBeamSearchOptions& opts, const std::shared_ptr& context_graph) : decodable_(opts.acoustic_scale), decoder_(fst, opts, context_graph), context_graph_(context_graph), opts_(opts) { Reset(); } void CtcWfstBeamSearch::Reset() { num_frames_ = 0; decoded_frames_mapping_.clear(); is_last_frame_blank_ = false; last_best_ = 0; inputs_.clear(); outputs_.clear(); likelihood_.clear(); times_.clear(); decodable_.Reset(); decoder_.InitDecoding(); } void CtcWfstBeamSearch::Search(const std::vector>& logp) { if (0 == logp.size()) { return; } // Every time we get the log posterior, we decode it all before return for (int i = 0; i < logp.size(); i++) { float blank_score = std::exp(logp[i][opts_.blank]); if (blank_score > opts_.blank_skip_thresh * opts_.blank_scale) { VLOG(3) << "skipping frame " << num_frames_ << " score " << blank_score; is_last_frame_blank_ = true; last_frame_prob_ = logp[i]; } else { // Get the best symbol int cur_best = std::max_element(logp[i].begin(), logp[i].end()) - logp[i].begin(); // Optional, adding one blank frame if we has skipped it in two same // symbols if (cur_best != opts_.blank && is_last_frame_blank_ && cur_best == last_best_) { decodable_.AcceptLoglikes(last_frame_prob_); decoder_.AdvanceDecoding(&decodable_, 1); decoded_frames_mapping_.push_back(num_frames_ - 1); VLOG(2) << "Adding blank frame at symbol " << cur_best; } last_best_ = cur_best; decodable_.AcceptLoglikes(logp[i]); decoder_.AdvanceDecoding(&decodable_, 1); decoded_frames_mapping_.push_back(num_frames_); is_last_frame_blank_ = false; } num_frames_++; } // Get the best path inputs_.clear(); outputs_.clear(); likelihood_.clear(); if (decoded_frames_mapping_.size() > 0) { inputs_.resize(1); outputs_.resize(1); likelihood_.resize(1); kaldi::Lattice lat; decoder_.GetBestPath(&lat, true); std::vector alignment; kaldi::LatticeWeight weight; fst::GetLinearSymbolSequence(lat, &alignment, &outputs_[0], &weight); ConvertToInputs(alignment, &inputs_[0]); VLOG(3) << weight.Value1() << " " << weight.Value2(); likelihood_[0] = -(weight.Value1() + weight.Value2()); } } void CtcWfstBeamSearch::FinalizeSearch() { decodable_.SetFinish(); decoder_.FinalizeDecoding(); inputs_.clear(); outputs_.clear(); likelihood_.clear(); times_.clear(); if (decoded_frames_mapping_.size() > 0) { std::vector nbest_lats; if (opts_.nbest == 1) { kaldi::Lattice lat; decoder_.GetBestPath(&lat, true); nbest_lats.push_back(std::move(lat)); } else { // Get N-best path by lattice(CompactLattice) kaldi::CompactLattice clat; decoder_.GetLattice(&clat, true); kaldi::Lattice lat, nbest_lat; fst::ConvertLattice(clat, &lat); // TODO(Binbin Zhang): it's n-best word lists here, not character n-best fst::ShortestPath(lat, &nbest_lat, opts_.nbest); fst::ConvertNbestToVector(nbest_lat, &nbest_lats); } int nbest = nbest_lats.size(); inputs_.resize(nbest); outputs_.resize(nbest); likelihood_.resize(nbest); times_.resize(nbest); for (int i = 0; i < nbest; i++) { kaldi::LatticeWeight weight; std::vector alignment; fst::GetLinearSymbolSequence(nbest_lats[i], &alignment, &outputs_[i], &weight); ConvertToInputs(alignment, &inputs_[i], ×_[i]); likelihood_[i] = -(weight.Value1() + weight.Value2()); } } } void CtcWfstBeamSearch::ConvertToInputs(const std::vector& alignment, std::vector* input, std::vector* time) { input->clear(); if (time != nullptr) time->clear(); for (int cur = 0; cur < alignment.size(); ++cur) { // ignore blank if (alignment[cur] - 1 == opts_.blank) continue; // merge continuous same label if (cur > 0 && alignment[cur] == alignment[cur - 1]) continue; input->push_back(alignment[cur] - 1); if (time != nullptr) { time->push_back(decoded_frames_mapping_[cur]); } } } } // namespace wenet