You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
11 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. // 2022 Binbin Zhang (binbzha@qq.com)
  3. //
  4. // Licensed under the Apache License, Version 2.0 (the "License");
  5. // you may not use this file except in compliance with the License.
  6. // You may obtain a copy of the License at
  7. //
  8. // http://www.apache.org/licenses/LICENSE-2.0
  9. //
  10. // Unless required by applicable law or agreed to in writing, software
  11. // distributed under the License is distributed on an "AS IS" BASIS,
  12. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. // See the License for the specific language governing permissions and
  14. // limitations under the License.
  15. #ifndef DECODER_PARAMS_H_
  16. #define DECODER_PARAMS_H_
  17. #include <memory>
  18. #include <string>
  19. #include <utility>
  20. #include <vector>
  21. #include "decoder/asr_decoder.h"
  22. #ifdef USE_ONNX
  23. #include "decoder/onnx_asr_model.h"
  24. #endif
  25. #ifdef USE_TORCH
  26. #include "decoder/torch_asr_model.h"
  27. #endif
  28. #ifdef USE_XPU
  29. #include "xpu/xpu_asr_model.h"
  30. #endif
  31. #ifdef USE_BPU
  32. #include "bpu/bpu_asr_model.h"
  33. #endif
  34. #ifdef USE_OPENVINO
  35. #include "ov/ov_asr_model.h"
  36. #endif
  37. #include "frontend/feature_pipeline.h"
  38. #include "../post_processor/processor/post_processor.h"
  39. #include "utils/file.h"
  40. #include "utils/flags.h"
  41. #include "utils/string.h"
  42. DEFINE_int32(device_id, 0, "set XPU DeviceID for ASR model");
  43. // TorchAsrModel flags
  44. DEFINE_string(model_path, "", "pytorch exported model path");
  45. // OnnxAsrModel flags
  46. DEFINE_string(onnx_dir, "", "directory where the onnx model is saved");
  47. // XPUAsrModel flags
  48. DEFINE_string(xpu_model_dir, "",
  49. "directory where the XPU model and weights is saved");
  50. // BPUAsrModel flags
  51. DEFINE_string(bpu_model_dir, "",
  52. "directory where the HORIZON BPU model is saved");
  53. // OVAsrModel flags
  54. DEFINE_string(openvino_dir, "", "directory where the OV model is saved");
  55. DEFINE_int32(core_number, 1, "Core number of process");
  56. // FeaturePipelineConfig flags
  57. DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
  58. DEFINE_int32(sample_rate, 16000, "sample rate for audio");
  59. DEFINE_string(feat_type, "kaldi", "Type of feature extraction: kaldi, whisper");
  60. // TLG fst
  61. DEFINE_string(fst_path, "", "TLG fst path");
  62. // ITN fst
  63. DEFINE_string(itn_model_dir, "",
  64. "fst based ITN model dir, "
  65. "should contain itn_tagger.fst and itn_verbalizer.fst");
  66. // DecodeOptions flags
  67. DEFINE_int32(chunk_size, 16, "decoding chunk size");
  68. DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
  69. DEFINE_double(ctc_weight, 0.5,
  70. "ctc weight when combining ctc score and rescoring score");
  71. DEFINE_double(rescoring_weight, 1.0,
  72. "rescoring weight when combining ctc score and rescoring score");
  73. DEFINE_double(reverse_weight, 0.0,
  74. "used for bitransformer rescoring. it must be 0.0 if decoder is"
  75. "conventional transformer decoder, and only reverse_weight > 0.0"
  76. "dose the right to left decoder will be calculated and used");
  77. DEFINE_int32(max_active, 7000, "max active states in ctc wfst search");
  78. DEFINE_int32(min_active, 200, "min active states in ctc wfst search");
  79. DEFINE_double(beam, 16.0, "beam in ctc wfst search");
  80. DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
  81. DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
  82. DEFINE_int32(blank_id, 0,
  83. "blank token idx for ctc wfst search and ctc prefix beam search");
  84. DEFINE_double(blank_skip_thresh, 1.0,
  85. "blank skip thresh for ctc wfst search, 1.0 means no skip");
  86. DEFINE_double(blank_scale, 1.0, "blank scale for ctc wfst search");
  87. DEFINE_double(length_penalty, 0.0,
  88. "length penalty ctc wfst search, will not"
  89. "apply on self-loop arc, for balancing the del/ins ratio, "
  90. "suggest set to -3.0");
  91. DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
  92. // SymbolTable flags
  93. DEFINE_string(dict_path, "",
  94. "dict symbol table path, required when LM is enabled");
  95. DEFINE_string(unit_path, "",
  96. "e2e model unit symbol table, it is used in both "
  97. "with/without LM scenarios for context/timestamp");
  98. // Context flags
  99. DEFINE_string(context_path, "", "context path, is used to build context graph");
  100. DEFINE_double(context_score, 3.0, "is used to rescore the decoded result");
  101. // PostProcessOptions flags
  102. DEFINE_int32(language_type, 0,
  103. "remove spaces according to language type"
  104. "0x00 = kMandarinEnglish, "
  105. "0x01 = kIndoEuropean");
  106. DEFINE_bool(lowercase, true, "lowercase final result if needed");
  107. namespace wenet {
  108. FeatureType StringToFeatureType(const std::string& feat_type_str) {
  109. if (feat_type_str == "kaldi")
  110. return FeatureType::kKaldi;
  111. else if (feat_type_str == "whisper")
  112. return FeatureType::kWhisper;
  113. else
  114. throw std::invalid_argument("Unsupported feat type!");
  115. }
  116. std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
  117. FeatureType feat_type = StringToFeatureType(FLAGS_feat_type);
  118. auto feature_config = std::make_shared<FeaturePipelineConfig>(
  119. FLAGS_num_bins, FLAGS_sample_rate, feat_type);
  120. return feature_config;
  121. }
  122. std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
  123. auto decode_config = std::make_shared<DecodeOptions>();
  124. decode_config->chunk_size = FLAGS_chunk_size;
  125. decode_config->num_left_chunks = FLAGS_num_left_chunks;
  126. decode_config->ctc_weight = FLAGS_ctc_weight;
  127. decode_config->reverse_weight = FLAGS_reverse_weight;
  128. decode_config->rescoring_weight = FLAGS_rescoring_weight;
  129. decode_config->ctc_wfst_search_opts.max_active = FLAGS_max_active;
  130. decode_config->ctc_wfst_search_opts.min_active = FLAGS_min_active;
  131. decode_config->ctc_wfst_search_opts.beam = FLAGS_beam;
  132. decode_config->ctc_wfst_search_opts.lattice_beam = FLAGS_lattice_beam;
  133. decode_config->ctc_wfst_search_opts.acoustic_scale = FLAGS_acoustic_scale;
  134. decode_config->ctc_wfst_search_opts.blank = FLAGS_blank_id;
  135. decode_config->ctc_wfst_search_opts.blank_skip_thresh =
  136. FLAGS_blank_skip_thresh;
  137. decode_config->ctc_wfst_search_opts.blank_scale = FLAGS_blank_scale;
  138. decode_config->ctc_wfst_search_opts.length_penalty = FLAGS_length_penalty;
  139. decode_config->ctc_wfst_search_opts.nbest = FLAGS_nbest;
  140. decode_config->ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
  141. decode_config->ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
  142. decode_config->ctc_prefix_search_opts.blank = FLAGS_blank_id;
  143. decode_config->ctc_endpoint_config.blank = FLAGS_blank_id;
  144. return decode_config;
  145. }
  146. std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
  147. auto resource = std::make_shared<DecodeResource>();
  148. const int kNumGemmThreads = 1;
  149. if (!FLAGS_onnx_dir.empty()) {
  150. #ifdef USE_ONNX
  151. LOG(INFO) << "Reading onnx model ";
  152. OnnxAsrModel::InitEngineThreads(kNumGemmThreads);
  153. auto model = std::make_shared<OnnxAsrModel>();
  154. model->Read(FLAGS_onnx_dir);
  155. resource->model = model;
  156. #else
  157. LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DONNX=ON'.";
  158. #endif
  159. } else if (!FLAGS_model_path.empty()) {
  160. #ifdef USE_TORCH
  161. LOG(INFO) << "Reading torch model " << FLAGS_model_path;
  162. TorchAsrModel::InitEngineThreads(kNumGemmThreads);
  163. auto model = std::make_shared<TorchAsrModel>();
  164. model->Read(FLAGS_model_path);
  165. resource->model = model;
  166. #else
  167. LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DTORCH=ON'.";
  168. #endif
  169. } else if (!FLAGS_xpu_model_dir.empty()) {
  170. #ifdef USE_XPU
  171. LOG(INFO) << "Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir;
  172. auto model = std::make_shared<XPUAsrModel>();
  173. model->SetEngineThreads(kNumGemmThreads);
  174. model->SetDeviceId(FLAGS_device_id);
  175. model->Read(FLAGS_xpu_model_dir);
  176. resource->model = model;
  177. #else
  178. LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DXPU=ON'.";
  179. #endif
  180. } else if (!FLAGS_bpu_model_dir.empty()) {
  181. #ifdef USE_BPU
  182. LOG(INFO) << "Reading Horizon BPU model from " << FLAGS_bpu_model_dir;
  183. auto model = std::make_shared<BPUAsrModel>();
  184. model->Read(FLAGS_bpu_model_dir);
  185. resource->model = model;
  186. #else
  187. LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DBPU=ON'.";
  188. #endif
  189. } else if (!FLAGS_openvino_dir.empty()) {
  190. #ifdef USE_OPENVINO
  191. LOG(INFO) << "Read OpenVINO model ";
  192. auto model = std::make_shared<OVAsrModel>();
  193. model->InitEngineThreads(FLAGS_core_number);
  194. model->Read(FLAGS_openvino_dir);
  195. resource->model = model;
  196. #else
  197. LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DOPENVINO=ON'.";
  198. #endif
  199. } else {
  200. LOG(FATAL) << "Please set ONNX, TORCH, XPU, BPU or OpenVINO model path!!!";
  201. }
  202. LOG(INFO) << "Reading unit table " << FLAGS_unit_path;
  203. auto unit_table = std::shared_ptr<fst::SymbolTable>(
  204. fst::SymbolTable::ReadText(FLAGS_unit_path));
  205. CHECK(unit_table != nullptr);
  206. resource->unit_table = unit_table;
  207. if (!FLAGS_fst_path.empty()) { // With LM
  208. CHECK(!FLAGS_dict_path.empty());
  209. LOG(INFO) << "Reading fst " << FLAGS_fst_path;
  210. auto fst = std::shared_ptr<fst::VectorFst<fst::StdArc>>(
  211. fst::VectorFst<fst::StdArc>::Read(FLAGS_fst_path));
  212. CHECK(fst != nullptr);
  213. resource->fst = fst;
  214. LOG(INFO) << "Reading symbol table " << FLAGS_dict_path;
  215. auto symbol_table = std::shared_ptr<fst::SymbolTable>(
  216. fst::SymbolTable::ReadText(FLAGS_dict_path));
  217. CHECK(symbol_table != nullptr);
  218. resource->symbol_table = symbol_table;
  219. } else { // Without LM, symbol_table is the same as unit_table
  220. resource->symbol_table = unit_table;
  221. }
  222. if (!FLAGS_context_path.empty()) {
  223. LOG(INFO) << "Reading context " << FLAGS_context_path;
  224. std::vector<std::string> contexts;
  225. std::ifstream infile(FLAGS_context_path);
  226. std::string context;
  227. while (getline(infile, context)) {
  228. contexts.emplace_back(Trim(context));
  229. }
  230. ContextConfig config;
  231. config.context_score = FLAGS_context_score;
  232. resource->context_graph = std::make_shared<ContextGraph>(config);
  233. resource->context_graph->BuildContextGraph(contexts, unit_table);
  234. }
  235. PostProcessOptions post_process_opts;
  236. post_process_opts.language_type =
  237. FLAGS_language_type == 0 ? kMandarinEnglish : kIndoEuropean;
  238. post_process_opts.lowercase = FLAGS_lowercase;
  239. resource->post_processor =
  240. std::make_shared<PostProcessor>(std::move(post_process_opts));
  241. if (!FLAGS_itn_model_dir.empty()) { // With ITN
  242. std::string itn_tagger_path =
  243. wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_tagger.fst");
  244. std::string itn_verbalizer_path =
  245. wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_verbalizer.fst");
  246. if (wenet::FileExists(itn_tagger_path) &&
  247. wenet::FileExists(itn_verbalizer_path)) {
  248. LOG(INFO) << "Reading ITN fst" << FLAGS_itn_model_dir;
  249. post_process_opts.itn = true;
  250. auto postprocessor =
  251. std::make_shared<wenet::PostProcessor>(std::move(post_process_opts));
  252. postprocessor->InitITNResource(itn_tagger_path, itn_verbalizer_path);
  253. resource->post_processor = postprocessor;
  254. }
  255. }
  256. return resource;
  257. }
  258. } // namespace wenet
  259. #endif // DECODER_PARAMS_H_