You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

252 lines
8.3 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes to compute the concatenation of two FSTs.
  19. #ifndef FST_CONCAT_H_
  20. #define FST_CONCAT_H_
  21. #include <algorithm>
  22. #include <vector>
  23. #include <fst/log.h>
  24. #include <fst/arc.h>
  25. #include <fst/cache.h>
  26. #include <fst/expanded-fst.h>
  27. #include <fst/float-weight.h>
  28. #include <fst/fst.h>
  29. #include <fst/impl-to-fst.h>
  30. #include <fst/mutable-fst.h>
  31. #include <fst/properties.h>
  32. #include <fst/rational.h>
  33. #include <fst/symbol-table.h>
  34. #include <fst/util.h>
  35. namespace fst {
  36. // Computes the concatenation (product) of two FSTs. If FST1 transduces string
  37. // x to y with weight a and FST2 transduces string w to v with weight b, then
  38. // their concatenation transduces string xw to yv with weight Times(a, b).
  39. //
  40. // This version modifies its MutableFst argument (in first position).
  41. //
  42. // Complexity:
  43. //
  44. // Time: O(V1 + V2 + E2)
  45. // Space: O(V1 + V2 + E2)
  46. //
  47. // where Vi is the number of states, and Ei is the number of arcs, of the ith
  48. // FST.
  49. template <class Arc>
  50. void Concat(MutableFst<Arc> *fst1, const Fst<Arc> &fst2) {
  51. using StateId = typename Arc::StateId;
  52. using Weight = typename Arc::Weight;
  53. // Checks that the symbol table are compatible.
  54. if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) ||
  55. !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) {
  56. FSTERROR() << "Concat: Input/output symbol tables of 1st argument "
  57. << "does not match input/output symbol tables of 2nd argument";
  58. fst1->SetProperties(kError, kError);
  59. return;
  60. }
  61. const auto props1 = fst1->Properties(kFstProperties, false);
  62. const auto props2 = fst2.Properties(kFstProperties, false);
  63. const auto start1 = fst1->Start();
  64. if (start1 == kNoStateId) {
  65. if (props2 & kError) fst1->SetProperties(kError, kError);
  66. return;
  67. }
  68. const auto numstates1 = fst1->NumStates();
  69. if (std::optional<StateId> numstates2 = fst2.NumStatesIfKnown()) {
  70. fst1->ReserveStates(numstates1 + *numstates2);
  71. }
  72. for (StateIterator<Fst<Arc>> siter2(fst2); !siter2.Done(); siter2.Next()) {
  73. const auto s1 = fst1->AddState();
  74. const auto s2 = siter2.Value();
  75. fst1->SetFinal(s1, fst2.Final(s2));
  76. fst1->ReserveArcs(s1, fst2.NumArcs(s2));
  77. for (ArcIterator<Fst<Arc>> aiter(fst2, s2); !aiter.Done(); aiter.Next()) {
  78. auto arc = aiter.Value();
  79. arc.nextstate += numstates1;
  80. fst1->AddArc(s1, arc);
  81. }
  82. }
  83. const auto start2 = fst2.Start();
  84. for (StateId s1 = 0; s1 < numstates1; ++s1) {
  85. const auto weight = fst1->Final(s1);
  86. if (weight != Weight::Zero()) {
  87. fst1->SetFinal(s1, Weight::Zero());
  88. if (start2 != kNoStateId) {
  89. fst1->AddArc(s1, Arc(0, 0, weight, start2 + numstates1));
  90. }
  91. }
  92. }
  93. if (start2 != kNoStateId) {
  94. fst1->SetProperties(ConcatProperties(props1, props2), kFstProperties);
  95. }
  96. }
  97. // Computes the concatentation of two FSTs. This version modifies its
  98. // RationalFst input (in first position).
  99. template <class Arc>
  100. void Concat(RationalFst<Arc> *fst1, const Fst<Arc> &fst2) {
  101. fst1->GetMutableImpl()->AddConcat(fst2, true);
  102. }
  103. // Computes the concatentation of two FSTs. This version modifies its
  104. // MutableFst argument (in second position).
  105. //
  106. // Complexity:
  107. //
  108. // Time: O(V1 + E1)
  109. // Space: O(V1 + E1)
  110. //
  111. // where Vi is the number of states, and Ei is the number of arcs, of the ith
  112. // FST.
  113. template <class Arc>
  114. void Concat(const Fst<Arc> &fst1, MutableFst<Arc> *fst2) {
  115. using StateId = typename Arc::StateId;
  116. using Weight = typename Arc::Weight;
  117. // Checks that the symbol table are compatible.
  118. if (!CompatSymbols(fst1.InputSymbols(), fst2->InputSymbols()) ||
  119. !CompatSymbols(fst1.OutputSymbols(), fst2->OutputSymbols())) {
  120. FSTERROR() << "Concat: Input/output symbol tables of 1st argument "
  121. << "does not match input/output symbol tables of 2nd argument";
  122. fst2->SetProperties(kError, kError);
  123. return;
  124. }
  125. const auto props1 = fst1.Properties(kFstProperties, false);
  126. const auto props2 = fst2->Properties(kFstProperties, false);
  127. const auto start2 = fst2->Start();
  128. if (start2 == kNoStateId) {
  129. if (props1 & kError) fst2->SetProperties(kError, kError);
  130. return;
  131. }
  132. const auto numstates2 = fst2->NumStates();
  133. if (std::optional<StateId> numstates1 = fst1.NumStatesIfKnown()) {
  134. fst2->ReserveStates(numstates2 + *numstates1);
  135. }
  136. for (StateIterator<Fst<Arc>> siter(fst1); !siter.Done(); siter.Next()) {
  137. const auto s1 = siter.Value();
  138. const auto s2 = fst2->AddState();
  139. const auto weight = fst1.Final(s1);
  140. if (weight != Weight::Zero()) {
  141. fst2->ReserveArcs(s2, fst1.NumArcs(s1) + 1);
  142. fst2->AddArc(s2, Arc(0, 0, weight, start2));
  143. } else {
  144. fst2->ReserveArcs(s2, fst1.NumArcs(s1));
  145. }
  146. for (ArcIterator<Fst<Arc>> aiter(fst1, s1); !aiter.Done(); aiter.Next()) {
  147. auto arc = aiter.Value();
  148. arc.nextstate += numstates2;
  149. fst2->AddArc(s2, arc);
  150. }
  151. }
  152. const auto start1 = fst1.Start();
  153. if (start1 != kNoStateId) {
  154. fst2->SetStart(start1 + numstates2);
  155. fst2->SetProperties(ConcatProperties(props1, props2), kFstProperties);
  156. } else {
  157. fst2->SetStart(fst2->AddState());
  158. }
  159. }
  160. // Same as the above but can handle arbitrarily many left-hand-side FSTs,
  161. // preallocating the states.
  162. template <class Arc>
  163. void Concat(const std::vector<const Fst<Arc> *> &fsts1, MutableFst<Arc> *fst2) {
  164. fst2->ReserveStates(CountStates(fsts1) + fst2->NumStates());
  165. for (const auto *fst1 : fsts1) Concat(*fst1, fst2);
  166. }
  167. // Computes the concatentation of two FSTs. This version modifies its
  168. // RationalFst input (in second position).
  169. template <class Arc>
  170. void Concat(const Fst<Arc> &fst1, RationalFst<Arc> *fst2) {
  171. fst2->GetMutableImpl()->AddConcat(fst1, false);
  172. }
  173. using ConcatFstOptions = RationalFstOptions;
  174. // Computes the concatenation (product) of two FSTs; this version is a delayed
  175. // FST. If FST1 transduces string x to y with weight a and FST2 transduces
  176. // string w to v with weight b, then their concatenation transduces string xw
  177. // to yv with Times(a, b).
  178. //
  179. // Complexity:
  180. //
  181. // Time: O(v1 + e1 + v2 + e2),
  182. // Space: O(v1 + v2)
  183. //
  184. // where vi is the number of states visited, and ei is the number of arcs
  185. // visited, of the ith FST. Constant time and space to visit an input state or
  186. // arc is assumed and exclusive of caching.
  187. template <class A>
  188. class ConcatFst : public RationalFst<A> {
  189. public:
  190. using Arc = A;
  191. using StateId = typename Arc::StateId;
  192. using Weight = typename Arc::Weight;
  193. ConcatFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2) {
  194. GetMutableImpl()->InitConcat(fst1, fst2);
  195. }
  196. ConcatFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
  197. const ConcatFstOptions &opts)
  198. : RationalFst<Arc>(opts) {
  199. GetMutableImpl()->InitConcat(fst1, fst2);
  200. }
  201. // See Fst<>::Copy() for doc.
  202. ConcatFst(const ConcatFst &fst, bool safe = false)
  203. : RationalFst<Arc>(fst, safe) {}
  204. // Get a copy of this ConcatFst. See Fst<>::Copy() for further doc.
  205. ConcatFst *Copy(bool safe = false) const override {
  206. return new ConcatFst(*this, safe);
  207. }
  208. private:
  209. using ImplToFst<internal::RationalFstImpl<Arc>>::GetImpl;
  210. using ImplToFst<internal::RationalFstImpl<Arc>>::GetMutableImpl;
  211. };
  212. // Specialization for ConcatFst.
  213. template <class Arc>
  214. class StateIterator<ConcatFst<Arc>> : public StateIterator<RationalFst<Arc>> {
  215. public:
  216. explicit StateIterator(const ConcatFst<Arc> &fst)
  217. : StateIterator<RationalFst<Arc>>(fst) {}
  218. };
  219. // Specialization for ConcatFst.
  220. template <class Arc>
  221. class ArcIterator<ConcatFst<Arc>> : public ArcIterator<RationalFst<Arc>> {
  222. public:
  223. using StateId = typename Arc::StateId;
  224. ArcIterator(const ConcatFst<Arc> &fst, StateId s)
  225. : ArcIterator<RationalFst<Arc>>(fst, s) {}
  226. };
  227. // Useful alias when using StdArc.
  228. using StdConcatFst = ConcatFst<StdArc>;
  229. } // namespace fst
  230. #endif // FST_CONCAT_H_