You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

237 lines
7.9 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to to compile a binary FST from textual input.
  19. #ifndef FST_SCRIPT_COMPILE_IMPL_H_
  20. #define FST_SCRIPT_COMPILE_IMPL_H_
  21. #include <cstddef>
  22. #include <iostream>
  23. #include <istream>
  24. #include <memory>
  25. #include <optional>
  26. #include <sstream>
  27. #include <string>
  28. #include <vector>
  29. #include <fst/log.h>
  30. #include <fst/fst.h>
  31. #include <fst/properties.h>
  32. #include <fst/symbol-table.h>
  33. #include <fst/util.h>
  34. #include <fst/vector-fst.h>
  35. #include <unordered_map>
  36. #include <string_view>
  37. DECLARE_string(fst_field_separator);
  38. namespace fst {
  39. // Compile a binary FST from textual input, helper class for fstcompile.cc.
  40. // WARNING: Stand-alone use of this class not recommended, most code should
  41. // read/write using the binary format which is much more efficient.
  42. template <class Arc>
  43. class FstCompiler {
  44. public:
  45. using Label = typename Arc::Label;
  46. using StateId = typename Arc::StateId;
  47. using Weight = typename Arc::Weight;
  48. // If add_symbols_ is true, then the symbols will be dynamically added to the
  49. // symbol tables. This is only useful if you set the (i/o)keep flag to attach
  50. // the final symbol table, or use the accessors. (The input symbol tables are
  51. // const and therefore not changed.)
  52. FstCompiler(std::istream &istrm, std::string_view source,
  53. const SymbolTable *isyms, const SymbolTable *osyms,
  54. const SymbolTable *ssyms, bool accep, bool ikeep, bool okeep,
  55. bool nkeep) {
  56. std::unique_ptr<SymbolTable> misyms(isyms ? isyms->Copy() : nullptr);
  57. std::unique_ptr<SymbolTable> mosyms(osyms ? osyms->Copy() : nullptr);
  58. std::unique_ptr<SymbolTable> mssyms(ssyms ? ssyms->Copy() : nullptr);
  59. Init(istrm, source, misyms.get(), mosyms.get(), mssyms.get(), accep, ikeep,
  60. okeep, nkeep, false);
  61. }
  62. FstCompiler(std::istream &istrm, std::string_view source, SymbolTable *isyms,
  63. SymbolTable *osyms, SymbolTable *ssyms, bool accep, bool ikeep,
  64. bool okeep, bool nkeep, bool add_symbols) {
  65. Init(istrm, source, isyms, osyms, ssyms, accep, ikeep, okeep, nkeep,
  66. add_symbols);
  67. }
  68. void Init(std::istream &istrm, std::string_view source, SymbolTable *isyms,
  69. SymbolTable *osyms, SymbolTable *ssyms, bool accep, bool ikeep,
  70. bool okeep, bool nkeep, bool add_symbols) {
  71. nline_ = 0;
  72. source_ = std::string(source);
  73. isyms_ = isyms;
  74. osyms_ = osyms;
  75. ssyms_ = ssyms;
  76. nstates_ = 0;
  77. keep_state_numbering_ = nkeep;
  78. add_symbols_ = add_symbols;
  79. bool start_state_populated = false;
  80. char line[kLineLen];
  81. const std::string separator =
  82. FST_FLAGS_fst_field_separator + "\n";
  83. while (istrm.getline(line, kLineLen)) {
  84. ++nline_;
  85. const std::vector<std::string_view> col =
  86. StrSplit(line, ByAnyChar(separator), SkipEmpty());
  87. if (col.empty() || col[0].empty()) continue;
  88. if (col.size() > 5 || (col.size() > 4 && accep) ||
  89. (col.size() == 3 && !accep)) {
  90. FSTERROR() << "FstCompiler: Bad number of columns, source = " << source_
  91. << ", line = " << nline_;
  92. fst_.SetProperties(kError, kError);
  93. return;
  94. }
  95. StateId s = StrToStateId(col[0]);
  96. while (s >= fst_.NumStates()) fst_.AddState();
  97. if (!start_state_populated) {
  98. fst_.SetStart(s);
  99. start_state_populated = true;
  100. }
  101. Arc arc;
  102. StateId d = s;
  103. switch (col.size()) {
  104. case 1:
  105. fst_.SetFinal(s, Weight::One());
  106. break;
  107. case 2:
  108. fst_.SetFinal(s, StrToWeight(col[1], true));
  109. break;
  110. case 3:
  111. arc.nextstate = d = StrToStateId(col[1]);
  112. arc.ilabel = StrToILabel(col[2]);
  113. arc.olabel = arc.ilabel;
  114. arc.weight = Weight::One();
  115. fst_.AddArc(s, arc);
  116. break;
  117. case 4:
  118. arc.nextstate = d = StrToStateId(col[1]);
  119. arc.ilabel = StrToILabel(col[2]);
  120. if (accep) {
  121. arc.olabel = arc.ilabel;
  122. arc.weight = StrToWeight(col[3], true);
  123. } else {
  124. arc.olabel = StrToOLabel(col[3]);
  125. arc.weight = Weight::One();
  126. }
  127. fst_.AddArc(s, arc);
  128. break;
  129. case 5:
  130. arc.nextstate = d = StrToStateId(col[1]);
  131. arc.ilabel = StrToILabel(col[2]);
  132. arc.olabel = StrToOLabel(col[3]);
  133. arc.weight = StrToWeight(col[4], true);
  134. fst_.AddArc(s, arc);
  135. }
  136. while (d >= fst_.NumStates()) fst_.AddState();
  137. }
  138. if (ikeep) fst_.SetInputSymbols(isyms);
  139. if (okeep) fst_.SetOutputSymbols(osyms);
  140. }
  141. const VectorFst<Arc> &Fst() const { return fst_; }
  142. private:
  143. // Maximum line length in text file.
  144. static constexpr int kLineLen = 8096;
  145. StateId StrToId(std::string_view s, SymbolTable *syms,
  146. std::string_view name) const {
  147. StateId n = 0;
  148. if (syms) {
  149. // n = (add_symbols_) ? syms->AddSymbol(s) : syms->Find(s);
  150. if (add_symbols_) {
  151. n = syms->AddSymbol(s);
  152. } else {
  153. n = syms->Find(s);
  154. }
  155. if (n == kNoSymbol) {
  156. FSTERROR() << "FstCompiler: Symbol \"" << s
  157. << "\" is not mapped to any integer " << name
  158. << ", symbol table = " << syms->Name()
  159. << ", source = " << source_ << ", line = " << nline_;
  160. fst_.SetProperties(kError, kError);
  161. }
  162. } else {
  163. auto maybe_n = ParseInt64(s);
  164. if (!maybe_n.has_value()) {
  165. FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
  166. << "\", source = " << source_ << ", line = " << nline_;
  167. fst_.SetProperties(kError, kError);
  168. }
  169. n = *maybe_n;
  170. }
  171. return n;
  172. }
  173. StateId StrToStateId(std::string_view s) {
  174. StateId n = StrToId(s, ssyms_, "state ID");
  175. if (keep_state_numbering_) return n;
  176. // Remaps state IDs to make dense set.
  177. const auto it = states_.find(n);
  178. if (it == states_.end()) {
  179. states_[n] = nstates_;
  180. return nstates_++;
  181. } else {
  182. return it->second;
  183. }
  184. }
  185. StateId StrToILabel(std::string_view s) const {
  186. return StrToId(s, isyms_, "arc ilabel");
  187. }
  188. StateId StrToOLabel(std::string_view s) const {
  189. return StrToId(s, osyms_, "arc olabel");
  190. }
  191. Weight StrToWeight(std::string_view s, bool allow_zero) const {
  192. Weight w;
  193. std::istringstream strm(std::string{s});
  194. strm >> w;
  195. if (!strm || (!allow_zero && w == Weight::Zero())) {
  196. FSTERROR() << "FstCompiler: Bad weight = \"" << s
  197. << "\", source = " << source_ << ", line = " << nline_;
  198. fst_.SetProperties(kError, kError);
  199. w = Weight::NoWeight();
  200. }
  201. return w;
  202. }
  203. mutable VectorFst<Arc> fst_;
  204. size_t nline_;
  205. std::string source_; // Text FST source name.
  206. SymbolTable *isyms_; // ilabel symbol table (not owned).
  207. SymbolTable *osyms_; // olabel symbol table (not owned).
  208. SymbolTable *ssyms_; // slabel symbol table (not owned).
  209. std::unordered_map<StateId, StateId> states_; // State ID map.
  210. StateId nstates_; // Number of seen states.
  211. bool keep_state_numbering_;
  212. bool add_symbols_; // Add to symbol tables on-the fly.
  213. FstCompiler(const FstCompiler &) = delete;
  214. FstCompiler &operator=(const FstCompiler &) = delete;
  215. };
  216. } // namespace fst
  217. #endif // FST_SCRIPT_COMPILE_IMPL_H_