You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

213 lines
6.3 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Generic FST augmented with state count-interface class definition.
  19. #ifndef FST_EXPANDED_FST_H_
  20. #define FST_EXPANDED_FST_H_
  21. #include <sys/types.h>
  22. #include <cstddef>
  23. #include <ios>
  24. #include <iostream>
  25. #include <istream>
  26. #include <memory>
  27. #include <optional>
  28. #include <string>
  29. #include <vector>
  30. #include <fst/log.h>
  31. #include <fst/arc.h>
  32. #include <fstream>
  33. #include <fst/fst.h>
  34. #include <fst/impl-to-fst.h>
  35. #include <fst/properties.h>
  36. #include <fst/register.h>
  37. namespace fst {
  38. // A generic FST plus state count.
  39. template <class A>
  40. class ExpandedFst : public Fst<A> {
  41. public:
  42. using Arc = A;
  43. using StateId = typename Arc::StateId;
  44. virtual StateId NumStates() const = 0; // State count
  45. std::optional<StateId> NumStatesIfKnown() const override {
  46. return NumStates();
  47. }
  48. // Get a copy of this ExpandedFst. See Fst<>::Copy() for further doc.
  49. ExpandedFst *Copy(bool safe = false) const override = 0;
  50. // Read an ExpandedFst from an input stream; return NULL on error.
  51. static ExpandedFst *Read(std::istream &strm, const FstReadOptions &opts) {
  52. FstReadOptions ropts(opts);
  53. FstHeader hdr;
  54. if (ropts.header) {
  55. hdr = *opts.header;
  56. } else {
  57. if (!hdr.Read(strm, opts.source)) return nullptr;
  58. ropts.header = &hdr;
  59. }
  60. if (!(hdr.Properties() & kExpanded)) {
  61. LOG(ERROR) << "ExpandedFst::Read: Not an ExpandedFst: " << ropts.source;
  62. return nullptr;
  63. }
  64. const auto reader =
  65. FstRegister<Arc>::GetRegister()->GetReader(hdr.FstType());
  66. if (!reader) {
  67. LOG(ERROR) << "ExpandedFst::Read: Unknown FST type \"" << hdr.FstType()
  68. << "\" (arc type = \"" << A::Type() << "\"): " << ropts.source;
  69. return nullptr;
  70. }
  71. auto *fst = reader(strm, ropts);
  72. if (!fst) return nullptr;
  73. return down_cast<ExpandedFst *>(fst);
  74. }
  75. // Read an ExpandedFst from a file; return NULL on error.
  76. // Empty source reads from standard input.
  77. static ExpandedFst *Read(std::string_view source) {
  78. if (!source.empty()) {
  79. std::ifstream strm(std::string(source),
  80. std::ios_base::in | std::ios_base::binary);
  81. if (!strm) {
  82. LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << source;
  83. return nullptr;
  84. }
  85. return Read(strm, FstReadOptions(source));
  86. } else {
  87. return Read(std::cin, FstReadOptions("standard input"));
  88. }
  89. }
  90. };
  91. namespace internal {
  92. // ExpandedFst<A> case - abstract methods.
  93. template <class Arc>
  94. inline typename Arc::Weight Final(const ExpandedFst<Arc> &fst,
  95. typename Arc::StateId s) {
  96. return fst.Final(s);
  97. }
  98. template <class Arc>
  99. inline ssize_t NumArcs(const ExpandedFst<Arc> &fst, typename Arc::StateId s) {
  100. return fst.NumArcs(s);
  101. }
  102. template <class Arc>
  103. inline ssize_t NumInputEpsilons(const ExpandedFst<Arc> &fst,
  104. typename Arc::StateId s) {
  105. return fst.NumInputEpsilons(s);
  106. }
  107. template <class Arc>
  108. inline ssize_t NumOutputEpsilons(const ExpandedFst<Arc> &fst,
  109. typename Arc::StateId s) {
  110. return fst.NumOutputEpsilons(s);
  111. }
  112. } // namespace internal
  113. // A useful alias when using StdArc.
  114. using StdExpandedFst = ExpandedFst<StdArc>;
  115. // This is a helper class template useful for attaching an ExpandedFst
  116. // interface to its implementation, handling reference counting. It
  117. // delegates to ImplToFst the handling of the Fst interface methods.
  118. template <class Impl, class FST = ExpandedFst<typename Impl::Arc>>
  119. class ImplToExpandedFst : public ImplToFst<Impl, FST> {
  120. public:
  121. using Arc = typename FST::Arc;
  122. using StateId = typename Arc::StateId;
  123. using Weight = typename Arc::Weight;
  124. StateId NumStates() const override { return GetImpl()->NumStates(); }
  125. protected:
  126. using ImplToFst<Impl, FST>::GetImpl;
  127. explicit ImplToExpandedFst(std::shared_ptr<Impl> impl)
  128. : ImplToFst<Impl, FST>(impl) {}
  129. ImplToExpandedFst(const ImplToExpandedFst &fst, bool safe)
  130. : ImplToFst<Impl, FST>(fst, safe) {}
  131. static Impl *Read(std::istream &strm, const FstReadOptions &opts) {
  132. return Impl::Read(strm, opts);
  133. }
  134. // Read FST implementation from a file; return NULL on error.
  135. // Empty source reads from standard input.
  136. static Impl *Read(std::string_view source) {
  137. if (!source.empty()) {
  138. std::ifstream strm(std::string(source),
  139. std::ios_base::in | std::ios_base::binary);
  140. if (!strm) {
  141. LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << source;
  142. return nullptr;
  143. }
  144. return Impl::Read(strm, FstReadOptions(source));
  145. } else {
  146. return Impl::Read(std::cin, FstReadOptions("standard input"));
  147. }
  148. }
  149. };
  150. // Function to return the number of states in an FST, counting them
  151. // if necessary.
  152. template <class Arc>
  153. typename Arc::StateId CountStates(const Fst<Arc> &fst) {
  154. if (std::optional<typename Arc::StateId> num_states =
  155. fst.NumStatesIfKnown()) {
  156. return *num_states;
  157. } else {
  158. typename Arc::StateId nstates = 0;
  159. for (StateIterator<Fst<Arc>> siter(fst); !siter.Done(); siter.Next()) {
  160. ++nstates;
  161. }
  162. return nstates;
  163. }
  164. }
  165. // Function to return the number of states in a vector of FSTs, counting them if
  166. // necessary.
  167. template <class Arc>
  168. typename Arc::StateId CountStates(const std::vector<const Fst<Arc> *> &fsts) {
  169. typename Arc::StateId nstates = 0;
  170. for (const auto *fst : fsts) nstates += CountStates(*fst);
  171. return nstates;
  172. }
  173. // Function to return the number of arcs in an FST.
  174. template <class F>
  175. size_t CountArcs(const F &fst) {
  176. size_t narcs = 0;
  177. for (StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
  178. narcs += fst.NumArcs(siter.Value());
  179. }
  180. return narcs;
  181. }
  182. } // namespace fst
  183. #endif // FST_EXPANDED_FST_H_