You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
7.7 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes to visit connected components of an FST.
  19. #ifndef FST_CC_VISITORS_H_
  20. #define FST_CC_VISITORS_H_
  21. #include <cstdint>
  22. #include <vector>
  23. #include <fst/fst.h>
  24. #include <fst/union-find.h>
  25. namespace fst {
  26. // Finds and returns connected components. Use with Visit().
  27. template <class Arc>
  28. class CcVisitor {
  29. public:
  30. using Weight = typename Arc::Weight;
  31. using StateId = typename Arc::StateId;
  32. // cc[i]: connected component number for state i.
  33. explicit CcVisitor(std::vector<StateId> *cc)
  34. : comps_(new UnionFind<StateId>(0, kNoStateId)), cc_(cc), nstates_(0) {}
  35. // comps: connected components equiv classes.
  36. explicit CcVisitor(UnionFind<StateId> *comps)
  37. : comps_(comps), cc_(nullptr), nstates_(0) {}
  38. ~CcVisitor() {
  39. if (cc_) delete comps_;
  40. }
  41. void InitVisit(const Fst<Arc> &fst) {}
  42. bool InitState(StateId s, StateId root) {
  43. ++nstates_;
  44. if (comps_->FindSet(s) == kNoStateId) comps_->MakeSet(s);
  45. return true;
  46. }
  47. bool WhiteArc(StateId s, const Arc &arc) {
  48. comps_->MakeSet(arc.nextstate);
  49. comps_->Union(s, arc.nextstate);
  50. return true;
  51. }
  52. bool GreyArc(StateId s, const Arc &arc) {
  53. comps_->Union(s, arc.nextstate);
  54. return true;
  55. }
  56. bool BlackArc(StateId s, const Arc &arc) {
  57. comps_->Union(s, arc.nextstate);
  58. return true;
  59. }
  60. void FinishState(StateId s) {}
  61. void FinishVisit() {
  62. if (cc_) GetCcVector(cc_);
  63. }
  64. // Returns number of components.
  65. // cc[i]: connected component number for state i.
  66. int GetCcVector(std::vector<StateId> *cc) {
  67. cc->clear();
  68. cc->resize(nstates_, kNoStateId);
  69. StateId ncomp = 0;
  70. for (StateId s = 0; s < nstates_; ++s) {
  71. const auto rep = comps_->FindSet(s);
  72. auto &comp = (*cc)[rep];
  73. if (comp == kNoStateId) {
  74. comp = ncomp;
  75. ++ncomp;
  76. }
  77. (*cc)[s] = comp;
  78. }
  79. return ncomp;
  80. }
  81. private:
  82. UnionFind<StateId> *comps_; // Components.
  83. std::vector<StateId> *cc_; // State's cc number.
  84. StateId nstates_; // State count.
  85. };
  86. // Finds and returns strongly-connected components, accessible and
  87. // coaccessible states and related properties. Uses Tarjan's single
  88. // DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer
  89. // Algorithms", 189pp). Use with DfsVisit();
  90. template <class Arc>
  91. class SccVisitor {
  92. public:
  93. using StateId = typename Arc::StateId;
  94. using Weight = typename Arc::Weight;
  95. // scc[i]: strongly-connected component number for state i.
  96. // SCC numbers will be in topological order for acyclic input.
  97. // access[i]: accessibility of state i.
  98. // coaccess[i]: coaccessibility of state i.
  99. // Any of above can be NULL.
  100. // props: related property bits (cyclicity, initial cyclicity,
  101. // accessibility, coaccessibility) set/cleared (o.w. unchanged).
  102. SccVisitor(std::vector<StateId> *scc, std::vector<bool> *access,
  103. std::vector<bool> *coaccess, uint64_t *props)
  104. : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {}
  105. explicit SccVisitor(uint64_t *props)
  106. : scc_(nullptr), access_(nullptr), coaccess_(nullptr), props_(props) {}
  107. void InitVisit(const Fst<Arc> &fst);
  108. bool InitState(StateId s, StateId root);
  109. bool TreeArc(StateId s, const Arc &arc) { return true; }
  110. bool BackArc(StateId s, const Arc &arc) {
  111. const auto t = arc.nextstate;
  112. if (dfnumber_[t] < lowlink_[s]) lowlink_[s] = dfnumber_[t];
  113. if ((*coaccess_)[t]) (*coaccess_)[s] = true;
  114. *props_ |= kCyclic;
  115. *props_ &= ~kAcyclic;
  116. if (t == start_) {
  117. *props_ |= kInitialCyclic;
  118. *props_ &= ~kInitialAcyclic;
  119. }
  120. return true;
  121. }
  122. bool ForwardOrCrossArc(StateId s, const Arc &arc) {
  123. const auto t = arc.nextstate;
  124. if (dfnumber_[t] < dfnumber_[s] /* cross edge */ && onstack_[t] &&
  125. dfnumber_[t] < lowlink_[s]) {
  126. lowlink_[s] = dfnumber_[t];
  127. }
  128. if ((*coaccess_)[t]) (*coaccess_)[s] = true;
  129. return true;
  130. }
  131. // Last argument always ignored, but required by the interface.
  132. void FinishState(StateId state, StateId p, const Arc *);
  133. void FinishVisit() {
  134. // Numbers SCCs in topological order when acyclic.
  135. if (scc_) {
  136. for (size_t s = 0; s < scc_->size(); ++s) {
  137. (*scc_)[s] = nscc_ - 1 - (*scc_)[s];
  138. }
  139. }
  140. if (coaccess_internal_) delete coaccess_;
  141. }
  142. private:
  143. std::vector<StateId> *scc_; // State's scc number.
  144. std::vector<bool> *access_; // State's accessibility.
  145. std::vector<bool> *coaccess_; // State's coaccessibility.
  146. uint64_t *props_;
  147. const Fst<Arc> *fst_;
  148. StateId start_;
  149. StateId nstates_; // State count.
  150. StateId nscc_; // SCC count.
  151. bool coaccess_internal_;
  152. std::vector<StateId> dfnumber_; // State discovery times.
  153. std::vector<StateId>
  154. lowlink_; // lowlink[state] == dfnumber[state] => SCC root
  155. std::vector<bool> onstack_; // Is a state on the SCC stack?
  156. std::vector<StateId> scc_stack_; // SCC stack, with random access.
  157. };
  158. template <class Arc>
  159. inline void SccVisitor<Arc>::InitVisit(const Fst<Arc> &fst) {
  160. if (scc_) scc_->clear();
  161. if (access_) access_->clear();
  162. if (coaccess_) {
  163. coaccess_->clear();
  164. coaccess_internal_ = false;
  165. } else {
  166. coaccess_ = new std::vector<bool>;
  167. coaccess_internal_ = true;
  168. }
  169. *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible;
  170. *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible);
  171. fst_ = &fst;
  172. start_ = fst.Start();
  173. nstates_ = 0;
  174. nscc_ = 0;
  175. dfnumber_.clear();
  176. lowlink_.clear();
  177. onstack_.clear();
  178. scc_stack_.clear();
  179. }
  180. template <class Arc>
  181. inline bool SccVisitor<Arc>::InitState(StateId s, StateId root) {
  182. scc_stack_.push_back(s);
  183. if (static_cast<StateId>(dfnumber_.size()) <= s) {
  184. if (scc_) scc_->resize(s + 1, -1);
  185. if (access_) access_->resize(s + 1, false);
  186. coaccess_->resize(s + 1, false);
  187. dfnumber_.resize(s + 1, -1);
  188. lowlink_.resize(s + 1, -1);
  189. onstack_.resize(s + 1, false);
  190. }
  191. dfnumber_[s] = nstates_;
  192. lowlink_[s] = nstates_;
  193. onstack_[s] = true;
  194. if (root == start_) {
  195. if (access_) (*access_)[s] = true;
  196. } else {
  197. if (access_) (*access_)[s] = false;
  198. *props_ |= kNotAccessible;
  199. *props_ &= ~kAccessible;
  200. }
  201. ++nstates_;
  202. return true;
  203. }
  204. template <class Arc>
  205. inline void SccVisitor<Arc>::FinishState(StateId s, StateId p, const Arc *) {
  206. if (fst_->Final(s) != Weight::Zero()) (*coaccess_)[s] = true;
  207. if (dfnumber_[s] == lowlink_[s]) { // Root of new SCC.
  208. bool scc_coaccess = false;
  209. auto i = scc_stack_.size();
  210. StateId t;
  211. do {
  212. t = scc_stack_[--i];
  213. if ((*coaccess_)[t]) scc_coaccess = true;
  214. } while (s != t);
  215. do {
  216. t = scc_stack_.back();
  217. if (scc_) (*scc_)[t] = nscc_;
  218. if (scc_coaccess) (*coaccess_)[t] = true;
  219. onstack_[t] = false;
  220. scc_stack_.pop_back();
  221. } while (s != t);
  222. if (!scc_coaccess) {
  223. *props_ |= kNotCoAccessible;
  224. *props_ &= ~kCoAccessible;
  225. }
  226. ++nscc_;
  227. }
  228. if (p != kNoStateId) {
  229. if ((*coaccess_)[s]) (*coaccess_)[p] = true;
  230. if (lowlink_[s] < lowlink_[p]) lowlink_[p] = lowlink_[s];
  231. }
  232. }
  233. } // namespace fst
  234. #endif // FST_CC_VISITORS_H_