|
|
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Author: binbin.zhang@horizon.ai (Binbin Zhang)
#ifndef DECODER_ASR_MODEL_H_
#define DECODER_ASR_MODEL_H_
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "utils/timer.h"
#include "utils/wn_utils.h"
namespace wenet {
class AsrModel { public: virtual int right_context() const { return right_context_; } virtual int subsampling_rate() const { return subsampling_rate_; } virtual int sos() const { return sos_; } virtual int eos() const { return eos_; } virtual bool is_bidirectional_decoder() const { return is_bidirectional_decoder_; } virtual int offset() const { return offset_; }
// If chunk_size > 0, streaming case. Otherwise, none streaming case
virtual void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; } virtual void set_num_left_chunks(int num_left_chunks) { num_left_chunks_ = num_left_chunks; } // start: if it is the start chunk of one sentence
virtual int num_frames_for_chunk(bool start) const;
virtual void Reset() = 0;
virtual void ForwardEncoder( const std::vector<std::vector<float>>& chunk_feats, std::vector<std::vector<float>>* ctc_prob);
virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps, float reverse_weight, std::vector<float>* rescoring_score) = 0;
virtual std::shared_ptr<AsrModel> Copy() const = 0;
protected: virtual void ForwardEncoderFunc( const std::vector<std::vector<float>>& chunk_feats, std::vector<std::vector<float>>* ctc_prob) = 0; virtual void CacheFeature(const std::vector<std::vector<float>>& chunk_feats);
int right_context_ = 1; int subsampling_rate_ = 1; int sos_ = 0; int eos_ = 0; bool is_bidirectional_decoder_ = false; int chunk_size_ = 16; int num_left_chunks_ = -1; // -1 means all left chunks
int offset_ = 0;
std::vector<std::vector<float>> cached_feature_; };
} // namespace wenet
#endif // DECODER_ASR_MODEL_H_
|