You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
3.3 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 ZeXuan Li (lizexuan@huya.com)
  3. // Xingchen Song(sxc19@mails.tsinghua.edu.cn)
  4. // hamddct@gmail.com (Mddct)
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. #ifndef DECODER_ONNX_ASR_MODEL_H_
  18. #define DECODER_ONNX_ASR_MODEL_H_
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "onnxruntime_cxx_api.h" // NOLINT
  23. #include "decoder/asr_model.h"
  24. #include "../utils/wn_utils.h"
  25. namespace wenet {
  26. class OnnxAsrModel : public AsrModel {
  27. public:
  28. static void InitEngineThreads(int num_threads = 1);
  29. public:
  30. OnnxAsrModel() = default;
  31. OnnxAsrModel(const OnnxAsrModel& other);
  32. void Read(const std::string& model_dir);
  33. void Reset() override;
  34. void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
  35. float reverse_weight,
  36. std::vector<float>* rescoring_score) override;
  37. std::shared_ptr<AsrModel> Copy() const override;
  38. void GetInputOutputInfo(const std::shared_ptr<Ort::Session>& session,
  39. std::vector<const char*>* in_names,
  40. std::vector<const char*>* out_names);
  41. protected:
  42. void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
  43. std::vector<std::vector<float>>* ctc_prob) override;
  44. float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
  45. int eos, int decode_out_len);
  46. private:
  47. int encoder_output_size_ = 0;
  48. int num_blocks_ = 0;
  49. int cnn_module_kernel_ = 0;
  50. int head_ = 0;
  51. // sessions
  52. // NOTE(Mddct): The Env holds the logging state used by all other objects.
  53. // One Env must be created before using any other Onnxruntime functionality.
  54. static Ort::Env env_; // shared environment across threads.
  55. static Ort::SessionOptions session_options_;
  56. std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
  57. std::shared_ptr<Ort::Session> rescore_session_ = nullptr;
  58. std::shared_ptr<Ort::Session> ctc_session_ = nullptr;
  59. // node names
  60. std::vector<const char*> encoder_in_names_, encoder_out_names_;
  61. std::vector<const char*> ctc_in_names_, ctc_out_names_;
  62. std::vector<const char*> rescore_in_names_, rescore_out_names_;
  63. // caches
  64. Ort::Value att_cache_ort_{nullptr};
  65. Ort::Value cnn_cache_ort_{nullptr};
  66. std::vector<Ort::Value> encoder_outs_;
  67. // NOTE: Instead of making a copy of the xx_cache, ONNX only maintains
  68. // its data pointer when initializing xx_cache_ort (see https://github.com/
  69. // microsoft/onnxruntime/blob/master/onnxruntime/core/framework
  70. // /tensor.cc#L102-L129), so we need the following variables to keep
  71. // our data "alive" during the lifetime of decoder.
  72. std::vector<float> att_cache_;
  73. std::vector<float> cnn_cache_;
  74. };
  75. } // namespace wenet
  76. #endif // DECODER_ONNX_ASR_MODEL_H_