You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

168 lines
5.5 KiB

// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_ASR_DECODER_H_
#define DECODER_ASR_DECODER_H_
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "decoder/asr_model.h"
#include "decoder/context_graph.h"
#include "decoder/ctc_endpoint.h"
#include "decoder/ctc_prefix_beam_search.h"
#include "decoder/ctc_wfst_beam_search.h"
#include "decoder/search_interface.h"
#include "frontend/feature_pipeline.h"
#include "../post_processor/processor/post_processor.h"
#include "../utils/wn_utils.h"
namespace wenet {
struct DecodeOptions {
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 64 = 16*4
int chunk_size = 16;
int num_left_chunks = -1;
// final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores in the following two search
// methods are different.
// For CtcPrefixBeamSearch, it's a sum(prefix) score + context score
// For CtcWfstBeamSearch, it's a max(viterbi) path score + context score
// So we should carefully set ctc_weight according to the search methods.
float ctc_weight = 0.5;
float rescoring_weight = 1.0;
float reverse_weight = 0.0;
CtcEndpointConfig ctc_endpoint_config;
CtcPrefixBeamSearchOptions ctc_prefix_search_opts;
CtcWfstBeamSearchOptions ctc_wfst_search_opts;
};
struct WordPiece {
std::string word;
int start = -1;
int end = -1;
WordPiece(std::string word, int start, int end)
: word(std::move(word)), start(start), end(end) {}
};
struct DecodeResult {
float score = -kFloatMax;
std::string sentence;
std::unordered_set<std::string> contexts;
std::vector<WordPiece> word_pieces;
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
return a.score > b.score;
}
};
enum DecodeState {
kEndBatch = 0x00, // End of current decoding batch, normal case
kEndpoint = 0x01, // Endpoint is detected
kEndFeats = 0x02, // All feature is decoded
kWaitFeats = 0x03 // Feat is not enough for one chunk inference, wait
};
// DecodeResource is thread safe, which can be shared for multiple
// decoding threads
struct DecodeResource {
std::shared_ptr<AsrModel> model = nullptr;
std::shared_ptr<fst::SymbolTable> symbol_table = nullptr;
std::shared_ptr<fst::VectorFst<fst::StdArc>> fst = nullptr;
std::shared_ptr<fst::SymbolTable> unit_table = nullptr;
std::shared_ptr<ContextGraph> context_graph = nullptr;
std::shared_ptr<PostProcessor> post_processor = nullptr;
};
// Torch ASR decoder
class AsrDecoder {
public:
AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
std::shared_ptr<DecodeResource> resource,
const DecodeOptions& opts);
// @param block: if true, block when feature is not enough for one chunk
// inference. Otherwise, return kWaitFeats.
DecodeState Decode(bool block = true);
void Rescoring();
void Reset();
void ResetContinuousDecoding();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
// This method is used for time benchmark
int num_frames_in_current_chunk() const {
return num_frames_in_current_chunk_;
}
int frame_shift_in_ms() const {
return model_->subsampling_rate() *
feature_pipeline_->config().frame_shift * 1000 /
feature_pipeline_->config().sample_rate;
}
int feature_frame_shift_in_ms() const {
return feature_pipeline_->config().frame_shift * 1000 /
feature_pipeline_->config().sample_rate;
}
const std::vector<DecodeResult>& result() const { return result_; }
private:
DecodeState AdvanceDecoding(bool block = true);
void AttentionRescoring();
void UpdateResult(bool finish = false);
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<AsrModel> model_;
std::shared_ptr<PostProcessor> post_processor_;
std::shared_ptr<ContextGraph> context_graph_;
std::shared_ptr<fst::VectorFst<fst::StdArc>> fst_ = nullptr;
// output symbol table
std::shared_ptr<fst::SymbolTable> symbol_table_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
const DecodeOptions& opts_;
// cache feature
bool start_ = false;
// For continuous decoding
int num_frames_ = 0;
int global_frame_offset_ = 0;
const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence
std::unique_ptr<SearchInterface> searcher_;
std::unique_ptr<CtcEndpoint> ctc_endpointer_;
int num_frames_in_current_chunk_ = 0;
std::vector<DecodeResult> result_;
public:
WENET_DISALLOW_COPY_AND_ASSIGN(AsrDecoder);
};
} // namespace wenet
#endif // DECODER_ASR_DECODER_H_