You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

181 lines
6.2 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Lexicographic weight set and associated semiring operation definitions.
  19. //
  20. // A lexicographic weight is a sequence of weights, each of which must have the
  21. // path property and Times() must be (strongly) cancellative
  22. // (for all a,b,c != Zero(): Times(c, a) = Times(c, b) => a = b,
  23. // Times(a, c) = Times(b, c) => a = b).
  24. // The + operation on two weights a and b is the lexicographically
  25. // prior of a and b.
  26. #ifndef FST_LEXICOGRAPHIC_WEIGHT_H_
  27. #define FST_LEXICOGRAPHIC_WEIGHT_H_
  28. #include <cstddef>
  29. #include <cstdint>
  30. #include <random>
  31. #include <string>
  32. #include <fst/log.h>
  33. #include <fst/pair-weight.h>
  34. #include <fst/weight.h>
  35. namespace fst {
  36. template <class W1, class W2>
  37. class LexicographicWeight : public PairWeight<W1, W2> {
  38. public:
  39. static_assert(IsPath<W1>::value, "W1 must have path property.");
  40. static_assert(IsPath<W2>::value, "W2 must have path property.");
  41. using ReverseWeight = LexicographicWeight<typename W1::ReverseWeight,
  42. typename W2::ReverseWeight>;
  43. using PairWeight<W1, W2>::Value1;
  44. using PairWeight<W1, W2>::Value2;
  45. using PairWeight<W1, W2>::SetValue1;
  46. using PairWeight<W1, W2>::SetValue2;
  47. using PairWeight<W1, W2>::Zero;
  48. using PairWeight<W1, W2>::One;
  49. using PairWeight<W1, W2>::NoWeight;
  50. using PairWeight<W1, W2>::Quantize;
  51. using PairWeight<W1, W2>::Reverse;
  52. LexicographicWeight() = default;
  53. explicit LexicographicWeight(const PairWeight<W1, W2> &w)
  54. : PairWeight<W1, W2>(w) {}
  55. LexicographicWeight(W1 w1, W2 w2) : PairWeight<W1, W2>(w1, w2) {}
  56. static const LexicographicWeight &Zero() {
  57. static const LexicographicWeight zero(PairWeight<W1, W2>::Zero());
  58. return zero;
  59. }
  60. static const LexicographicWeight &One() {
  61. static const LexicographicWeight one(PairWeight<W1, W2>::One());
  62. return one;
  63. }
  64. static const LexicographicWeight &NoWeight() {
  65. static const LexicographicWeight no_weight(PairWeight<W1, W2>::NoWeight());
  66. return no_weight;
  67. }
  68. static const std::string &Type() {
  69. static const std::string *const type =
  70. new std::string(W1::Type() + "_LT_" + W2::Type());
  71. return *type;
  72. }
  73. bool Member() const {
  74. if (!Value1().Member() || !Value2().Member()) return false;
  75. // Lexicographic weights cannot mix zeroes and non-zeroes.
  76. if (Value1() == W1::Zero() && Value2() == W2::Zero()) return true;
  77. if (Value1() != W1::Zero() && Value2() != W2::Zero()) return true;
  78. return false;
  79. }
  80. LexicographicWeight Quantize(float delta = kDelta) const {
  81. return LexicographicWeight(PairWeight<W1, W2>::Quantize());
  82. }
  83. ReverseWeight Reverse() const {
  84. return ReverseWeight(PairWeight<W1, W2>::Reverse());
  85. }
  86. static constexpr uint64_t Properties() {
  87. return W1::Properties() & W2::Properties() &
  88. (kLeftSemiring | kRightSemiring | kPath | kIdempotent |
  89. kCommutative);
  90. }
  91. };
  92. template <class W1, class W2>
  93. inline LexicographicWeight<W1, W2> Plus(const LexicographicWeight<W1, W2> &w,
  94. const LexicographicWeight<W1, W2> &v) {
  95. if (!w.Member() || !v.Member()) {
  96. return LexicographicWeight<W1, W2>::NoWeight();
  97. }
  98. NaturalLess<W1> less1;
  99. NaturalLess<W2> less2;
  100. if (less1(w.Value1(), v.Value1())) return w;
  101. if (less1(v.Value1(), w.Value1())) return v;
  102. if (less2(w.Value2(), v.Value2())) return w;
  103. if (less2(v.Value2(), w.Value2())) return v;
  104. return w;
  105. }
  106. template <class W1, class W2>
  107. inline LexicographicWeight<W1, W2> Times(const LexicographicWeight<W1, W2> &w,
  108. const LexicographicWeight<W1, W2> &v) {
  109. return LexicographicWeight<W1, W2>(Times(w.Value1(), v.Value1()),
  110. Times(w.Value2(), v.Value2()));
  111. }
  112. template <class W1, class W2>
  113. inline LexicographicWeight<W1, W2> Divide(const LexicographicWeight<W1, W2> &w,
  114. const LexicographicWeight<W1, W2> &v,
  115. DivideType typ = DIVIDE_ANY) {
  116. return LexicographicWeight<W1, W2>(Divide(w.Value1(), v.Value1(), typ),
  117. Divide(w.Value2(), v.Value2(), typ));
  118. }
  119. // This function object generates weights by calling the underlying generators
  120. // for the templated weight types, like all other pair weight types. However,
  121. // for lexicographic weights, we cannot generate zeroes for the two subweights
  122. // separately: weights are members iff both members are zero or both members
  123. // are non-zero. This is intended primarily for testing.
  124. template <class W1, class W2>
  125. class WeightGenerate<LexicographicWeight<W1, W2>> {
  126. public:
  127. using Weight = LexicographicWeight<W1, W1>;
  128. using Generate1 = WeightGenerate<W1>;
  129. using Generate2 = WeightGenerate<W2>;
  130. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  131. bool allow_zero = true,
  132. size_t num_random_weights = kNumRandomWeights)
  133. : rand_(seed),
  134. allow_zero_(allow_zero),
  135. num_random_weights_(num_random_weights),
  136. generator1_(seed, false, num_random_weights),
  137. generator2_(seed, false, num_random_weights) {}
  138. Weight operator()() const {
  139. if (allow_zero_) {
  140. const int sample =
  141. std::uniform_int_distribution<>(0, num_random_weights_)(rand_);
  142. if (sample == num_random_weights_) return Weight(W1::Zero(), W2::Zero());
  143. }
  144. return Weight(generator1_(), generator2_());
  145. }
  146. private:
  147. mutable std::mt19937_64 rand_;
  148. const bool allow_zero_;
  149. const size_t num_random_weights_;
  150. const Generate1 generator1_;
  151. const Generate2 generator2_;
  152. };
  153. } // namespace fst
  154. #endif // FST_LEXICOGRAPHIC_WEIGHT_H_