You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
6.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to reweight/push an FST, and utility functions to weigh and reweight
  19. // an FST.
  20. #ifndef FST_PUSH_H_
  21. #define FST_PUSH_H_
  22. #include <cstdint>
  23. #include <vector>
  24. #include <fst/log.h>
  25. #include <fst/arc-map.h>
  26. #include <fst/arc.h>
  27. #include <fst/factor-weight.h>
  28. #include <fst/fst.h>
  29. #include <fst/mutable-fst.h>
  30. #include <fst/reweight.h>
  31. #include <fst/shortest-distance.h>
  32. #include <fst/string-weight.h>
  33. #include <fst/vector-fst.h>
  34. #include <fst/weight.h>
  35. namespace fst {
  36. // Computes the total weight (sum of the weights of all accepting paths) from
  37. // the output of ShortestDistance, using the shortest distance from the final
  38. // state when reverse is true and from the initial state otherwise.
  39. template <class Arc>
  40. typename Arc::Weight ComputeTotalWeight(
  41. const Fst<Arc> &fst, const std::vector<typename Arc::Weight> &distance,
  42. bool reverse) {
  43. if (reverse) {
  44. return fst.Start() < distance.size() ? distance[fst.Start()]
  45. : Arc::Weight::Zero();
  46. }
  47. auto sum = Arc::Weight::Zero();
  48. for (typename Arc::StateId s = 0; s < distance.size(); ++s) {
  49. sum = Plus(sum, Times(distance[s], fst.Final(s)));
  50. }
  51. return sum;
  52. }
  53. // Divides the weight of every accepting path by a fixed weight. This weight
  54. // is also divided at the final state if at_final is true and at the initial
  55. // state otherwise.
  56. template <class Arc>
  57. void RemoveWeight(MutableFst<Arc> *fst, const typename Arc::Weight &weight,
  58. bool at_final) {
  59. using Weight = typename Arc::Weight;
  60. if ((weight == Weight::One()) || (weight == Weight::Zero())) return;
  61. if (at_final) {
  62. for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
  63. siter.Next()) {
  64. fst->SetFinal(siter.Value(),
  65. Divide(fst->Final(siter.Value()), weight, DIVIDE_RIGHT));
  66. }
  67. } else {
  68. const auto start = fst->Start();
  69. for (MutableArcIterator<MutableFst<Arc>> aiter(fst, start); !aiter.Done();
  70. aiter.Next()) {
  71. auto arc = aiter.Value();
  72. arc.weight = Divide(arc.weight, weight, DIVIDE_LEFT);
  73. aiter.SetValue(arc);
  74. }
  75. fst->SetFinal(start, Divide(fst->Final(start), weight, DIVIDE_LEFT));
  76. }
  77. }
  78. // Pushes the weights in FST in the requested direction. If pushing towards the
  79. // initial state, the sum of the weight of the outgoing transitions and final
  80. // weight at a non-initial state is equal to One() in the resulting machine. If
  81. // pushing towards the final state, the same property holds on the reverse
  82. // machine.
  83. //
  84. // Weight needs to be left distributive when pushing towards the initial state
  85. // and right distributive when pushing towards the final states.
  86. template <class Arc>
  87. void Push(MutableFst<Arc> *fst, ReweightType type = REWEIGHT_TO_INITIAL,
  88. float delta = kShortestDelta, bool remove_total_weight = false) {
  89. using Weight = typename Arc::Weight;
  90. std::vector<Weight> distance;
  91. const bool reverse = type == REWEIGHT_TO_INITIAL;
  92. ShortestDistance(*fst, &distance, reverse, delta);
  93. if (remove_total_weight) {
  94. const auto total_weight = ComputeTotalWeight(*fst, distance, reverse);
  95. Reweight(fst, distance, type);
  96. RemoveWeight(fst, total_weight, !reverse);
  97. } else {
  98. Reweight(fst, distance, type);
  99. }
  100. }
  101. inline constexpr uint8_t kPushWeights = 0x01;
  102. inline constexpr uint8_t kPushLabels = 0x02;
  103. inline constexpr uint8_t kPushRemoveTotalWeight = 0x04;
  104. inline constexpr uint8_t kPushRemoveCommonAffix = 0x08;
  105. // Pushes the weights and/or labels of the input FST into the output mutable FST
  106. // by pushing weights and/or labels (as determined by the ptype argument)
  107. // towards the initial state or final states (as determined by the rtype
  108. // template parameter). The weight type must be left distributive when pushing
  109. // weights towards the initial state, and right distribution when pushing
  110. // weights towards the final states.
  111. template <class Arc, ReweightType rtype>
  112. void Push(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, uint8_t ptype,
  113. float delta = kShortestDelta) {
  114. using Label = typename Arc::Label;
  115. using Weight = typename Arc::Weight;
  116. if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
  117. *ofst = ifst;
  118. Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
  119. } else if (ptype & kPushLabels) {
  120. const auto gtype =
  121. rtype == REWEIGHT_TO_INITIAL ? GALLIC_LEFT : GALLIC_RIGHT;
  122. using GallicWeight = typename GallicArc<Arc, gtype>::Weight;
  123. std::vector<GallicWeight> gdistance;
  124. VectorFst<GallicArc<Arc, gtype>> gfst;
  125. ArcMap(ifst, &gfst, ToGallicMapper<Arc, gtype>());
  126. if (ptype & kPushWeights) {
  127. ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
  128. } else {
  129. ArcMapFst uwfst(ifst, RmWeightMapper<Arc>());
  130. ArcMapFst guwfst(uwfst, ToGallicMapper<Arc, gtype>());
  131. ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
  132. }
  133. auto total_weight = GallicWeight::One();
  134. if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
  135. total_weight =
  136. ComputeTotalWeight(gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
  137. total_weight = GallicWeight(
  138. ptype & kPushRemoveCommonAffix
  139. ? total_weight.Value1()
  140. : StringWeight<Label, GallicStringType(gtype)>::One(),
  141. ptype & kPushRemoveTotalWeight ? total_weight.Value2()
  142. : Weight::One());
  143. }
  144. Reweight(&gfst, gdistance, rtype);
  145. if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
  146. RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
  147. }
  148. FactorWeightFst<GallicArc<Arc, gtype>, GallicFactor<Label, Weight, gtype>>
  149. fwfst(gfst);
  150. ArcMap(fwfst, ofst, FromGallicMapper<Arc, gtype>());
  151. ofst->SetOutputSymbols(ifst.OutputSymbols());
  152. } else {
  153. LOG(WARNING) << "Push: pushing type is set to 0, so not pushing";
  154. *ofst = ifst;
  155. }
  156. }
  157. } // namespace fst
  158. #endif // FST_PUSH_H_