You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
5.4 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Cartesian power weight semiring operation definitions.
  19. #ifndef FST_POWER_WEIGHT_H_
  20. #define FST_POWER_WEIGHT_H_
  21. #include <cstddef>
  22. #include <cstdint>
  23. #include <random>
  24. #include <string>
  25. #include <fst/tuple-weight.h>
  26. #include <fst/weight.h>
  27. namespace fst {
  28. // Cartesian power semiring: W ^ n
  29. //
  30. // Forms:
  31. // - a left semimodule when W is a left semiring,
  32. // - a right semimodule when W is a right semiring,
  33. // - a bisemimodule when W is a semiring,
  34. // the free semimodule of rank n over W
  35. // The Times operation is overloaded to provide the left and right scalar
  36. // products.
  37. template <class W, size_t n>
  38. class PowerWeight : public TupleWeight<W, n> {
  39. public:
  40. using ReverseWeight = PowerWeight<typename W::ReverseWeight, n>;
  41. PowerWeight() = default;
  42. explicit PowerWeight(const TupleWeight<W, n> &weight)
  43. : TupleWeight<W, n>(weight) {}
  44. template <class Iterator>
  45. PowerWeight(Iterator begin, Iterator end) : TupleWeight<W, n>(begin, end) {}
  46. // Initialize component `index` to `weight`; initialize all other components
  47. // to `default_weight`
  48. PowerWeight(size_t index, const W &weight,
  49. const W &default_weight = W::Zero())
  50. : TupleWeight<W, n>(index, weight, default_weight) {}
  51. static const PowerWeight &Zero() {
  52. static const PowerWeight zero(TupleWeight<W, n>::Zero());
  53. return zero;
  54. }
  55. static const PowerWeight &One() {
  56. static const PowerWeight one(TupleWeight<W, n>::One());
  57. return one;
  58. }
  59. static const PowerWeight &NoWeight() {
  60. static const PowerWeight no_weight(TupleWeight<W, n>::NoWeight());
  61. return no_weight;
  62. }
  63. static const std::string &Type() {
  64. static const std::string *const type =
  65. new std::string(W::Type() + "_^" + std::to_string(n));
  66. return *type;
  67. }
  68. static constexpr uint64_t Properties() {
  69. return W::Properties() &
  70. (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
  71. }
  72. PowerWeight Quantize(float delta = kDelta) const {
  73. return PowerWeight(TupleWeight<W, n>::Quantize(delta));
  74. }
  75. ReverseWeight Reverse() const {
  76. return ReverseWeight(TupleWeight<W, n>::Reverse());
  77. }
  78. };
  79. // Semiring plus operation.
  80. template <class W, size_t n>
  81. inline PowerWeight<W, n> Plus(const PowerWeight<W, n> &w1,
  82. const PowerWeight<W, n> &w2) {
  83. PowerWeight<W, n> result;
  84. for (size_t i = 0; i < n; ++i) {
  85. result.SetValue(i, Plus(w1.Value(i), w2.Value(i)));
  86. }
  87. return result;
  88. }
  89. // Semiring times operation.
  90. template <class W, size_t n>
  91. inline PowerWeight<W, n> Times(const PowerWeight<W, n> &w1,
  92. const PowerWeight<W, n> &w2) {
  93. PowerWeight<W, n> result;
  94. for (size_t i = 0; i < n; ++i) {
  95. result.SetValue(i, Times(w1.Value(i), w2.Value(i)));
  96. }
  97. return result;
  98. }
  99. // Semiring divide operation.
  100. template <class W, size_t n>
  101. inline PowerWeight<W, n> Divide(const PowerWeight<W, n> &w1,
  102. const PowerWeight<W, n> &w2,
  103. DivideType type = DIVIDE_ANY) {
  104. PowerWeight<W, n> result;
  105. for (size_t i = 0; i < n; ++i) {
  106. result.SetValue(i, Divide(w1.Value(i), w2.Value(i), type));
  107. }
  108. return result;
  109. }
  110. // Semimodule left scalar product.
  111. template <class W, size_t n>
  112. inline PowerWeight<W, n> Times(const W &scalar,
  113. const PowerWeight<W, n> &weight) {
  114. PowerWeight<W, n> result;
  115. for (size_t i = 0; i < n; ++i) {
  116. result.SetValue(i, Times(scalar, weight.Value(i)));
  117. }
  118. return result;
  119. }
  120. // Semimodule right scalar product.
  121. template <class W, size_t n>
  122. inline PowerWeight<W, n> Times(const PowerWeight<W, n> &weight,
  123. const W &scalar) {
  124. PowerWeight<W, n> result;
  125. for (size_t i = 0; i < n; ++i) {
  126. result.SetValue(i, Times(weight.Value(i), scalar));
  127. }
  128. return result;
  129. }
  130. // Semimodule dot product.
  131. template <class W, size_t n>
  132. inline W DotProduct(const PowerWeight<W, n> &w1, const PowerWeight<W, n> &w2) {
  133. W result(W::Zero());
  134. for (size_t i = 0; i < n; ++i) {
  135. result = Plus(result, Times(w1.Value(i), w2.Value(i)));
  136. }
  137. return result;
  138. }
  139. // This function object generates weights over the Cartesian power of rank
  140. // n over the underlying weight. This is intended primarily for testing.
  141. template <class W, size_t n>
  142. class WeightGenerate<PowerWeight<W, n>> {
  143. public:
  144. using Weight = PowerWeight<W, n>;
  145. using Generate = WeightGenerate<W>;
  146. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  147. bool allow_zero = true)
  148. : generate_(seed, allow_zero) {}
  149. Weight operator()() const {
  150. Weight result;
  151. for (size_t i = 0; i < n; ++i) result.SetValue(i, generate_());
  152. return result;
  153. }
  154. private:
  155. const Generate generate_;
  156. };
  157. } // namespace fst
  158. #endif // FST_POWER_WEIGHT_H_