You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

268 lines
9.7 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Function to test two FSTs are isomorphic, i.e., they are equal up to a state
  19. // and arc re-ordering. FSTs should be deterministic when viewed as
  20. // unweighted automata. False negatives (but not false positives) are possible
  21. // when the inputs are nondeterministic (when viewed as unweighted automata).
  22. #ifndef FST_ISOMORPHIC_H_
  23. #define FST_ISOMORPHIC_H_
  24. #include <algorithm>
  25. #include <cstddef>
  26. #include <memory>
  27. #include <queue>
  28. #include <type_traits>
  29. #include <utility>
  30. #include <vector>
  31. #include <fst/log.h>
  32. #include <fst/fst.h>
  33. #include <fst/util.h>
  34. #include <fst/weight.h>
  35. namespace fst {
  36. namespace internal {
  37. // Orders weights for equality checking; delta is ignored.
  38. template <class Weight,
  39. typename std::enable_if_t<IsIdempotent<Weight>::value> * = nullptr>
  40. bool WeightCompare(const Weight &w1, const Weight &w2, float, bool *) {
  41. static const NaturalLess<Weight> less;
  42. return less(w1, w2);
  43. }
  44. template <class Weight,
  45. typename std::enable_if_t<!IsIdempotent<Weight>::value> * = nullptr>
  46. bool WeightCompare(const Weight &w1, const Weight &w2, float delta,
  47. bool *error) {
  48. // No natural order; use hash.
  49. const auto q1 = w1.Quantize(delta);
  50. const auto q2 = w2.Quantize(delta);
  51. const auto n1 = q1.Hash();
  52. const auto n2 = q2.Hash();
  53. // Hash not unique; very unlikely to happen.
  54. if (n1 == n2 && q1 != q2) {
  55. VLOG(1) << "Isomorphic: Weight hash collision";
  56. *error = true;
  57. }
  58. return n1 < n2;
  59. }
  60. template <class Arc>
  61. class Isomorphism {
  62. using StateId = typename Arc::StateId;
  63. public:
  64. Isomorphism(const Fst<Arc> &fst1, const Fst<Arc> &fst2, float delta)
  65. : fst1_(fst1.Copy()),
  66. fst2_(fst2.Copy()),
  67. delta_(delta),
  68. error_(false),
  69. nondet_(false),
  70. comp_(delta, &error_) {}
  71. // Checks if input FSTs are isomorphic.
  72. bool IsIsomorphic() {
  73. if (fst1_->Start() == kNoStateId && fst2_->Start() == kNoStateId) {
  74. return true;
  75. }
  76. if (fst1_->Start() == kNoStateId || fst2_->Start() == kNoStateId) {
  77. VLOG(1) << "Isomorphic: Only one of the FSTs is empty.";
  78. return false;
  79. }
  80. PairState(fst1_->Start(), fst2_->Start());
  81. while (!queue_.empty()) {
  82. const auto &[state1, state2] = queue_.front();
  83. if (!IsIsomorphicState(state1, state2)) {
  84. if (nondet_) {
  85. VLOG(1) << "Isomorphic: Non-determinism as an unweighted automaton. "
  86. << "state1: " << state1 << " state2: " << state2;
  87. error_ = true;
  88. }
  89. return false;
  90. }
  91. queue_.pop();
  92. }
  93. return true;
  94. }
  95. bool Error() const { return error_; }
  96. private:
  97. // Orders arcs for equality checking.
  98. class ArcCompare {
  99. public:
  100. ArcCompare(float delta, bool *error) : delta_(delta), error_(error) {}
  101. bool operator()(const Arc &arc1, const Arc &arc2) const {
  102. if (arc1.ilabel < arc2.ilabel) return true;
  103. if (arc1.ilabel > arc2.ilabel) return false;
  104. if (arc1.olabel < arc2.olabel) return true;
  105. if (arc1.olabel > arc2.olabel) return false;
  106. if (!ApproxEqual(arc1.weight, arc2.weight, delta_)) {
  107. return WeightCompare(arc1.weight, arc2.weight, delta_, error_);
  108. } else {
  109. return arc1.nextstate < arc2.nextstate;
  110. }
  111. }
  112. private:
  113. const float delta_;
  114. bool *error_;
  115. };
  116. // Maintains state correspondences and queue.
  117. bool PairState(StateId s1, StateId s2) {
  118. if (state_pairs_.size() <= s1) state_pairs_.resize(s1 + 1, kNoStateId);
  119. if (state_pairs_[s1] == s2) {
  120. return true; // Already seen this pair.
  121. } else if (state_pairs_[s1] != kNoStateId) {
  122. return false; // s1 already paired with another s2.
  123. }
  124. VLOG(3) << "Pairing states: (" << s1 << ", " << s2 << ")";
  125. state_pairs_[s1] = s2;
  126. queue_.emplace(s1, s2);
  127. return true;
  128. }
  129. // Checks if state pair is isomorphic.
  130. bool IsIsomorphicState(StateId s1, StateId s2);
  131. std::unique_ptr<Fst<Arc>> fst1_;
  132. std::unique_ptr<Fst<Arc>> fst2_;
  133. float delta_; // Weight equality delta.
  134. std::vector<Arc> arcs1_; // For sorting arcs on FST1.
  135. std::vector<Arc> arcs2_; // For sorting arcs on FST2.
  136. std::vector<StateId> state_pairs_; // Maintains state correspondences.
  137. std::queue<std::pair<StateId, StateId>> queue_; // Queue of state pairs.
  138. bool error_; // Error flag.
  139. bool nondet_; // Nondeterminism detected.
  140. ArcCompare comp_;
  141. };
  142. template <class Arc>
  143. bool Isomorphism<Arc>::IsIsomorphicState(StateId s1, StateId s2) {
  144. if (!ApproxEqual(fst1_->Final(s1), fst2_->Final(s2), delta_)) {
  145. VLOG(1) << "Isomorphic: Final weights not equal to within delta=" << delta_
  146. << ": "
  147. << "fst1.Final(" << s1 << ") = " << fst1_->Final(s1) << ", "
  148. << "fst2.Final(" << s2 << ") = " << fst2_->Final(s2);
  149. return false;
  150. }
  151. const auto narcs1 = fst1_->NumArcs(s1);
  152. const auto narcs2 = fst2_->NumArcs(s2);
  153. if (narcs1 != narcs2) {
  154. VLOG(1) << "Isomorphic: NumArcs not equal. "
  155. << "fst1.NumArcs(" << s1 << ") = " << narcs1 << ", "
  156. << "fst2.NumArcs(" << s2 << ") = " << narcs2;
  157. return false;
  158. }
  159. ArcIterator<Fst<Arc>> aiter1(*fst1_, s1);
  160. ArcIterator<Fst<Arc>> aiter2(*fst2_, s2);
  161. arcs1_.clear();
  162. arcs1_.reserve(narcs1);
  163. arcs2_.clear();
  164. arcs2_.reserve(narcs2);
  165. for (; !aiter1.Done(); aiter1.Next(), aiter2.Next()) {
  166. arcs1_.push_back(aiter1.Value());
  167. arcs2_.push_back(aiter2.Value());
  168. }
  169. std::sort(arcs1_.begin(), arcs1_.end(), comp_);
  170. std::sort(arcs2_.begin(), arcs2_.end(), comp_);
  171. for (size_t i = 0; i < arcs1_.size(); ++i) {
  172. const auto &arc1 = arcs1_[i];
  173. const auto &arc2 = arcs2_[i];
  174. if (arc1.ilabel != arc2.ilabel) {
  175. VLOG(1) << "Isomorphic: ilabels not equal. "
  176. << "state1: " << s1 << " arc1: *" << arc1.ilabel << "* "
  177. << arc1.olabel << " " << arc1.weight << " " << arc1.nextstate
  178. << " state2: " << s2 << " arc2: *" << arc2.ilabel << "* "
  179. << arc2.olabel << " " << arc2.weight << " " << arc2.nextstate;
  180. return false;
  181. }
  182. if (arc1.olabel != arc2.olabel) {
  183. VLOG(1) << "Isomorphic: olabels not equal. "
  184. << "state1: " << s1 << " arc1: " << arc1.ilabel << " *"
  185. << arc1.olabel << "* " << arc1.weight << " " << arc1.nextstate
  186. << " state2: " << s2 << " arc2: " << arc2.ilabel << " *"
  187. << arc2.olabel << "* " << arc2.weight << " " << arc2.nextstate;
  188. return false;
  189. }
  190. if (!ApproxEqual(arc1.weight, arc2.weight, delta_)) {
  191. VLOG(1) << "Isomorphic: weights not ApproxEqual. "
  192. << "state1: " << s1 << " arc1: " << arc1.ilabel << " "
  193. << arc1.olabel << " *" << arc1.weight << "* " << arc1.nextstate
  194. << " state2: " << s2 << " arc2: " << arc2.ilabel << " "
  195. << arc2.olabel << " *" << arc2.weight << "* " << arc2.nextstate;
  196. return false;
  197. }
  198. if (!PairState(arc1.nextstate, arc2.nextstate)) {
  199. VLOG(1) << "Isomorphic: nextstates could not be paired. "
  200. << "state1: " << s1 << " arc1: " << arc1.ilabel << " "
  201. << arc1.olabel << " " << arc1.weight << " *" << arc1.nextstate
  202. << "* "
  203. << "state2: " << s2 << " arc2: " << arc2.ilabel << " "
  204. << arc2.olabel << " " << arc2.weight << " *" << arc2.nextstate
  205. << "*";
  206. return false;
  207. }
  208. if (i > 0) { // Checks for non-determinism.
  209. const auto &arc0 = arcs1_[i - 1];
  210. if (arc1.ilabel == arc0.ilabel && arc1.olabel == arc0.olabel &&
  211. ApproxEqual(arc1.weight, arc0.weight, delta_)) {
  212. // Any subsequent matching failure maybe a false negative
  213. // since we only consider one permutation when pairing destination
  214. // states of nondeterministic transitions.
  215. VLOG(1) << "Isomorphic: Detected non-determinism as an unweighted "
  216. << "automaton; deferring error. "
  217. << "state: " << s1 << " arc1: " << arc1.ilabel << " "
  218. << arc1.olabel << " " << arc1.weight << " " << arc1.nextstate
  219. << " arc2: " << arc2.ilabel << " " << arc2.olabel << " "
  220. << arc2.weight << " " << arc2.nextstate;
  221. nondet_ = true;
  222. }
  223. }
  224. }
  225. return true;
  226. }
  227. } // namespace internal
  228. // Tests if two FSTs have the same states and arcs up to a reordering.
  229. // Inputs should be deterministic when viewed as unweighted automata.
  230. // When the inputs are nondeterministic, the algorithm only considers one
  231. // permutation for each set of equivalent nondeterministic transitions
  232. // (the permutation that preserves state ID ordering) and hence might return
  233. // false negatives (but it never returns false positives).
  234. template <class Arc>
  235. bool Isomorphic(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
  236. float delta = kDelta) {
  237. internal::Isomorphism<Arc> iso(fst1, fst2, delta);
  238. const bool result = iso.IsIsomorphic();
  239. if (iso.Error()) {
  240. FSTERROR() << "Isomorphic: Cannot determine if inputs are isomorphic";
  241. return false;
  242. } else {
  243. return result;
  244. }
  245. }
  246. } // namespace fst
  247. #endif // FST_ISOMORPHIC_H_