You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
4.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Product weight set and associated semiring operation definitions.
  19. #ifndef FST_PRODUCT_WEIGHT_H_
  20. #define FST_PRODUCT_WEIGHT_H_
  21. #include <cstdint>
  22. #include <random>
  23. #include <string>
  24. #include <utility>
  25. #include <fst/pair-weight.h>
  26. #include <fst/weight.h>
  27. namespace fst {
  28. // Product semiring: W1 * W2.
  29. template <class W1, class W2>
  30. class ProductWeight : public PairWeight<W1, W2> {
  31. public:
  32. using ReverseWeight =
  33. ProductWeight<typename W1::ReverseWeight, typename W2::ReverseWeight>;
  34. ProductWeight() = default;
  35. explicit ProductWeight(const PairWeight<W1, W2> &weight)
  36. : PairWeight<W1, W2>(weight) {}
  37. ProductWeight(W1 w1, W2 w2)
  38. : PairWeight<W1, W2>(std::move(w1), std::move(w2)) {}
  39. static const ProductWeight &Zero() {
  40. static const ProductWeight zero(PairWeight<W1, W2>::Zero());
  41. return zero;
  42. }
  43. static const ProductWeight &One() {
  44. static const ProductWeight one(PairWeight<W1, W2>::One());
  45. return one;
  46. }
  47. static const ProductWeight &NoWeight() {
  48. static const ProductWeight no_weight(PairWeight<W1, W2>::NoWeight());
  49. return no_weight;
  50. }
  51. static const std::string &Type() {
  52. static const std::string *const type =
  53. new std::string(W1::Type() + "_X_" + W2::Type());
  54. return *type;
  55. }
  56. static constexpr uint64_t Properties() {
  57. return W1::Properties() & W2::Properties() &
  58. (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
  59. }
  60. ProductWeight Quantize(float delta = kDelta) const {
  61. return ProductWeight(PairWeight<W1, W2>::Quantize(delta));
  62. }
  63. ReverseWeight Reverse() const {
  64. return ReverseWeight(PairWeight<W1, W2>::Reverse());
  65. }
  66. };
  67. template <class W1, class W2>
  68. inline ProductWeight<W1, W2> Plus(const ProductWeight<W1, W2> &w1,
  69. const ProductWeight<W1, W2> &w2) {
  70. return ProductWeight<W1, W2>(Plus(w1.Value1(), w2.Value1()),
  71. Plus(w1.Value2(), w2.Value2()));
  72. }
  73. template <class W1, class W2>
  74. inline ProductWeight<W1, W2> Times(const ProductWeight<W1, W2> &w1,
  75. const ProductWeight<W1, W2> &w2) {
  76. return ProductWeight<W1, W2>(Times(w1.Value1(), w2.Value1()),
  77. Times(w1.Value2(), w2.Value2()));
  78. }
  79. template <class W1, class W2>
  80. inline ProductWeight<W1, W2> Divide(const ProductWeight<W1, W2> &w1,
  81. const ProductWeight<W1, W2> &w2,
  82. DivideType typ = DIVIDE_ANY) {
  83. return ProductWeight<W1, W2>(Divide(w1.Value1(), w2.Value1(), typ),
  84. Divide(w1.Value2(), w2.Value2(), typ));
  85. }
  86. // Specialization for product weight
  87. template <class W1, class W2>
  88. class Adder<ProductWeight<W1, W2>> {
  89. public:
  90. using Weight = ProductWeight<W1, W2>;
  91. Adder() = default;
  92. explicit Adder(Weight w) : adder1_(w.Value1()), adder2_(w.Value2()) {}
  93. Weight Add(const Weight &w) {
  94. adder1_.Add(w.Value1());
  95. adder2_.Add(w.Value2());
  96. return Sum();
  97. }
  98. Weight Sum() const { return Weight(adder1_.Sum(), adder2_.Sum()); }
  99. void Reset(Weight w = Weight::Zero()) {
  100. adder1_.Reset(w.Value1());
  101. adder2_.Reset(w.Value2());
  102. }
  103. private:
  104. Adder<W1> adder1_;
  105. Adder<W2> adder2_;
  106. };
  107. // This function object generates weights by calling the underlying generators
  108. // for the template weight types, like all other pair weight types. This is
  109. // intended primarily for testing.
  110. template <class W1, class W2>
  111. class WeightGenerate<ProductWeight<W1, W2>> {
  112. public:
  113. using Weight = ProductWeight<W1, W2>;
  114. using Generate = WeightGenerate<PairWeight<W1, W2>>;
  115. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  116. bool allow_zero = true)
  117. : generate_(seed, allow_zero) {}
  118. Weight operator()() const { return Weight(generate_()); }
  119. private:
  120. const Generate generate_;
  121. };
  122. } // namespace fst
  123. #endif // FST_PRODUCT_WEIGHT_H_