You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
7.7 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to determine whether a given (final) state can be reached from some
  19. // other given state.
  20. #ifndef FST_STATE_REACHABLE_H_
  21. #define FST_STATE_REACHABLE_H_
  22. #include <cstddef>
  23. #include <cstdlib>
  24. #include <vector>
  25. #include <fst/log.h>
  26. #include <fst/connect.h>
  27. #include <fst/dfs-visit.h>
  28. #include <fst/fst.h>
  29. #include <fst/interval-set.h>
  30. #include <fst/properties.h>
  31. #include <fst/util.h>
  32. #include <fst/vector-fst.h>
  33. namespace fst {
  34. // Computes the (final) states reachable from a given state in an FST. After
  35. // this visitor has been called, a final state f can be reached from a state
  36. // s iff (*isets)[s].Member(state2index[f]) is true, where (*isets[s]) is a
  37. // set of half-open interval of final state indices and state2index[f] maps from
  38. // a final state to its index. If state2index is empty, it is filled-in with
  39. // suitable indices. If it is non-empty, those indices are used; in this case,
  40. // the final states must have out-degree 0.
  41. template <class Arc, class I = typename Arc::StateId, class S = IntervalSet<I>>
  42. class IntervalReachVisitor {
  43. public:
  44. using Label = typename Arc::Label;
  45. using StateId = typename Arc::StateId;
  46. using Weight = typename Arc::Weight;
  47. using Index = I;
  48. using ISet = S;
  49. using Interval = typename ISet::Interval;
  50. IntervalReachVisitor(const Fst<Arc> &fst, std::vector<S> *isets,
  51. std::vector<Index> *state2index)
  52. : fst_(fst),
  53. isets_(isets),
  54. state2index_(state2index),
  55. index_(state2index->empty() ? 1 : -1),
  56. error_(false) {
  57. isets_->clear();
  58. }
  59. void InitVisit(const Fst<Arc> &) { error_ = false; }
  60. bool InitState(StateId s, StateId r) {
  61. while (isets_->size() <= s) isets_->push_back(S());
  62. while (state2index_->size() <= s) state2index_->push_back(-1);
  63. if (fst_.Final(s) != Weight::Zero()) {
  64. // Create tree interval.
  65. auto *intervals = (*isets_)[s].MutableIntervals();
  66. if (index_ < 0) { // Uses state2index_ map to set index.
  67. if (fst_.NumArcs(s) > 0) {
  68. FSTERROR() << "IntervalReachVisitor: state2index map must be empty "
  69. << "for this FST";
  70. error_ = true;
  71. return false;
  72. }
  73. const auto index = (*state2index_)[s];
  74. if (index < 0) {
  75. FSTERROR() << "IntervalReachVisitor: state2index map incomplete";
  76. error_ = true;
  77. return false;
  78. }
  79. intervals->push_back(Interval(index, index + 1));
  80. } else { // Use pre-order index.
  81. intervals->push_back(Interval(index_, index_ + 1));
  82. (*state2index_)[s] = index_++;
  83. }
  84. }
  85. return true;
  86. }
  87. constexpr bool TreeArc(StateId, const Arc &) const { return true; }
  88. bool BackArc(StateId s, const Arc &arc) {
  89. FSTERROR() << "IntervalReachVisitor: Cyclic input";
  90. error_ = true;
  91. return false;
  92. }
  93. bool ForwardOrCrossArc(StateId s, const Arc &arc) {
  94. // Non-tree interval.
  95. (*isets_)[s].Union((*isets_)[arc.nextstate]);
  96. return true;
  97. }
  98. void FinishState(StateId s, StateId p, const Arc *) {
  99. if (index_ >= 0 && fst_.Final(s) != Weight::Zero()) {
  100. auto *intervals = (*isets_)[s].MutableIntervals();
  101. (*intervals)[0].end = index_; // Updates tree interval end.
  102. }
  103. (*isets_)[s].Normalize();
  104. if (p != kNoStateId) {
  105. (*isets_)[p].Union((*isets_)[s]); // Propagates intervals to parent.
  106. }
  107. }
  108. void FinishVisit() {}
  109. bool Error() const { return error_; }
  110. private:
  111. const Fst<Arc> &fst_;
  112. std::vector<ISet> *isets_;
  113. std::vector<Index> *state2index_;
  114. Index index_;
  115. bool error_;
  116. };
  117. // Tests reachability of final states from a given state. To test for
  118. // reachability from a state s, first do SetState(s). Then a final state f can
  119. // be reached from state s of FST iff Reach(f) is true. The input can be cyclic,
  120. // but no cycle may contain a final state.
  121. template <class Arc, class I = typename Arc::StateId, class S = IntervalSet<I>>
  122. class StateReachable {
  123. public:
  124. using Label = typename Arc::Label;
  125. using StateId = typename Arc::StateId;
  126. using Weight = typename Arc::Weight;
  127. using Index = I;
  128. using ISet = S;
  129. using Interval = typename ISet::Interval;
  130. explicit StateReachable(const Fst<Arc> &fst) : error_(false) {
  131. if (fst.Properties(kAcyclic, true)) {
  132. AcyclicStateReachable(fst);
  133. } else {
  134. CyclicStateReachable(fst);
  135. }
  136. }
  137. explicit StateReachable(const StateReachable<Arc> &reachable) {
  138. FSTERROR() << "Copy constructor for state reachable class "
  139. << "not implemented.";
  140. error_ = true;
  141. }
  142. // Sets current state.
  143. void SetState(StateId s) { s_ = s; }
  144. // Can reach this final state from current state?
  145. bool Reach(StateId s) {
  146. if (s >= state2index_.size()) return false;
  147. const auto i = state2index_[s];
  148. if (i < 0) {
  149. FSTERROR() << "StateReachable: State non-final: " << s;
  150. error_ = true;
  151. return false;
  152. }
  153. return isets_[s_].Member(i);
  154. }
  155. // Access to the state-to-index mapping. Unassigned states have index -1.
  156. std::vector<Index> &State2Index() { return state2index_; }
  157. // Access to the interval sets. These specify the reachability to the final
  158. // states as intervals of the final state indices.
  159. const std::vector<ISet> &IntervalSets() { return isets_; }
  160. bool Error() const { return error_; }
  161. private:
  162. void AcyclicStateReachable(const Fst<Arc> &fst) {
  163. IntervalReachVisitor<Arc, StateId, ISet> reach_visitor(fst, &isets_,
  164. &state2index_);
  165. DfsVisit(fst, &reach_visitor);
  166. if (reach_visitor.Error()) error_ = true;
  167. }
  168. void CyclicStateReachable(const Fst<Arc> &fst) {
  169. // Finds state reachability on the acyclic condensation FST.
  170. VectorFst<Arc> cfst;
  171. std::vector<StateId> scc;
  172. Condense(fst, &cfst, &scc);
  173. StateReachable reachable(cfst);
  174. if (reachable.Error()) {
  175. error_ = true;
  176. return;
  177. }
  178. // Gets the number of states per SCC.
  179. std::vector<size_t> nscc;
  180. for (StateId s = 0; s < scc.size(); ++s) {
  181. const auto c = scc[s];
  182. while (c >= nscc.size()) nscc.push_back(0);
  183. ++nscc[c];
  184. }
  185. // Constructs the interval sets and state index mapping for the original
  186. // FST from the condensation FST.
  187. state2index_.resize(scc.size(), -1);
  188. isets_.resize(scc.size());
  189. for (StateId s = 0; s < scc.size(); ++s) {
  190. const auto c = scc[s];
  191. isets_[s] = reachable.IntervalSets()[c];
  192. state2index_[s] = reachable.State2Index()[c];
  193. // Checks that each final state in an input FST is not contained in a
  194. // cycle (i.e., not in a non-trivial SCC).
  195. if (cfst.Final(c) != Weight::Zero() && nscc[c] > 1) {
  196. FSTERROR() << "StateReachable: Final state contained in a cycle";
  197. error_ = true;
  198. return;
  199. }
  200. }
  201. }
  202. StateId s_; // Current state.
  203. std::vector<ISet> isets_; // Interval sets per state.
  204. std::vector<Index> state2index_; // Finds index for a final state.
  205. bool error_;
  206. StateReachable &operator=(const StateReachable &) = delete;
  207. };
  208. } // namespace fst
  209. #endif // FST_STATE_REACHABLE_H_