You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

68 lines
2.0 KiB

// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Author: binbin.zhang@horizon.ai (Binbin Zhang)
#ifndef DECODER_ASR_MODEL_H_
#define DECODER_ASR_MODEL_H_
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "utils/timer.h"
#include "utils/wn_utils.h"
namespace wenet {
class AsrModel {
public:
virtual int right_context() const { return right_context_; }
virtual int subsampling_rate() const { return subsampling_rate_; }
virtual int sos() const { return sos_; }
virtual int eos() const { return eos_; }
virtual bool is_bidirectional_decoder() const {
return is_bidirectional_decoder_;
}
virtual int offset() const { return offset_; }
// If chunk_size > 0, streaming case. Otherwise, none streaming case
virtual void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; }
virtual void set_num_left_chunks(int num_left_chunks) {
num_left_chunks_ = num_left_chunks;
}
// start: if it is the start chunk of one sentence
virtual int num_frames_for_chunk(bool start) const;
virtual void Reset() = 0;
virtual void ForwardEncoder(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob);
virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) = 0;
virtual std::shared_ptr<AsrModel> Copy() const = 0;
protected:
virtual void ForwardEncoderFunc(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob) = 0;
virtual void CacheFeature(const std::vector<std::vector<float>>& chunk_feats);
int right_context_ = 1;
int subsampling_rate_ = 1;
int sos_ = 0;
int eos_ = 0;
bool is_bidirectional_decoder_ = false;
int chunk_size_ = 16;
int num_left_chunks_ = -1; // -1 means all left chunks
int offset_ = 0;
std::vector<std::vector<float>> cached_feature_;
};
} // namespace wenet
#endif // DECODER_ASR_MODEL_H_