|
|
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CTC_WFST_BEAM_SEARCH_H_
#define DECODER_CTC_WFST_BEAM_SEARCH_H_
#include <memory>
#include <vector>
#include "decoder/context_graph.h"
#include "decoder/search_interface.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "../utils/wn_utils.h"
namespace wenet {
class DecodableTensorScaled : public kaldi::DecodableInterface { public: explicit DecodableTensorScaled(float scale = 1.0) : scale_(scale) { Reset(); }
void Reset(); int32 NumFramesReady() const override { return num_frames_ready_; } bool IsLastFrame(int32 frame) const override; float LogLikelihood(int32 frame, int32 index) override; int32 NumIndices() const override; void AcceptLoglikes(const std::vector<float>& logp); void SetFinish() { done_ = true; }
private: int num_frames_ready_ = 0; float scale_ = 1.0; bool done_ = false; std::vector<float> logp_; };
// LatticeFasterDecoderConfig has the following key members
// beam: decoding beam
// max_active: Decoder max active states
// lattice_beam: Lattice generation beam
struct CtcWfstBeamSearchOptions : public kaldi::LatticeFasterDecoderConfig { float acoustic_scale = 1.0; float nbest = 10; // When blank score is greater than this thresh, skip the frame in viterbi
// search
float blank_skip_thresh = 0.98; float blank_scale = 1.0; int blank = 0; };
class CtcWfstBeamSearch : public SearchInterface { public: explicit CtcWfstBeamSearch( const fst::Fst<fst::StdArc>& fst, const CtcWfstBeamSearchOptions& opts, const std::shared_ptr<ContextGraph>& context_graph); void Search(const std::vector<std::vector<float>>& logp) override; void Reset() override; void FinalizeSearch() override; SearchType Type() const override { return SearchType::kWfstBeamSearch; } // For CTC prefix beam search, both inputs and outputs are hypotheses_
const std::vector<std::vector<int>>& Inputs() const override { return inputs_; } const std::vector<std::vector<int>>& Outputs() const override { return outputs_; } const std::vector<float>& Likelihood() const override { return likelihood_; } const std::vector<std::vector<int>>& Times() const override { return times_; }
private: // Sub one and remove <blank>
void ConvertToInputs(const std::vector<int>& alignment, std::vector<int>* input, std::vector<int>* time = nullptr);
int num_frames_ = 0; std::vector<int> decoded_frames_mapping_;
int last_best_ = 0; // last none blank best id
std::vector<float> last_frame_prob_; bool is_last_frame_blank_ = false; std::vector<std::vector<int>> inputs_, outputs_; std::vector<float> likelihood_; std::vector<std::vector<int>> times_; DecodableTensorScaled decodable_; kaldi::LatticeFasterOnlineDecoder decoder_; std::shared_ptr<ContextGraph> context_graph_; const CtcWfstBeamSearchOptions& opts_; };
} // namespace wenet
#endif // DECODER_CTC_WFST_BEAM_SEARCH_H_
|