// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) // 2022 Binbin Zhang (binbzha@qq.com) // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef DECODER_PARAMS_H_ #define DECODER_PARAMS_H_ #include #include #include #include #include "decoder/asr_decoder.h" #ifdef USE_ONNX #include "decoder/onnx_asr_model.h" #endif #ifdef USE_TORCH #include "decoder/torch_asr_model.h" #endif #ifdef USE_XPU #include "xpu/xpu_asr_model.h" #endif #ifdef USE_BPU #include "bpu/bpu_asr_model.h" #endif #ifdef USE_OPENVINO #include "ov/ov_asr_model.h" #endif #include "frontend/feature_pipeline.h" #include "../post_processor/processor/post_processor.h" #include "utils/file.h" #include "utils/flags.h" #include "utils/wn_string.h" DEFINE_int32(device_id, 0, "set XPU DeviceID for ASR model"); // TorchAsrModel flags DEFINE_string(model_path, "", "pytorch exported model path"); // OnnxAsrModel flags DEFINE_string(onnx_dir, "", "directory where the onnx model is saved"); // XPUAsrModel flags DEFINE_string(xpu_model_dir, "", "directory where the XPU model and weights is saved"); // BPUAsrModel flags DEFINE_string(bpu_model_dir, "", "directory where the HORIZON BPU model is saved"); // OVAsrModel flags DEFINE_string(openvino_dir, "", "directory where the OV model is saved"); DEFINE_int32(core_number, 1, "Core number of process"); // FeaturePipelineConfig flags DEFINE_int32(num_bins, 80, "num mel bins for fbank feature"); DEFINE_int32(sample_rate, 16000, "sample rate for audio"); DEFINE_string(feat_type, "kaldi", "Type of feature extraction: kaldi, whisper"); // TLG fst DEFINE_string(fst_path, "", "TLG fst path"); // ITN fst DEFINE_string(itn_model_dir, "", "fst based ITN model dir, " "should contain itn_tagger.fst and itn_verbalizer.fst"); // DecodeOptions flags DEFINE_int32(chunk_size, 16, "decoding chunk size"); DEFINE_int32(num_left_chunks, -1, "left chunks in decoding"); DEFINE_double(ctc_weight, 0.5, "ctc weight when combining ctc score and rescoring score"); DEFINE_double(rescoring_weight, 1.0, "rescoring weight when combining ctc score and rescoring score"); DEFINE_double(reverse_weight, 0.0, "used for bitransformer rescoring. it must be 0.0 if decoder is" "conventional transformer decoder, and only reverse_weight > 0.0" "dose the right to left decoder will be calculated and used"); DEFINE_int32(max_active, 7000, "max active states in ctc wfst search"); DEFINE_int32(min_active, 200, "min active states in ctc wfst search"); DEFINE_double(beam, 16.0, "beam in ctc wfst search"); DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search"); DEFINE_int32(blank_id, 0, "blank token idx for ctc wfst search and ctc prefix beam search"); DEFINE_double(blank_skip_thresh, 1.0, "blank skip thresh for ctc wfst search, 1.0 means no skip"); DEFINE_double(blank_scale, 1.0, "blank scale for ctc wfst search"); DEFINE_double(length_penalty, 0.0, "length penalty ctc wfst search, will not" "apply on self-loop arc, for balancing the del/ins ratio, " "suggest set to -3.0"); DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search"); // SymbolTable flags DEFINE_string(dict_path, "", "dict symbol table path, required when LM is enabled"); DEFINE_string(unit_path, "", "e2e model unit symbol table, it is used in both " "with/without LM scenarios for context/timestamp"); // Context flags DEFINE_string(context_path, "", "context path, is used to build context graph"); DEFINE_double(context_score, 3.0, "is used to rescore the decoded result"); // PostProcessOptions flags DEFINE_int32(language_type, 0, "remove spaces according to language type" "0x00 = kMandarinEnglish, " "0x01 = kIndoEuropean"); DEFINE_bool(lowercase, true, "lowercase final result if needed"); namespace wenet { FeatureType StringToFeatureType(const std::string& feat_type_str) { if (feat_type_str == "kaldi") return FeatureType::kKaldi; else if (feat_type_str == "whisper") return FeatureType::kWhisper; else throw std::invalid_argument("Unsupported feat type!"); } std::shared_ptr InitFeaturePipelineConfigFromFlags() { FeatureType feat_type = StringToFeatureType(FLAGS_feat_type); auto feature_config = std::make_shared( FLAGS_num_bins, FLAGS_sample_rate, feat_type); return feature_config; } std::shared_ptr InitDecodeOptionsFromFlags() { auto decode_config = std::make_shared(); decode_config->chunk_size = FLAGS_chunk_size; decode_config->num_left_chunks = FLAGS_num_left_chunks; decode_config->ctc_weight = FLAGS_ctc_weight; decode_config->reverse_weight = FLAGS_reverse_weight; decode_config->rescoring_weight = FLAGS_rescoring_weight; decode_config->ctc_wfst_search_opts.max_active = FLAGS_max_active; decode_config->ctc_wfst_search_opts.min_active = FLAGS_min_active; decode_config->ctc_wfst_search_opts.beam = FLAGS_beam; decode_config->ctc_wfst_search_opts.lattice_beam = FLAGS_lattice_beam; decode_config->ctc_wfst_search_opts.acoustic_scale = FLAGS_acoustic_scale; decode_config->ctc_wfst_search_opts.blank = FLAGS_blank_id; decode_config->ctc_wfst_search_opts.blank_skip_thresh = FLAGS_blank_skip_thresh; decode_config->ctc_wfst_search_opts.blank_scale = FLAGS_blank_scale; decode_config->ctc_wfst_search_opts.length_penalty = FLAGS_length_penalty; decode_config->ctc_wfst_search_opts.nbest = FLAGS_nbest; decode_config->ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decode_config->ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; decode_config->ctc_prefix_search_opts.blank = FLAGS_blank_id; decode_config->ctc_endpoint_config.blank = FLAGS_blank_id; return decode_config; } std::shared_ptr InitDecodeResourceFromFlags() { auto resource = std::make_shared(); const int kNumGemmThreads = 1; if (!FLAGS_onnx_dir.empty()) { #ifdef USE_ONNX LOG(INFO) << "Reading onnx model "; OnnxAsrModel::InitEngineThreads(kNumGemmThreads); auto model = std::make_shared(); model->Read(FLAGS_onnx_dir); resource->model = model; #else LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DONNX=ON'."; #endif } else if (!FLAGS_model_path.empty()) { #ifdef USE_TORCH LOG(INFO) << "Reading torch model " << FLAGS_model_path; TorchAsrModel::InitEngineThreads(kNumGemmThreads); auto model = std::make_shared(); model->Read(FLAGS_model_path); resource->model = model; #else LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DTORCH=ON'."; #endif } else if (!FLAGS_xpu_model_dir.empty()) { #ifdef USE_XPU LOG(INFO) << "Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir; auto model = std::make_shared(); model->SetEngineThreads(kNumGemmThreads); model->SetDeviceId(FLAGS_device_id); model->Read(FLAGS_xpu_model_dir); resource->model = model; #else LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DXPU=ON'."; #endif } else if (!FLAGS_bpu_model_dir.empty()) { #ifdef USE_BPU LOG(INFO) << "Reading Horizon BPU model from " << FLAGS_bpu_model_dir; auto model = std::make_shared(); model->Read(FLAGS_bpu_model_dir); resource->model = model; #else LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DBPU=ON'."; #endif } else if (!FLAGS_openvino_dir.empty()) { #ifdef USE_OPENVINO LOG(INFO) << "Read OpenVINO model "; auto model = std::make_shared(); model->InitEngineThreads(FLAGS_core_number); model->Read(FLAGS_openvino_dir); resource->model = model; #else LOG(FATAL) << "Please rebuild with cmake-vcpkg options '-DOPENVINO=ON'."; #endif } else { LOG(FATAL) << "Please set ONNX, TORCH, XPU, BPU or OpenVINO model path!!!"; } LOG(INFO) << "Reading unit table " << FLAGS_unit_path; auto unit_table = std::shared_ptr( fst::SymbolTable::ReadText(FLAGS_unit_path)); CHECK(unit_table != nullptr); resource->unit_table = unit_table; if (!FLAGS_fst_path.empty()) { // With LM CHECK(!FLAGS_dict_path.empty()); LOG(INFO) << "Reading fst " << FLAGS_fst_path; auto fst = std::shared_ptr>( fst::VectorFst::Read(FLAGS_fst_path)); CHECK(fst != nullptr); resource->fst = fst; LOG(INFO) << "Reading symbol table " << FLAGS_dict_path; auto symbol_table = std::shared_ptr( fst::SymbolTable::ReadText(FLAGS_dict_path)); CHECK(symbol_table != nullptr); resource->symbol_table = symbol_table; } else { // Without LM, symbol_table is the same as unit_table resource->symbol_table = unit_table; } if (!FLAGS_context_path.empty()) { LOG(INFO) << "Reading context " << FLAGS_context_path; std::vector contexts; std::ifstream infile(FLAGS_context_path); std::string context; while (getline(infile, context)) { contexts.emplace_back(Trim(context)); } ContextConfig config; config.context_score = FLAGS_context_score; resource->context_graph = std::make_shared(config); resource->context_graph->BuildContextGraph(contexts, unit_table); } PostProcessOptions post_process_opts; post_process_opts.language_type = FLAGS_language_type == 0 ? kMandarinEnglish : kIndoEuropean; post_process_opts.lowercase = FLAGS_lowercase; resource->post_processor = std::make_shared(std::move(post_process_opts)); if (!FLAGS_itn_model_dir.empty()) { // With ITN std::string itn_tagger_path = wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_tagger.fst"); std::string itn_verbalizer_path = wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_verbalizer.fst"); if (wenet::FileExists(itn_tagger_path) && wenet::FileExists(itn_verbalizer_path)) { LOG(INFO) << "Reading ITN fst" << FLAGS_itn_model_dir; post_process_opts.itn = true; auto postprocessor = std::make_shared(std::move(post_process_opts)); postprocessor->InitITNResource(itn_tagger_path, itn_verbalizer_path); resource->post_processor = postprocessor; } } return resource; } } // namespace wenet #endif // DECODER_PARAMS_H_