You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

103 lines
3.5 KiB

// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CTC_WFST_BEAM_SEARCH_H_
#define DECODER_CTC_WFST_BEAM_SEARCH_H_
#include <memory>
#include <vector>
#include "decoder/context_graph.h"
#include "decoder/search_interface.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "utils/utils.h"
namespace wenet {
class DecodableTensorScaled : public kaldi::DecodableInterface {
public:
explicit DecodableTensorScaled(float scale = 1.0) : scale_(scale) { Reset(); }
void Reset();
int32 NumFramesReady() const override { return num_frames_ready_; }
bool IsLastFrame(int32 frame) const override;
float LogLikelihood(int32 frame, int32 index) override;
int32 NumIndices() const override;
void AcceptLoglikes(const std::vector<float>& logp);
void SetFinish() { done_ = true; }
private:
int num_frames_ready_ = 0;
float scale_ = 1.0;
bool done_ = false;
std::vector<float> logp_;
};
// LatticeFasterDecoderConfig has the following key members
// beam: decoding beam
// max_active: Decoder max active states
// lattice_beam: Lattice generation beam
struct CtcWfstBeamSearchOptions : public kaldi::LatticeFasterDecoderConfig {
float acoustic_scale = 1.0;
float nbest = 10;
// When blank score is greater than this thresh, skip the frame in viterbi
// search
float blank_skip_thresh = 0.98;
float blank_scale = 1.0;
int blank = 0;
};
class CtcWfstBeamSearch : public SearchInterface {
public:
explicit CtcWfstBeamSearch(
const fst::Fst<fst::StdArc>& fst, const CtcWfstBeamSearchOptions& opts,
const std::shared_ptr<ContextGraph>& context_graph);
void Search(const std::vector<std::vector<float>>& logp) override;
void Reset() override;
void FinalizeSearch() override;
SearchType Type() const override { return SearchType::kWfstBeamSearch; }
// For CTC prefix beam search, both inputs and outputs are hypotheses_
const std::vector<std::vector<int>>& Inputs() const override {
return inputs_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return outputs_;
}
const std::vector<float>& Likelihood() const override { return likelihood_; }
const std::vector<std::vector<int>>& Times() const override { return times_; }
private:
// Sub one and remove <blank>
void ConvertToInputs(const std::vector<int>& alignment,
std::vector<int>* input,
std::vector<int>* time = nullptr);
int num_frames_ = 0;
std::vector<int> decoded_frames_mapping_;
int last_best_ = 0; // last none blank best id
std::vector<float> last_frame_prob_;
bool is_last_frame_blank_ = false;
std::vector<std::vector<int>> inputs_, outputs_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
DecodableTensorScaled decodable_;
kaldi::LatticeFasterOnlineDecoder decoder_;
std::shared_ptr<ContextGraph> context_graph_;
const CtcWfstBeamSearchOptions& opts_;
};
} // namespace wenet
#endif // DECODER_CTC_WFST_BEAM_SEARCH_H_