You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

138 lines
5.0 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Tests if two FSTS are equivalent by checking if random strings from one FST
  19. // are transduced the same by both FSTs.
  20. #ifndef FST_RANDEQUIVALENT_H_
  21. #define FST_RANDEQUIVALENT_H_
  22. #include <cstdint>
  23. #include <limits>
  24. #include <random>
  25. #include <fst/log.h>
  26. #include <fst/arcsort.h>
  27. #include <fst/compose.h>
  28. #include <fst/connect.h>
  29. #include <fst/fst.h>
  30. #include <fst/project.h>
  31. #include <fst/properties.h>
  32. #include <fst/randgen.h>
  33. #include <fst/shortest-distance.h>
  34. #include <fst/symbol-table.h>
  35. #include <fst/util.h>
  36. #include <fst/vector-fst.h>
  37. #include <fst/weight.h>
  38. namespace fst {
  39. // Test if two FSTs are stochastically equivalent by randomly generating
  40. // random paths through the FSTs.
  41. //
  42. // For each randomly generated path, the algorithm computes for each
  43. // of the two FSTs the sum of the weights of all the successful paths
  44. // sharing the same input and output labels as the considered randomly
  45. // generated path and checks that these two values are within a user-specified
  46. // delta. Returns optional error value (when FST_FLAGS_error_fatal = false).
  47. template <class Arc, class ArcSelector>
  48. bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32_t npath,
  49. const RandGenOptions<ArcSelector> &opts,
  50. float delta = kDelta,
  51. uint64_t seed = std::random_device()(),
  52. bool *error = nullptr) {
  53. using Weight = typename Arc::Weight;
  54. if (error) *error = false;
  55. // Checks that the symbol table are compatible.
  56. if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) ||
  57. !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) {
  58. FSTERROR() << "RandEquivalent: Input/output symbol tables of 1st "
  59. << "argument do not match input/output symbol tables of 2nd "
  60. << "argument";
  61. if (error) *error = true;
  62. return false;
  63. }
  64. static const ILabelCompare<Arc> icomp;
  65. static const OLabelCompare<Arc> ocomp;
  66. VectorFst<Arc> sfst1(fst1);
  67. VectorFst<Arc> sfst2(fst2);
  68. Connect(&sfst1);
  69. Connect(&sfst2);
  70. ArcSort(&sfst1, icomp);
  71. ArcSort(&sfst2, icomp);
  72. bool result = true;
  73. std::mt19937 rand(seed);
  74. std::bernoulli_distribution coin(.5);
  75. for (int32_t n = 0; n < npath; ++n) {
  76. VectorFst<Arc> path;
  77. const auto &fst = coin(rand) ? sfst1 : sfst2;
  78. RandGen(fst, &path, opts);
  79. VectorFst<Arc> ipath(path);
  80. VectorFst<Arc> opath(path);
  81. Project(&ipath, ProjectType::INPUT);
  82. Project(&opath, ProjectType::OUTPUT);
  83. VectorFst<Arc> cfst1, pfst1;
  84. Compose(ipath, sfst1, &cfst1);
  85. ArcSort(&cfst1, ocomp);
  86. Compose(cfst1, opath, &pfst1);
  87. // Gives up if there are epsilon cycles in a non-idempotent semiring.
  88. if (!IsIdempotent<Weight>::value && pfst1.Properties(kCyclic, true)) {
  89. continue;
  90. }
  91. const auto sum1 = ShortestDistance(pfst1);
  92. VectorFst<Arc> cfst2;
  93. Compose(ipath, sfst2, &cfst2);
  94. ArcSort(&cfst2, ocomp);
  95. VectorFst<Arc> pfst2;
  96. Compose(cfst2, opath, &pfst2);
  97. // Gives up if there are epsilon cycles in a non-idempotent semiring.
  98. if (!IsIdempotent<Weight>::value && pfst2.Properties(kCyclic, true)) {
  99. continue;
  100. }
  101. const auto sum2 = ShortestDistance(pfst2);
  102. if (!ApproxEqual(sum1, sum2, delta)) {
  103. VLOG(1) << "Sum1 = " << sum1;
  104. VLOG(1) << "Sum2 = " << sum2;
  105. result = false;
  106. break;
  107. }
  108. }
  109. if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) {
  110. if (error) *error = true;
  111. return false;
  112. }
  113. return result;
  114. }
  115. // Tests if two FSTs are equivalent by randomly generating a nnpath paths
  116. // (no longer than the path_length) using a user-specified seed, optionally
  117. // indicating an error setting an optional error argument to true.
  118. template <class Arc>
  119. bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2, int32_t npath,
  120. float delta = kDelta,
  121. uint64_t seed = std::random_device()(),
  122. int32_t max_length = std::numeric_limits<int32_t>::max(),
  123. bool *error = nullptr) {
  124. const UniformArcSelector<Arc> uniform_selector(seed);
  125. const RandGenOptions<UniformArcSelector<Arc>> opts(uniform_selector,
  126. max_length);
  127. return RandEquivalent(fst1, fst2, npath, opts, delta, seed, error);
  128. }
  129. } // namespace fst
  130. #endif // FST_RANDEQUIVALENT_H_