You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
9.1 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes to determine the equivalence of two FSTs.
  19. #ifndef FST_EQUIVALENT_H_
  20. #define FST_EQUIVALENT_H_
  21. #include <algorithm>
  22. #include <cstdint>
  23. #include <queue>
  24. #include <utility>
  25. #include <vector>
  26. #include <fst/log.h>
  27. #include <fst/arc-map.h>
  28. #include <fst/encode.h>
  29. #include <fst/fst.h>
  30. #include <fst/properties.h>
  31. #include <fst/push.h>
  32. #include <fst/reweight.h>
  33. #include <fst/symbol-table.h>
  34. #include <fst/union-find.h>
  35. #include <fst/util.h>
  36. #include <fst/vector-fst.h>
  37. #include <fst/weight.h>
  38. #include <unordered_map>
  39. namespace fst {
  40. namespace internal {
  41. // Traits-like struct holding utility functions/typedefs/constants for
  42. // the equivalence algorithm.
  43. //
  44. // Encoding device: in order to make the statesets of the two acceptors
  45. // disjoint, we map Arc::StateId on the type MappedId. The states of
  46. // the first acceptor are mapped on odd numbers (s -> 2s + 1), and
  47. // those of the second one on even numbers (s -> 2s + 2). The number 0
  48. // is reserved for an implicit (non-final) dead state (required for
  49. // the correct treatment of non-coaccessible states; kNoStateId is mapped to
  50. // kDeadState for both acceptors). The union-find algorithm operates on the
  51. // mapped IDs.
  52. template <class Arc>
  53. struct EquivalenceUtil {
  54. using StateId = typename Arc::StateId;
  55. using Weight = typename Arc::Weight;
  56. using MappedId = StateId; // ID for an equivalence class.
  57. // MappedId for an implicit dead state.
  58. static constexpr MappedId kDeadState = 0;
  59. // MappedId for lookup failure.
  60. static constexpr MappedId kInvalidId = -1;
  61. // Maps state ID to the representative of the corresponding
  62. // equivalence class. The parameter 'which_fst' takes the values 1
  63. // and 2, identifying the input FST.
  64. static MappedId MapState(StateId s, int32_t which_fst) {
  65. return (kNoStateId == s) ? kDeadState
  66. : (static_cast<MappedId>(s) << 1) + which_fst;
  67. }
  68. // Maps set ID to State ID.
  69. static StateId UnMapState(MappedId id) {
  70. return static_cast<StateId>((--id) >> 1);
  71. }
  72. // Convenience function: checks if state with MappedId s is final in
  73. // acceptor fa.
  74. static bool IsFinal(const Fst<Arc> &fa, MappedId s) {
  75. return (kDeadState == s) ? false
  76. : (fa.Final(UnMapState(s)) != Weight::Zero());
  77. }
  78. // Convenience function: returns the representative of ID in sets,
  79. // creating a new set if needed.
  80. static MappedId FindSet(UnionFind<MappedId> *sets, MappedId id) {
  81. const auto repr = sets->FindSet(id);
  82. if (repr != kInvalidId) {
  83. return repr;
  84. } else {
  85. sets->MakeSet(id);
  86. return id;
  87. }
  88. }
  89. };
  90. } // namespace internal
  91. // Equivalence checking algorithm: determines if the two FSTs fst1 and fst2
  92. // are equivalent. The input FSTs must be deterministic input-side epsilon-free
  93. // acceptors, unweighted or with weights over a left semiring. Two acceptors are
  94. // considered equivalent if they accept exactly the same set of strings (with
  95. // the same weights).
  96. //
  97. // The algorithm (cf. Aho, Hopcroft and Ullman, "The Design and Analysis of
  98. // Computer Programs") successively constructs sets of states that can be
  99. // reached by the same prefixes, starting with a set containing the start states
  100. // of both acceptors. A disjoint tree forest (the union-find algorithm) is used
  101. // to represent the sets of states. The algorithm returns false if one of the
  102. // constructed sets contains both final and non-final states. Returns an
  103. // optional error value (useful when FST_FLAGS_error_fatal = false).
  104. //
  105. // Complexity:
  106. //
  107. // Quasi-linear, i.e., O(n G(n)), where
  108. //
  109. // n = |S1| + |S2| is the number of states in both acceptors
  110. //
  111. // G(n) is a very slowly growing function that can be approximated
  112. // by 4 by all practical purposes.
  113. template <class Arc>
  114. bool Equivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
  115. float delta = kDelta, bool *error = nullptr) {
  116. using Weight = typename Arc::Weight;
  117. if (error) *error = false;
  118. // Check that the symbol table are compatible.
  119. if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) ||
  120. !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) {
  121. FSTERROR() << "Equivalent: Input/output symbol tables of 1st argument "
  122. << "do not match input/output symbol tables of 2nd argument";
  123. if (error) *error = true;
  124. return false;
  125. }
  126. // Check properties first.
  127. static constexpr auto props = kNoEpsilons | kIDeterministic | kAcceptor;
  128. if (fst1.Properties(props, true) != props) {
  129. FSTERROR() << "Equivalent: 1st argument not an"
  130. << " epsilon-free deterministic acceptor";
  131. if (error) *error = true;
  132. return false;
  133. }
  134. if (fst2.Properties(props, true) != props) {
  135. FSTERROR() << "Equivalent: 2nd argument not an"
  136. << " epsilon-free deterministic acceptor";
  137. if (error) *error = true;
  138. return false;
  139. }
  140. if ((fst1.Properties(kUnweighted, true) != kUnweighted) ||
  141. (fst2.Properties(kUnweighted, true) != kUnweighted)) {
  142. VectorFst<Arc> efst1(fst1);
  143. VectorFst<Arc> efst2(fst2);
  144. Push(&efst1, REWEIGHT_TO_INITIAL, delta);
  145. Push(&efst2, REWEIGHT_TO_INITIAL, delta);
  146. ArcMap(&efst1, QuantizeMapper<Arc>(delta));
  147. ArcMap(&efst2, QuantizeMapper<Arc>(delta));
  148. EncodeMapper<Arc> mapper(kEncodeWeights | kEncodeLabels, ENCODE);
  149. ArcMap(&efst1, &mapper);
  150. ArcMap(&efst2, &mapper);
  151. return Equivalent(efst1, efst2);
  152. }
  153. using Util = internal::EquivalenceUtil<Arc>;
  154. using MappedId = typename Util::MappedId;
  155. enum { FST1 = 1, FST2 = 2 }; // Required by Util::MapState(...)
  156. auto s1 = Util::MapState(fst1.Start(), FST1);
  157. auto s2 = Util::MapState(fst2.Start(), FST2);
  158. // The union-find structure.
  159. UnionFind<MappedId> eq_classes(1000, Util::kInvalidId);
  160. // Initializes the union-find structure.
  161. eq_classes.MakeSet(s1);
  162. eq_classes.MakeSet(s2);
  163. // Data structure for the (partial) acceptor transition function of fst1 and
  164. // fst2: input labels mapped to pairs of MappedIds representing destination
  165. // states of the corresponding arcs in fst1 and fst2, respectively.
  166. using Label2StatePairMap =
  167. std::unordered_map<typename Arc::Label, std::pair<MappedId, MappedId>>;
  168. Label2StatePairMap arc_pairs;
  169. // Pairs of MappedId's to be processed, organized in a queue.
  170. std::queue<std::pair<MappedId, MappedId>> q;
  171. bool ret = true;
  172. // Returns early if the start states differ w.r.t. finality.
  173. if (Util::IsFinal(fst1, s1) != Util::IsFinal(fst2, s2)) ret = false;
  174. // Main loop: explores the two acceptors in a breadth-first manner, updating
  175. // the equivalence relation on the statesets. Loop invariant: each block of
  176. // the states contains either final states only or non-final states only.
  177. for (q.emplace(s1, s2); ret && !q.empty(); q.pop()) {
  178. s1 = q.front().first;
  179. s2 = q.front().second;
  180. // Representatives of the equivalence classes of s1/s2.
  181. const auto rep1 = Util::FindSet(&eq_classes, s1);
  182. const auto rep2 = Util::FindSet(&eq_classes, s2);
  183. if (rep1 != rep2) {
  184. eq_classes.Union(rep1, rep2);
  185. arc_pairs.clear();
  186. // Copies outgoing arcs starting at s1 into the hash-table.
  187. if (Util::kDeadState != s1) {
  188. ArcIterator<Fst<Arc>> arc_iter(fst1, Util::UnMapState(s1));
  189. for (; !arc_iter.Done(); arc_iter.Next()) {
  190. const auto &arc = arc_iter.Value();
  191. // Zero-weight arcs are treated as if they did not exist.
  192. if (arc.weight != Weight::Zero()) {
  193. arc_pairs[arc.ilabel].first = Util::MapState(arc.nextstate, FST1);
  194. }
  195. }
  196. }
  197. // Copies outgoing arcs starting at s2 into the hashtable.
  198. if (Util::kDeadState != s2) {
  199. ArcIterator<Fst<Arc>> arc_iter(fst2, Util::UnMapState(s2));
  200. for (; !arc_iter.Done(); arc_iter.Next()) {
  201. const auto &arc = arc_iter.Value();
  202. // Zero-weight arcs are treated as if they did not exist.
  203. if (arc.weight != Weight::Zero()) {
  204. arc_pairs[arc.ilabel].second = Util::MapState(arc.nextstate, FST2);
  205. }
  206. }
  207. }
  208. // Iterates through the hashtable and process pairs of target states.
  209. for (const auto &arc_iter : arc_pairs) {
  210. const auto &pair = arc_iter.second;
  211. if (Util::IsFinal(fst1, pair.first) !=
  212. Util::IsFinal(fst2, pair.second)) {
  213. // Detected inconsistency: return false.
  214. ret = false;
  215. break;
  216. }
  217. q.push(pair);
  218. }
  219. }
  220. }
  221. if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) {
  222. if (error) *error = true;
  223. return false;
  224. }
  225. return ret;
  226. }
  227. } // namespace fst
  228. #endif // FST_EQUIVALENT_H_