You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

221 lines
7.3 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Depth-first search visitation. See visit.h for more general search queue
  19. // disciplines.
  20. #ifndef FST_DFS_VISIT_H_
  21. #define FST_DFS_VISIT_H_
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <new>
  25. #include <stack>
  26. #include <vector>
  27. #include <fst/arcfilter.h>
  28. #include <fst/fst.h>
  29. #include <fst/memory.h>
  30. #include <fst/properties.h>
  31. namespace fst {
  32. // Visitor Interface: class determining actions taken during a depth-first
  33. // search-style visit. If any of the boolean member functions return false, the
  34. // DFS is aborted by first calling FinishState() on all currently grey states
  35. // and then calling FinishVisit().
  36. //
  37. // This is similar to the more general visitor interface in visit.h, except
  38. // that FinishState returns additional information appropriate only for a DFS
  39. // and some methods names here are better suited to a DFS.
  40. //
  41. // template <class Arc>
  42. // class Visitor {
  43. // public:
  44. // using StateId = typename Arc::StateId;
  45. //
  46. // Visitor(T *return_data);
  47. //
  48. // // Invoked before DFS visit.
  49. // void InitVisit(const Fst<Arc> &fst);
  50. //
  51. // // Invoked when state discovered (2nd arg is DFS tree root).
  52. // bool InitState(StateId s, StateId root);
  53. //
  54. // // Invoked when tree arc to white/undiscovered state examined.
  55. // bool TreeArc(StateId s, const Arc &arc);
  56. //
  57. // // Invoked when back arc to grey/unfinished state examined.
  58. // bool BackArc(StateId s, const Arc &arc);
  59. //
  60. // // Invoked when forward or cross arc to black/finished state examined.
  61. // bool ForwardOrCrossArc(StateId s, const Arc &arc);
  62. //
  63. // // Invoked when state finished ('s' is tree root, 'parent' is kNoStateId,
  64. // // and 'arc' is nullptr).
  65. // void FinishState(StateId s, StateId parent, const Arc *arc);
  66. //
  67. // // Invoked after DFS visit.
  68. // void FinishVisit();
  69. // };
  70. namespace internal {
  71. // An FST state's DFS stack state.
  72. template <class FST>
  73. struct DfsState {
  74. using Arc = typename FST::Arc;
  75. using StateId = typename Arc::StateId;
  76. DfsState(const FST &fst, StateId s) : state_id(s), arc_iter(fst, s) {}
  77. void *operator new(size_t size, MemoryPool<DfsState<FST>> *pool) {
  78. return pool->Allocate();
  79. }
  80. static void Destroy(DfsState<FST> *dfs_state,
  81. MemoryPool<DfsState<FST>> *pool) {
  82. if (dfs_state) {
  83. dfs_state->~DfsState<FST>();
  84. pool->Free(dfs_state);
  85. }
  86. }
  87. StateId state_id; // FST state.
  88. ArcIterator<FST> arc_iter; // The corresponding arcs.
  89. };
  90. } // namespace internal
  91. // Performs depth-first visitation. Visitor class argument determines actions
  92. // and contains any return data. ArcFilter determines arcs that are considered.
  93. // If 'access_only' is true, performs visitation only to states accessible from
  94. // the initial state.
  95. //
  96. // Note this is similar to Visit() in visit.h called with a LIFO queue, except
  97. // this version has a Visitor class specialized and augmented for a DFS.
  98. template <class FST, class Visitor, class ArcFilter>
  99. void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter,
  100. bool access_only = false) {
  101. using Arc = typename FST::Arc;
  102. using StateId = typename Arc::StateId;
  103. visitor->InitVisit(fst);
  104. const auto start = fst.Start();
  105. if (start == kNoStateId) {
  106. visitor->FinishVisit();
  107. return;
  108. }
  109. // An FST state's DFS status
  110. enum class StateColor : uint8_t {
  111. kWhite = 0, // Undiscovered.
  112. kGrey = 1, // Discovered but unfinished.
  113. kBlack = 2, // Finished.
  114. };
  115. std::vector<StateColor> state_color;
  116. std::stack<internal::DfsState<FST> *> state_stack; // DFS execution stack.
  117. MemoryPool<internal::DfsState<FST>> state_pool; // Pool for DFSStates.
  118. // Exact number of states if known, otherwise lower bound.
  119. StateId nstates = fst.NumStatesIfKnown().value_or(start + 1);
  120. const bool expanded = fst.Properties(kExpanded, false);
  121. state_color.resize(nstates, StateColor::kWhite);
  122. StateIterator<FST> siter(fst);
  123. // Continue DFS while true.
  124. bool dfs = true;
  125. // Iterate over trees in DFS forest.
  126. for (auto root = start; dfs && root < nstates;) {
  127. state_color[root] = StateColor::kGrey;
  128. state_stack.push(new (&state_pool) internal::DfsState<FST>(fst, root));
  129. dfs = visitor->InitState(root, root);
  130. while (!state_stack.empty()) {
  131. auto *dfs_state = state_stack.top();
  132. const auto s = dfs_state->state_id;
  133. if (s >= static_cast<decltype(s)>(state_color.size())) {
  134. nstates = s + 1;
  135. state_color.resize(nstates, StateColor::kWhite);
  136. }
  137. ArcIterator<FST> &aiter = dfs_state->arc_iter;
  138. if (!dfs || aiter.Done()) {
  139. state_color[s] = StateColor::kBlack;
  140. internal::DfsState<FST>::Destroy(dfs_state, &state_pool);
  141. state_stack.pop();
  142. if (!state_stack.empty()) {
  143. auto *parent_state = state_stack.top();
  144. auto &piter = parent_state->arc_iter;
  145. visitor->FinishState(s, parent_state->state_id, &piter.Value());
  146. piter.Next();
  147. } else {
  148. visitor->FinishState(s, kNoStateId, nullptr);
  149. }
  150. continue;
  151. }
  152. const auto &arc = aiter.Value();
  153. if (arc.nextstate >=
  154. static_cast<decltype(arc.nextstate)>(state_color.size())) {
  155. nstates = arc.nextstate + 1;
  156. state_color.resize(nstates, StateColor::kWhite);
  157. }
  158. if (!filter(arc)) {
  159. aiter.Next();
  160. continue;
  161. }
  162. const auto next_color = state_color[arc.nextstate];
  163. switch (next_color) {
  164. case StateColor::kWhite:
  165. dfs = visitor->TreeArc(s, arc);
  166. if (!dfs) break;
  167. state_color[arc.nextstate] = StateColor::kGrey;
  168. state_stack.push(new (&state_pool)
  169. internal::DfsState<FST>(fst, arc.nextstate));
  170. dfs = visitor->InitState(arc.nextstate, root);
  171. break;
  172. case StateColor::kGrey:
  173. dfs = visitor->BackArc(s, arc);
  174. aiter.Next();
  175. break;
  176. case StateColor::kBlack:
  177. dfs = visitor->ForwardOrCrossArc(s, arc);
  178. aiter.Next();
  179. break;
  180. }
  181. }
  182. if (access_only) break;
  183. // Finds next tree root.
  184. for (root = root == start ? 0 : root + 1;
  185. root < nstates && state_color[root] != StateColor::kWhite; ++root) {
  186. }
  187. // Checks for a state beyond the largest known state.
  188. if (!expanded && root == nstates) {
  189. for (; !siter.Done(); siter.Next()) {
  190. if (siter.Value() == nstates) {
  191. ++nstates;
  192. state_color.push_back(StateColor::kWhite);
  193. break;
  194. }
  195. }
  196. }
  197. }
  198. visitor->FinishVisit();
  199. }
  200. template <class Arc, class Visitor>
  201. void DfsVisit(const Fst<Arc> &fst, Visitor *visitor) {
  202. DfsVisit(fst, visitor, AnyArcFilter<Arc>());
  203. }
  204. } // namespace fst
  205. #endif // FST_DFS_VISIT_H_