diff --git a/decoder/params.h b/decoder/params.h index 59c8ff9..c3672c8 100644 --- a/decoder/params.h +++ b/decoder/params.h @@ -261,21 +261,21 @@ std::shared_ptr InitDecodeResourceFromFlags() { resource->post_processor = std::make_shared(std::move(post_process_opts)); -// if (!FLAGS_itn_model_dir.empty()) { // With ITN -// std::string itn_tagger_path = -// wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_tagger.fst"); -// std::string itn_verbalizer_path = -// wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_verbalizer.fst"); -// if (wenet::FileExists(itn_tagger_path) && -// wenet::FileExists(itn_verbalizer_path)) { -// LOG(INFO) << "Reading ITN fst" << FLAGS_itn_model_dir; -// post_process_opts.itn = true; -// auto postprocessor = -// std::make_shared(std::move(post_process_opts)); -// postprocessor->InitITNResource(itn_tagger_path, itn_verbalizer_path); -// resource->post_processor = postprocessor; -// } -// } + if (!FLAGS_itn_model_dir.empty()) { // With ITN + std::string itn_tagger_path = + wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_tagger.fst"); + std::string itn_verbalizer_path = + wenet::JoinPath(FLAGS_itn_model_dir, "zh_itn_verbalizer.fst"); + if (wenet::FileExists(itn_tagger_path) && + wenet::FileExists(itn_verbalizer_path)) { + LOG(INFO) << "Reading ITN fst" << FLAGS_itn_model_dir; + post_process_opts.itn = true; + auto postprocessor = + std::make_shared(std::move(post_process_opts)); + postprocessor->InitITNResource(itn_tagger_path, itn_verbalizer_path); + resource->post_processor = postprocessor; + } + } return resource; } diff --git a/post_processor/post_processor.cc b/post_processor/post_processor.cc index 8b752f3..4f070f8 100644 --- a/post_processor/post_processor.cc +++ b/post_processor/post_processor.cc @@ -16,9 +16,16 @@ #include "post_processor/post_processor.h" #include #include +#include "processor/wetext_processor.h" #include "utils/string.h" namespace wenet { +void PostProcessor::InitITNResource(const std::string& tagger_path, + const std::string& verbalizer_path) { + auto itn_processor = + std::make_shared(tagger_path, verbalizer_path); + itn_resource = itn_processor; +} std::string PostProcessor::ProcessSpace(const std::string& str) { std::string result = str; @@ -78,11 +85,11 @@ std::string PostProcessor::Process(const std::string& str, bool finish) { result = ProcessSymbols(str); result = ProcessSpace(result); // TODO(xcsong): do punctuation if finish == true -// if (finish == true && opts_.itn) { -// if (nullptr != itn_resource) { -// result = itn_resource->Normalize(result); -// } -// } + if (finish == true && opts_.itn) { + if (nullptr != itn_resource) { + result = itn_resource->Normalize(result); + } + } return result; } diff --git a/post_processor/post_processor.h b/post_processor/post_processor.h index 2699cbf..b05aee6 100644 --- a/post_processor/post_processor.h +++ b/post_processor/post_processor.h @@ -19,6 +19,7 @@ #include #include #include +#include "processor/wetext_processor.h" #include "utils/utils.h" namespace wenet { @@ -51,8 +52,8 @@ struct PostProcessOptions { // Post Processor class PostProcessor { public: -// explicit PostProcessor(PostProcessOptions&& opts) : opts_(std::move(opts)) {} -// explicit PostProcessor(const PostProcessOptions& opts) : opts_(opts) {} + explicit PostProcessor(PostProcessOptions&& opts) : opts_(std::move(opts)) {} + explicit PostProcessor(const PostProcessOptions& opts) : opts_(opts) {} // call other functions to do post processing std::string Process(const std::string& str, bool finish); // process spaces according to configurations @@ -63,10 +64,12 @@ class PostProcessor { void InitITNResource(const std::string& tagger_path, const std::string& verbalizer_path); - private: -// std::shared_ptr itn_resource = nullptr; - const PostProcessOptions opts_; - public: + + private: + std::shared_ptr itn_resource = nullptr; + const PostProcessOptions opts_; + + public: WENET_DISALLOW_COPY_AND_ASSIGN(PostProcessor); };