You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

178 lines
4.7 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Tuple weight set operation definitions.
  19. #ifndef FST_TUPLE_WEIGHT_H_
  20. #define FST_TUPLE_WEIGHT_H_
  21. #include <algorithm>
  22. #include <array>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <functional>
  26. #include <istream>
  27. #include <ostream>
  28. #include <string>
  29. #include <vector>
  30. #include <fst/flags.h>
  31. #include <fst/log.h>
  32. #include <fst/weight.h>
  33. namespace fst {
  34. // n-tuple weight, element of the n-th Cartesian power of W.
  35. template <class W, size_t n>
  36. class TupleWeight {
  37. public:
  38. using ReverseWeight = TupleWeight<typename W::ReverseWeight, n>;
  39. using Weight = W;
  40. using Index = size_t;
  41. template <class Iterator>
  42. TupleWeight(Iterator begin, Iterator end) {
  43. std::copy(begin, end, values_.begin());
  44. }
  45. explicit TupleWeight(const W &weight = W::Zero()) { values_.fill(weight); }
  46. // Initialize component `index` to `weight`; initialize all other components
  47. // to `default_weight`
  48. TupleWeight(Index index, const W &weight, const W &default_weight)
  49. : TupleWeight(default_weight) {
  50. values_[index] = weight;
  51. }
  52. static const TupleWeight<W, n> &Zero() {
  53. static const TupleWeight<W, n> zero(W::Zero());
  54. return zero;
  55. }
  56. static const TupleWeight<W, n> &One() {
  57. static const TupleWeight<W, n> one(W::One());
  58. return one;
  59. }
  60. static const TupleWeight<W, n> &NoWeight() {
  61. static const TupleWeight<W, n> no_weight(W::NoWeight());
  62. return no_weight;
  63. }
  64. constexpr static size_t Length() { return n; }
  65. std::istream &Read(std::istream &istrm) {
  66. for (size_t i = 0; i < n; ++i) values_[i].Read(istrm);
  67. return istrm;
  68. }
  69. std::ostream &Write(std::ostream &ostrm) const {
  70. for (size_t i = 0; i < n; ++i) values_[i].Write(ostrm);
  71. return ostrm;
  72. }
  73. bool Member() const {
  74. return std::all_of(values_.begin(), values_.end(), std::mem_fn(&W::Member));
  75. }
  76. size_t Hash() const {
  77. uint64_t hash = 0;
  78. for (size_t i = 0; i < n; ++i) hash = 5 * hash + values_[i].Hash();
  79. return size_t(hash);
  80. }
  81. TupleWeight<W, n> Quantize(float delta = kDelta) const {
  82. TupleWeight<W, n> weight;
  83. for (size_t i = 0; i < n; ++i) {
  84. weight.values_[i] = values_[i].Quantize(delta);
  85. }
  86. return weight;
  87. }
  88. ReverseWeight Reverse() const {
  89. TupleWeight<W, n> w;
  90. for (size_t i = 0; i < n; ++i) w.values_[i] = values_[i].Reverse();
  91. return w;
  92. }
  93. const W &Value(size_t i) const { return values_[i]; }
  94. void SetValue(size_t i, const W &w) { values_[i] = w; }
  95. private:
  96. std::array<W, n> values_;
  97. };
  98. template <class W, size_t n>
  99. inline bool operator==(const TupleWeight<W, n> &w1,
  100. const TupleWeight<W, n> &w2) {
  101. for (size_t i = 0; i < n; ++i) {
  102. if (w1.Value(i) != w2.Value(i)) return false;
  103. }
  104. return true;
  105. }
  106. template <class W, size_t n>
  107. inline bool operator!=(const TupleWeight<W, n> &w1,
  108. const TupleWeight<W, n> &w2) {
  109. for (size_t i = 0; i < n; ++i) {
  110. if (w1.Value(i) != w2.Value(i)) return true;
  111. }
  112. return false;
  113. }
  114. template <class W, size_t n>
  115. inline bool ApproxEqual(const TupleWeight<W, n> &w1,
  116. const TupleWeight<W, n> &w2, float delta = kDelta) {
  117. for (size_t i = 0; i < n; ++i) {
  118. if (!ApproxEqual(w1.Value(i), w2.Value(i), delta)) return false;
  119. }
  120. return true;
  121. }
  122. template <class W, size_t n>
  123. inline std::ostream &operator<<(std::ostream &strm,
  124. const TupleWeight<W, n> &w) {
  125. CompositeWeightWriter writer(strm);
  126. writer.WriteBegin();
  127. for (size_t i = 0; i < n; ++i) writer.WriteElement(w.Value(i));
  128. writer.WriteEnd();
  129. return strm;
  130. }
  131. template <class W, size_t n>
  132. inline std::istream &operator>>(std::istream &strm, TupleWeight<W, n> &w) {
  133. CompositeWeightReader reader(strm);
  134. reader.ReadBegin();
  135. W v;
  136. // Reads first n-1 elements.
  137. static_assert(n > 0, "Size must be positive.");
  138. for (size_t i = 0; i < n - 1; ++i) {
  139. reader.ReadElement(&v);
  140. w.SetValue(i, v);
  141. }
  142. // Reads n-th element.
  143. reader.ReadElement(&v, true);
  144. w.SetValue(n - 1, v);
  145. reader.ReadEnd();
  146. return strm;
  147. }
  148. } // namespace fst
  149. #endif // FST_TUPLE_WEIGHT_H_