You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

414 lines
14 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. /// See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Utilities to convert strings into FSTs.
  19. #ifndef FST_STRING_H_
  20. #define FST_STRING_H_
  21. #include <cstdint>
  22. #include <memory>
  23. #include <optional>
  24. #include <ostream>
  25. #include <sstream>
  26. #include <string>
  27. #include <utility>
  28. #include <vector>
  29. #include <fst/flags.h>
  30. #include <fst/log.h>
  31. #include <fst/arc.h>
  32. #include <fst/compact-fst.h>
  33. #include <fst/fst.h>
  34. #include <fst/icu.h>
  35. #include <fst/mutable-fst.h>
  36. #include <fst/properties.h>
  37. #include <fst/symbol-table.h>
  38. #include <fst/util.h>
  39. #include <fst/compat.h>
  40. #include <string_view>
  41. DECLARE_string(fst_field_separator);
  42. namespace fst {
  43. enum class TokenType : uint8_t { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
  44. inline std::ostream &operator<<(std::ostream &strm,
  45. const TokenType &token_type) {
  46. switch (token_type) {
  47. case TokenType::BYTE:
  48. return strm << "byte";
  49. case TokenType::UTF8:
  50. return strm << "utf8";
  51. case TokenType::SYMBOL:
  52. return strm << "symbol";
  53. }
  54. return strm; // unreachable
  55. }
  56. namespace internal {
  57. template <class Label>
  58. bool ConvertSymbolToLabel(std::string_view str, const SymbolTable *syms,
  59. Label unknown_label, Label *output) {
  60. int64_t n;
  61. if (syms) {
  62. n = syms->Find(str);
  63. if ((n == kNoSymbol) && (unknown_label != kNoLabel)) n = unknown_label;
  64. if (n == kNoSymbol) {
  65. LOG(ERROR) << "ConvertSymbolToLabel: Symbol \"" << str
  66. << "\" is not mapped to any integer label, symbol table = "
  67. << syms->Name();
  68. return false;
  69. }
  70. } else {
  71. const auto maybe_n = ParseInt64(str);
  72. if (!maybe_n.has_value()) {
  73. LOG(ERROR) << "ConvertSymbolToLabel: Bad label integer "
  74. << "= \"" << str << "\"";
  75. return false;
  76. }
  77. n = *maybe_n;
  78. }
  79. *output = n;
  80. return true;
  81. }
  82. template <class Label>
  83. bool ConvertStringToLabels(
  84. std::string_view str, TokenType token_type, const SymbolTable *syms,
  85. Label unknown_label, std::vector<Label> *labels,
  86. std::string_view sep = FST_FLAGS_fst_field_separator) {
  87. labels->clear();
  88. switch (token_type) {
  89. case TokenType::BYTE: {
  90. labels->reserve(str.size());
  91. return ByteStringToLabels(str, labels);
  92. }
  93. case TokenType::UTF8: {
  94. return UTF8StringToLabels(str, labels);
  95. }
  96. case TokenType::SYMBOL: {
  97. const std::string separator = fst::StrCat("\n", sep);
  98. for (std::string_view c :
  99. StrSplit(str, ByAnyChar(separator), SkipEmpty())) {
  100. Label label;
  101. if (!ConvertSymbolToLabel(c, syms, unknown_label, &label)) return false;
  102. labels->push_back(label);
  103. }
  104. return true;
  105. }
  106. }
  107. return false; // Unreachable.
  108. }
  109. // The last character of 'sep' is used as a separator between symbols.
  110. // Additionally, epsilon symbols will be printed only if omit_epsilon
  111. // is false.
  112. template <class Label>
  113. bool LabelsToSymbolString(const std::vector<Label> &labels, std::string *str,
  114. const SymbolTable &syms, std::string_view sep,
  115. bool omit_epsilon) {
  116. std::stringstream ostrm;
  117. sep.remove_prefix(sep.size() - 1); // We only respect the final char of sep.
  118. std::string_view delim = "";
  119. for (auto label : labels) {
  120. if (omit_epsilon && !label) continue;
  121. ostrm << delim;
  122. const std::string &symbol = syms.Find(label);
  123. if (symbol.empty()) {
  124. LOG(ERROR) << "LabelsToSymbolString: Label " << label
  125. << " is not mapped onto any textual symbol in symbol table "
  126. << syms.Name();
  127. return false;
  128. }
  129. ostrm << symbol;
  130. delim = sep;
  131. }
  132. *str = ostrm.str();
  133. return !!ostrm;
  134. }
  135. // The last character of 'sep' is used as a separator between symbols.
  136. // Additionally, epsilon symbols will be printed only if omit_epsilon
  137. // is false.
  138. template <class Label>
  139. bool LabelsToNumericString(const std::vector<Label> &labels, std::string *str,
  140. std::string_view sep, bool omit_epsilon) {
  141. std::stringstream ostrm;
  142. sep.remove_prefix(sep.size() - 1); // We only respect the final char of sep.
  143. std::string_view delim = "";
  144. for (auto label : labels) {
  145. if (omit_epsilon && !label) continue;
  146. ostrm << delim;
  147. ostrm << label;
  148. delim = sep;
  149. }
  150. *str = ostrm.str();
  151. return !!ostrm;
  152. }
  153. } // namespace internal
  154. // Functor for compiling a string in an FST.
  155. template <class Arc>
  156. class OPENFST_DEPRECATED("allow_negative is no-op") StringCompiler {
  157. public:
  158. using Label = typename Arc::Label;
  159. using StateId = typename Arc::StateId;
  160. using Weight = typename Arc::Weight;
  161. explicit StringCompiler(TokenType token_type = TokenType::BYTE,
  162. const SymbolTable *syms = nullptr,
  163. Label unknown_label = kNoLabel)
  164. : token_type_(token_type), syms_(syms), unknown_label_(unknown_label) {}
  165. // Compiles string into an FST. With SYMBOL token type, sep is used to
  166. // specify the set of char separators between symbols, in addition
  167. // of '\n' which is always treated as a separator.
  168. // Returns true on success.
  169. template <class FST>
  170. bool operator()(
  171. std::string_view str, FST *fst,
  172. std::string_view sep = FST_FLAGS_fst_field_separator) const {
  173. std::vector<Label> labels;
  174. if (!internal::ConvertStringToLabels(str, token_type_, syms_,
  175. unknown_label_, &labels, sep)) {
  176. return false;
  177. }
  178. Compile(labels, fst);
  179. return true;
  180. }
  181. // Same as above but allows to specify a weight for the string.
  182. template <class FST>
  183. bool operator()(
  184. std::string_view str, FST *fst, Weight weight,
  185. std::string_view sep = FST_FLAGS_fst_field_separator) const {
  186. std::vector<Label> labels;
  187. if (!internal::ConvertStringToLabels(str, token_type_, syms_,
  188. unknown_label_, &labels, sep)) {
  189. return false;
  190. }
  191. Compile(labels, fst, std::move(weight));
  192. return true;
  193. }
  194. private:
  195. void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
  196. Weight weight = Weight::One()) const {
  197. fst->DeleteStates();
  198. auto state = fst->AddState();
  199. fst->SetStart(state);
  200. fst->AddStates(labels.size());
  201. for (auto label : labels) {
  202. fst->AddArc(state, Arc(label, label, state + 1));
  203. ++state;
  204. }
  205. fst->SetFinal(state, std::move(weight));
  206. fst->SetProperties(kCompiledStringProperties, kCompiledStringProperties);
  207. }
  208. template <class Unsigned>
  209. void Compile(const std::vector<Label> &labels,
  210. CompactStringFst<Arc, Unsigned> *fst) const {
  211. using Compactor = typename CompactStringFst<Arc, Unsigned>::Compactor;
  212. fst->SetCompactor(
  213. std::make_shared<Compactor>(labels.begin(), labels.end()));
  214. }
  215. template <class Unsigned>
  216. void Compile(const std::vector<Label> &labels,
  217. CompactWeightedStringFst<Arc, Unsigned> *fst,
  218. Weight weight = Weight::One()) const {
  219. std::vector<std::pair<Label, Weight>> compacts;
  220. compacts.reserve(labels.size() + 1);
  221. for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
  222. compacts.emplace_back(labels[i], Weight::One());
  223. }
  224. compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
  225. using Compactor =
  226. typename CompactWeightedStringFst<Arc, Unsigned>::Compactor;
  227. fst->SetCompactor(
  228. std::make_shared<Compactor>(compacts.begin(), compacts.end()));
  229. }
  230. const TokenType token_type_;
  231. const SymbolTable *syms_; // Symbol table (used when token type is symbol).
  232. const Label unknown_label_; // Label for token missing from symbol table.
  233. StringCompiler(const StringCompiler &) = delete;
  234. StringCompiler &operator=(const StringCompiler &) = delete;
  235. };
  236. // A useful alias when using StdArc.
  237. using StdStringCompiler = StringCompiler<StdArc>;
  238. // Helpers for StringPrinter.
  239. // Converts an FST to a vector of output labels. To get input labels, use
  240. // Project or Invert. Returns true on success. Use only with string FSTs; may
  241. // loop for non-string FSTs.
  242. template <class Arc>
  243. bool StringFstToOutputLabels(const Fst<Arc> &fst,
  244. std::vector<typename Arc::Label> *labels) {
  245. labels->clear();
  246. auto s = fst.Start();
  247. if (s == kNoStateId) {
  248. LOG(ERROR) << "StringFstToOutputLabels: Invalid start state";
  249. return false;
  250. }
  251. while (fst.Final(s) == Arc::Weight::Zero()) {
  252. ArcIterator<Fst<Arc>> aiter(fst, s);
  253. if (aiter.Done()) {
  254. LOG(ERROR) << "StringFstToOutputLabels: Does not reach final state";
  255. return false;
  256. }
  257. const auto &arc = aiter.Value();
  258. labels->push_back(arc.olabel);
  259. s = arc.nextstate;
  260. aiter.Next();
  261. if (!aiter.Done()) {
  262. LOG(ERROR) << "StringFstToOutputLabels: State " << s
  263. << " has multiple outgoing arcs";
  264. return false;
  265. }
  266. }
  267. if (fst.NumArcs(s) != 0) {
  268. LOG(ERROR) << "StringFstToOutputLabels: Final state " << s
  269. << " has outgoing arc(s)";
  270. return false;
  271. }
  272. return true;
  273. }
  274. // Same as above but also computes the path weight. The output weight parameter
  275. // is only set if labels extraction is successful.
  276. template <class Arc>
  277. bool StringFstToOutputLabels(const Fst<Arc> &fst,
  278. std::vector<typename Arc::Label> *labels,
  279. typename Arc::Weight *weight) {
  280. labels->clear();
  281. auto path_weight = Arc::Weight::One();
  282. auto s = fst.Start();
  283. if (s == kNoStateId) {
  284. LOG(ERROR) << "StringFstToOutputLabels: Invalid start state";
  285. return false;
  286. }
  287. auto final_weight = fst.Final(s);
  288. while (final_weight == Arc::Weight::Zero()) {
  289. ArcIterator<Fst<Arc>> aiter(fst, s);
  290. if (aiter.Done()) {
  291. LOG(ERROR) << "StringFstToOutputLabels: Does not reach final state";
  292. return false;
  293. }
  294. const auto &arc = aiter.Value();
  295. labels->push_back(arc.olabel);
  296. path_weight = Times(path_weight, arc.weight);
  297. s = arc.nextstate;
  298. aiter.Next();
  299. if (!aiter.Done()) {
  300. LOG(ERROR) << "StringFstToOutputLabels: State " << s
  301. << " has multiple outgoing arcs";
  302. return false;
  303. }
  304. final_weight = fst.Final(s);
  305. }
  306. if (fst.NumArcs(s) != 0) {
  307. LOG(ERROR) << "StringFstToOutputLabels: Final state " << s
  308. << " has outgoing arc(s)";
  309. return false;
  310. }
  311. *weight = Times(path_weight, final_weight);
  312. return true;
  313. }
  314. // Converts a list of symbols to a string. If the token type is SYMBOL, the last
  315. // character of sep is used to separate textual symbols. Additionally, if the
  316. // token type is SYMBOL, epsilon symbols will be printed only if omit_epsilon
  317. // is false. Returns true on success.
  318. template <class Label>
  319. bool LabelsToString(
  320. const std::vector<Label> &labels, std::string *str,
  321. TokenType ttype = TokenType::BYTE, const SymbolTable *syms = nullptr,
  322. std::string_view sep = FST_FLAGS_fst_field_separator,
  323. bool omit_epsilon = true) {
  324. switch (ttype) {
  325. case TokenType::BYTE: {
  326. return LabelsToByteString(labels, str);
  327. }
  328. case TokenType::UTF8: {
  329. return LabelsToUTF8String(labels, str);
  330. }
  331. case TokenType::SYMBOL: {
  332. return syms ? internal::LabelsToSymbolString(labels, str, *syms, sep,
  333. omit_epsilon)
  334. : internal::LabelsToNumericString(labels, str, sep,
  335. omit_epsilon);
  336. }
  337. }
  338. return false;
  339. }
  340. // Functor for printing a string FST as a string.
  341. template <class Arc>
  342. class StringPrinter {
  343. public:
  344. using Label = typename Arc::Label;
  345. using Weight = typename Arc::Weight;
  346. explicit StringPrinter(TokenType token_type = TokenType::BYTE,
  347. const SymbolTable *syms = nullptr,
  348. bool omit_epsilon = true)
  349. : token_type_(token_type), syms_(syms), omit_epsilon_(omit_epsilon) {}
  350. // Converts the FST into a string. With SYMBOL token type, the last character
  351. // of sep is used as a separator between symbols. Returns true on success.
  352. bool operator()(
  353. const Fst<Arc> &fst, std::string *str,
  354. std::string_view sep = FST_FLAGS_fst_field_separator) const {
  355. std::vector<Label> labels;
  356. return StringFstToOutputLabels(fst, &labels) &&
  357. LabelsToString(labels, str, token_type_, syms_, sep, omit_epsilon_);
  358. }
  359. // Same as above but also computes the path weight. The output weight
  360. // parameter is only set if labels extraction is successful.
  361. bool operator()(
  362. const Fst<Arc> &fst, std::string *str, Weight *weight,
  363. std::string_view sep = FST_FLAGS_fst_field_separator) const {
  364. std::vector<Label> labels;
  365. return StringFstToOutputLabels(fst, &labels, weight) &&
  366. LabelsToString(labels, str, token_type_, syms_, sep, omit_epsilon_);
  367. }
  368. private:
  369. const TokenType token_type_;
  370. const SymbolTable *syms_;
  371. const bool omit_epsilon_;
  372. StringPrinter(const StringPrinter &) = delete;
  373. StringPrinter &operator=(const StringPrinter &) = delete;
  374. };
  375. // A useful alias when using StdArc.
  376. using StdStringPrinter = StringPrinter<StdArc>;
  377. } // namespace fst
  378. #endif // FST_STRING_H_