You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.0 KiB

  1. // Copyright 2022 Horizon Robotics. All Rights Reserved.
  2. // Author: binbin.zhang@horizon.ai (Binbin Zhang)
  3. #ifndef DECODER_ASR_MODEL_H_
  4. #define DECODER_ASR_MODEL_H_
  5. #include <limits>
  6. #include <memory>
  7. #include <string>
  8. #include <vector>
  9. #include "utils/timer.h"
  10. #include "utils/wn_utils.h"
  11. namespace wenet {
  12. class AsrModel {
  13. public:
  14. virtual int right_context() const { return right_context_; }
  15. virtual int subsampling_rate() const { return subsampling_rate_; }
  16. virtual int sos() const { return sos_; }
  17. virtual int eos() const { return eos_; }
  18. virtual bool is_bidirectional_decoder() const {
  19. return is_bidirectional_decoder_;
  20. }
  21. virtual int offset() const { return offset_; }
  22. // If chunk_size > 0, streaming case. Otherwise, none streaming case
  23. virtual void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; }
  24. virtual void set_num_left_chunks(int num_left_chunks) {
  25. num_left_chunks_ = num_left_chunks;
  26. }
  27. // start: if it is the start chunk of one sentence
  28. virtual int num_frames_for_chunk(bool start) const;
  29. virtual void Reset() = 0;
  30. virtual void ForwardEncoder(
  31. const std::vector<std::vector<float>>& chunk_feats,
  32. std::vector<std::vector<float>>* ctc_prob);
  33. virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
  34. float reverse_weight,
  35. std::vector<float>* rescoring_score) = 0;
  36. virtual std::shared_ptr<AsrModel> Copy() const = 0;
  37. protected:
  38. virtual void ForwardEncoderFunc(
  39. const std::vector<std::vector<float>>& chunk_feats,
  40. std::vector<std::vector<float>>* ctc_prob) = 0;
  41. virtual void CacheFeature(const std::vector<std::vector<float>>& chunk_feats);
  42. int right_context_ = 1;
  43. int subsampling_rate_ = 1;
  44. int sos_ = 0;
  45. int eos_ = 0;
  46. bool is_bidirectional_decoder_ = false;
  47. int chunk_size_ = 16;
  48. int num_left_chunks_ = -1; // -1 means all left chunks
  49. int offset_ = 0;
  50. std::vector<std::vector<float>> cached_feature_;
  51. };
  52. } // namespace wenet
  53. #endif // DECODER_ASR_MODEL_H_