You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

329 lines
11 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // An FST implementation and base interface for delayed unions, concatenations,
  19. // and closures.
  20. #ifndef FST_RATIONAL_H_
  21. #define FST_RATIONAL_H_
  22. #include <algorithm>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <memory>
  26. #include <string>
  27. #include <utility>
  28. #include <vector>
  29. #include <fst/cache.h>
  30. #include <fst/fst.h>
  31. #include <fst/impl-to-fst.h>
  32. #include <fst/mutable-fst.h>
  33. #include <fst/properties.h>
  34. #include <fst/replace.h>
  35. #include <fst/vector-fst.h>
  36. namespace fst {
  37. using RationalFstOptions = CacheOptions;
  38. // This specifies whether to add the empty string.
  39. enum ClosureType {
  40. CLOSURE_STAR = 0, // Add the empty string.
  41. CLOSURE_PLUS = 1 // Don't add the empty string.
  42. };
  43. template <class Arc>
  44. class RationalFst;
  45. template <class Arc>
  46. void Union(RationalFst<Arc> *fst1, const Fst<Arc> &fst2);
  47. template <class Arc>
  48. void Concat(RationalFst<Arc> *fst1, const Fst<Arc> &fst2);
  49. template <class Arc>
  50. void Concat(const Fst<Arc> &fst1, RationalFst<Arc> *fst2);
  51. template <class Arc>
  52. void Closure(RationalFst<Arc> *fst, ClosureType closure_type);
  53. namespace internal {
  54. // Implementation class for delayed unions, concatenations and closures.
  55. template <class A>
  56. class RationalFstImpl : public FstImpl<A> {
  57. public:
  58. using Arc = A;
  59. using Label = typename Arc::Label;
  60. using StateId = typename Arc::StateId;
  61. using Weight = typename Arc::Weight;
  62. using FstImpl<Arc>::SetType;
  63. using FstImpl<Arc>::SetProperties;
  64. using FstImpl<Arc>::WriteHeader;
  65. using FstImpl<Arc>::SetInputSymbols;
  66. using FstImpl<Arc>::SetOutputSymbols;
  67. explicit RationalFstImpl(const RationalFstOptions &opts)
  68. : nonterminals_(0), replace_options_(opts, 0) {
  69. SetType("rational");
  70. fst_tuples_.emplace_back(0, nullptr);
  71. }
  72. RationalFstImpl(const RationalFstImpl<Arc> &impl)
  73. : rfst_(impl.rfst_),
  74. nonterminals_(impl.nonterminals_),
  75. replace_(impl.replace_ ? impl.replace_->Copy(true) : nullptr),
  76. replace_options_(impl.replace_options_) {
  77. SetType("rational");
  78. fst_tuples_.reserve(impl.fst_tuples_.size());
  79. for (const auto &pair : impl.fst_tuples_) {
  80. fst_tuples_.emplace_back(pair.first,
  81. pair.second ? pair.second->Copy(true) : nullptr);
  82. }
  83. }
  84. ~RationalFstImpl() override {
  85. for (auto &tuple : fst_tuples_) delete tuple.second;
  86. }
  87. StateId Start() { return Replace()->Start(); }
  88. Weight Final(StateId s) { return Replace()->Final(s); }
  89. size_t NumArcs(StateId s) { return Replace()->NumArcs(s); }
  90. size_t NumInputEpsilons(StateId s) { return Replace()->NumInputEpsilons(s); }
  91. size_t NumOutputEpsilons(StateId s) {
  92. return Replace()->NumOutputEpsilons(s);
  93. }
  94. uint64_t Properties() const override { return Properties(kFstProperties); }
  95. // Sets error if found, and returns other FST impl properties.
  96. uint64_t Properties(uint64_t mask) const override {
  97. if ((mask & kError) && Replace()->Properties(kError, false)) {
  98. SetProperties(kError, kError);
  99. }
  100. return FstImpl<Arc>::Properties(mask);
  101. }
  102. // Implementation of UnionFst(fst1, fst2).
  103. void InitUnion(const Fst<Arc> &fst1, const Fst<Arc> &fst2) {
  104. replace_.reset();
  105. const auto props1 = fst1.Properties(kFstProperties, false);
  106. const auto props2 = fst2.Properties(kFstProperties, false);
  107. SetInputSymbols(fst1.InputSymbols());
  108. SetOutputSymbols(fst1.OutputSymbols());
  109. rfst_.AddState();
  110. rfst_.AddState();
  111. rfst_.SetStart(0);
  112. rfst_.SetFinal(1);
  113. rfst_.SetInputSymbols(fst1.InputSymbols());
  114. rfst_.SetOutputSymbols(fst1.OutputSymbols());
  115. nonterminals_ = 2;
  116. rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1);
  117. rfst_.EmplaceArc(0, 0, -2, Weight::One(), 1);
  118. fst_tuples_.emplace_back(-1, fst1.Copy());
  119. fst_tuples_.emplace_back(-2, fst2.Copy());
  120. SetProperties(UnionProperties(props1, props2, true), kCopyProperties);
  121. }
  122. // Implementation of ConcatFst(fst1, fst2).
  123. void InitConcat(const Fst<Arc> &fst1, const Fst<Arc> &fst2) {
  124. replace_.reset();
  125. const auto props1 = fst1.Properties(kFstProperties, false);
  126. const auto props2 = fst2.Properties(kFstProperties, false);
  127. SetInputSymbols(fst1.InputSymbols());
  128. SetOutputSymbols(fst1.OutputSymbols());
  129. rfst_.AddState();
  130. rfst_.AddState();
  131. rfst_.AddState();
  132. rfst_.SetStart(0);
  133. rfst_.SetFinal(2);
  134. rfst_.SetInputSymbols(fst1.InputSymbols());
  135. rfst_.SetOutputSymbols(fst1.OutputSymbols());
  136. nonterminals_ = 2;
  137. rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1);
  138. rfst_.EmplaceArc(1, 0, -2, Weight::One(), 2);
  139. fst_tuples_.emplace_back(-1, fst1.Copy());
  140. fst_tuples_.emplace_back(-2, fst2.Copy());
  141. SetProperties(ConcatProperties(props1, props2, true), kCopyProperties);
  142. }
  143. // Implementation of ClosureFst(fst, closure_type).
  144. void InitClosure(const Fst<Arc> &fst, ClosureType closure_type) {
  145. replace_.reset();
  146. const auto props = fst.Properties(kFstProperties, false);
  147. SetInputSymbols(fst.InputSymbols());
  148. SetOutputSymbols(fst.OutputSymbols());
  149. if (closure_type == CLOSURE_STAR) {
  150. rfst_.AddState();
  151. rfst_.SetStart(0);
  152. rfst_.SetFinal(0);
  153. rfst_.EmplaceArc(0, 0, -1, Weight::One(), 0);
  154. } else {
  155. rfst_.AddState();
  156. rfst_.AddState();
  157. rfst_.SetStart(0);
  158. rfst_.SetFinal(1);
  159. rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1);
  160. rfst_.EmplaceArc(1, 0, 0, Weight::One(), 0);
  161. }
  162. rfst_.SetInputSymbols(fst.InputSymbols());
  163. rfst_.SetOutputSymbols(fst.OutputSymbols());
  164. fst_tuples_.emplace_back(-1, fst.Copy());
  165. nonterminals_ = 1;
  166. SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true),
  167. kCopyProperties);
  168. }
  169. // Implementation of Union(Fst &, RationalFst *).
  170. void AddUnion(const Fst<Arc> &fst) {
  171. replace_.reset();
  172. const auto props1 = FstImpl<A>::Properties();
  173. const auto props2 = fst.Properties(kFstProperties, false);
  174. VectorFst<Arc> afst;
  175. afst.AddState();
  176. afst.AddState();
  177. afst.SetStart(0);
  178. afst.SetFinal(1);
  179. ++nonterminals_;
  180. afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1);
  181. Union(&rfst_, afst);
  182. fst_tuples_.emplace_back(-nonterminals_, fst.Copy());
  183. SetProperties(UnionProperties(props1, props2, true), kCopyProperties);
  184. }
  185. // Implementation of Concat(Fst &, RationalFst *).
  186. void AddConcat(const Fst<Arc> &fst, bool append) {
  187. replace_.reset();
  188. const auto props1 = FstImpl<A>::Properties();
  189. const auto props2 = fst.Properties(kFstProperties, false);
  190. VectorFst<Arc> afst;
  191. afst.AddState();
  192. afst.AddState();
  193. afst.SetStart(0);
  194. afst.SetFinal(1);
  195. ++nonterminals_;
  196. afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1);
  197. if (append) {
  198. Concat(&rfst_, afst);
  199. } else {
  200. Concat(afst, &rfst_);
  201. }
  202. fst_tuples_.emplace_back(-nonterminals_, fst.Copy());
  203. SetProperties(ConcatProperties(props1, props2, true), kCopyProperties);
  204. }
  205. // Implementation of Closure(RationalFst *, closure_type).
  206. void AddClosure(ClosureType closure_type) {
  207. replace_.reset();
  208. const auto props = FstImpl<A>::Properties();
  209. Closure(&rfst_, closure_type);
  210. SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true),
  211. kCopyProperties);
  212. }
  213. // Returns the underlying ReplaceFst, preserving ownership of the underlying
  214. // object.
  215. ReplaceFst<Arc> *Replace() const {
  216. if (!replace_) {
  217. fst_tuples_[0].second = rfst_.Copy();
  218. replace_ =
  219. std::make_unique<ReplaceFst<Arc>>(fst_tuples_, replace_options_);
  220. }
  221. return replace_.get();
  222. }
  223. private:
  224. // Rational topology machine, using negative non-terminals.
  225. VectorFst<Arc> rfst_;
  226. // Number of nonterminals used.
  227. Label nonterminals_;
  228. // Contains the nonterminals and their corresponding FSTs.
  229. mutable std::vector<std::pair<Label, const Fst<Arc> *>> fst_tuples_;
  230. // Underlying ReplaceFst.
  231. mutable std::unique_ptr<ReplaceFst<Arc>> replace_;
  232. const ReplaceFstOptions<Arc> replace_options_;
  233. };
  234. } // namespace internal
  235. // Parent class for the delayed rational operations (union, concatenation, and
  236. // closure). This class attaches interface to implementation and handles
  237. // reference counting, delegating most methods to ImplToFst.
  238. template <class A>
  239. class RationalFst : public ImplToFst<internal::RationalFstImpl<A>> {
  240. public:
  241. using Arc = A;
  242. using StateId = typename Arc::StateId;
  243. using Impl = internal::RationalFstImpl<Arc>;
  244. friend class StateIterator<RationalFst<Arc>>;
  245. friend class ArcIterator<RationalFst<Arc>>;
  246. friend void Union<>(RationalFst<Arc> *fst1, const Fst<Arc> &fst2);
  247. friend void Concat<>(RationalFst<Arc> *fst1, const Fst<Arc> &fst2);
  248. friend void Concat<>(const Fst<Arc> &fst1, RationalFst<Arc> *fst2);
  249. friend void Closure<>(RationalFst<Arc> *fst, ClosureType closure_type);
  250. void InitStateIterator(StateIteratorData<Arc> *data) const override {
  251. GetImpl()->Replace()->InitStateIterator(data);
  252. }
  253. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  254. GetImpl()->Replace()->InitArcIterator(s, data);
  255. }
  256. protected:
  257. using ImplToFst<Impl>::GetImpl;
  258. explicit RationalFst(const RationalFstOptions &opts = RationalFstOptions())
  259. : ImplToFst<Impl>(std::make_shared<Impl>(opts)) {}
  260. // See Fst<>::Copy() for doc.
  261. RationalFst(const RationalFst &fst, bool safe = false)
  262. : ImplToFst<Impl>(fst, safe) {}
  263. private:
  264. RationalFst &operator=(const RationalFst &) = delete;
  265. };
  266. // Specialization for RationalFst.
  267. template <class Arc>
  268. class StateIterator<RationalFst<Arc>> : public StateIterator<ReplaceFst<Arc>> {
  269. public:
  270. explicit StateIterator(const RationalFst<Arc> &fst)
  271. : StateIterator<ReplaceFst<Arc>>(*(fst.GetImpl()->Replace())) {}
  272. };
  273. // Specialization for RationalFst.
  274. template <class Arc>
  275. class ArcIterator<RationalFst<Arc>> : public CacheArcIterator<ReplaceFst<Arc>> {
  276. public:
  277. using StateId = typename Arc::StateId;
  278. ArcIterator(const RationalFst<Arc> &fst, StateId s)
  279. : ArcIterator<ReplaceFst<Arc>>(*(fst.GetImpl()->Replace()), s) {}
  280. };
  281. } // namespace fst
  282. #endif // FST_RATIONAL_H_