|
// Copyright 2022 Horizon Robotics. All Rights Reserved.
|
|
// Author: binbin.zhang@horizon.ai (Binbin Zhang)
|
|
|
|
#include "decoder/asr_model.h"
|
|
|
|
#include <memory>
|
|
#include <utility>
|
|
|
|
namespace wenet {
|
|
|
|
int AsrModel::num_frames_for_chunk(bool start) const {
|
|
int num_required_frames = 0;
|
|
if (chunk_size_ > 0) {
|
|
if (!start) { // First batch
|
|
int context = right_context_ + 1; // Add current frame
|
|
num_required_frames = (chunk_size_ - 1) * subsampling_rate_ + context;
|
|
} else {
|
|
num_required_frames = chunk_size_ * subsampling_rate_;
|
|
}
|
|
} else {
|
|
num_required_frames = std::numeric_limits<int>::max();
|
|
}
|
|
return num_required_frames;
|
|
}
|
|
|
|
void AsrModel::CacheFeature(
|
|
const std::vector<std::vector<float>>& chunk_feats) {
|
|
// Cache feature for next chunk
|
|
const int cached_feature_size = 1 + right_context_ - subsampling_rate_;
|
|
if (chunk_feats.size() >= cached_feature_size) {
|
|
// TODO(Binbin Zhang): Only deal the case when
|
|
// chunk_feats.size() > cached_feature_size here, and it's consistent
|
|
// with our current model, refine it later if we have new model or
|
|
// new requirements
|
|
cached_feature_.resize(cached_feature_size);
|
|
for (int i = 0; i < cached_feature_size; ++i) {
|
|
cached_feature_[i] =
|
|
chunk_feats[chunk_feats.size() - cached_feature_size + i];
|
|
}
|
|
}
|
|
}
|
|
|
|
void AsrModel::ForwardEncoder(
|
|
const std::vector<std::vector<float>>& chunk_feats,
|
|
std::vector<std::vector<float>>* ctc_prob) {
|
|
ctc_prob->clear();
|
|
int num_frames = cached_feature_.size() + chunk_feats.size();
|
|
if (num_frames >= right_context_ + 1) {
|
|
this->ForwardEncoderFunc(chunk_feats, ctc_prob);
|
|
this->CacheFeature(chunk_feats);
|
|
}
|
|
}
|
|
|
|
} // namespace wenet
|