You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
4.8 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Pair weight templated base class for weight classes that contain two weights
  19. // (e.g. Product, Lexicographic).
  20. #ifndef FST_PAIR_WEIGHT_H_
  21. #define FST_PAIR_WEIGHT_H_
  22. #include <climits>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <istream>
  26. #include <ostream>
  27. #include <random>
  28. #include <stack>
  29. #include <string>
  30. #include <utility>
  31. #include <fst/flags.h>
  32. #include <fst/log.h>
  33. #include <fst/weight.h>
  34. namespace fst {
  35. template <class W1, class W2>
  36. class PairWeight {
  37. public:
  38. using ReverseWeight =
  39. PairWeight<typename W1::ReverseWeight, typename W2::ReverseWeight>;
  40. PairWeight() = default;
  41. PairWeight(W1 w1, W2 w2) : value1_(std::move(w1)), value2_(std::move(w2)) {}
  42. static const PairWeight<W1, W2> &Zero() {
  43. static const PairWeight zero(W1::Zero(), W2::Zero());
  44. return zero;
  45. }
  46. static const PairWeight<W1, W2> &One() {
  47. static const PairWeight one(W1::One(), W2::One());
  48. return one;
  49. }
  50. static const PairWeight<W1, W2> &NoWeight() {
  51. static const PairWeight no_weight(W1::NoWeight(), W2::NoWeight());
  52. return no_weight;
  53. }
  54. std::istream &Read(std::istream &strm) {
  55. value1_.Read(strm);
  56. return value2_.Read(strm);
  57. }
  58. std::ostream &Write(std::ostream &strm) const {
  59. value1_.Write(strm);
  60. return value2_.Write(strm);
  61. }
  62. bool Member() const { return value1_.Member() && value2_.Member(); }
  63. size_t Hash() const {
  64. const auto h1 = value1_.Hash();
  65. const auto h2 = value2_.Hash();
  66. static constexpr int lshift = 5;
  67. static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5;
  68. return h1 << lshift ^ h1 >> rshift ^ h2;
  69. }
  70. PairWeight<W1, W2> Quantize(float delta = kDelta) const {
  71. return PairWeight<W1, W2>(value1_.Quantize(delta), value2_.Quantize(delta));
  72. }
  73. ReverseWeight Reverse() const {
  74. return ReverseWeight(value1_.Reverse(), value2_.Reverse());
  75. }
  76. const W1 &Value1() const { return value1_; }
  77. const W2 &Value2() const { return value2_; }
  78. void SetValue1(const W1 &weight) { value1_ = weight; }
  79. void SetValue2(const W2 &weight) { value2_ = weight; }
  80. private:
  81. W1 value1_;
  82. W2 value2_;
  83. };
  84. template <class W1, class W2>
  85. inline bool operator==(const PairWeight<W1, W2> &w1,
  86. const PairWeight<W1, W2> &w2) {
  87. return w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2();
  88. }
  89. template <class W1, class W2>
  90. inline bool operator!=(const PairWeight<W1, W2> &w1,
  91. const PairWeight<W1, W2> &w2) {
  92. return w1.Value1() != w2.Value1() || w1.Value2() != w2.Value2();
  93. }
  94. template <class W1, class W2>
  95. inline bool ApproxEqual(const PairWeight<W1, W2> &w1,
  96. const PairWeight<W1, W2> &w2, float delta = kDelta) {
  97. return ApproxEqual(w1.Value1(), w2.Value1(), delta) &&
  98. ApproxEqual(w1.Value2(), w2.Value2(), delta);
  99. }
  100. template <class W1, class W2>
  101. inline std::ostream &operator<<(std::ostream &strm,
  102. const PairWeight<W1, W2> &weight) {
  103. CompositeWeightWriter writer(strm);
  104. writer.WriteBegin();
  105. writer.WriteElement(weight.Value1());
  106. writer.WriteElement(weight.Value2());
  107. writer.WriteEnd();
  108. return strm;
  109. }
  110. template <class W1, class W2>
  111. inline std::istream &operator>>(std::istream &strm,
  112. PairWeight<W1, W2> &weight) {
  113. CompositeWeightReader reader(strm);
  114. reader.ReadBegin();
  115. W1 w1;
  116. reader.ReadElement(&w1);
  117. weight.SetValue1(w1);
  118. W2 w2;
  119. reader.ReadElement(&w2, true);
  120. weight.SetValue2(w2);
  121. reader.ReadEnd();
  122. return strm;
  123. }
  124. // This function object returns weights by calling the underlying generators
  125. // and forming a pair. This is intended primarily for testing.
  126. template <class W1, class W2>
  127. class WeightGenerate<PairWeight<W1, W2>> {
  128. public:
  129. using Weight = PairWeight<W1, W2>;
  130. using Generate1 = WeightGenerate<W1>;
  131. using Generate2 = WeightGenerate<W2>;
  132. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  133. bool allow_zero = true)
  134. : generate1_(seed, allow_zero), generate2_(seed, allow_zero) {}
  135. Weight operator()() const { return Weight(generate1_(), generate2_()); }
  136. private:
  137. const Generate1 generate1_;
  138. const Generate2 generate2_;
  139. };
  140. } // namespace fst
  141. #endif // FST_PAIR_WEIGHT_H_