You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

229 lines
7.4 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Cartesian power weight semiring operation definitions, using
  19. // SparseTupleWeight as underlying representation.
  20. #ifndef FST_SPARSE_POWER_WEIGHT_H_
  21. #define FST_SPARSE_POWER_WEIGHT_H_
  22. #include <climits>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <random>
  26. #include <string>
  27. #include <fst/sparse-tuple-weight.h>
  28. #include <fst/weight.h>
  29. namespace fst {
  30. // Sparse cartesian power semiring: W ^ n
  31. //
  32. // Forms:
  33. //
  34. // - a left semimodule when W is a left semiring,
  35. // - a right semimodule when W is a right semiring,
  36. // - a bisemimodule when W is a semiring,
  37. // the free semimodule of rank n over W
  38. //
  39. // The Times operation is overloaded to provide the left and right scalar
  40. // products.
  41. //
  42. // K is the key value type. kNoKey (-1) is reserved for internal use
  43. template <class W, class K = int>
  44. class SparsePowerWeight : public SparseTupleWeight<W, K> {
  45. public:
  46. using Base = SparseTupleWeight<W, K>;
  47. using ReverseWeight = SparsePowerWeight<typename W::ReverseWeight, K>;
  48. SparsePowerWeight() = default;
  49. explicit SparsePowerWeight(const Base &weight) : Base(weight) {}
  50. template <class Iterator>
  51. SparsePowerWeight(Iterator begin, Iterator end) : Base(begin, end) {}
  52. // Initialize component `key` to `weight`, with `default_weight` for all
  53. // other components.
  54. SparsePowerWeight(const K &key, const W &weight,
  55. const W &default_weight = W::Zero())
  56. : Base(key, weight, default_weight) {}
  57. static const SparsePowerWeight &Zero() {
  58. static const SparsePowerWeight zero(Base::Zero());
  59. return zero;
  60. }
  61. static const SparsePowerWeight &One() {
  62. static const SparsePowerWeight one(Base::One());
  63. return one;
  64. }
  65. static const SparsePowerWeight &NoWeight() {
  66. static const SparsePowerWeight no_weight(Base::NoWeight());
  67. return no_weight;
  68. }
  69. // Overide this: Overwrite the Type method to reflect the key type if using
  70. // a non-default key type.
  71. static const std::string &Type() {
  72. static const std::string *const type = [] {
  73. std::string type = W::Type() + "_^n";
  74. if (sizeof(K) != sizeof(uint32_t)) {
  75. type += "_" + std::to_string(CHAR_BIT * sizeof(K));
  76. }
  77. return new std::string(type);
  78. }();
  79. return *type;
  80. }
  81. static constexpr uint64_t Properties() {
  82. return W::Properties() &
  83. (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
  84. }
  85. SparsePowerWeight Quantize(float delta = kDelta) const {
  86. return SparsePowerWeight(Base::Quantize(delta));
  87. }
  88. ReverseWeight Reverse() const { return ReverseWeight(Base::Reverse()); }
  89. };
  90. template <class W, class K, class M>
  91. inline SparsePowerWeight<W, K> SparsePowerWeightMap(
  92. const SparsePowerWeight<W, K> &w1, const SparsePowerWeight<W, K> &w2,
  93. const M &operator_mapper) {
  94. SparsePowerWeight<W, K> result;
  95. SparseTupleWeightMap(&result, w1, w2, operator_mapper);
  96. return result;
  97. }
  98. // Semimodule plus operation.
  99. template <class W, class K>
  100. inline SparsePowerWeight<W, K> Plus(const SparsePowerWeight<W, K> &w1,
  101. const SparsePowerWeight<W, K> &w2) {
  102. return SparsePowerWeightMap(w1, w2, [](const K &k, const W &v1, const W &v2) {
  103. return Plus(v1, v2);
  104. });
  105. }
  106. // Semimodule minus operation.
  107. template <class W, class K>
  108. inline SparsePowerWeight<W, K> Minus(const SparsePowerWeight<W, K> &w1,
  109. const SparsePowerWeight<W, K> &w2) {
  110. return SparsePowerWeightMap(w1, w2, [](const K &k, const W &v1, const W &v2) {
  111. return Minus(v1, v2);
  112. });
  113. }
  114. // Semimodule times operation.
  115. template <class W, class K>
  116. inline SparsePowerWeight<W, K> Times(const SparsePowerWeight<W, K> &w1,
  117. const SparsePowerWeight<W, K> &w2) {
  118. return SparsePowerWeightMap(w1, w2, [](const K &k, const W &v1, const W &v2) {
  119. return Times(v1, v2);
  120. });
  121. }
  122. // Semimodule divide operation.
  123. template <class W, class K>
  124. inline SparsePowerWeight<W, K> Divide(const SparsePowerWeight<W, K> &w1,
  125. const SparsePowerWeight<W, K> &w2,
  126. DivideType type = DIVIDE_ANY) {
  127. return SparsePowerWeightMap(w1, w2,
  128. [type](const K &k, const W &v1, const W &v2) {
  129. return Divide(v1, v2, type);
  130. });
  131. }
  132. // Semimodule dot product operation.
  133. template <class W, class K>
  134. inline const W &DotProduct(const SparsePowerWeight<W, K> &w1,
  135. const SparsePowerWeight<W, K> &w2) {
  136. const SparsePowerWeight<W, K> product = Times(w1, w2);
  137. W result(W::Zero());
  138. for (SparseTupleWeightIterator<W, K> it(product); !it.Done(); it.Next()) {
  139. result = Plus(result, it.Value().second);
  140. }
  141. return result;
  142. }
  143. template <class W, class K>
  144. inline bool ApproxEqual(const SparsePowerWeight<W, K> &w1,
  145. const SparsePowerWeight<W, K> &w2,
  146. float delta = kDelta) {
  147. auto result = SparsePowerWeightMap(
  148. w1, w2, [delta](const K &k, const W &v1, const W &v2) {
  149. return ApproxEqual(v1, v2, delta) ? W::One() : W::Zero();
  150. });
  151. return result == SparsePowerWeight<W, K>::One();
  152. }
  153. template <class W, class K>
  154. inline SparsePowerWeight<W, K> Times(const W &k,
  155. const SparsePowerWeight<W, K> &w2) {
  156. const SparseTupleWeight<W, K> t2(k);
  157. const SparsePowerWeight<W, K> w1(t2);
  158. return Times(w1, w2);
  159. }
  160. template <class W, class K>
  161. inline SparsePowerWeight<W, K> Times(const SparsePowerWeight<W, K> &w1,
  162. const W &k) {
  163. const SparseTupleWeight<W, K> t2(k);
  164. const SparsePowerWeight<W, K> w2(t2);
  165. return Times(w1, w2);
  166. }
  167. template <class W, class K>
  168. inline SparsePowerWeight<W, K> Divide(const SparsePowerWeight<W, K> &w1,
  169. const W &k,
  170. DivideType divide_type = DIVIDE_ANY) {
  171. const SparseTupleWeight<W, K> t2(k);
  172. const SparsePowerWeight<W, K> w2(t2);
  173. return Divide(w1, w2, divide_type);
  174. }
  175. // This function object generates weights over the Cartesian power of rank
  176. // n over the underlying weight. This is intended primarily for testing.
  177. template <class W, class K>
  178. class WeightGenerate<SparsePowerWeight<W, K>> {
  179. public:
  180. using Weight = SparsePowerWeight<W, K>;
  181. using Generate = WeightGenerate<W>;
  182. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  183. bool allow_zero = true, size_t sparse_power_rank = 3)
  184. : generate_(seed, allow_zero), sparse_power_rank_(sparse_power_rank) {}
  185. Weight operator()() const {
  186. Weight weight;
  187. for (size_t i = 1; i <= sparse_power_rank_; ++i) {
  188. weight.PushBack(i, generate_(), true);
  189. }
  190. return weight;
  191. }
  192. private:
  193. const Generate generate_;
  194. const size_t sparse_power_rank_;
  195. };
  196. } // namespace fst
  197. #endif // FST_SPARSE_POWER_WEIGHT_H_