You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

339 lines
9.8 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Queue-dependent visitation of finite-state transducers. See also dfs-visit.h.
  19. #ifndef FST_VISIT_H_
  20. #define FST_VISIT_H_
  21. #include <cstdint>
  22. #include <new>
  23. #include <vector>
  24. #include <fst/arcfilter.h>
  25. #include <fst/fst.h>
  26. #include <fst/memory.h>
  27. #include <fst/mutable-fst.h>
  28. #include <fst/properties.h>
  29. namespace fst {
  30. // Visitor Interface: class determining actions taken during a visit. If any of
  31. // the boolean member functions return false, the visit is aborted by first
  32. // calling FinishState() on all unfinished (grey) states and then calling
  33. // FinishVisit().
  34. //
  35. // Note this is more general than the visitor interface in dfs-visit.h but lacks
  36. // some DFS-specific behavior.
  37. //
  38. // template <class Arc>
  39. // class Visitor {
  40. // public:
  41. // using StateId = typename Arc::StateId;
  42. //
  43. // Visitor(T *return_data);
  44. //
  45. // // Invoked before visit.
  46. // void InitVisit(const Fst<Arc> &fst);
  47. //
  48. // // Invoked when state discovered (2nd arg is visitation root).
  49. // bool InitState(StateId s, StateId root);
  50. //
  51. // // Invoked when arc to white/undiscovered state examined.
  52. // bool WhiteArc(StateId s, const Arc &arc);
  53. //
  54. // // Invoked when arc to grey/unfinished state examined.
  55. // bool GreyArc(StateId s, const Arc &arc);
  56. //
  57. // // Invoked when arc to black/finished state examined.
  58. // bool BlackArc(StateId s, const Arc &arc);
  59. //
  60. // // Invoked when state finished.
  61. // void FinishState(StateId s);
  62. //
  63. // // Invoked after visit.
  64. // void FinishVisit();
  65. // };
  66. // Performs queue-dependent visitation. Visitor class argument determines
  67. // actions and contains any return data. ArcFilter determines arcs that are
  68. // considered. If 'access_only' is true, performs visitation only to states
  69. // accessible from the initial state.
  70. template <class FST, class Visitor, class Queue, class ArcFilter>
  71. void Visit(const FST &fst, Visitor *visitor, Queue *queue, ArcFilter filter,
  72. bool access_only = false) {
  73. using Arc = typename FST::Arc;
  74. using StateId = typename Arc::StateId;
  75. visitor->InitVisit(fst);
  76. const auto start = fst.Start();
  77. if (start == kNoStateId) {
  78. visitor->FinishVisit();
  79. return;
  80. }
  81. // An FST's state's visit color.
  82. static constexpr uint8_t kWhiteState = 0x01; // Undiscovered.
  83. static constexpr uint8_t kGreyState = 0x02; // Discovered & unfinished.
  84. static constexpr uint8_t kBlackState = 0x04; // Finished.
  85. // We destroy an iterator as soon as possible and mark it so.
  86. static constexpr uint8_t kArcIterDone = 0x08;
  87. std::vector<uint8_t> state_status;
  88. std::vector<ArcIterator<FST> *> arc_iterator;
  89. MemoryPool<ArcIterator<FST>> aiter_pool;
  90. // Exact number of states if known, otherwise lower bound.
  91. StateId nstates = fst.NumStatesIfKnown().value_or(start + 1);
  92. const bool expanded = fst.Properties(kExpanded, false);
  93. state_status.resize(nstates, kWhiteState);
  94. arc_iterator.resize(nstates);
  95. StateIterator<Fst<Arc>> siter(fst);
  96. // Continues visit while true.
  97. bool visit = true;
  98. // Iterates over trees in visit forest.
  99. for (auto root = start; visit && root < nstates;) {
  100. visit = visitor->InitState(root, root);
  101. state_status[root] = kGreyState;
  102. queue->Enqueue(root);
  103. while (!queue->Empty()) {
  104. auto state = queue->Head();
  105. if (state >= state_status.size()) {
  106. nstates = state + 1;
  107. state_status.resize(nstates, kWhiteState);
  108. arc_iterator.resize(nstates);
  109. }
  110. // Creates arc iterator if needed.
  111. if (!arc_iterator[state] && !(state_status[state] & kArcIterDone) &&
  112. visit) {
  113. arc_iterator[state] = new (&aiter_pool) ArcIterator<FST>(fst, state);
  114. }
  115. // Deletes arc iterator if done.
  116. auto *aiter = arc_iterator[state];
  117. if ((aiter && aiter->Done()) || !visit) {
  118. Destroy(aiter, &aiter_pool);
  119. arc_iterator[state] = nullptr;
  120. state_status[state] |= kArcIterDone;
  121. }
  122. // Dequeues state and marks black if done.
  123. if (state_status[state] & kArcIterDone) {
  124. queue->Dequeue();
  125. visitor->FinishState(state);
  126. state_status[state] = kBlackState;
  127. continue;
  128. }
  129. const auto &arc = aiter->Value();
  130. if (arc.nextstate >= state_status.size()) {
  131. nstates = arc.nextstate + 1;
  132. state_status.resize(nstates, kWhiteState);
  133. arc_iterator.resize(nstates);
  134. }
  135. // Visits respective arc types.
  136. if (filter(arc)) {
  137. // Enqueues destination state and marks grey if white.
  138. if (state_status[arc.nextstate] == kWhiteState) {
  139. visit = visitor->WhiteArc(state, arc);
  140. if (!visit) continue;
  141. visit = visitor->InitState(arc.nextstate, root);
  142. state_status[arc.nextstate] = kGreyState;
  143. queue->Enqueue(arc.nextstate);
  144. } else if (state_status[arc.nextstate] == kBlackState) {
  145. visit = visitor->BlackArc(state, arc);
  146. } else {
  147. visit = visitor->GreyArc(state, arc);
  148. }
  149. }
  150. aiter->Next();
  151. // Destroys an iterator ASAP for efficiency.
  152. if (aiter->Done()) {
  153. Destroy(aiter, &aiter_pool);
  154. arc_iterator[state] = nullptr;
  155. state_status[state] |= kArcIterDone;
  156. }
  157. }
  158. if (access_only) break;
  159. // Finds next tree root.
  160. for (root = (root == start) ? 0 : root + 1;
  161. root < nstates && state_status[root] != kWhiteState; ++root) {
  162. }
  163. // Check for a state beyond the largest known state.
  164. if (!expanded && root == nstates) {
  165. for (; !siter.Done(); siter.Next()) {
  166. if (siter.Value() == nstates) {
  167. ++nstates;
  168. state_status.push_back(kWhiteState);
  169. arc_iterator.push_back(nullptr);
  170. break;
  171. }
  172. }
  173. }
  174. }
  175. visitor->FinishVisit();
  176. }
  177. template <class Arc, class Visitor, class Queue>
  178. inline void Visit(const Fst<Arc> &fst, Visitor *visitor, Queue *queue) {
  179. Visit(fst, visitor, queue, AnyArcFilter<Arc>());
  180. }
  181. // Copies input FST to mutable FST following queue order.
  182. template <class A>
  183. class CopyVisitor {
  184. public:
  185. using Arc = A;
  186. using StateId = typename Arc::StateId;
  187. explicit CopyVisitor(MutableFst<Arc> *ofst) : ifst_(nullptr), ofst_(ofst) {}
  188. void InitVisit(const Fst<A> &ifst) {
  189. ifst_ = &ifst;
  190. ofst_->DeleteStates();
  191. ofst_->SetStart(ifst_->Start());
  192. }
  193. bool InitState(StateId state, StateId) {
  194. while (ofst_->NumStates() <= state) ofst_->AddState();
  195. return true;
  196. }
  197. bool WhiteArc(StateId state, const Arc &arc) {
  198. ofst_->AddArc(state, arc);
  199. return true;
  200. }
  201. bool GreyArc(StateId state, const Arc &arc) {
  202. ofst_->AddArc(state, arc);
  203. return true;
  204. }
  205. bool BlackArc(StateId state, const Arc &arc) {
  206. ofst_->AddArc(state, arc);
  207. return true;
  208. }
  209. void FinishState(StateId state) {
  210. ofst_->SetFinal(state, ifst_->Final(state));
  211. }
  212. void FinishVisit() {}
  213. private:
  214. const Fst<Arc> *ifst_;
  215. MutableFst<Arc> *ofst_;
  216. };
  217. // Visits input FST up to a state limit following queue order.
  218. template <class A>
  219. class PartialVisitor {
  220. public:
  221. using Arc = A;
  222. using StateId = typename Arc::StateId;
  223. explicit PartialVisitor(StateId maxvisit)
  224. : fst_(nullptr), maxvisit_(maxvisit) {}
  225. void InitVisit(const Fst<A> &ifst) {
  226. fst_ = &ifst;
  227. ninit_ = 0;
  228. nfinish_ = 0;
  229. }
  230. bool InitState(StateId state, StateId root) {
  231. ++ninit_;
  232. return ninit_ <= maxvisit_;
  233. }
  234. bool WhiteArc(StateId state, const Arc &arc) { return true; }
  235. bool GreyArc(StateId state, const Arc &arc) { return true; }
  236. bool BlackArc(StateId state, const Arc &arc) { return true; }
  237. void FinishState(StateId state) {
  238. fst_->Final(state); // Visits super-final arc.
  239. ++nfinish_;
  240. }
  241. void FinishVisit() {}
  242. StateId NumInitialized() { return ninit_; }
  243. StateId NumFinished() { return nfinish_; }
  244. private:
  245. const Fst<Arc> *fst_;
  246. StateId maxvisit_;
  247. StateId ninit_;
  248. StateId nfinish_;
  249. };
  250. // Copies input FST to mutable FST up to a state limit following queue order.
  251. template <class A>
  252. class PartialCopyVisitor : public CopyVisitor<A> {
  253. public:
  254. using Arc = A;
  255. using StateId = typename Arc::StateId;
  256. using CopyVisitor<A>::WhiteArc;
  257. PartialCopyVisitor(MutableFst<Arc> *ofst, StateId maxvisit,
  258. bool copy_grey = true, bool copy_black = true)
  259. : CopyVisitor<A>(ofst),
  260. maxvisit_(maxvisit),
  261. copy_grey_(copy_grey),
  262. copy_black_(copy_black) {}
  263. void InitVisit(const Fst<A> &ifst) {
  264. CopyVisitor<A>::InitVisit(ifst);
  265. ninit_ = 0;
  266. nfinish_ = 0;
  267. }
  268. bool InitState(StateId state, StateId root) {
  269. CopyVisitor<A>::InitState(state, root);
  270. ++ninit_;
  271. return ninit_ <= maxvisit_;
  272. }
  273. bool GreyArc(StateId state, const Arc &arc) {
  274. if (copy_grey_) return CopyVisitor<A>::GreyArc(state, arc);
  275. return true;
  276. }
  277. bool BlackArc(StateId state, const Arc &arc) {
  278. if (copy_black_) return CopyVisitor<A>::BlackArc(state, arc);
  279. return true;
  280. }
  281. void FinishState(StateId state) {
  282. CopyVisitor<A>::FinishState(state);
  283. ++nfinish_;
  284. }
  285. void FinishVisit() {}
  286. StateId NumInitialized() { return ninit_; }
  287. StateId NumFinished() { return nfinish_; }
  288. private:
  289. StateId maxvisit_;
  290. StateId ninit_;
  291. StateId nfinish_;
  292. const bool copy_grey_;
  293. const bool copy_black_;
  294. };
  295. } // namespace fst
  296. #endif // FST_VISIT_H_