You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

287 lines
9.7 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 Binbin Zhang (binbzha@qq.com)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #include "decoder/torch_asr_model.h"
  16. #include <algorithm>
  17. #include <memory>
  18. #include <stdexcept>
  19. #include <utility>
  20. #include "torch/script.h"
  21. #ifndef IOS
  22. #include "torch/torch.h"
  23. #endif
  24. #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
  25. namespace wenet {
  26. #ifndef IOS
  27. void TorchAsrModel::InitEngineThreads(int num_threads) {
  28. // For multi-thread performance
  29. at::set_num_threads(num_threads);
  30. VLOG(1) << "Num intra-op threads: " << at::get_num_threads();
  31. }
  32. #endif
  33. void TorchAsrModel::Read(const std::string& model_path) {
  34. torch::DeviceType device = at::kCPU;
  35. #ifdef USE_GPU
  36. if (!torch::cuda::is_available()) {
  37. VLOG(1) << "CUDA is not available! Please check your GPU settings";
  38. throw std::runtime_error("CUDA is not available!");
  39. } else {
  40. VLOG(1) << "CUDA available! Running on GPU";
  41. device = at::kCUDA;
  42. }
  43. #endif
  44. #ifdef USE_IPEX
  45. torch::jit::setTensorExprFuserEnabled(false);
  46. #endif
  47. torch::jit::script::Module model = torch::jit::load(model_path, device);
  48. model_ = std::make_shared<TorchModule>(std::move(model));
  49. torch::NoGradGuard no_grad;
  50. model_->eval();
  51. torch::jit::IValue o1 = model_->run_method("subsampling_rate");
  52. CHECK_EQ(o1.isInt(), true);
  53. subsampling_rate_ = o1.toInt();
  54. torch::jit::IValue o2 = model_->run_method("right_context");
  55. CHECK_EQ(o2.isInt(), true);
  56. right_context_ = o2.toInt();
  57. torch::jit::IValue o3 = model_->run_method("sos_symbol");
  58. CHECK_EQ(o3.isInt(), true);
  59. sos_ = o3.toInt();
  60. torch::jit::IValue o4 = model_->run_method("eos_symbol");
  61. CHECK_EQ(o4.isInt(), true);
  62. eos_ = o4.toInt();
  63. torch::jit::IValue o5 = model_->run_method("is_bidirectional_decoder");
  64. CHECK_EQ(o5.isBool(), true);
  65. is_bidirectional_decoder_ = o5.toBool();
  66. torch::jit::setGraphExecutorOptimize(false);
  67. torch::jit::FusionStrategy static0 = {
  68. {torch::jit::FusionBehavior::STATIC, 0}};
  69. torch::jit::setFusionStrategy(static0);
  70. VLOG(1) << "Torch Model Info:";
  71. VLOG(1) << "\tsubsampling_rate " << subsampling_rate_;
  72. VLOG(1) << "\tright context " << right_context_;
  73. VLOG(1) << "\tsos " << sos_;
  74. VLOG(1) << "\teos " << eos_;
  75. VLOG(1) << "\tis bidirectional decoder " << is_bidirectional_decoder_;
  76. }
  77. TorchAsrModel::TorchAsrModel(const TorchAsrModel& other) {
  78. // 1. Init the model info
  79. right_context_ = other.right_context_;
  80. subsampling_rate_ = other.subsampling_rate_;
  81. sos_ = other.sos_;
  82. eos_ = other.eos_;
  83. is_bidirectional_decoder_ = other.is_bidirectional_decoder_;
  84. chunk_size_ = other.chunk_size_;
  85. num_left_chunks_ = other.num_left_chunks_;
  86. offset_ = other.offset_;
  87. // 2. Model copy, just copy the model ptr since:
  88. // PyTorch allows using multiple CPU threads during TorchScript model
  89. // inference, please see https://pytorch.org/docs/stable/notes/cpu_
  90. // threading_torchscript_inference.html
  91. model_ = other.model_;
  92. // NOTE(Binbin Zhang):
  93. // inner states for forward are not copied here.
  94. }
  95. std::shared_ptr<AsrModel> TorchAsrModel::Copy() const {
  96. auto asr_model = std::make_shared<TorchAsrModel>(*this);
  97. // Reset the inner states for new decoding
  98. asr_model->Reset();
  99. return asr_model;
  100. }
  101. void TorchAsrModel::Reset() {
  102. offset_ = 0;
  103. att_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
  104. cnn_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
  105. encoder_outs_.clear();
  106. cached_feature_.clear();
  107. }
  108. void TorchAsrModel::ForwardEncoderFunc(
  109. const std::vector<std::vector<float>>& chunk_feats,
  110. std::vector<std::vector<float>>* out_prob) {
  111. // 1. Prepare libtorch required data, splice cached_feature_ and chunk_feats
  112. // The first dimension is for batchsize, which is 1.
  113. int num_frames = cached_feature_.size() + chunk_feats.size();
  114. const int feature_dim = chunk_feats[0].size();
  115. torch::Tensor feats =
  116. torch::zeros({1, num_frames, feature_dim}, torch::kFloat);
  117. for (size_t i = 0; i < cached_feature_.size(); ++i) {
  118. torch::Tensor row =
  119. torch::from_blob(const_cast<float*>(cached_feature_[i].data()),
  120. {feature_dim}, torch::kFloat)
  121. .clone();
  122. feats[0][i] = std::move(row);
  123. }
  124. for (size_t i = 0; i < chunk_feats.size(); ++i) {
  125. torch::Tensor row =
  126. torch::from_blob(const_cast<float*>(chunk_feats[i].data()),
  127. {feature_dim}, torch::kFloat)
  128. .clone();
  129. feats[0][cached_feature_.size() + i] = std::move(row);
  130. }
  131. // 2. Encoder chunk forward
  132. #ifdef USE_GPU
  133. feats = feats.to(at::kCUDA);
  134. att_cache_ = att_cache_.to(at::kCUDA);
  135. cnn_cache_ = cnn_cache_.to(at::kCUDA);
  136. #endif
  137. int required_cache_size = chunk_size_ * num_left_chunks_;
  138. torch::NoGradGuard no_grad;
  139. std::vector<torch::jit::IValue> inputs = {feats, offset_, required_cache_size,
  140. att_cache_, cnn_cache_};
  141. // Refer interfaces in wenet/transformer/asr_model.py
  142. auto outputs =
  143. model_->get_method("forward_encoder_chunk")(inputs).toTuple()->elements();
  144. CHECK_EQ(outputs.size(), 3);
  145. #ifdef USE_GPU
  146. torch::Tensor chunk_out = outputs[0].toTensor().to(at::kCPU);
  147. att_cache_ = outputs[1].toTensor().to(at::kCPU);
  148. cnn_cache_ = outputs[2].toTensor().to(at::kCPU);
  149. #else
  150. torch::Tensor chunk_out = outputs[0].toTensor();
  151. att_cache_ = outputs[1].toTensor();
  152. cnn_cache_ = outputs[2].toTensor();
  153. #endif
  154. offset_ += chunk_out.size(1);
  155. // The first dimension of returned value is for batchsize, which is 1
  156. #ifdef USE_GPU
  157. chunk_out = chunk_out.to(at::kCUDA);
  158. torch::Tensor ctc_log_probs =
  159. model_->run_method("ctc_activation", chunk_out).toTensor();
  160. ctc_log_probs = ctc_log_probs.to(at::kCPU)[0];
  161. encoder_outs_.push_back(std::move(chunk_out.to(at::kCPU)));
  162. #else
  163. torch::Tensor ctc_log_probs =
  164. model_->run_method("ctc_activation", chunk_out).toTensor()[0];
  165. encoder_outs_.push_back(std::move(chunk_out));
  166. #endif
  167. // Copy to output
  168. int num_outputs = ctc_log_probs.size(0);
  169. int output_dim = ctc_log_probs.size(1);
  170. out_prob->resize(num_outputs);
  171. for (int i = 0; i < num_outputs; i++) {
  172. (*out_prob)[i].resize(output_dim);
  173. memcpy((*out_prob)[i].data(), ctc_log_probs[i].data_ptr(),
  174. sizeof(float) * output_dim);
  175. }
  176. }
  177. float TorchAsrModel::ComputeAttentionScore(const torch::Tensor& prob,
  178. const std::vector<int>& hyp,
  179. int eos) {
  180. float score = 0.0f;
  181. auto accessor = prob.accessor<float, 2>();
  182. for (size_t j = 0; j < hyp.size(); ++j) {
  183. score += accessor[j][hyp[j]];
  184. }
  185. score += accessor[hyp.size()][eos];
  186. return score;
  187. }
  188. void TorchAsrModel::AttentionRescoring(
  189. const std::vector<std::vector<int>>& hyps, float reverse_weight,
  190. std::vector<float>* rescoring_score) {
  191. CHECK(rescoring_score != nullptr);
  192. int num_hyps = hyps.size();
  193. rescoring_score->resize(num_hyps, 0.0f);
  194. if (num_hyps == 0) {
  195. return;
  196. }
  197. // No encoder output
  198. if (encoder_outs_.size() == 0) {
  199. return;
  200. }
  201. torch::NoGradGuard no_grad;
  202. // Step 1: Prepare input for libtorch
  203. torch::Tensor hyps_length = torch::zeros({num_hyps}, torch::kLong);
  204. int max_hyps_len = 0;
  205. for (size_t i = 0; i < num_hyps; ++i) {
  206. int length = hyps[i].size() + 1;
  207. max_hyps_len = std::max(length, max_hyps_len);
  208. hyps_length[i] = static_cast<int64_t>(length);
  209. }
  210. torch::Tensor hyps_tensor =
  211. torch::zeros({num_hyps, max_hyps_len}, torch::kLong);
  212. for (size_t i = 0; i < num_hyps; ++i) {
  213. const std::vector<int>& hyp = hyps[i];
  214. hyps_tensor[i][0] = sos_;
  215. for (size_t j = 0; j < hyp.size(); ++j) {
  216. hyps_tensor[i][j + 1] = hyp[j];
  217. }
  218. }
  219. // Step 2: Forward attention decoder by hyps and corresponding encoder_outs_
  220. torch::Tensor encoder_out = torch::cat(encoder_outs_, 1);
  221. #ifdef USE_GPU
  222. hyps_tensor = hyps_tensor.to(at::kCUDA);
  223. hyps_length = hyps_length.to(at::kCUDA);
  224. encoder_out = encoder_out.to(at::kCUDA);
  225. #endif
  226. auto outputs = model_
  227. ->run_method("forward_attention_decoder", hyps_tensor,
  228. hyps_length, encoder_out, reverse_weight)
  229. .toTuple()
  230. ->elements();
  231. #ifdef USE_GPU
  232. auto probs = outputs[0].toTensor().to(at::kCPU);
  233. auto r_probs = outputs[1].toTensor().to(at::kCPU);
  234. #else
  235. auto probs = outputs[0].toTensor();
  236. auto r_probs = outputs[1].toTensor();
  237. #endif
  238. CHECK_EQ(probs.size(0), num_hyps);
  239. CHECK_EQ(probs.size(1), max_hyps_len);
  240. // Step 3: Compute rescoring score
  241. for (size_t i = 0; i < num_hyps; ++i) {
  242. const std::vector<int>& hyp = hyps[i];
  243. float score = 0.0f;
  244. // left-to-right decoder score
  245. score = ComputeAttentionScore(probs[i], hyp, eos_);
  246. // Optional: Used for right to left score
  247. float r_score = 0.0f;
  248. if (is_bidirectional_decoder_ && reverse_weight > 0) {
  249. // right-to-left score
  250. CHECK_EQ(r_probs.size(0), num_hyps);
  251. CHECK_EQ(r_probs.size(1), max_hyps_len);
  252. std::vector<int> r_hyp(hyp.size());
  253. std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
  254. // right to left decoder score
  255. r_score = ComputeAttentionScore(r_probs[i], r_hyp, eos_);
  256. }
  257. // combined left-to-right and right-to-left score
  258. (*rescoring_score)[i] =
  259. score * (1 - reverse_weight) + r_score * reverse_weight;
  260. }
  261. }
  262. } // namespace wenet