You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

87 lines
2.8 KiB

  1. // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "wetext_processor.h"
  15. #include <fst/string.h>
  16. #include <fst/compat.h>
  17. #include <fst/extensions/mpdt/compose.h>
  18. #include <fst/extensions/mpdt/mpdt.h>
  19. #include <fst/extensions/pdt/compose.h>
  20. #include <fst/extensions/pdt/pdt.h>
  21. #include <fst/extensions/pdt/shortest-path.h>
  22. #include <fst/arc.h>
  23. #include <fst/fstlib.h>
  24. #include <fst/fst.h>
  25. #include <fst/vector-fst.h>
  26. namespace wetext {
  27. Processor::Processor(const std::string& tagger_path,
  28. const std::string& verbalizer_path) {
  29. tagger_.reset(StdVectorFst::Read(tagger_path));
  30. verbalizer_.reset(StdVectorFst::Read(verbalizer_path));
  31. compiler_ = std::make_shared<StringCompiler<StdArc>>(fst::StringTokenType::BYTE);
  32. printer_ = std::make_shared<StringPrinter<StdArc>>(fst::StringTokenType::BYTE);
  33. if (tagger_path.find("_tn_") != tagger_path.npos) {
  34. parse_type_ = ParseType::kTN;
  35. } else if (tagger_path.find("_itn_") != tagger_path.npos) {
  36. parse_type_ = ParseType::kITN;
  37. } else {
  38. LOG(FATAL) << "Invalid fst prefix, prefix should contain"
  39. << " either \"_tn_\" or \"_itn_\".";
  40. }
  41. }
  42. std::string Processor::ShortestPath(const StdVectorFst& lattice) {
  43. StdVectorFst shortest_path;
  44. fst::ShortestPath(lattice, &shortest_path, 1, true);
  45. std::string output;
  46. printer_->operator()(shortest_path, &output);
  47. return output;
  48. }
  49. std::string Processor::Compose(const std::string& input,
  50. const StdVectorFst* fst) {
  51. StdVectorFst input_fst;
  52. compiler_->operator()(input, &input_fst);
  53. StdVectorFst lattice;
  54. fst::Compose(input_fst, *fst, &lattice);
  55. return ShortestPath(lattice);
  56. }
  57. std::string Processor::Tag(const std::string& input) {
  58. if (input.empty()) {
  59. return "";
  60. }
  61. return Compose(input, tagger_.get());
  62. }
  63. std::string Processor::Verbalize(const std::string& input) {
  64. if (input.empty()) {
  65. return "";
  66. }
  67. TokenParser parser(parse_type_);
  68. std::string output = parser.Reorder(input);
  69. output = Compose(output, verbalizer_.get());
  70. output.erase(std::remove(output.begin(), output.end(), '\0'), output.end());
  71. return output;
  72. }
  73. std::string Processor::Normalize(const std::string& input) {
  74. return Verbalize(Tag(input));
  75. }
  76. } // namespace wetext