// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "decoder/ctc_wfst_beam_search.h"
|
|
|
|
#include <utility>
|
|
|
|
namespace wenet {
|
|
|
|
void DecodableTensorScaled::Reset() {
|
|
num_frames_ready_ = 0;
|
|
done_ = false;
|
|
// Give an empty initialization, will throw error when
|
|
// AcceptLoglikes is not called
|
|
logp_.clear();
|
|
}
|
|
|
|
void DecodableTensorScaled::AcceptLoglikes(const std::vector<float>& logp) {
|
|
++num_frames_ready_;
|
|
// TODO(Binbin Zhang): Avoid copy here
|
|
logp_ = logp;
|
|
}
|
|
|
|
float DecodableTensorScaled::LogLikelihood(int32 frame, int32 index) {
|
|
CHECK_GT(index, 0);
|
|
CHECK_LT(frame, num_frames_ready_);
|
|
return scale_ * logp_[index - 1];
|
|
}
|
|
|
|
bool DecodableTensorScaled::IsLastFrame(int32 frame) const {
|
|
CHECK_LT(frame, num_frames_ready_);
|
|
return done_ && (frame == num_frames_ready_ - 1);
|
|
}
|
|
|
|
int32 DecodableTensorScaled::NumIndices() const {
|
|
LOG(FATAL) << "Not implement";
|
|
return 0;
|
|
}
|
|
|
|
CtcWfstBeamSearch::CtcWfstBeamSearch(
|
|
const fst::Fst<fst::StdArc>& fst, const CtcWfstBeamSearchOptions& opts,
|
|
const std::shared_ptr<ContextGraph>& context_graph)
|
|
: decodable_(opts.acoustic_scale),
|
|
decoder_(fst, opts, context_graph),
|
|
context_graph_(context_graph),
|
|
opts_(opts) {
|
|
Reset();
|
|
}
|
|
|
|
void CtcWfstBeamSearch::Reset() {
|
|
num_frames_ = 0;
|
|
decoded_frames_mapping_.clear();
|
|
is_last_frame_blank_ = false;
|
|
last_best_ = 0;
|
|
inputs_.clear();
|
|
outputs_.clear();
|
|
likelihood_.clear();
|
|
times_.clear();
|
|
decodable_.Reset();
|
|
decoder_.InitDecoding();
|
|
}
|
|
|
|
void CtcWfstBeamSearch::Search(const std::vector<std::vector<float>>& logp) {
|
|
if (0 == logp.size()) {
|
|
return;
|
|
}
|
|
// Every time we get the log posterior, we decode it all before return
|
|
for (int i = 0; i < logp.size(); i++) {
|
|
float blank_score = std::exp(logp[i][opts_.blank]);
|
|
if (blank_score > opts_.blank_skip_thresh * opts_.blank_scale) {
|
|
VLOG(3) << "skipping frame " << num_frames_ << " score " << blank_score;
|
|
is_last_frame_blank_ = true;
|
|
last_frame_prob_ = logp[i];
|
|
} else {
|
|
// Get the best symbol
|
|
int cur_best =
|
|
std::max_element(logp[i].begin(), logp[i].end()) - logp[i].begin();
|
|
// Optional, adding one blank frame if we has skipped it in two same
|
|
// symbols
|
|
if (cur_best != opts_.blank && is_last_frame_blank_ &&
|
|
cur_best == last_best_) {
|
|
decodable_.AcceptLoglikes(last_frame_prob_);
|
|
decoder_.AdvanceDecoding(&decodable_, 1);
|
|
decoded_frames_mapping_.push_back(num_frames_ - 1);
|
|
VLOG(2) << "Adding blank frame at symbol " << cur_best;
|
|
}
|
|
last_best_ = cur_best;
|
|
|
|
decodable_.AcceptLoglikes(logp[i]);
|
|
decoder_.AdvanceDecoding(&decodable_, 1);
|
|
decoded_frames_mapping_.push_back(num_frames_);
|
|
is_last_frame_blank_ = false;
|
|
}
|
|
num_frames_++;
|
|
}
|
|
// Get the best path
|
|
inputs_.clear();
|
|
outputs_.clear();
|
|
likelihood_.clear();
|
|
if (decoded_frames_mapping_.size() > 0) {
|
|
inputs_.resize(1);
|
|
outputs_.resize(1);
|
|
likelihood_.resize(1);
|
|
kaldi::Lattice lat;
|
|
decoder_.GetBestPath(&lat, true);
|
|
std::vector<int> alignment;
|
|
kaldi::LatticeWeight weight;
|
|
fst::GetLinearSymbolSequence(lat, &alignment, &outputs_[0], &weight);
|
|
ConvertToInputs(alignment, &inputs_[0]);
|
|
VLOG(3) << weight.Value1() << " " << weight.Value2();
|
|
likelihood_[0] = -(weight.Value1() + weight.Value2());
|
|
}
|
|
}
|
|
|
|
void CtcWfstBeamSearch::FinalizeSearch() {
|
|
decodable_.SetFinish();
|
|
decoder_.FinalizeDecoding();
|
|
inputs_.clear();
|
|
outputs_.clear();
|
|
likelihood_.clear();
|
|
times_.clear();
|
|
if (decoded_frames_mapping_.size() > 0) {
|
|
std::vector<kaldi::Lattice> nbest_lats;
|
|
if (opts_.nbest == 1) {
|
|
kaldi::Lattice lat;
|
|
decoder_.GetBestPath(&lat, true);
|
|
nbest_lats.push_back(std::move(lat));
|
|
} else {
|
|
// Get N-best path by lattice(CompactLattice)
|
|
kaldi::CompactLattice clat;
|
|
decoder_.GetLattice(&clat, true);
|
|
kaldi::Lattice lat, nbest_lat;
|
|
fst::ConvertLattice(clat, &lat);
|
|
// TODO(Binbin Zhang): it's n-best word lists here, not character n-best
|
|
fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
|
|
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
|
|
}
|
|
int nbest = nbest_lats.size();
|
|
inputs_.resize(nbest);
|
|
outputs_.resize(nbest);
|
|
likelihood_.resize(nbest);
|
|
times_.resize(nbest);
|
|
for (int i = 0; i < nbest; i++) {
|
|
kaldi::LatticeWeight weight;
|
|
std::vector<int> alignment;
|
|
fst::GetLinearSymbolSequence(nbest_lats[i], &alignment, &outputs_[i],
|
|
&weight);
|
|
ConvertToInputs(alignment, &inputs_[i], ×_[i]);
|
|
likelihood_[i] = -(weight.Value1() + weight.Value2());
|
|
}
|
|
}
|
|
}
|
|
|
|
void CtcWfstBeamSearch::ConvertToInputs(const std::vector<int>& alignment,
|
|
std::vector<int>* input,
|
|
std::vector<int>* time) {
|
|
input->clear();
|
|
if (time != nullptr) time->clear();
|
|
for (int cur = 0; cur < alignment.size(); ++cur) {
|
|
// ignore blank
|
|
if (alignment[cur] - 1 == opts_.blank) continue;
|
|
// merge continuous same label
|
|
if (cur > 0 && alignment[cur] == alignment[cur - 1]) continue;
|
|
|
|
input->push_back(alignment[cur] - 1);
|
|
if (time != nullptr) {
|
|
time->push_back(decoded_frames_mapping_[cur]);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace wenet
|