You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.7 KiB

  1. // Copyright 2022 Horizon Robotics. All Rights Reserved.
  2. // Author: binbin.zhang@horizon.ai (Binbin Zhang)
  3. #include "decoder/asr_model.h"
  4. #include <memory>
  5. #include <utility>
  6. namespace wenet {
  7. int AsrModel::num_frames_for_chunk(bool start) const {
  8. int num_required_frames = 0;
  9. if (chunk_size_ > 0) {
  10. if (!start) { // First batch
  11. int context = right_context_ + 1; // Add current frame
  12. num_required_frames = (chunk_size_ - 1) * subsampling_rate_ + context;
  13. } else {
  14. num_required_frames = chunk_size_ * subsampling_rate_;
  15. }
  16. } else {
  17. num_required_frames = std::numeric_limits<int>::max();
  18. }
  19. return num_required_frames;
  20. }
  21. void AsrModel::CacheFeature(
  22. const std::vector<std::vector<float>>& chunk_feats) {
  23. // Cache feature for next chunk
  24. const int cached_feature_size = 1 + right_context_ - subsampling_rate_;
  25. if (chunk_feats.size() >= cached_feature_size) {
  26. // TODO(Binbin Zhang): Only deal the case when
  27. // chunk_feats.size() > cached_feature_size here, and it's consistent
  28. // with our current model, refine it later if we have new model or
  29. // new requirements
  30. cached_feature_.resize(cached_feature_size);
  31. for (int i = 0; i < cached_feature_size; ++i) {
  32. cached_feature_[i] =
  33. chunk_feats[chunk_feats.size() - cached_feature_size + i];
  34. }
  35. }
  36. }
  37. void AsrModel::ForwardEncoder(
  38. const std::vector<std::vector<float>>& chunk_feats,
  39. std::vector<std::vector<float>>* ctc_prob) {
  40. ctc_prob->clear();
  41. int num_frames = cached_feature_.size() + chunk_feats.size();
  42. if (num_frames >= right_context_ + 1) {
  43. this->ForwardEncoderFunc(chunk_feats, ctc_prob);
  44. this->CacheFeature(chunk_feats);
  45. }
  46. }
  47. } // namespace wenet