You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

430 lines
16 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 ZeXuan Li (lizexuan@huya.com)
  3. // Xingchen Song(sxc19@mails.tsinghua.edu.cn)
  4. // hamddct@gmail.com (Mddct)
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. // See the License for the specific language governing permissions and
  16. // limitations under the License.
  17. #include "decoder/onnx_asr_model.h"
  18. #include <algorithm>
  19. #include <memory>
  20. #include <utility>
  21. #include "utils/wn_string.h"
  22. namespace wenet {
  23. Ort::Env OnnxAsrModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "");
  24. Ort::SessionOptions OnnxAsrModel::session_options_ = Ort::SessionOptions();
  25. void OnnxAsrModel::InitEngineThreads(int num_threads) {
  26. session_options_.SetIntraOpNumThreads(num_threads);
  27. }
  28. void OnnxAsrModel::GetInputOutputInfo(
  29. const std::shared_ptr<Ort::Session>& session,
  30. std::vector<const char*>* in_names, std::vector<const char*>* out_names) {
  31. Ort::AllocatorWithDefaultOptions allocator;
  32. // Input info
  33. int num_nodes = session->GetInputCount();
  34. in_names->resize(num_nodes);
  35. for (int i = 0; i < num_nodes; ++i) {
  36. char* name = session->GetInputName(i, allocator);
  37. Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
  38. auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
  39. ONNXTensorElementDataType type = tensor_info.GetElementType();
  40. std::vector<int64_t> node_dims = tensor_info.GetShape();
  41. std::stringstream shape;
  42. for (auto j : node_dims) {
  43. shape << j;
  44. shape << " ";
  45. }
  46. LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type
  47. << " dims=" << shape.str();
  48. (*in_names)[i] = name;
  49. }
  50. // Output info
  51. num_nodes = session->GetOutputCount();
  52. out_names->resize(num_nodes);
  53. for (int i = 0; i < num_nodes; ++i) {
  54. char* name = session->GetOutputName(i, allocator);
  55. Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
  56. auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
  57. ONNXTensorElementDataType type = tensor_info.GetElementType();
  58. std::vector<int64_t> node_dims = tensor_info.GetShape();
  59. std::stringstream shape;
  60. for (auto j : node_dims) {
  61. shape << j;
  62. shape << " ";
  63. }
  64. LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type
  65. << " dims=" << shape.str();
  66. (*out_names)[i] = name;
  67. }
  68. }
  69. void OnnxAsrModel::Read(const std::string& model_dir) {
  70. std::string encoder_onnx_path = model_dir + "/encoder.onnx";
  71. std::string rescore_onnx_path = model_dir + "/decoder.onnx";
  72. std::string ctc_onnx_path = model_dir + "/ctc.onnx";
  73. // 1. Load sessions
  74. try {
  75. #ifdef _MSC_VER
  76. encoder_session_ = std::make_shared<Ort::Session>(
  77. env_, ToWString(encoder_onnx_path).c_str(), session_options_);
  78. rescore_session_ = std::make_shared<Ort::Session>(
  79. env_, ToWString(rescore_onnx_path).c_str(), session_options_);
  80. ctc_session_ = std::make_shared<Ort::Session>(
  81. env_, ToWString(ctc_onnx_path).c_str(), session_options_);
  82. #else
  83. encoder_session_ = std::make_shared<Ort::Session>(
  84. env_, encoder_onnx_path.c_str(), session_options_);
  85. rescore_session_ = std::make_shared<Ort::Session>(
  86. env_, rescore_onnx_path.c_str(), session_options_);
  87. ctc_session_ = std::make_shared<Ort::Session>(env_, ctc_onnx_path.c_str(),
  88. session_options_);
  89. #endif
  90. } catch (std::exception const& e) {
  91. LOG(ERROR) << "error when load onnx model: " << e.what();
  92. exit(0);
  93. }
  94. // 2. Read metadata
  95. auto model_metadata = encoder_session_->GetModelMetadata();
  96. Ort::AllocatorWithDefaultOptions allocator;
  97. encoder_output_size_ =
  98. atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator));
  99. num_blocks_ =
  100. atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator));
  101. head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator));
  102. cnn_module_kernel_ = atoi(
  103. model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator));
  104. subsampling_rate_ = atoi(
  105. model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator));
  106. right_context_ =
  107. atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator));
  108. sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator));
  109. eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator));
  110. is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMap(
  111. "is_bidirectional_decoder", allocator));
  112. chunk_size_ =
  113. atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator));
  114. num_left_chunks_ =
  115. atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator));
  116. LOG(INFO) << "Onnx Model Info:";
  117. LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
  118. LOG(INFO) << "\tnum_blocks " << num_blocks_;
  119. LOG(INFO) << "\thead " << head_;
  120. LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_;
  121. LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
  122. LOG(INFO) << "\tright_context " << right_context_;
  123. LOG(INFO) << "\tsos " << sos_;
  124. LOG(INFO) << "\teos " << eos_;
  125. LOG(INFO) << "\tis bidirectional decoder " << is_bidirectional_decoder_;
  126. LOG(INFO) << "\tchunk_size " << chunk_size_;
  127. LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_;
  128. // 3. Read model nodes
  129. LOG(INFO) << "Onnx Encoder:";
  130. GetInputOutputInfo(encoder_session_, &encoder_in_names_, &encoder_out_names_);
  131. LOG(INFO) << "Onnx CTC:";
  132. GetInputOutputInfo(ctc_session_, &ctc_in_names_, &ctc_out_names_);
  133. LOG(INFO) << "Onnx Rescore:";
  134. GetInputOutputInfo(rescore_session_, &rescore_in_names_, &rescore_out_names_);
  135. }
  136. OnnxAsrModel::OnnxAsrModel(const OnnxAsrModel& other) {
  137. // metadatas
  138. encoder_output_size_ = other.encoder_output_size_;
  139. num_blocks_ = other.num_blocks_;
  140. head_ = other.head_;
  141. cnn_module_kernel_ = other.cnn_module_kernel_;
  142. right_context_ = other.right_context_;
  143. subsampling_rate_ = other.subsampling_rate_;
  144. sos_ = other.sos_;
  145. eos_ = other.eos_;
  146. is_bidirectional_decoder_ = other.is_bidirectional_decoder_;
  147. chunk_size_ = other.chunk_size_;
  148. num_left_chunks_ = other.num_left_chunks_;
  149. offset_ = other.offset_;
  150. // sessions
  151. encoder_session_ = other.encoder_session_;
  152. ctc_session_ = other.ctc_session_;
  153. rescore_session_ = other.rescore_session_;
  154. // node names
  155. encoder_in_names_ = other.encoder_in_names_;
  156. encoder_out_names_ = other.encoder_out_names_;
  157. ctc_in_names_ = other.ctc_in_names_;
  158. ctc_out_names_ = other.ctc_out_names_;
  159. rescore_in_names_ = other.rescore_in_names_;
  160. rescore_out_names_ = other.rescore_out_names_;
  161. }
  162. std::shared_ptr<AsrModel> OnnxAsrModel::Copy() const {
  163. auto asr_model = std::make_shared<OnnxAsrModel>(*this);
  164. // Reset the inner states for new decoding
  165. asr_model->Reset();
  166. return asr_model;
  167. }
  168. void OnnxAsrModel::Reset() {
  169. offset_ = 0;
  170. encoder_outs_.clear();
  171. cached_feature_.clear();
  172. // Reset att_cache
  173. Ort::MemoryInfo memory_info =
  174. Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  175. if (num_left_chunks_ > 0) {
  176. int required_cache_size = chunk_size_ * num_left_chunks_;
  177. offset_ = required_cache_size;
  178. att_cache_.resize(num_blocks_ * head_ * required_cache_size *
  179. encoder_output_size_ / head_ * 2,
  180. 0.0);
  181. const int64_t att_cache_shape[] = {num_blocks_, head_, required_cache_size,
  182. encoder_output_size_ / head_ * 2};
  183. att_cache_ort_ = Ort::Value::CreateTensor<float>(
  184. memory_info, att_cache_.data(), att_cache_.size(), att_cache_shape, 4);
  185. } else {
  186. att_cache_.resize(0, 0.0);
  187. const int64_t att_cache_shape[] = {num_blocks_, head_, 0,
  188. encoder_output_size_ / head_ * 2};
  189. att_cache_ort_ = Ort::Value::CreateTensor<float>(
  190. memory_info, att_cache_.data(), att_cache_.size(), att_cache_shape, 4);
  191. }
  192. // Reset cnn_cache
  193. cnn_cache_.resize(
  194. num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
  195. const int64_t cnn_cache_shape[] = {num_blocks_, 1, encoder_output_size_,
  196. cnn_module_kernel_ - 1};
  197. cnn_cache_ort_ = Ort::Value::CreateTensor<float>(
  198. memory_info, cnn_cache_.data(), cnn_cache_.size(), cnn_cache_shape, 4);
  199. }
  200. void OnnxAsrModel::ForwardEncoderFunc(
  201. const std::vector<std::vector<float>>& chunk_feats,
  202. std::vector<std::vector<float>>* out_prob) {
  203. Ort::MemoryInfo memory_info =
  204. Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  205. // 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
  206. // chunk
  207. int num_frames = cached_feature_.size() + chunk_feats.size();
  208. const int feature_dim = chunk_feats[0].size();
  209. std::vector<float> feats;
  210. for (size_t i = 0; i < cached_feature_.size(); ++i) {
  211. feats.insert(feats.end(), cached_feature_[i].begin(),
  212. cached_feature_[i].end());
  213. }
  214. for (size_t i = 0; i < chunk_feats.size(); ++i) {
  215. feats.insert(feats.end(), chunk_feats[i].begin(), chunk_feats[i].end());
  216. }
  217. const int64_t feats_shape[3] = {1, num_frames, feature_dim};
  218. Ort::Value feats_ort = Ort::Value::CreateTensor<float>(
  219. memory_info, feats.data(), feats.size(), feats_shape, 3);
  220. // offset
  221. int64_t offset_int64 = static_cast<int64_t>(offset_);
  222. Ort::Value offset_ort = Ort::Value::CreateTensor<int64_t>(
  223. memory_info, &offset_int64, 1, std::vector<int64_t>{}.data(), 0);
  224. // required_cache_size
  225. int64_t required_cache_size = chunk_size_ * num_left_chunks_;
  226. Ort::Value required_cache_size_ort = Ort::Value::CreateTensor<int64_t>(
  227. memory_info, &required_cache_size, 1, std::vector<int64_t>{}.data(), 0);
  228. // att_mask
  229. Ort::Value att_mask_ort{nullptr};
  230. std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
  231. if (num_left_chunks_ > 0) {
  232. int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
  233. if (chunk_idx < num_left_chunks_) {
  234. for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
  235. att_mask[i] = 0;
  236. }
  237. }
  238. const int64_t att_mask_shape[] = {1, 1, required_cache_size + chunk_size_};
  239. att_mask_ort = Ort::Value::CreateTensor<bool>(
  240. memory_info, reinterpret_cast<bool*>(att_mask.data()), att_mask.size(),
  241. att_mask_shape, 3);
  242. }
  243. // 2. Encoder chunk forward
  244. std::vector<Ort::Value> inputs;
  245. for (auto name : encoder_in_names_) {
  246. if (!strcmp(name, "chunk")) {
  247. inputs.emplace_back(std::move(feats_ort));
  248. } else if (!strcmp(name, "offset")) {
  249. inputs.emplace_back(std::move(offset_ort));
  250. } else if (!strcmp(name, "required_cache_size")) {
  251. inputs.emplace_back(std::move(required_cache_size_ort));
  252. } else if (!strcmp(name, "att_cache")) {
  253. inputs.emplace_back(std::move(att_cache_ort_));
  254. } else if (!strcmp(name, "cnn_cache")) {
  255. inputs.emplace_back(std::move(cnn_cache_ort_));
  256. } else if (!strcmp(name, "att_mask")) {
  257. inputs.emplace_back(std::move(att_mask_ort));
  258. }
  259. }
  260. std::vector<Ort::Value> ort_outputs = encoder_session_->Run(
  261. Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(),
  262. inputs.size(), encoder_out_names_.data(), encoder_out_names_.size());
  263. offset_ += static_cast<int>(
  264. ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape()[1]);
  265. att_cache_ort_ = std::move(ort_outputs[1]);
  266. cnn_cache_ort_ = std::move(ort_outputs[2]);
  267. std::vector<Ort::Value> ctc_inputs;
  268. ctc_inputs.emplace_back(std::move(ort_outputs[0]));
  269. std::vector<Ort::Value> ctc_ort_outputs = ctc_session_->Run(
  270. Ort::RunOptions{nullptr}, ctc_in_names_.data(), ctc_inputs.data(),
  271. ctc_inputs.size(), ctc_out_names_.data(), ctc_out_names_.size());
  272. encoder_outs_.push_back(std::move(ctc_inputs[0]));
  273. float* logp_data = ctc_ort_outputs[0].GetTensorMutableData<float>();
  274. auto type_info = ctc_ort_outputs[0].GetTensorTypeAndShapeInfo();
  275. int num_outputs = type_info.GetShape()[1];
  276. int output_dim = type_info.GetShape()[2];
  277. out_prob->resize(num_outputs);
  278. for (int i = 0; i < num_outputs; i++) {
  279. (*out_prob)[i].resize(output_dim);
  280. memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
  281. sizeof(float) * output_dim);
  282. }
  283. }
  284. float OnnxAsrModel::ComputeAttentionScore(const float* prob,
  285. const std::vector<int>& hyp, int eos,
  286. int decode_out_len) {
  287. float score = 0.0f;
  288. for (size_t j = 0; j < hyp.size(); ++j) {
  289. score += *(prob + j * decode_out_len + hyp[j]);
  290. }
  291. score += *(prob + hyp.size() * decode_out_len + eos);
  292. return score;
  293. }
  294. void OnnxAsrModel::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
  295. float reverse_weight,
  296. std::vector<float>* rescoring_score) {
  297. Ort::MemoryInfo memory_info =
  298. Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
  299. CHECK(rescoring_score != nullptr);
  300. int num_hyps = hyps.size();
  301. rescoring_score->resize(num_hyps, 0.0f);
  302. if (num_hyps == 0) {
  303. return;
  304. }
  305. // No encoder output
  306. if (encoder_outs_.size() == 0) {
  307. return;
  308. }
  309. std::vector<int64_t> hyps_lens;
  310. int max_hyps_len = 0;
  311. for (size_t i = 0; i < num_hyps; ++i) {
  312. int length = hyps[i].size() + 1;
  313. max_hyps_len = std::max(length, max_hyps_len);
  314. hyps_lens.emplace_back(static_cast<int64_t>(length));
  315. }
  316. std::vector<float> rescore_input;
  317. int encoder_len = 0;
  318. for (int i = 0; i < encoder_outs_.size(); i++) {
  319. float* encoder_outs_data = encoder_outs_[i].GetTensorMutableData<float>();
  320. auto type_info = encoder_outs_[i].GetTensorTypeAndShapeInfo();
  321. for (int j = 0; j < type_info.GetElementCount(); j++) {
  322. rescore_input.emplace_back(encoder_outs_data[j]);
  323. }
  324. encoder_len += type_info.GetShape()[1];
  325. }
  326. const int64_t decode_input_shape[] = {1, encoder_len, encoder_output_size_};
  327. std::vector<int64_t> hyps_pad;
  328. for (size_t i = 0; i < num_hyps; ++i) {
  329. const std::vector<int>& hyp = hyps[i];
  330. hyps_pad.emplace_back(sos_);
  331. size_t j = 0;
  332. for (; j < hyp.size(); ++j) {
  333. hyps_pad.emplace_back(hyp[j]);
  334. }
  335. if (j == max_hyps_len - 1) {
  336. continue;
  337. }
  338. for (; j < max_hyps_len - 1; ++j) {
  339. hyps_pad.emplace_back(0);
  340. }
  341. }
  342. const int64_t hyps_pad_shape[] = {num_hyps, max_hyps_len};
  343. const int64_t hyps_lens_shape[] = {num_hyps};
  344. Ort::Value decode_input_tensor_ = Ort::Value::CreateTensor<float>(
  345. memory_info, rescore_input.data(), rescore_input.size(),
  346. decode_input_shape, 3);
  347. Ort::Value hyps_pad_tensor_ = Ort::Value::CreateTensor<int64_t>(
  348. memory_info, hyps_pad.data(), hyps_pad.size(), hyps_pad_shape, 2);
  349. Ort::Value hyps_lens_tensor_ = Ort::Value::CreateTensor<int64_t>(
  350. memory_info, hyps_lens.data(), hyps_lens.size(), hyps_lens_shape, 1);
  351. std::vector<Ort::Value> rescore_inputs;
  352. rescore_inputs.emplace_back(std::move(hyps_pad_tensor_));
  353. rescore_inputs.emplace_back(std::move(hyps_lens_tensor_));
  354. rescore_inputs.emplace_back(std::move(decode_input_tensor_));
  355. std::vector<Ort::Value> rescore_outputs = rescore_session_->Run(
  356. Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(),
  357. rescore_inputs.size(), rescore_out_names_.data(),
  358. rescore_out_names_.size());
  359. float* decoder_outs_data = rescore_outputs[0].GetTensorMutableData<float>();
  360. float* r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData<float>();
  361. auto type_info = rescore_outputs[0].GetTensorTypeAndShapeInfo();
  362. int decode_out_len = type_info.GetShape()[2];
  363. for (size_t i = 0; i < num_hyps; ++i) {
  364. const std::vector<int>& hyp = hyps[i];
  365. float score = 0.0f;
  366. // left to right decoder score
  367. score = ComputeAttentionScore(
  368. decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_,
  369. decode_out_len);
  370. // Optional: Used for right to left score
  371. float r_score = 0.0f;
  372. if (is_bidirectional_decoder_ && reverse_weight > 0) {
  373. std::vector<int> r_hyp(hyp.size());
  374. std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
  375. // right to left decoder score
  376. r_score = ComputeAttentionScore(
  377. r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_,
  378. decode_out_len);
  379. }
  380. // combined left-to-right and right-to-left score
  381. (*rescoring_score)[i] =
  382. score * (1 - reverse_weight) + r_score * reverse_weight;
  383. }
  384. }
  385. } // namespace wenet