You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

212 lines
6.4 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to draw a binary FST by producing a text file in dot format, a helper
  19. // class to fstdraw.cc.
  20. #ifndef FST_SCRIPT_DRAW_IMPL_H_
  21. #define FST_SCRIPT_DRAW_IMPL_H_
  22. #include <iomanip>
  23. #include <ios>
  24. #include <ostream>
  25. #include <sstream>
  26. #include <string>
  27. #include <fst/log.h>
  28. #include <fst/fst.h>
  29. #include <fst/properties.h>
  30. #include <fst/symbol-table.h>
  31. #include <fst/util.h>
  32. #include <fst/script/fst-class.h>
  33. #include <string_view>
  34. namespace fst {
  35. // Print a binary FST in GraphViz textual format (helper class for fstdraw.cc).
  36. // WARNING: Stand-alone use not recommend.
  37. template <class Arc>
  38. class FstDrawer {
  39. public:
  40. using Label = typename Arc::Label;
  41. using StateId = typename Arc::StateId;
  42. using Weight = typename Arc::Weight;
  43. FstDrawer(const Fst<Arc> &fst, const SymbolTable *isyms,
  44. const SymbolTable *osyms, const SymbolTable *ssyms, bool accep,
  45. std::string_view title, float width, float height, bool portrait,
  46. bool vertical, float ranksep, float nodesep, int fontsize,
  47. int precision, std::string_view float_format,
  48. bool show_weight_one)
  49. : fst_(fst),
  50. isyms_(isyms),
  51. osyms_(osyms),
  52. ssyms_(ssyms),
  53. accep_(accep && fst.Properties(kAcceptor, true)),
  54. title_(title),
  55. width_(width),
  56. height_(height),
  57. portrait_(portrait),
  58. vertical_(vertical),
  59. ranksep_(ranksep),
  60. nodesep_(nodesep),
  61. fontsize_(fontsize),
  62. precision_(precision),
  63. float_format_(float_format),
  64. show_weight_one_(show_weight_one) {}
  65. // Draws FST to an output buffer.
  66. void Draw(std::ostream &strm, std::string_view dest) {
  67. SetStreamState(strm);
  68. dest_ = std::string(dest);
  69. const auto start = fst_.Start();
  70. if (start == kNoStateId) return;
  71. strm << "digraph FST {\n";
  72. if (vertical_) {
  73. strm << "rankdir = BT;\n";
  74. } else {
  75. strm << "rankdir = LR;\n";
  76. }
  77. strm << "size = \"" << width_ << "," << height_ << "\";\n";
  78. if (!title_.empty()) strm << "label = \"" + title_ + "\";\n";
  79. strm << "center = 1;\n";
  80. if (portrait_) {
  81. strm << "orientation = Portrait;\n";
  82. } else {
  83. strm << "orientation = Landscape;\n";
  84. }
  85. strm << "ranksep = \"" << ranksep_ << "\";\n"
  86. << "nodesep = \"" << nodesep_ << "\";\n";
  87. // Initial state first.
  88. DrawState(strm, start);
  89. for (StateIterator<Fst<Arc>> siter(fst_); !siter.Done(); siter.Next()) {
  90. const auto s = siter.Value();
  91. if (s != start) DrawState(strm, s);
  92. }
  93. strm << "}\n";
  94. }
  95. private:
  96. void SetStreamState(std::ostream &strm) const {
  97. strm << std::setprecision(precision_);
  98. if (float_format_ == "e") strm << std::scientific;
  99. if (float_format_ == "f") strm << std::fixed;
  100. // O.w. defaults to "g" per standard lib.
  101. }
  102. // Escapes backslash and double quote if these occur in the string. Dot
  103. // will not deal gracefully with these if they are not escaped.
  104. static std::string Escape(std::string_view str) {
  105. std::string ns;
  106. for (char c : str) {
  107. if (c == '\\' || c == '"') ns.push_back('\\');
  108. ns.push_back(c);
  109. }
  110. return ns;
  111. }
  112. std::string FormatId(StateId id, const SymbolTable *syms) const {
  113. if (syms) {
  114. auto symbol = syms->Find(id);
  115. if (symbol.empty()) {
  116. FSTERROR() << "FstDrawer: Integer " << id
  117. << " is not mapped to any textual symbol"
  118. << ", symbol table = " << syms->Name()
  119. << ", destination = " << dest_;
  120. symbol = "?";
  121. }
  122. return Escape(symbol);
  123. } else {
  124. return std::to_string(id);
  125. }
  126. }
  127. std::string FormatStateId(StateId s) const { return FormatId(s, ssyms_); }
  128. std::string FormatILabel(Label label) const {
  129. return FormatId(label, isyms_);
  130. }
  131. std::string FormatOLabel(Label label) const {
  132. return FormatId(label, osyms_);
  133. }
  134. std::string FormatWeight(Weight w) const {
  135. std::stringstream ss;
  136. SetStreamState(ss);
  137. ss << w;
  138. // Weight may have double quote characters in it, so escape it.
  139. return Escape(ss.str());
  140. }
  141. void DrawState(std::ostream &strm, StateId s) const {
  142. strm << s << " [label = \"" << FormatStateId(s);
  143. const auto weight = fst_.Final(s);
  144. if (weight != Weight::Zero()) {
  145. if (show_weight_one_ || (weight != Weight::One())) {
  146. strm << "/" << FormatWeight(weight);
  147. }
  148. strm << "\", shape = doublecircle,";
  149. } else {
  150. strm << "\", shape = circle,";
  151. }
  152. if (s == fst_.Start()) {
  153. strm << " style = bold,";
  154. } else {
  155. strm << " style = solid,";
  156. }
  157. strm << " fontsize = " << fontsize_ << "]\n";
  158. for (ArcIterator<Fst<Arc>> aiter(fst_, s); !aiter.Done(); aiter.Next()) {
  159. const auto &arc = aiter.Value();
  160. strm << "\t" << s << " -> " << arc.nextstate << " [label = \""
  161. << FormatILabel(arc.ilabel);
  162. if (!accep_) {
  163. strm << ":" << FormatOLabel(arc.olabel);
  164. }
  165. if (show_weight_one_ || (arc.weight != Weight::One())) {
  166. strm << "/" << FormatWeight(arc.weight);
  167. }
  168. strm << "\", fontsize = " << fontsize_ << "];\n";
  169. }
  170. }
  171. const Fst<Arc> &fst_;
  172. const SymbolTable *isyms_; // ilabel symbol table.
  173. const SymbolTable *osyms_; // olabel symbol table.
  174. const SymbolTable *ssyms_; // slabel symbol table.
  175. bool accep_; // Print as acceptor when possible.
  176. std::string dest_; // Drawn FST destination name.
  177. std::string title_;
  178. float width_;
  179. float height_;
  180. bool portrait_;
  181. bool vertical_;
  182. float ranksep_;
  183. float nodesep_;
  184. int fontsize_;
  185. int precision_;
  186. std::string float_format_;
  187. bool show_weight_one_;
  188. FstDrawer(const FstDrawer &) = delete;
  189. FstDrawer &operator=(const FstDrawer &) = delete;
  190. };
  191. } // namespace fst
  192. #endif // FST_SCRIPT_DRAW_IMPL_H_