You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

352 lines
9.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to compute various information about FSTs, a helper class for
  19. // fstinfo.cc.
  20. #ifndef FST_SCRIPT_INFO_IMPL_H_
  21. #define FST_SCRIPT_INFO_IMPL_H_
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <map>
  25. #include <ostream>
  26. #include <string>
  27. #include <vector>
  28. #include <fst/log.h>
  29. #include <fst/arcfilter.h>
  30. #include <fst/cc-visitors.h>
  31. #include <fst/dfs-visit.h>
  32. #include <fst/fst.h>
  33. #include <fst/lookahead-matcher.h>
  34. #include <fst/matcher.h>
  35. #include <fst/properties.h>
  36. #include <fst/queue.h>
  37. #include <fst/util.h>
  38. #include <fst/verify.h>
  39. #include <fst/visit.h>
  40. #include <fst/script/arcfilter-impl.h>
  41. #include <string_view>
  42. namespace fst {
  43. // Compute various information about FSTs, helper class for fstinfo.cc.
  44. // WARNING: Stand-alone use of this class is not recommended, most code
  45. // should call directly the relevant library functions: Fst<Arc>::NumStates,
  46. // Fst<Arc>::NumArcs, TestProperties, etc.
  47. class FstInfo {
  48. public:
  49. // When info_type is "short" (or "auto" and not an ExpandedFst) then only
  50. // minimal info is computed and can be requested.
  51. template <typename Arc>
  52. FstInfo(const Fst<Arc> &fst, bool test_properties,
  53. script::ArcFilterType arc_filter_type = script::ArcFilterType::ANY,
  54. std::string_view info_type = "auto", bool verify = true)
  55. : fst_type_(fst.Type()),
  56. input_symbols_(fst.InputSymbols() ? fst.InputSymbols()->Name()
  57. : "none"),
  58. output_symbols_(fst.OutputSymbols() ? fst.OutputSymbols()->Name()
  59. : "none"),
  60. nstates_(0),
  61. narcs_(0),
  62. start_(kNoStateId),
  63. nfinal_(0),
  64. nepsilons_(0),
  65. niepsilons_(0),
  66. noepsilons_(0),
  67. ilabel_mult_(0.0),
  68. olabel_mult_(0.0),
  69. naccess_(0),
  70. ncoaccess_(0),
  71. nconnect_(0),
  72. ncc_(0),
  73. nscc_(0),
  74. input_match_type_(MATCH_NONE),
  75. output_match_type_(MATCH_NONE),
  76. input_lookahead_(false),
  77. output_lookahead_(false),
  78. properties_(0),
  79. arc_filter_type_(arc_filter_type),
  80. long_info_(true),
  81. arc_type_(Arc::Type()) {
  82. using Label = typename Arc::Label;
  83. using StateId = typename Arc::StateId;
  84. using Weight = typename Arc::Weight;
  85. if (info_type == "long") {
  86. long_info_ = true;
  87. } else if (info_type == "short") {
  88. long_info_ = false;
  89. } else if (info_type == "auto") {
  90. long_info_ = fst.Properties(kExpanded, false);
  91. } else {
  92. FSTERROR() << "Bad info type: " << info_type;
  93. return;
  94. }
  95. if (!long_info_) return;
  96. // If the FST is not sane, we return.
  97. if (verify && !Verify(fst)) {
  98. FSTERROR() << "FstInfo: Verify: FST not well-formed";
  99. return;
  100. }
  101. start_ = fst.Start();
  102. properties_ = fst.Properties(kFstProperties, test_properties);
  103. for (StateIterator<Fst<Arc>> siter(fst); !siter.Done(); siter.Next()) {
  104. ++nstates_;
  105. const auto s = siter.Value();
  106. if (fst.Final(s) != Weight::Zero()) ++nfinal_;
  107. std::map<Label, size_t> ilabel_count;
  108. std::map<Label, size_t> olabel_count;
  109. for (ArcIterator<Fst<Arc>> aiter(fst, s); !aiter.Done(); aiter.Next()) {
  110. const auto &arc = aiter.Value();
  111. ++narcs_;
  112. if (arc.ilabel == 0 && arc.olabel == 0) ++nepsilons_;
  113. if (arc.ilabel == 0) ++niepsilons_;
  114. if (arc.olabel == 0) ++noepsilons_;
  115. ++ilabel_count[arc.ilabel];
  116. ++olabel_count[arc.olabel];
  117. }
  118. for (auto it = ilabel_count.begin(); it != ilabel_count.end(); ++it) {
  119. ilabel_mult_ += it->second * it->second;
  120. }
  121. for (auto it = olabel_count.begin(); it != olabel_count.end(); ++it) {
  122. olabel_mult_ += it->second * it->second;
  123. }
  124. }
  125. if (narcs_ > 0) {
  126. ilabel_mult_ /= narcs_;
  127. olabel_mult_ /= narcs_;
  128. }
  129. {
  130. std::vector<StateId> cc;
  131. CcVisitor<Arc> cc_visitor(&cc);
  132. FifoQueue<StateId> fifo_queue;
  133. switch (arc_filter_type) {
  134. case script::ArcFilterType::ANY: {
  135. Visit(fst, &cc_visitor, &fifo_queue);
  136. break;
  137. }
  138. case script::ArcFilterType::EPSILON: {
  139. Visit(fst, &cc_visitor, &fifo_queue, EpsilonArcFilter<Arc>());
  140. break;
  141. }
  142. case script::ArcFilterType::INPUT_EPSILON: {
  143. Visit(fst, &cc_visitor, &fifo_queue, InputEpsilonArcFilter<Arc>());
  144. break;
  145. }
  146. case script::ArcFilterType::OUTPUT_EPSILON: {
  147. Visit(fst, &cc_visitor, &fifo_queue, OutputEpsilonArcFilter<Arc>());
  148. break;
  149. }
  150. }
  151. for (StateId s = 0; s < cc.size(); ++s) {
  152. if (cc[s] >= ncc_) ncc_ = cc[s] + 1;
  153. }
  154. }
  155. {
  156. std::vector<StateId> scc;
  157. std::vector<bool> access, coaccess;
  158. uint64_t props = 0;
  159. SccVisitor<Arc> scc_visitor(&scc, &access, &coaccess, &props);
  160. switch (arc_filter_type) {
  161. case script::ArcFilterType::ANY: {
  162. DfsVisit(fst, &scc_visitor);
  163. break;
  164. }
  165. case script::ArcFilterType::EPSILON: {
  166. DfsVisit(fst, &scc_visitor, EpsilonArcFilter<Arc>());
  167. break;
  168. }
  169. case script::ArcFilterType::INPUT_EPSILON: {
  170. DfsVisit(fst, &scc_visitor, InputEpsilonArcFilter<Arc>());
  171. break;
  172. }
  173. case script::ArcFilterType::OUTPUT_EPSILON: {
  174. DfsVisit(fst, &scc_visitor, OutputEpsilonArcFilter<Arc>());
  175. break;
  176. }
  177. }
  178. for (StateId s = 0; s < scc.size(); ++s) {
  179. if (access[s]) ++naccess_;
  180. if (coaccess[s]) ++ncoaccess_;
  181. if (access[s] && coaccess[s]) ++nconnect_;
  182. if (scc[s] >= nscc_) nscc_ = scc[s] + 1;
  183. }
  184. }
  185. LookAheadMatcher<Fst<Arc>> imatcher(fst, MATCH_INPUT);
  186. input_match_type_ = imatcher.Type(test_properties);
  187. input_lookahead_ = imatcher.Flags() & kInputLookAheadMatcher;
  188. LookAheadMatcher<Fst<Arc>> omatcher(fst, MATCH_OUTPUT);
  189. output_match_type_ = omatcher.Type(test_properties);
  190. output_lookahead_ = omatcher.Flags() & kOutputLookAheadMatcher;
  191. }
  192. // Short info.
  193. const std::string &FstType() const { return fst_type_; }
  194. const std::string &ArcType() const { return arc_type_; }
  195. const std::string &InputSymbols() const { return input_symbols_; }
  196. const std::string &OutputSymbols() const { return output_symbols_; }
  197. bool LongInfo() const { return long_info_; }
  198. script::ArcFilterType ArcFilterType() const { return arc_filter_type_; }
  199. // Long info.
  200. MatchType InputMatchType() const {
  201. CheckLong();
  202. return input_match_type_;
  203. }
  204. MatchType OutputMatchType() const {
  205. CheckLong();
  206. return output_match_type_;
  207. }
  208. bool InputLookAhead() const {
  209. CheckLong();
  210. return input_lookahead_;
  211. }
  212. bool OutputLookAhead() const {
  213. CheckLong();
  214. return output_lookahead_;
  215. }
  216. int64_t NumStates() const {
  217. CheckLong();
  218. return nstates_;
  219. }
  220. size_t NumArcs() const {
  221. CheckLong();
  222. return narcs_;
  223. }
  224. int64_t Start() const {
  225. CheckLong();
  226. return start_;
  227. }
  228. size_t NumFinal() const {
  229. CheckLong();
  230. return nfinal_;
  231. }
  232. size_t NumEpsilons() const {
  233. CheckLong();
  234. return nepsilons_;
  235. }
  236. size_t NumInputEpsilons() const {
  237. CheckLong();
  238. return niepsilons_;
  239. }
  240. size_t NumOutputEpsilons() const {
  241. CheckLong();
  242. return noepsilons_;
  243. }
  244. double InputLabelMultiplicity() const {
  245. CheckLong();
  246. return ilabel_mult_;
  247. }
  248. double OutputLabelMultiplicity() const {
  249. CheckLong();
  250. return olabel_mult_;
  251. }
  252. size_t NumAccessible() const {
  253. CheckLong();
  254. return naccess_;
  255. }
  256. size_t NumCoAccessible() const {
  257. CheckLong();
  258. return ncoaccess_;
  259. }
  260. size_t NumConnected() const {
  261. CheckLong();
  262. return nconnect_;
  263. }
  264. size_t NumCc() const {
  265. CheckLong();
  266. return ncc_;
  267. }
  268. size_t NumScc() const {
  269. CheckLong();
  270. return nscc_;
  271. }
  272. uint64_t Properties() const {
  273. CheckLong();
  274. return properties_;
  275. }
  276. void Info() const;
  277. private:
  278. void CheckLong() const {
  279. if (!long_info_)
  280. FSTERROR() << "FstInfo: Method only available with long info signature";
  281. }
  282. std::string fst_type_;
  283. std::string input_symbols_;
  284. std::string output_symbols_;
  285. int64_t nstates_;
  286. size_t narcs_;
  287. int64_t start_;
  288. size_t nfinal_;
  289. size_t nepsilons_;
  290. size_t niepsilons_;
  291. size_t noepsilons_;
  292. double ilabel_mult_;
  293. double olabel_mult_;
  294. size_t naccess_;
  295. size_t ncoaccess_;
  296. size_t nconnect_;
  297. size_t ncc_;
  298. size_t nscc_;
  299. MatchType input_match_type_;
  300. MatchType output_match_type_;
  301. bool input_lookahead_;
  302. bool output_lookahead_;
  303. uint64_t properties_;
  304. script::ArcFilterType arc_filter_type_;
  305. bool long_info_;
  306. std::string arc_type_;
  307. };
  308. // Prints `properties` to `ostrm` in a user-friendly multi-line format.
  309. void PrintProperties(std::ostream &ostrm, uint64_t properties);
  310. // Prints `header` to `ostrm` in a user-friendly multi-line format.
  311. void PrintHeader(std::ostream &ostrm, const FstHeader &header);
  312. } // namespace fst
  313. #endif // FST_SCRIPT_INFO_IMPL_H_