You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

241 lines
7.3 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Utility class for regression testing of FST weights.
  19. #ifndef FST_TEST_WEIGHT_TESTER_H_
  20. #define FST_TEST_WEIGHT_TESTER_H_
  21. #include <sstream>
  22. #include <utility>
  23. #include <fst/log.h>
  24. #include <fst/weight.h>
  25. namespace fst {
  26. // This class tests a variety of identities and properties that must
  27. // hold for the Weight class to be well-defined. It calls function object
  28. // WEIGHT_GENERATOR to select weights that are used in the tests.
  29. template <class Weight, class WeightGenerator = WeightGenerate<Weight>>
  30. class WeightTester {
  31. public:
  32. explicit WeightTester(WeightGenerator generator)
  33. : weight_generator_(std::move(generator)) {}
  34. void Test(int iterations) {
  35. for (int i = 0; i < iterations; ++i) {
  36. // Selects the test weights.
  37. const Weight w1(weight_generator_());
  38. const Weight w2(weight_generator_());
  39. const Weight w3(weight_generator_());
  40. VLOG(1) << "weight type = " << Weight::Type();
  41. VLOG(1) << "w1 = " << w1;
  42. VLOG(1) << "w2 = " << w2;
  43. VLOG(1) << "w3 = " << w3;
  44. TestSemiring(w1, w2, w3);
  45. TestDivision(w1, w2);
  46. TestReverse(w1, w2);
  47. TestEquality(w1, w2, w3);
  48. TestIO(w1);
  49. TestCopy(w1);
  50. }
  51. }
  52. private:
  53. // Note in the tests below we use ApproxEqual rather than == and add
  54. // kDelta to inequalities where the weights might be inexact.
  55. // Tests (Plus, Times, Zero, One) defines a commutative semiring.
  56. void TestSemiring(Weight w1, Weight w2, Weight w3) {
  57. // Checks that the operations are closed.
  58. CHECK(Plus(w1, w2).Member());
  59. CHECK(Times(w1, w2).Member());
  60. // Checks that the operations are associative.
  61. CHECK(ApproxEqual(Plus(w1, Plus(w2, w3)), Plus(Plus(w1, w2), w3)));
  62. CHECK(ApproxEqual(Times(w1, Times(w2, w3)), Times(Times(w1, w2), w3)));
  63. // Checks the identity elements.
  64. CHECK(Plus(w1, Weight::Zero()) == w1);
  65. CHECK(Plus(Weight::Zero(), w1) == w1);
  66. CHECK(Times(w1, Weight::One()) == w1);
  67. CHECK(Times(Weight::One(), w1) == w1);
  68. // Check the no weight element.
  69. CHECK(!Weight::NoWeight().Member());
  70. CHECK(!Plus(w1, Weight::NoWeight()).Member());
  71. CHECK(!Plus(Weight::NoWeight(), w1).Member());
  72. CHECK(!Times(w1, Weight::NoWeight()).Member());
  73. CHECK(!Times(Weight::NoWeight(), w1).Member());
  74. // Checks that the operations commute.
  75. CHECK(ApproxEqual(Plus(w1, w2), Plus(w2, w1)));
  76. if (Weight::Properties() & kCommutative)
  77. CHECK(ApproxEqual(Times(w1, w2), Times(w2, w1)));
  78. // Checks Zero() is the annihilator.
  79. CHECK(Times(w1, Weight::Zero()) == Weight::Zero());
  80. CHECK(Times(Weight::Zero(), w1) == Weight::Zero());
  81. // Check Power(w, 0) is Weight::One()
  82. CHECK(Power(w1, 0) == Weight::One());
  83. // Check Power(w, 1) is w
  84. CHECK(Power(w1, 1) == w1);
  85. // Check Power(w, 2) is Times(w, w)
  86. CHECK(Power(w1, 2) == Times(w1, w1));
  87. // Check Power(w, 3) is Times(w, Times(w, w))
  88. CHECK(Power(w1, 3) == Times(w1, Times(w1, w1)));
  89. // Checks distributivity.
  90. if (Weight::Properties() & kLeftSemiring) {
  91. CHECK(ApproxEqual(Times(w1, Plus(w2, w3)),
  92. Plus(Times(w1, w2), Times(w1, w3))));
  93. }
  94. if (Weight::Properties() & kRightSemiring)
  95. CHECK(ApproxEqual(Times(Plus(w1, w2), w3),
  96. Plus(Times(w1, w3), Times(w2, w3))));
  97. if (Weight::Properties() & kIdempotent) CHECK(Plus(w1, w1) == w1);
  98. if (Weight::Properties() & kPath)
  99. CHECK(Plus(w1, w2) == w1 || Plus(w1, w2) == w2);
  100. // Ensure weights form a left or right semiring.
  101. CHECK(Weight::Properties() & (kLeftSemiring | kRightSemiring));
  102. // Check when Times() is commutative that it is marked as a semiring.
  103. if (Weight::Properties() & kCommutative)
  104. CHECK(Weight::Properties() & kSemiring);
  105. }
  106. // Tests division operation.
  107. void TestDivision(Weight w1, Weight w2) {
  108. Weight p = Times(w1, w2);
  109. VLOG(1) << "TestDivision: p = " << p;
  110. if (Weight::Properties() & kLeftSemiring) {
  111. Weight d = Divide(p, w1, DIVIDE_LEFT);
  112. if (d.Member()) CHECK(ApproxEqual(p, Times(w1, d)));
  113. CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_LEFT).Member());
  114. CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_LEFT).Member());
  115. }
  116. if (Weight::Properties() & kRightSemiring) {
  117. Weight d = Divide(p, w2, DIVIDE_RIGHT);
  118. if (d.Member()) CHECK(ApproxEqual(p, Times(d, w2)));
  119. CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_RIGHT).Member());
  120. CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_RIGHT).Member());
  121. }
  122. if (Weight::Properties() & kCommutative) {
  123. Weight d1 = Divide(p, w1, DIVIDE_ANY);
  124. if (d1.Member()) CHECK(ApproxEqual(p, Times(d1, w1)));
  125. Weight d2 = Divide(p, w2, DIVIDE_ANY);
  126. if (d2.Member()) CHECK(ApproxEqual(p, Times(w2, d2)));
  127. }
  128. }
  129. // Tests reverse operation.
  130. void TestReverse(Weight w1, Weight w2) {
  131. using ReverseWeight = typename Weight::ReverseWeight;
  132. ReverseWeight rw1 = w1.Reverse();
  133. ReverseWeight rw2 = w2.Reverse();
  134. CHECK(rw1.Reverse() == w1);
  135. CHECK(Plus(w1, w2).Reverse() == Plus(rw1, rw2));
  136. CHECK(Times(w1, w2).Reverse() == Times(rw2, rw1));
  137. }
  138. // Tests == is an equivalence relation.
  139. void TestEquality(Weight w1, Weight w2, Weight w3) {
  140. // Checks reflexivity.
  141. CHECK(w1 == w1);
  142. // Checks symmetry.
  143. CHECK((w1 == w2) == (w2 == w1));
  144. // Checks transitivity.
  145. if (w1 == w2 && w2 == w3) CHECK(w1 == w3);
  146. // Checks that two weights are either equal or not equal.
  147. CHECK((w1 == w2) ^ (w1 != w2));
  148. if (w1 == w2) {
  149. // Checks that equal weights have identical hashes.
  150. CHECK(w1.Hash() == w2.Hash());
  151. // Checks that equal weights are also approximately equal.
  152. CHECK(ApproxEqual(w1, w2));
  153. }
  154. // Checks that weights which are not even approximately equal are also
  155. // strictly unequal.
  156. if (!ApproxEqual(w1, w2)) {
  157. CHECK(w1 != w2);
  158. }
  159. }
  160. // Tests binary serialization and textual I/O.
  161. void TestIO(Weight w) {
  162. // Tests binary I/O
  163. {
  164. std::ostringstream os;
  165. w.Write(os);
  166. os.flush();
  167. std::istringstream is(os.str());
  168. Weight v;
  169. v.Read(is);
  170. CHECK_EQ(w, v);
  171. }
  172. // Tests textual I/O.
  173. {
  174. std::ostringstream os;
  175. os << w;
  176. std::istringstream is(os.str());
  177. Weight v(Weight::One());
  178. is >> v;
  179. CHECK(ApproxEqual(w, v));
  180. }
  181. }
  182. // Tests copy constructor and assignment operator
  183. void TestCopy(Weight w) {
  184. Weight x = w;
  185. CHECK(w == x);
  186. x = Weight(w);
  187. CHECK(w == x);
  188. x.operator=(x);
  189. CHECK(w == x);
  190. }
  191. // Generates weights used in testing.
  192. WeightGenerator weight_generator_;
  193. };
  194. } // namespace fst
  195. #endif // FST_TEST_WEIGHT_TESTER_H_