You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

229 lines
7.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Expectation semiring, as described in:
  19. //
  20. // Eisner, J. 2002. Parameter estimation for probabilistic finite-state
  21. // transducers. In Proceedings of the 40th Annual Meeting of the
  22. // Association for Computational Linguistics, pages 1-8.
  23. //
  24. // Multiplex semiring operations and identities:
  25. //
  26. // One: <One, Zero>
  27. // Zero: <Zero, Zero>
  28. // Plus: <a1, b1> + <a2, b2> = <(a1 + a2), (b1 + b2)>
  29. // Times: <a1, b1> + <a2, b2> = <(a1 * a2), [(a1 * b2) + (a2 * b1)]>
  30. // Division (see Divide() for proof):
  31. // For a left-semiring:
  32. // <a1, b1> / <a2, b2> = <a1 / a2, (b1 - b2 * (a1 / a2)) / a2>
  33. // For a right-semiring:
  34. // <a1, b1> / <a2, b2> = <a1 / a2, (b1 - (a1 / a2) * b2) / a2>
  35. //
  36. // It is commonly used to store a probability, random variable pair so that
  37. // the shortest distance gives the posterior probability and the associated
  38. // expected value.
  39. #ifndef FST_EXPECTATION_WEIGHT_H_
  40. #define FST_EXPECTATION_WEIGHT_H_
  41. #include <cstdint>
  42. #include <random>
  43. #include <string>
  44. #include <fst/log.h>
  45. #include <fst/pair-weight.h>
  46. #include <fst/weight.h>
  47. namespace fst {
  48. // W1 is usually a probability weight like LogWeight.
  49. // W2 is usually a random variable or vector (see SignedLogWeight or
  50. // SparsePowerWeight).
  51. //
  52. // If W1 is distinct from W2, it is required that there is an external product
  53. // between W1 and W2 (that is, both Times(W1, W2) -> W2 and Times(W2, W1) -> W2
  54. // must be defined) and if both semirings are commutative, or left or right
  55. // semirings, then the result must have those properties.
  56. template <class W1, class W2>
  57. class ExpectationWeight : public PairWeight<W1, W2> {
  58. public:
  59. using PairWeight<W1, W2>::Value1;
  60. using PairWeight<W1, W2>::Value2;
  61. using PairWeight<W1, W2>::Reverse;
  62. using PairWeight<W1, W2>::Quantize;
  63. using PairWeight<W1, W2>::Member;
  64. using ReverseWeight =
  65. ExpectationWeight<typename W1::ReverseWeight, typename W2::ReverseWeight>;
  66. ExpectationWeight() : PairWeight<W1, W2>(Zero()) {}
  67. explicit ExpectationWeight(const PairWeight<W1, W2> &weight)
  68. : PairWeight<W1, W2>(weight) {}
  69. ExpectationWeight(const W1 &w1, const W2 &w2) : PairWeight<W1, W2>(w1, w2) {}
  70. static const ExpectationWeight &Zero() {
  71. static const ExpectationWeight zero(W1::Zero(), W2::Zero());
  72. return zero;
  73. }
  74. static const ExpectationWeight &One() {
  75. static const ExpectationWeight one(W1::One(), W2::Zero());
  76. return one;
  77. }
  78. static const ExpectationWeight &NoWeight() {
  79. static const ExpectationWeight no_weight(W1::NoWeight(), W2::NoWeight());
  80. return no_weight;
  81. }
  82. static const std::string &Type() {
  83. static const std::string *const type =
  84. new std::string("expectation_" + W1::Type() + "_" + W2::Type());
  85. return *type;
  86. }
  87. ExpectationWeight Quantize(float delta = kDelta) const {
  88. return ExpectationWeight(PairWeight<W1, W2>::Quantize(delta));
  89. }
  90. ReverseWeight Reverse() const {
  91. return ReverseWeight(PairWeight<W1, W2>::Reverse());
  92. }
  93. bool Member() const { return PairWeight<W1, W2>::Member(); }
  94. static constexpr uint64_t Properties() {
  95. return W1::Properties() & W2::Properties() &
  96. (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
  97. }
  98. };
  99. template <class W1, class W2>
  100. inline ExpectationWeight<W1, W2> Plus(const ExpectationWeight<W1, W2> &w1,
  101. const ExpectationWeight<W1, W2> &w2) {
  102. return ExpectationWeight<W1, W2>(Plus(w1.Value1(), w2.Value1()),
  103. Plus(w1.Value2(), w2.Value2()));
  104. }
  105. template <class W1, class W2>
  106. inline ExpectationWeight<W1, W2> Times(const ExpectationWeight<W1, W2> &w1,
  107. const ExpectationWeight<W1, W2> &w2) {
  108. return ExpectationWeight<W1, W2>(
  109. Times(w1.Value1(), w2.Value1()),
  110. Plus(Times(w1.Value1(), w2.Value2()), Times(w1.Value2(), w2.Value1())));
  111. }
  112. // Requires
  113. // * Divide(W1, W1) -> W1
  114. // * Divide(W2, W1) -> W2
  115. // * Times(W1, W2) -> W2
  116. // (already required by Times(ExpectationWeight, ExpectationWeight).)
  117. // * Minus(W2, W2) -> W2
  118. // (not part of the Weight interface, so Divide will not compile if
  119. // Minus is not defined).
  120. template <class W1, class W2>
  121. inline ExpectationWeight<W1, W2> Divide(const ExpectationWeight<W1, W2> &w1,
  122. const ExpectationWeight<W1, W2> &w2,
  123. DivideType typ) {
  124. // No special cases are required for !w1.Member(), !w2.Member(), or
  125. // w2 == Zero(), since Minus and Divide will already return NoWeight()
  126. // in these cases.
  127. // For a right-semiring, by the definition of Divide, we are looking for
  128. // z = x / y such that (x / y) * y = x.
  129. // Let <x1, x2> = x, <y1, y2> = y, <z1, z2> = z.
  130. // <z1, z2> * <y1, y2> = <x1, x2>.
  131. // By the definition of Times:
  132. // z1 * y1 = x1 and
  133. // z1 * y2 + z2 * y1 = x2.
  134. // So z1 = x1 / y1, and
  135. // z2 * y2 = x2 - z1 * y2
  136. // z2 = (x2 - z1 * y2) / y2.
  137. // The left-semiring case is symmetric. The commutative case allows
  138. // additional simplification to
  139. // z2 = z1 * (x2 / x1 - y2 / y1) if x1 != 0
  140. // z2 = x2 / y1 if x1 == 0, but this requires testing against 0
  141. // with ApproxEquals. We just use the right-semiring result in
  142. // this case.
  143. const auto w11 = w1.Value1();
  144. const auto w12 = w1.Value2();
  145. const auto w21 = w2.Value1();
  146. const auto w22 = w2.Value2();
  147. const W1 q1 = Divide(w11, w21, typ);
  148. if (typ == DIVIDE_LEFT) {
  149. const W2 q2 = Divide(Minus(w12, Times(w22, q1)), w21, typ);
  150. return ExpectationWeight<W1, W2>(q1, q2);
  151. } else {
  152. // Right or commutative semiring.
  153. const W2 q2 = Divide(Minus(w12, Times(q1, w22)), w21, typ);
  154. return ExpectationWeight<W1, W2>(q1, q2);
  155. }
  156. }
  157. // Specialization for expectation weight.
  158. template <class W1, class W2>
  159. class Adder<ExpectationWeight<W1, W2>> {
  160. public:
  161. using Weight = ExpectationWeight<W1, W2>;
  162. Adder() = default;
  163. explicit Adder(Weight w) : adder1_(w.Value1()), adder2_(w.Value2()) {}
  164. Weight Add(const Weight &w) {
  165. adder1_.Add(w.Value1());
  166. adder2_.Add(w.Value2());
  167. return Sum();
  168. }
  169. Weight Sum() const { return Weight(adder1_.Sum(), adder2_.Sum()); }
  170. void Reset(Weight w = Weight::Zero()) {
  171. adder1_.Reset(w.Value1());
  172. adder2_.Reset(w.Value2());
  173. }
  174. private:
  175. Adder<W1> adder1_;
  176. Adder<W2> adder2_;
  177. };
  178. // This function object generates weights by calling the underlying generators
  179. // for the template weight types, like all other pair weight types. This is
  180. // intended primarily for testing.
  181. template <class W1, class W2>
  182. class WeightGenerate<ExpectationWeight<W1, W2>> {
  183. public:
  184. using Weight = ExpectationWeight<W1, W2>;
  185. using Generate = WeightGenerate<PairWeight<W1, W2>>;
  186. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  187. bool allow_zero = true)
  188. : generate_(seed, allow_zero) {}
  189. Weight operator()() const { return Weight(generate_()); }
  190. private:
  191. const Generate generate_;
  192. };
  193. } // namespace fst
  194. #endif // FST_EXPECTATION_WEIGHT_H_