You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

650 lines
23 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Utility classes for the recursive replacement of FSTs (RTNs).
  19. #ifndef FST_REPLACE_UTIL_H_
  20. #define FST_REPLACE_UTIL_H_
  21. #include <cstddef>
  22. #include <cstdint>
  23. #include <map>
  24. #include <memory>
  25. #include <utility>
  26. #include <vector>
  27. #include <fst/log.h>
  28. #include <fst/cc-visitors.h>
  29. #include <fst/connect.h>
  30. #include <fst/dfs-visit.h>
  31. #include <fst/fst.h>
  32. #include <fst/mutable-fst.h>
  33. #include <fst/properties.h>
  34. #include <fst/topsort.h>
  35. #include <fst/util.h>
  36. #include <fst/vector-fst.h>
  37. #include <unordered_map>
  38. #include <unordered_set>
  39. namespace fst {
  40. // This specifies what labels to output on the call or return arc. Note that
  41. // REPLACE_LABEL_INPUT and REPLACE_LABEL_OUTPUT will produce transducers when
  42. // applied to acceptors.
  43. enum ReplaceLabelType {
  44. // Epsilon labels on both input and output.
  45. REPLACE_LABEL_NEITHER = 1,
  46. // Non-epsilon labels on input and epsilon on output.
  47. REPLACE_LABEL_INPUT = 2,
  48. // Epsilon on input and non-epsilon on output.
  49. REPLACE_LABEL_OUTPUT = 3,
  50. // Non-epsilon labels on both input and output.
  51. REPLACE_LABEL_BOTH = 4
  52. };
  53. // By default ReplaceUtil will copy the input label of the replace arc.
  54. // The call_label_type and return_label_type options specify how to manage
  55. // the labels of the call arc and the return arc of the replace FST
  56. struct ReplaceUtilOptions {
  57. int64_t root; // Root rule for expansion.
  58. ReplaceLabelType call_label_type; // How to label call arc.
  59. ReplaceLabelType return_label_type; // How to label return arc.
  60. int64_t return_label; // Label to put on return arc.
  61. explicit ReplaceUtilOptions(
  62. int64_t root = kNoLabel,
  63. ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT,
  64. ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER,
  65. int64_t return_label = 0)
  66. : root(root),
  67. call_label_type(call_label_type),
  68. return_label_type(return_label_type),
  69. return_label(return_label) {}
  70. // For backwards compatibility.
  71. ReplaceUtilOptions(int64_t root, bool epsilon_replace_arc)
  72. : ReplaceUtilOptions(root, epsilon_replace_arc ? REPLACE_LABEL_NEITHER
  73. : REPLACE_LABEL_INPUT) {}
  74. };
  75. // Every non-terminal on a path appears as the first label on that path in every
  76. // FST associated with a given SCC of the replace dependency graph. This would
  77. // be true if the SCC were formed from left-linear grammar rules.
  78. inline constexpr uint8_t kReplaceSCCLeftLinear = 0x01;
  79. // Every non-terminal on a path appears as the final label on that path in every
  80. // FST associated with a given SCC of the replace dependency graph. This would
  81. // be true if the SCC were formed from right-linear grammar rules.
  82. inline constexpr uint8_t kReplaceSCCRightLinear = 0x02;
  83. // The SCC in the replace dependency graph has more than one state or a
  84. // self-loop.
  85. inline constexpr uint8_t kReplaceSCCNonTrivial = 0x04;
  86. // Defined in replace.h.
  87. template <class Arc>
  88. void Replace(
  89. const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>> &,
  90. MutableFst<Arc> *, const ReplaceUtilOptions &);
  91. // Utility class for the recursive replacement of FSTs (RTNs). The user provides
  92. // a set of label/FST pairs at construction. These are used by methods for
  93. // testing cyclic dependencies and connectedness and doing RTN connection and
  94. // specific FST replacement by label or for various optimization properties. The
  95. // modified results can be obtained with the GetFstPairs() or
  96. // GetMutableFstPairs() methods.
  97. template <class Arc>
  98. class ReplaceUtil {
  99. public:
  100. using Label = typename Arc::Label;
  101. using StateId = typename Arc::StateId;
  102. using Weight = typename Arc::Weight;
  103. using FstPair = std::pair<Label, const Fst<Arc> *>;
  104. using MutableFstPair = std::pair<Label, MutableFst<Arc> *>;
  105. using NonTerminalHash = std::unordered_map<Label, Label>;
  106. // Constructs from mutable FSTs; FST ownership is given to ReplaceUtil.
  107. ReplaceUtil(const std::vector<MutableFstPair> &fst_pairs,
  108. const ReplaceUtilOptions &opts);
  109. // Constructs from FSTs; FST ownership is retained by caller.
  110. ReplaceUtil(const std::vector<FstPair> &fst_pairs,
  111. const ReplaceUtilOptions &opts);
  112. // Constructs from ReplaceFst internals; FST ownership is retained by caller.
  113. ReplaceUtil(const std::vector<std::unique_ptr<const Fst<Arc>>> &fst_array,
  114. const NonTerminalHash &nonterminal_hash,
  115. const ReplaceUtilOptions &opts);
  116. ~ReplaceUtil() {
  117. for (Label i = 0; i < fst_array_.size(); ++i) delete fst_array_[i];
  118. }
  119. // True if the non-terminal dependencies are cyclic. Cyclic dependencies will
  120. // result in an unexpandable FST.
  121. bool CyclicDependencies() const {
  122. GetDependencies(false);
  123. return depprops_ & kCyclic;
  124. }
  125. // Returns the strongly-connected component ID in the dependency graph of the
  126. // replace FSTS.
  127. StateId SCC(Label label) const {
  128. GetDependencies(false);
  129. if (const auto it = nonterminal_hash_.find(label);
  130. it != nonterminal_hash_.end()) {
  131. return depscc_[it->second];
  132. }
  133. return kNoStateId;
  134. }
  135. // Returns properties for the strongly-connected component in the dependency
  136. // graph of the replace FSTs. If the SCC is kReplaceSCCLeftLinear or
  137. // kReplaceSCCRightLinear, that SCC can be represented as finite-state despite
  138. // any cyclic dependencies, but not by the usual replacement operation (see
  139. // fst/extensions/pdt/replace.h).
  140. uint8_t SCCProperties(StateId scc_id) {
  141. GetSCCProperties();
  142. return depsccprops_[scc_id];
  143. }
  144. // Returns true if no useless FSTs, states or transitions are present in the
  145. // RTN.
  146. bool Connected() const {
  147. GetDependencies(false);
  148. uint64_t props = kAccessible | kCoAccessible;
  149. for (Label i = 0; i < fst_array_.size(); ++i) {
  150. if (!fst_array_[i]) continue;
  151. if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) {
  152. return false;
  153. }
  154. }
  155. return true;
  156. }
  157. // Removes useless FSTs, states and transitions from the RTN.
  158. void Connect();
  159. // Replaces FSTs specified by labels, unless there are cyclic dependencies.
  160. void ReplaceLabels(const std::vector<Label> &labels);
  161. // Replaces FSTs that have at most nstates states, narcs arcs and nnonterm
  162. // non-terminals (updating in reverse dependency order), unless there are
  163. // cyclic dependencies.
  164. void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
  165. // Replaces singleton FSTS, unless there are cyclic dependencies.
  166. void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
  167. // Replaces non-terminals that have at most ninstances instances (updating in
  168. // dependency order), unless there are cyclic dependencies.
  169. void ReplaceByInstances(size_t ninstances);
  170. // Replaces non-terminals that have only one instance, unless there are cyclic
  171. // dependencies.
  172. void ReplaceUnique() { ReplaceByInstances(1); }
  173. // Returns label/FST pairs, retaining FST ownership.
  174. void GetFstPairs(std::vector<FstPair> *fst_pairs);
  175. // Returns label/mutable FST pairs, giving FST ownership over to the caller.
  176. void GetMutableFstPairs(std::vector<MutableFstPair> *mutable_fst_pairs);
  177. private:
  178. // FST statistics.
  179. struct ReplaceStats {
  180. StateId nstates; // Number of states.
  181. StateId nfinal; // Number of final states.
  182. size_t narcs; // Number of arcs.
  183. Label nnonterms; // Number of non-terminals in FST.
  184. size_t nref; // Number of non-terminal instances referring to this FST.
  185. // Number of times that ith FST references this FST
  186. std::map<Label, size_t> inref;
  187. // Number of times that this FST references the ith FST
  188. std::map<Label, size_t> outref;
  189. ReplaceStats() : nstates(0), nfinal(0), narcs(0), nnonterms(0), nref(0) {}
  190. };
  191. // Checks that Mutable FSTs exists, creating them if necessary.
  192. void CheckMutableFsts();
  193. // Computes the dependency graph for the RTN, computing dependency statistics
  194. // if stats is true.
  195. void GetDependencies(bool stats) const;
  196. void ClearDependencies() const {
  197. depfst_.DeleteStates();
  198. stats_.clear();
  199. depprops_ = 0;
  200. depsccprops_.clear();
  201. have_stats_ = false;
  202. }
  203. // Gets topological order of dependencies, returning false with cyclic input.
  204. bool GetTopOrder(const Fst<Arc> &fst, std::vector<Label> *toporder) const;
  205. // Updates statistics to reflect the replacement of the jth FST.
  206. void UpdateStats(Label j);
  207. // Computes the properties for the strongly-connected component in the
  208. // dependency graph of the replace FSTs.
  209. void GetSCCProperties() const;
  210. Label root_label_; // Root non-terminal.
  211. Label root_fst_; // Root FST ID.
  212. ReplaceLabelType call_label_type_; // See Replace().
  213. ReplaceLabelType return_label_type_; // See Replace().
  214. int64_t return_label_; // See Replace().
  215. std::vector<const Fst<Arc> *> fst_array_; // FST per ID.
  216. std::vector<MutableFst<Arc> *> mutable_fst_array_; // Mutable FST per ID.
  217. std::vector<Label> nonterminal_array_; // FST ID to non-terminal.
  218. NonTerminalHash nonterminal_hash_; // Non-terminal to FST ID.
  219. mutable VectorFst<Arc> depfst_; // FST ID dependencies.
  220. mutable std::vector<StateId> depscc_; // FST SCC ID.
  221. mutable std::vector<bool> depaccess_; // FST ID accessibility.
  222. mutable uint64_t depprops_; // Dependency FST props.
  223. mutable bool have_stats_; // Have dependency statistics?
  224. mutable std::vector<ReplaceStats> stats_; // Per-FST statistics.
  225. mutable std::vector<uint8_t> depsccprops_; // SCC properties.
  226. ReplaceUtil(const ReplaceUtil &) = delete;
  227. ReplaceUtil &operator=(const ReplaceUtil &) = delete;
  228. };
  229. template <class Arc>
  230. ReplaceUtil<Arc>::ReplaceUtil(const std::vector<MutableFstPair> &fst_pairs,
  231. const ReplaceUtilOptions &opts)
  232. : root_label_(opts.root),
  233. call_label_type_(opts.call_label_type),
  234. return_label_type_(opts.return_label_type),
  235. return_label_(opts.return_label),
  236. depprops_(0),
  237. have_stats_(false) {
  238. fst_array_.push_back(nullptr);
  239. mutable_fst_array_.push_back(nullptr);
  240. nonterminal_array_.push_back(kNoLabel);
  241. for (const auto &fst_pair : fst_pairs) {
  242. const auto label = fst_pair.first;
  243. auto *fst = fst_pair.second;
  244. nonterminal_hash_[label] = fst_array_.size();
  245. nonterminal_array_.push_back(label);
  246. fst_array_.push_back(fst);
  247. mutable_fst_array_.push_back(fst);
  248. }
  249. root_fst_ = nonterminal_hash_[root_label_];
  250. if (!root_fst_) {
  251. FSTERROR() << "ReplaceUtil: No root FST for label: " << root_label_;
  252. }
  253. }
  254. template <class Arc>
  255. ReplaceUtil<Arc>::ReplaceUtil(const std::vector<FstPair> &fst_pairs,
  256. const ReplaceUtilOptions &opts)
  257. : root_label_(opts.root),
  258. call_label_type_(opts.call_label_type),
  259. return_label_type_(opts.return_label_type),
  260. return_label_(opts.return_label),
  261. depprops_(0),
  262. have_stats_(false) {
  263. fst_array_.push_back(nullptr);
  264. nonterminal_array_.push_back(kNoLabel);
  265. for (const auto &fst_pair : fst_pairs) {
  266. const auto label = fst_pair.first;
  267. const auto *fst = fst_pair.second;
  268. nonterminal_hash_[label] = fst_array_.size();
  269. nonterminal_array_.push_back(label);
  270. fst_array_.push_back(fst->Copy());
  271. }
  272. root_fst_ = nonterminal_hash_[root_label_];
  273. if (!root_fst_) {
  274. FSTERROR() << "ReplaceUtil: No root FST for label: " << root_label_;
  275. }
  276. }
  277. template <class Arc>
  278. ReplaceUtil<Arc>::ReplaceUtil(
  279. const std::vector<std::unique_ptr<const Fst<Arc>>> &fst_array,
  280. const NonTerminalHash &nonterminal_hash, const ReplaceUtilOptions &opts)
  281. : root_fst_(opts.root),
  282. call_label_type_(opts.call_label_type),
  283. return_label_type_(opts.return_label_type),
  284. return_label_(opts.return_label),
  285. nonterminal_array_(fst_array.size()),
  286. nonterminal_hash_(nonterminal_hash),
  287. depprops_(0),
  288. have_stats_(false) {
  289. fst_array_.push_back(nullptr);
  290. for (size_t i = 1; i < fst_array.size(); ++i) {
  291. fst_array_.push_back(fst_array[i]->Copy());
  292. }
  293. for (auto it = nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) {
  294. nonterminal_array_[it->second] = it->first;
  295. }
  296. root_label_ = nonterminal_array_[root_fst_];
  297. }
  298. template <class Arc>
  299. void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
  300. if (depfst_.NumStates() > 0) {
  301. if (stats && !have_stats_) {
  302. ClearDependencies();
  303. } else {
  304. return;
  305. }
  306. }
  307. have_stats_ = stats;
  308. if (have_stats_) stats_.reserve(fst_array_.size());
  309. for (Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) {
  310. depfst_.AddState();
  311. depfst_.SetFinal(ilabel);
  312. if (have_stats_) stats_.push_back(ReplaceStats());
  313. }
  314. depfst_.SetStart(root_fst_);
  315. // An arc from each state (representing the FST) to the state representing the
  316. // FST being replaced
  317. for (Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) {
  318. const auto *ifst = fst_array_[ilabel];
  319. if (!ifst) continue;
  320. for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
  321. const auto s = siter.Value();
  322. if (have_stats_) {
  323. ++stats_[ilabel].nstates;
  324. if (ifst->Final(s) != Weight::Zero()) ++stats_[ilabel].nfinal;
  325. }
  326. for (ArcIterator<Fst<Arc>> aiter(*ifst, s); !aiter.Done(); aiter.Next()) {
  327. if (have_stats_) ++stats_[ilabel].narcs;
  328. const auto &arc = aiter.Value();
  329. if (auto it = nonterminal_hash_.find(arc.olabel);
  330. it != nonterminal_hash_.end()) {
  331. const auto nextstate = it->second;
  332. depfst_.EmplaceArc(ilabel, arc.olabel, arc.olabel, nextstate);
  333. if (have_stats_) {
  334. ++stats_[ilabel].nnonterms;
  335. ++stats_[nextstate].nref;
  336. ++stats_[nextstate].inref[ilabel];
  337. ++stats_[ilabel].outref[nextstate];
  338. }
  339. }
  340. }
  341. }
  342. }
  343. // Computes accessibility info.
  344. SccVisitor<Arc> scc_visitor(&depscc_, &depaccess_, nullptr, &depprops_);
  345. DfsVisit(depfst_, &scc_visitor);
  346. }
  347. template <class Arc>
  348. void ReplaceUtil<Arc>::UpdateStats(Label j) {
  349. if (!have_stats_) {
  350. FSTERROR() << "ReplaceUtil::UpdateStats: Stats not available";
  351. return;
  352. }
  353. if (j == root_fst_) return; // Can't replace root.
  354. for (auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) {
  355. const auto i = in->first;
  356. const auto ni = in->second;
  357. stats_[i].nstates += stats_[j].nstates * ni;
  358. stats_[i].narcs += (stats_[j].narcs + 1) * ni;
  359. stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
  360. stats_[i].outref.erase(j);
  361. for (auto out = stats_[j].outref.begin(); out != stats_[j].outref.end();
  362. ++out) {
  363. const auto k = out->first;
  364. const auto nk = out->second;
  365. stats_[i].outref[k] += ni * nk;
  366. }
  367. }
  368. for (auto out = stats_[j].outref.begin(); out != stats_[j].outref.end();
  369. ++out) {
  370. const auto k = out->first;
  371. const auto nk = out->second;
  372. stats_[k].nref -= nk;
  373. stats_[k].inref.erase(j);
  374. for (auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) {
  375. const auto i = in->first;
  376. const auto ni = in->second;
  377. stats_[k].inref[i] += ni * nk;
  378. stats_[k].nref += ni * nk;
  379. }
  380. }
  381. }
  382. template <class Arc>
  383. void ReplaceUtil<Arc>::CheckMutableFsts() {
  384. if (mutable_fst_array_.empty()) {
  385. for (Label i = 0; i < fst_array_.size(); ++i) {
  386. if (!fst_array_[i]) {
  387. mutable_fst_array_.push_back(nullptr);
  388. } else {
  389. mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
  390. delete fst_array_[i];
  391. fst_array_[i] = mutable_fst_array_[i];
  392. }
  393. }
  394. }
  395. }
  396. template <class Arc>
  397. void ReplaceUtil<Arc>::Connect() {
  398. CheckMutableFsts();
  399. static constexpr auto props = kAccessible | kCoAccessible;
  400. for (auto *mutable_fst : mutable_fst_array_) {
  401. if (!mutable_fst) continue;
  402. if (mutable_fst->Properties(props, false) != props) {
  403. fst::Connect(mutable_fst);
  404. }
  405. }
  406. GetDependencies(false);
  407. for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
  408. auto *fst = mutable_fst_array_[i];
  409. if (fst && !depaccess_[i]) {
  410. delete fst;
  411. fst_array_[i] = nullptr;
  412. mutable_fst_array_[i] = nullptr;
  413. }
  414. }
  415. ClearDependencies();
  416. }
  417. template <class Arc>
  418. bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
  419. std::vector<Label> *toporder) const {
  420. // Finds topological order of dependencies.
  421. std::vector<StateId> order;
  422. bool acyclic = false;
  423. TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
  424. DfsVisit(fst, &top_order_visitor);
  425. if (!acyclic) {
  426. LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
  427. return false;
  428. }
  429. toporder->resize(order.size());
  430. for (Label i = 0; i < order.size(); ++i) (*toporder)[order[i]] = i;
  431. return true;
  432. }
  433. template <class Arc>
  434. void ReplaceUtil<Arc>::ReplaceLabels(const std::vector<Label> &labels) {
  435. CheckMutableFsts();
  436. std::unordered_set<Label> label_set;
  437. for (const auto label : labels) {
  438. // Can't replace root.
  439. if (label != root_label_) label_set.insert(label);
  440. }
  441. // Finds FST dependencies restricted to the labels requested.
  442. GetDependencies(false);
  443. VectorFst<Arc> pfst(depfst_);
  444. for (StateId i = 0; i < pfst.NumStates(); ++i) {
  445. std::vector<Arc> arcs;
  446. for (ArcIterator<VectorFst<Arc>> aiter(pfst, i); !aiter.Done();
  447. aiter.Next()) {
  448. const auto &arc = aiter.Value();
  449. const auto label = nonterminal_array_[arc.nextstate];
  450. if (label_set.count(label) > 0) arcs.push_back(arc);
  451. }
  452. pfst.DeleteArcs(i);
  453. for (auto &arc : arcs) pfst.AddArc(i, std::move(arc));
  454. }
  455. std::vector<Label> toporder;
  456. if (!GetTopOrder(pfst, &toporder)) {
  457. ClearDependencies();
  458. return;
  459. }
  460. // Visits FSTs in reverse topological order of dependencies and performs
  461. // replacements.
  462. for (Label o = toporder.size() - 1; o >= 0; --o) {
  463. std::vector<FstPair> fst_pairs;
  464. auto s = toporder[o];
  465. for (ArcIterator<VectorFst<Arc>> aiter(pfst, s); !aiter.Done();
  466. aiter.Next()) {
  467. const auto &arc = aiter.Value();
  468. const auto label = nonterminal_array_[arc.nextstate];
  469. const auto *fst = fst_array_[arc.nextstate];
  470. fst_pairs.emplace_back(label, fst);
  471. }
  472. if (fst_pairs.empty()) continue;
  473. const auto label = nonterminal_array_[s];
  474. const auto *fst = fst_array_[s];
  475. fst_pairs.emplace_back(label, fst);
  476. const ReplaceUtilOptions opts(label, call_label_type_, return_label_type_,
  477. return_label_);
  478. Replace(fst_pairs, mutable_fst_array_[s], opts);
  479. }
  480. ClearDependencies();
  481. }
  482. template <class Arc>
  483. void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
  484. size_t nnonterms) {
  485. std::vector<Label> labels;
  486. GetDependencies(true);
  487. std::vector<Label> toporder;
  488. if (!GetTopOrder(depfst_, &toporder)) {
  489. ClearDependencies();
  490. return;
  491. }
  492. for (Label o = toporder.size() - 1; o >= 0; --o) {
  493. const auto j = toporder[o];
  494. if (stats_[j].nstates <= nstates && stats_[j].narcs <= narcs &&
  495. stats_[j].nnonterms <= nnonterms) {
  496. labels.push_back(nonterminal_array_[j]);
  497. UpdateStats(j);
  498. }
  499. }
  500. ReplaceLabels(labels);
  501. }
  502. template <class Arc>
  503. void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
  504. std::vector<Label> labels;
  505. GetDependencies(true);
  506. std::vector<Label> toporder;
  507. if (!GetTopOrder(depfst_, &toporder)) {
  508. ClearDependencies();
  509. return;
  510. }
  511. for (Label o = 0; o < toporder.size(); ++o) {
  512. const auto j = toporder[o];
  513. if (stats_[j].nref <= ninstances) {
  514. labels.push_back(nonterminal_array_[j]);
  515. UpdateStats(j);
  516. }
  517. }
  518. ReplaceLabels(labels);
  519. }
  520. template <class Arc>
  521. void ReplaceUtil<Arc>::GetFstPairs(std::vector<FstPair> *fst_pairs) {
  522. CheckMutableFsts();
  523. fst_pairs->clear();
  524. for (Label i = 0; i < fst_array_.size(); ++i) {
  525. const auto label = nonterminal_array_[i];
  526. const auto *fst = fst_array_[i];
  527. if (!fst) continue;
  528. fst_pairs->emplace_back(label, fst);
  529. }
  530. }
  531. template <class Arc>
  532. void ReplaceUtil<Arc>::GetMutableFstPairs(
  533. std::vector<MutableFstPair> *mutable_fst_pairs) {
  534. CheckMutableFsts();
  535. mutable_fst_pairs->clear();
  536. for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
  537. const auto label = nonterminal_array_[i];
  538. const auto *fst = mutable_fst_array_[i];
  539. if (!fst) continue;
  540. mutable_fst_pairs->emplace_back(label, fst->Copy());
  541. }
  542. }
  543. template <class Arc>
  544. void ReplaceUtil<Arc>::GetSCCProperties() const {
  545. if (!depsccprops_.empty()) return;
  546. GetDependencies(false);
  547. if (depscc_.empty()) return;
  548. for (StateId scc = 0; scc < depscc_.size(); ++scc) {
  549. depsccprops_.push_back(kReplaceSCCLeftLinear | kReplaceSCCRightLinear);
  550. }
  551. if (!(depprops_ & kCyclic)) return; // No cyclic dependencies.
  552. // Checks for self-loops in the dependency graph.
  553. for (StateId scc = 0; scc < depscc_.size(); ++scc) {
  554. for (ArcIterator<Fst<Arc>> aiter(depfst_, scc); !aiter.Done();
  555. aiter.Next()) {
  556. const auto &arc = aiter.Value();
  557. if (arc.nextstate == scc) { // SCC has a self loop.
  558. depsccprops_[scc] |= kReplaceSCCNonTrivial;
  559. }
  560. }
  561. }
  562. std::vector<bool> depscc_visited(depscc_.size(), false);
  563. for (Label i = 0; i < fst_array_.size(); ++i) {
  564. const auto *fst = fst_array_[i];
  565. if (!fst) continue;
  566. const auto depscc = depscc_[i];
  567. if (depscc_visited[depscc]) { // SCC has more than one state.
  568. depsccprops_[depscc] |= kReplaceSCCNonTrivial;
  569. }
  570. depscc_visited[depscc] = true;
  571. std::vector<StateId> fstscc; // SCCs of the current FST.
  572. uint64_t fstprops;
  573. SccVisitor<Arc> scc_visitor(&fstscc, nullptr, nullptr, &fstprops);
  574. DfsVisit(*fst, &scc_visitor);
  575. for (StateIterator<Fst<Arc>> siter(*fst); !siter.Done(); siter.Next()) {
  576. const auto s = siter.Value();
  577. for (ArcIterator<Fst<Arc>> aiter(*fst, s); !aiter.Done(); aiter.Next()) {
  578. const auto &arc = aiter.Value();
  579. auto it = nonterminal_hash_.find(arc.olabel);
  580. if (it == nonterminal_hash_.end() || depscc_[it->second] != depscc) {
  581. continue; // Skips if a terminal or a non-terminal not in SCC.
  582. }
  583. const bool arc_in_cycle = fstscc[s] == fstscc[arc.nextstate];
  584. // Left linear iff all non-terminals are initial.
  585. if (s != fst->Start() || arc_in_cycle) {
  586. depsccprops_[depscc] &= ~kReplaceSCCLeftLinear;
  587. }
  588. // Right linear iff all non-terminals are final.
  589. if (fst->Final(arc.nextstate) == Weight::Zero() || arc_in_cycle) {
  590. depsccprops_[depscc] &= ~kReplaceSCCRightLinear;
  591. }
  592. }
  593. }
  594. }
  595. }
  596. } // namespace fst
  597. #endif // FST_REPLACE_UTIL_H_