You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

185 lines
6.7 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Function to test equality of two FSTs.
  19. #ifndef FST_EQUAL_H_
  20. #define FST_EQUAL_H_
  21. #include <cstdint>
  22. #include <string>
  23. #include <fst/log.h>
  24. #include <fst/fst.h>
  25. #include <fst/properties.h>
  26. #include <fst/symbol-table.h>
  27. #include <fst/util.h>
  28. #include <fst/weight.h>
  29. namespace fst {
  30. inline constexpr uint8_t kEqualFsts = 0x01;
  31. inline constexpr uint8_t kEqualFstTypes = 0x02;
  32. inline constexpr uint8_t kEqualCompatProperties = 0x04;
  33. inline constexpr uint8_t kEqualCompatSymbols = 0x08;
  34. inline constexpr uint8_t kEqualAll =
  35. kEqualFsts | kEqualFstTypes | kEqualCompatProperties | kEqualCompatSymbols;
  36. class WeightApproxEqual {
  37. public:
  38. explicit WeightApproxEqual(float delta) : delta_(delta) {}
  39. // We use two weight types to avoid some conflicts caused by
  40. // conversions.
  41. template <class Weight1, class Weight2>
  42. bool operator()(const Weight1 &w1, const Weight2 &w2) const {
  43. return ApproxEqual(w1, w2, delta_);
  44. }
  45. private:
  46. const float delta_;
  47. };
  48. // Tests if two FSTs have the same states and arcs in the same order (when
  49. // etype & kEqualFst); optionally, also checks equality of FST types
  50. // (etype & kEqualFstTypes) and compatibility of stored properties
  51. // (etype & kEqualCompatProperties) and of symbol tables
  52. // (etype & kEqualCompatSymbols).
  53. template <class Arc, class WeightEqual>
  54. bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, WeightEqual weight_equal,
  55. uint8_t etype = kEqualFsts) {
  56. if ((etype & kEqualFstTypes) && (fst1.Type() != fst2.Type())) {
  57. VLOG(1) << "Equal: Mismatched FST types (" << fst1.Type()
  58. << " != " << fst2.Type() << ")";
  59. return false;
  60. }
  61. if ((etype & kEqualCompatProperties) &&
  62. !internal::CompatProperties(fst1.Properties(kCopyProperties, false),
  63. fst2.Properties(kCopyProperties, false))) {
  64. VLOG(1) << "Equal: Properties not compatible";
  65. return false;
  66. }
  67. if (etype & kEqualCompatSymbols) {
  68. if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols(), false)) {
  69. VLOG(1) << "Equal: Input symbols not compatible";
  70. return false;
  71. }
  72. if (!CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols(), false)) {
  73. VLOG(1) << "Equal: Output symbols not compatible";
  74. return false;
  75. }
  76. }
  77. if (!(etype & kEqualFsts)) return true;
  78. if (fst1.Start() != fst2.Start()) {
  79. VLOG(1) << "Equal: Mismatched start states (" << fst1.Start()
  80. << " != " << fst2.Start() << ")";
  81. return false;
  82. }
  83. StateIterator<Fst<Arc>> siter1(fst1);
  84. StateIterator<Fst<Arc>> siter2(fst2);
  85. while (!siter1.Done() || !siter2.Done()) {
  86. if (siter1.Done() || siter2.Done()) {
  87. VLOG(1) << "Equal: Mismatched number of states";
  88. return false;
  89. }
  90. const auto s1 = siter1.Value();
  91. const auto s2 = siter2.Value();
  92. if (s1 != s2) {
  93. VLOG(1) << "Equal: Mismatched states (" << s1 << "!= " << s2 << ")";
  94. return false;
  95. }
  96. const auto &final1 = fst1.Final(s1);
  97. const auto &final2 = fst2.Final(s2);
  98. if (!weight_equal(final1, final2)) {
  99. VLOG(1) << "Equal: Mismatched final weights at state " << s1 << " ("
  100. << final1 << " != " << final2 << ")";
  101. return false;
  102. }
  103. ArcIterator<Fst<Arc>> aiter1(fst1, s1);
  104. ArcIterator<Fst<Arc>> aiter2(fst2, s2);
  105. for (auto a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
  106. if (aiter1.Done() || aiter2.Done()) {
  107. VLOG(1) << "Equal: Mismatched number of arcs at state " << s1;
  108. return false;
  109. }
  110. const auto &arc1 = aiter1.Value();
  111. const auto &arc2 = aiter2.Value();
  112. if (arc1.ilabel != arc2.ilabel) {
  113. VLOG(1) << "Equal: Mismatched arc input labels at state " << s1
  114. << ", arc " << a << " (" << arc1.ilabel << " != " << arc2.ilabel
  115. << ")";
  116. return false;
  117. } else if (arc1.olabel != arc2.olabel) {
  118. VLOG(1) << "Equal: Mismatched arc output labels at state " << s1
  119. << ", arc " << a << " (" << arc1.olabel << " != " << arc2.olabel
  120. << ")";
  121. return false;
  122. } else if (!weight_equal(arc1.weight, arc2.weight)) {
  123. VLOG(1) << "Equal: Mismatched arc weights at state " << s1 << ", arc "
  124. << a << " (" << arc1.weight << " != " << arc2.weight << ")";
  125. return false;
  126. } else if (arc1.nextstate != arc2.nextstate) {
  127. VLOG(1) << "Equal: Mismatched next state at state " << s1 << ", arc "
  128. << a << " (" << arc1.nextstate << " != " << arc2.nextstate
  129. << ")";
  130. return false;
  131. }
  132. aiter1.Next();
  133. aiter2.Next();
  134. }
  135. // Sanity checks: should never fail.
  136. if (fst1.NumArcs(s1) != fst2.NumArcs(s2)) {
  137. FSTERROR() << "Equal: Inconsistent arc counts at state " << s1 << " ("
  138. << fst1.NumArcs(s1) << " != " << fst2.NumArcs(s2) << ")";
  139. return false;
  140. }
  141. if (fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2)) {
  142. FSTERROR() << "Equal: Inconsistent input epsilon counts at state " << s1
  143. << " (" << fst1.NumInputEpsilons(s1)
  144. << " != " << fst2.NumInputEpsilons(s2) << ")";
  145. return false;
  146. }
  147. if (fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) {
  148. FSTERROR() << "Equal: Inconsistent output epsilon counts at state " << s1
  149. << " (" << fst1.NumOutputEpsilons(s1)
  150. << " != " << fst2.NumOutputEpsilons(s2) << ")";
  151. }
  152. siter1.Next();
  153. siter2.Next();
  154. }
  155. return true;
  156. }
  157. template <class Arc>
  158. bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, float delta = kDelta,
  159. uint8_t etype = kEqualFsts) {
  160. return Equal(fst1, fst2, WeightApproxEqual(delta), etype);
  161. }
  162. // Support double deltas without forcing all clients to cast to float.
  163. // Without this overload, Equal<Arc, WeightEqual=double> will be chosen,
  164. // since it is a better match than double -> float narrowing, but
  165. // the instantiation will fail.
  166. template <class Arc>
  167. bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2, double delta,
  168. uint8_t etype = kEqualFsts) {
  169. return Equal(fst1, fst2, WeightApproxEqual(static_cast<float>(delta)), etype);
  170. }
  171. } // namespace fst
  172. #endif // FST_EQUAL_H_