You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
5.7 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Function to reweight an FST.
  19. #ifndef FST_REWEIGHT_H_
  20. #define FST_REWEIGHT_H_
  21. #include <cstdint>
  22. #include <vector>
  23. #include <fst/log.h>
  24. #include <fst/fst.h>
  25. #include <fst/mutable-fst.h>
  26. #include <fst/properties.h>
  27. #include <fst/util.h>
  28. #include <fst/weight.h>
  29. namespace fst {
  30. enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
  31. // Reweights an FST according to a vector of potentials in a given direction.
  32. // The weight must be left distributive when reweighting towards the initial
  33. // state and right distributive when reweighting towards the final states.
  34. //
  35. // An arc of weight w, with an origin state of potential p and destination state
  36. // of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting
  37. // torwards the initial state, and by (p \otimes w) \otimes q^-1 when
  38. // reweighting towards the final states.
  39. template <class Arc>
  40. void Reweight(MutableFst<Arc> *fst,
  41. const std::vector<typename Arc::Weight> &potential,
  42. ReweightType type) {
  43. using Weight = typename Arc::Weight;
  44. if (fst->NumStates() == 0) return;
  45. // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
  46. // way to "deregister" this operation for non-distributive semirings so an
  47. // informative error message is produced.
  48. if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
  49. FSTERROR() << "Reweight: Reweighting to the final states requires "
  50. << "Weight to be right distributive: " << Weight::Type();
  51. fst->SetProperties(kError, kError);
  52. return;
  53. }
  54. // TODO(kbg): Make this a compile-time static_assert once we have a pleasant
  55. // way to "deregister" this operation for non-distributive semirings so an
  56. // informative error message is produced.
  57. if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
  58. FSTERROR() << "Reweight: Reweighting to the initial state requires "
  59. << "Weight to be left distributive: " << Weight::Type();
  60. fst->SetProperties(kError, kError);
  61. return;
  62. }
  63. const uint64_t input_props = fst->Properties(kFstProperties, false);
  64. StateIterator<MutableFst<Arc>> siter(*fst);
  65. for (; !siter.Done(); siter.Next()) {
  66. const auto s = siter.Value();
  67. if (s == potential.size()) break;
  68. const auto &weight = potential[s];
  69. if (weight != Weight::Zero()) {
  70. for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
  71. aiter.Next()) {
  72. auto arc = aiter.Value();
  73. if (arc.nextstate >= potential.size()) continue;
  74. const auto &nextweight = potential[arc.nextstate];
  75. if (nextweight == Weight::Zero()) continue;
  76. if (type == REWEIGHT_TO_INITIAL) {
  77. arc.weight =
  78. Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT);
  79. }
  80. if (type == REWEIGHT_TO_FINAL) {
  81. arc.weight =
  82. Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT);
  83. }
  84. aiter.SetValue(arc);
  85. }
  86. if (type == REWEIGHT_TO_INITIAL) {
  87. fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT));
  88. }
  89. }
  90. if (type == REWEIGHT_TO_FINAL) {
  91. fst->SetFinal(s, Times(weight, fst->Final(s)));
  92. }
  93. }
  94. // This handles elements past the end of the potentials array.
  95. for (; !siter.Done(); siter.Next()) {
  96. const auto s = siter.Value();
  97. if (type == REWEIGHT_TO_FINAL) {
  98. fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s)));
  99. }
  100. }
  101. const auto startweight = fst->Start() < potential.size()
  102. ? potential[fst->Start()]
  103. : Weight::Zero();
  104. bool added_start_epsilon = false;
  105. if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
  106. if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
  107. const auto s = fst->Start();
  108. for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
  109. aiter.Next()) {
  110. auto arc = aiter.Value();
  111. if (type == REWEIGHT_TO_INITIAL) {
  112. arc.weight = Times(startweight, arc.weight);
  113. } else {
  114. arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
  115. arc.weight);
  116. }
  117. aiter.SetValue(arc);
  118. }
  119. if (type == REWEIGHT_TO_INITIAL) {
  120. fst->SetFinal(s, Times(startweight, fst->Final(s)));
  121. } else {
  122. fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT),
  123. fst->Final(s)));
  124. }
  125. } else {
  126. const auto s = fst->AddState();
  127. const auto weight =
  128. (type == REWEIGHT_TO_INITIAL)
  129. ? startweight
  130. : Divide(Weight::One(), startweight, DIVIDE_RIGHT);
  131. fst->AddArc(s, Arc(0, 0, weight, fst->Start()));
  132. fst->SetStart(s);
  133. added_start_epsilon = true;
  134. }
  135. }
  136. fst->SetProperties(ReweightProperties(input_props, added_start_epsilon) |
  137. fst->Properties(kFstProperties, false),
  138. kFstProperties);
  139. }
  140. } // namespace fst
  141. #endif // FST_REWEIGHT_H_