You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.2 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 Binbin Zhang (binbzha@qq.com)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #ifndef DECODER_TORCH_ASR_MODEL_H_
  16. #define DECODER_TORCH_ASR_MODEL_H_
  17. #include <memory>
  18. #include <string>
  19. #include <vector>
  20. #include "torch/script.h"
  21. #ifndef IOS
  22. #include "torch/torch.h"
  23. #endif
  24. #include "decoder/asr_model.h"
  25. #include "../utils/wn_utils.h"
  26. namespace wenet {
  27. class TorchAsrModel : public AsrModel {
  28. public:
  29. #ifndef IOS
  30. static void InitEngineThreads(int num_threads = 1);
  31. #endif
  32. public:
  33. using TorchModule = torch::jit::script::Module;
  34. TorchAsrModel() = default;
  35. TorchAsrModel(const TorchAsrModel& other);
  36. void Read(const std::string& model_path);
  37. std::shared_ptr<TorchModule> torch_model() const { return model_; }
  38. void Reset() override;
  39. void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
  40. float reverse_weight,
  41. std::vector<float>* rescoring_score) override;
  42. std::shared_ptr<AsrModel> Copy() const override;
  43. protected:
  44. void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
  45. std::vector<std::vector<float>>* ctc_prob) override;
  46. float ComputeAttentionScore(const torch::Tensor& prob,
  47. const std::vector<int>& hyp, int eos);
  48. private:
  49. std::shared_ptr<TorchModule> model_ = nullptr;
  50. std::vector<torch::Tensor> encoder_outs_;
  51. // transformer/conformer attention cache
  52. torch::Tensor att_cache_ = torch::zeros({0, 0, 0, 0});
  53. // conformer-only conv_module cache
  54. torch::Tensor cnn_cache_ = torch::zeros({0, 0, 0, 0});
  55. };
  56. } // namespace wenet
  57. #endif // DECODER_TORCH_ASR_MODEL_H_