You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1533 lines
57 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes for the recursive replacement of FSTs.
  19. #ifndef FST_REPLACE_H_
  20. #define FST_REPLACE_H_
  21. #include <sys/types.h>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <memory>
  25. #include <set>
  26. #include <string>
  27. #include <utility>
  28. #include <vector>
  29. #include <fst/log.h>
  30. #include <fst/arc.h>
  31. #include <fst/bi-table.h>
  32. #include <fst/cache.h>
  33. #include <fst/expanded-fst.h>
  34. #include <fst/float-weight.h>
  35. #include <fst/fst-decl.h> // For optional argument declarations.
  36. #include <fst/fst.h>
  37. #include <fst/impl-to-fst.h>
  38. #include <fst/matcher.h>
  39. #include <fst/mutable-fst.h>
  40. #include <fst/properties.h>
  41. #include <fst/replace-util.h>
  42. #include <fst/state-table.h>
  43. #include <fst/symbol-table.h>
  44. #include <fst/util.h>
  45. #include <unordered_map>
  46. namespace fst {
  47. // Replace state tables have the form:
  48. //
  49. // template <class Arc, class P>
  50. // class ReplaceStateTable {
  51. // public:
  52. // using Label = typename Arc::Label Label;
  53. // using StateId = typename Arc::StateId;
  54. //
  55. // using PrefixId = P;
  56. // using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
  57. // using StackPrefix = ReplaceStackPrefix<Label, StateId>;
  58. //
  59. // // Required constructor.
  60. // ReplaceStateTable(
  61. // const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
  62. // Label root);
  63. //
  64. // // Required copy constructor that does not copy state.
  65. // ReplaceStateTable(const ReplaceStateTable<Arc, PrefixId> &table);
  66. //
  67. // // Looks up state ID by tuple, adding it if it doesn't exist.
  68. // StateId FindState(const StateTuple &tuple);
  69. //
  70. // // Looks up state tuple by ID.
  71. // const StateTuple &Tuple(StateId id) const;
  72. //
  73. // // Lookus up prefix ID by stack prefix, adding it if it doesn't exist.
  74. // PrefixId FindPrefixId(const StackPrefix &stack_prefix);
  75. //
  76. // // Looks up stack prefix by ID.
  77. // const StackPrefix &GetStackPrefix(PrefixId id) const;
  78. // };
  79. // Tuple that uniquely defines a state in replace.
  80. template <class S, class P>
  81. struct ReplaceStateTuple {
  82. using StateId = S;
  83. using PrefixId = P;
  84. ReplaceStateTuple(PrefixId prefix_id = -1, StateId fst_id = kNoStateId,
  85. StateId fst_state = kNoStateId)
  86. : prefix_id(prefix_id), fst_id(fst_id), fst_state(fst_state) {}
  87. template <typename H>
  88. friend H AbslHashValue(H h, const ReplaceStateTuple &t) {
  89. return H::combine(std::move(h), t.prefix_id, t.fst_id, t.fst_state);
  90. }
  91. PrefixId prefix_id; // Index in prefix table.
  92. StateId fst_id; // Current FST being walked.
  93. StateId fst_state; // Current state in FST being walked (not to be
  94. // confused with the thse StateId of the combined FST).
  95. };
  96. // Equality of replace state tuples.
  97. template <class StateId, class PrefixId>
  98. inline bool operator==(const ReplaceStateTuple<StateId, PrefixId> &x,
  99. const ReplaceStateTuple<StateId, PrefixId> &y) {
  100. return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id &&
  101. x.fst_state == y.fst_state;
  102. }
  103. // Functor returning true for tuples corresponding to states in the root FST.
  104. template <class StateId, class PrefixId>
  105. class ReplaceRootSelector {
  106. public:
  107. bool operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
  108. return tuple.prefix_id == 0;
  109. }
  110. };
  111. // Functor for fingerprinting replace state tuples.
  112. template <class StateId, class PrefixId>
  113. class ReplaceFingerprint {
  114. public:
  115. explicit ReplaceFingerprint(const std::vector<uint64_t> *size_array)
  116. : size_array_(size_array) {}
  117. uint64_t operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
  118. return tuple.prefix_id * size_array_->back() +
  119. size_array_->at(tuple.fst_id - 1) + tuple.fst_state;
  120. }
  121. private:
  122. const std::vector<uint64_t> *size_array_;
  123. };
  124. // Useful when the fst_state uniquely define the tuple.
  125. template <class StateId, class PrefixId>
  126. class ReplaceFstStateFingerprint {
  127. public:
  128. uint64_t operator()(const ReplaceStateTuple<StateId, PrefixId> &tuple) const {
  129. return tuple.fst_state;
  130. }
  131. };
  132. // A generic hash function for replace state tuples.
  133. template <typename S, typename P>
  134. class ReplaceHash {
  135. public:
  136. size_t operator()(const ReplaceStateTuple<S, P> &t) const {
  137. // We want three prime numbers that are all reasonably large and whose
  138. // differences are far from each other. (E.g., we want prime1-prime0 to be
  139. // far from prime2-prime1). It would be safer still to use large prime
  140. // numbers (i.e., prime numbers that use all 64 bits on a 64-bit machine and
  141. // all 32-bits on a 32-bit machine), so that all 64-bits (respectively
  142. // 32-bits) of the resulting hash would look random. However, these
  143. // modest-sized prime numbers are good enough for hash tables (such as
  144. // std::unordered_set and std::unordered_set) that use the low-order bits
  145. // of the hash.
  146. //
  147. // It is important that all three components are multiplied by a prime
  148. // number. E.g., don't compute
  149. // t.prefix_id + t.fst_id * prime1 + t.fst_state * prime2
  150. // which is just the identity on t.prefix_id. Using the identity will
  151. // result in long probe sequences in open-addressed hash tables (such as
  152. // std::unordered_map).
  153. static constexpr size_t prime0 = 7853;
  154. static constexpr size_t prime1 = 9001;
  155. static constexpr size_t prime2 = 100003;
  156. return t.prefix_id * prime0 + t.fst_id * prime1 + t.fst_state * prime2;
  157. }
  158. };
  159. // Container for stack prefix.
  160. template <class Label, class StateId>
  161. class ReplaceStackPrefix {
  162. public:
  163. struct PrefixTuple {
  164. PrefixTuple(Label fst_id = kNoLabel, StateId nextstate = kNoStateId)
  165. : fst_id(fst_id), nextstate(nextstate) {}
  166. Label fst_id;
  167. StateId nextstate;
  168. };
  169. ReplaceStackPrefix() = default;
  170. ReplaceStackPrefix(const ReplaceStackPrefix &other)
  171. : prefix_(other.prefix_) {}
  172. void Push(StateId fst_id, StateId nextstate) {
  173. prefix_.push_back(PrefixTuple(fst_id, nextstate));
  174. }
  175. void Pop() { prefix_.pop_back(); }
  176. const PrefixTuple &Top() const { return prefix_[prefix_.size() - 1]; }
  177. size_t Depth() const { return prefix_.size(); }
  178. public:
  179. std::vector<PrefixTuple> prefix_;
  180. };
  181. // Equality stack prefix classes.
  182. template <class Label, class StateId>
  183. inline bool operator==(const ReplaceStackPrefix<Label, StateId> &x,
  184. const ReplaceStackPrefix<Label, StateId> &y) {
  185. if (x.prefix_.size() != y.prefix_.size()) return false;
  186. for (size_t i = 0; i < x.prefix_.size(); ++i) {
  187. if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
  188. x.prefix_[i].nextstate != y.prefix_[i].nextstate) {
  189. return false;
  190. }
  191. }
  192. return true;
  193. }
  194. // Hash function for stack prefix to prefix id.
  195. template <class Label, class StateId>
  196. class ReplaceStackPrefixHash {
  197. public:
  198. size_t operator()(const ReplaceStackPrefix<Label, StateId> &prefix) const {
  199. size_t sum = 0;
  200. for (const auto &pair : prefix.prefix_) {
  201. static constexpr size_t prime = 7863;
  202. sum += pair.fst_id + pair.nextstate * prime;
  203. }
  204. return sum;
  205. }
  206. };
  207. // Replace state tables.
  208. // A two-level state table for replace. Warning: calls CountStates to compute
  209. // the number of states of each component FST.
  210. template <class Arc, class P = ssize_t>
  211. class VectorHashReplaceStateTable {
  212. public:
  213. using Label = typename Arc::Label;
  214. using StateId = typename Arc::StateId;
  215. using PrefixId = P;
  216. using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
  217. using StateTable =
  218. VectorHashStateTable<ReplaceStateTuple<StateId, PrefixId>,
  219. ReplaceRootSelector<StateId, PrefixId>,
  220. ReplaceFstStateFingerprint<StateId, PrefixId>,
  221. ReplaceFingerprint<StateId, PrefixId>>;
  222. using StackPrefix = ReplaceStackPrefix<Label, StateId>;
  223. using StackPrefixTable =
  224. CompactHashBiTable<PrefixId, StackPrefix,
  225. ReplaceStackPrefixHash<Label, StateId>>;
  226. VectorHashReplaceStateTable(
  227. const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_list,
  228. Label root)
  229. : root_size_(0) {
  230. size_array_.push_back(0);
  231. for (const auto &[label, fst] : fst_list) {
  232. if (label == root) {
  233. root_size_ = CountStates(*fst);
  234. size_array_.push_back(size_array_.back());
  235. } else {
  236. size_array_.push_back(size_array_.back() + CountStates(*fst));
  237. }
  238. }
  239. state_table_ = std::make_unique<StateTable>(
  240. ReplaceRootSelector<StateId, PrefixId>(),
  241. ReplaceFstStateFingerprint<StateId, PrefixId>(),
  242. ReplaceFingerprint<StateId, PrefixId>(&size_array_), root_size_,
  243. root_size_ + size_array_.back());
  244. }
  245. VectorHashReplaceStateTable(
  246. const VectorHashReplaceStateTable<Arc, PrefixId> &table)
  247. : root_size_(table.root_size_),
  248. size_array_(table.size_array_),
  249. prefix_table_(table.prefix_table_) {
  250. state_table_ = std::make_unique<StateTable>(
  251. ReplaceRootSelector<StateId, PrefixId>(),
  252. ReplaceFstStateFingerprint<StateId, PrefixId>(),
  253. ReplaceFingerprint<StateId, PrefixId>(&size_array_), root_size_,
  254. root_size_ + size_array_.back());
  255. }
  256. StateId FindState(const StateTuple &tuple) {
  257. return state_table_->FindState(tuple);
  258. }
  259. const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); }
  260. PrefixId FindPrefixId(const StackPrefix &prefix) {
  261. return prefix_table_.FindId(prefix);
  262. }
  263. const StackPrefix &GetStackPrefix(PrefixId id) const {
  264. return prefix_table_.FindEntry(id);
  265. }
  266. private:
  267. StateId root_size_;
  268. std::vector<uint64_t> size_array_;
  269. std::unique_ptr<StateTable> state_table_;
  270. StackPrefixTable prefix_table_;
  271. };
  272. // Default replace state table.
  273. template <class Arc, class P /* = size_t */>
  274. class DefaultReplaceStateTable
  275. : public CompactHashStateTable<ReplaceStateTuple<typename Arc::StateId, P>,
  276. ReplaceHash<typename Arc::StateId, P>> {
  277. public:
  278. using Label = typename Arc::Label;
  279. using StateId = typename Arc::StateId;
  280. using PrefixId = P;
  281. using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
  282. using StateTable =
  283. CompactHashStateTable<StateTuple, ReplaceHash<StateId, PrefixId>>;
  284. using StackPrefix = ReplaceStackPrefix<Label, StateId>;
  285. using StackPrefixTable =
  286. CompactHashBiTable<PrefixId, StackPrefix,
  287. ReplaceStackPrefixHash<Label, StateId>>;
  288. using StateTable::FindState;
  289. using StateTable::Tuple;
  290. DefaultReplaceStateTable(
  291. const std::vector<std::pair<Label, const Fst<Arc> *>> &, Label) {}
  292. DefaultReplaceStateTable(const DefaultReplaceStateTable<Arc, PrefixId> &table)
  293. : StateTable(), prefix_table_(table.prefix_table_) {}
  294. PrefixId FindPrefixId(const StackPrefix &prefix) {
  295. return prefix_table_.FindId(prefix);
  296. }
  297. const StackPrefix &GetStackPrefix(PrefixId id) const {
  298. return prefix_table_.FindEntry(id);
  299. }
  300. private:
  301. StackPrefixTable prefix_table_;
  302. };
  303. // By default ReplaceFst will copy the input label of the replace arc.
  304. // The call_label_type and return_label_type options specify how to manage
  305. // the labels of the call arc and the return arc of the replace FST
  306. template <class Arc, class StateTable = DefaultReplaceStateTable<Arc>,
  307. class CacheStore = DefaultCacheStore<Arc>>
  308. struct ReplaceFstOptions : CacheImplOptions<CacheStore> {
  309. using Label = typename Arc::Label;
  310. // Index of root rule for expansion.
  311. Label root;
  312. // How to label call arc.
  313. ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT;
  314. // How to label return arc.
  315. ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER;
  316. // Specifies output label to put on call arc; if kNoLabel, use existing label
  317. // on call arc. Otherwise, use this field as the output label.
  318. Label call_output_label = kNoLabel;
  319. // Specifies label to put on return arc.
  320. Label return_label = 0;
  321. // Take ownership of input FSTs?
  322. bool take_ownership = false;
  323. // Pointer to optional pre-constructed state table.
  324. StateTable *state_table = nullptr;
  325. explicit ReplaceFstOptions(const CacheImplOptions<CacheStore> &opts,
  326. Label root = kNoLabel)
  327. : CacheImplOptions<CacheStore>(opts), root(root) {}
  328. explicit ReplaceFstOptions(const CacheOptions &opts, Label root = kNoLabel)
  329. : CacheImplOptions<CacheStore>(opts), root(root) {}
  330. // FIXME(kbg): There are too many constructors here. Come up with a consistent
  331. // position for call_output_label (probably the very end) so that it is
  332. // possible to express all the remaining constructors with a single
  333. // default-argument constructor. Also move clients off of the "backwards
  334. // compatibility" constructor, for good.
  335. explicit ReplaceFstOptions(Label root) : root(root) {}
  336. explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
  337. ReplaceLabelType return_label_type,
  338. Label return_label)
  339. : root(root),
  340. call_label_type(call_label_type),
  341. return_label_type(return_label_type),
  342. return_label(return_label) {}
  343. explicit ReplaceFstOptions(Label root, ReplaceLabelType call_label_type,
  344. ReplaceLabelType return_label_type,
  345. Label call_output_label, Label return_label)
  346. : root(root),
  347. call_label_type(call_label_type),
  348. return_label_type(return_label_type),
  349. call_output_label(call_output_label),
  350. return_label(return_label) {}
  351. explicit ReplaceFstOptions(const ReplaceUtilOptions &opts)
  352. : ReplaceFstOptions(opts.root, opts.call_label_type,
  353. opts.return_label_type, opts.return_label) {}
  354. ReplaceFstOptions() : root(kNoLabel) {}
  355. // For backwards compatibility.
  356. ReplaceFstOptions(int64_t root, bool epsilon_replace_arc)
  357. : root(root),
  358. call_label_type(epsilon_replace_arc ? REPLACE_LABEL_NEITHER
  359. : REPLACE_LABEL_INPUT),
  360. call_output_label(epsilon_replace_arc ? 0 : kNoLabel) {}
  361. };
  362. // Forward declaration.
  363. template <class Arc, class StateTable, class CacheStore>
  364. class ReplaceFstMatcher;
  365. template <class Arc>
  366. using FstList = std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>;
  367. // Returns true if label type on arc results in epsilon input label.
  368. inline bool EpsilonOnInput(ReplaceLabelType label_type) {
  369. return label_type == REPLACE_LABEL_NEITHER ||
  370. label_type == REPLACE_LABEL_OUTPUT;
  371. }
  372. // Returns true if label type on arc results in epsilon input label.
  373. inline bool EpsilonOnOutput(ReplaceLabelType label_type) {
  374. return label_type == REPLACE_LABEL_NEITHER ||
  375. label_type == REPLACE_LABEL_INPUT;
  376. }
  377. // Returns true if for either the call or return arc ilabel != olabel.
  378. template <class Label>
  379. bool ReplaceTransducer(ReplaceLabelType call_label_type,
  380. ReplaceLabelType return_label_type,
  381. Label call_output_label) {
  382. return call_label_type == REPLACE_LABEL_INPUT ||
  383. call_label_type == REPLACE_LABEL_OUTPUT ||
  384. (call_label_type == REPLACE_LABEL_BOTH &&
  385. call_output_label != kNoLabel) ||
  386. return_label_type == REPLACE_LABEL_INPUT ||
  387. return_label_type == REPLACE_LABEL_OUTPUT;
  388. }
  389. template <class Arc>
  390. uint64_t ReplaceFstProperties(typename Arc::Label root_label,
  391. const FstList<Arc> &fst_list,
  392. ReplaceLabelType call_label_type,
  393. ReplaceLabelType return_label_type,
  394. typename Arc::Label call_output_label,
  395. bool *sorted_and_non_empty) {
  396. using Label = typename Arc::Label;
  397. std::vector<uint64_t> inprops;
  398. bool all_ilabel_sorted = true;
  399. bool all_olabel_sorted = true;
  400. bool all_non_empty = true;
  401. // All nonterminals are negative?
  402. bool all_negative = true;
  403. // All nonterminals are positive and form a dense range containing 1?
  404. bool dense_range = true;
  405. Label root_fst_idx = 0;
  406. for (Label i = 0; i < fst_list.size(); ++i) {
  407. const auto label = fst_list[i].first;
  408. if (label >= 0) all_negative = false;
  409. if (label > fst_list.size() || label <= 0) dense_range = false;
  410. if (label == root_label) root_fst_idx = i;
  411. const auto *fst = fst_list[i].second;
  412. if (fst->Start() == kNoStateId) all_non_empty = false;
  413. if (!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false;
  414. if (!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false;
  415. inprops.push_back(fst->Properties(kCopyProperties, false));
  416. }
  417. const auto props = ReplaceProperties(
  418. inprops, root_fst_idx, EpsilonOnInput(call_label_type),
  419. EpsilonOnInput(return_label_type), EpsilonOnOutput(call_label_type),
  420. EpsilonOnOutput(return_label_type),
  421. ReplaceTransducer(call_label_type, return_label_type, call_output_label),
  422. all_non_empty, all_ilabel_sorted, all_olabel_sorted,
  423. all_negative || dense_range);
  424. const bool sorted = props & (kILabelSorted | kOLabelSorted);
  425. *sorted_and_non_empty = all_non_empty && sorted;
  426. return props;
  427. }
  428. namespace internal {
  429. // The replace implementation class supports a dynamic expansion of a recursive
  430. // transition network represented as label/FST pairs with dynamic replacable
  431. // arcs.
  432. template <class Arc, class StateTable, class CacheStore>
  433. class ReplaceFstImpl
  434. : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
  435. public:
  436. using Label = typename Arc::Label;
  437. using StateId = typename Arc::StateId;
  438. using Weight = typename Arc::Weight;
  439. using State = typename CacheStore::State;
  440. using CacheImpl = CacheBaseImpl<State, CacheStore>;
  441. using PrefixId = typename StateTable::PrefixId;
  442. using StateTuple = ReplaceStateTuple<StateId, PrefixId>;
  443. using StackPrefix = ReplaceStackPrefix<Label, StateId>;
  444. using NonTerminalHash = std::unordered_map<Label, Label>;
  445. using FstImpl<Arc>::SetType;
  446. using FstImpl<Arc>::SetProperties;
  447. using FstImpl<Arc>::WriteHeader;
  448. using FstImpl<Arc>::SetInputSymbols;
  449. using FstImpl<Arc>::SetOutputSymbols;
  450. using FstImpl<Arc>::InputSymbols;
  451. using FstImpl<Arc>::OutputSymbols;
  452. using CacheImpl::HasArcs;
  453. using CacheImpl::HasFinal;
  454. using CacheImpl::HasStart;
  455. using CacheImpl::PushArc;
  456. using CacheImpl::SetArcs;
  457. using CacheImpl::SetFinal;
  458. using CacheImpl::SetStart;
  459. friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
  460. ReplaceFstImpl(const FstList<Arc> &fst_list,
  461. const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
  462. : CacheImpl(opts),
  463. call_label_type_(opts.call_label_type),
  464. return_label_type_(opts.return_label_type),
  465. call_output_label_(opts.call_output_label),
  466. return_label_(opts.return_label),
  467. state_table_(opts.state_table ? opts.state_table
  468. : new StateTable(fst_list, opts.root)) {
  469. SetType("replace");
  470. // If the label is epsilon, then all replace label options are equivalent,
  471. // so we set the label types to NEITHER for simplicity.
  472. if (call_output_label_ == 0) call_label_type_ = REPLACE_LABEL_NEITHER;
  473. if (return_label_ == 0) return_label_type_ = REPLACE_LABEL_NEITHER;
  474. if (!fst_list.empty()) {
  475. SetInputSymbols(fst_list[0].second->InputSymbols());
  476. SetOutputSymbols(fst_list[0].second->OutputSymbols());
  477. }
  478. fst_array_.push_back(nullptr);
  479. for (Label i = 0; i < fst_list.size(); ++i) {
  480. const auto label = fst_list[i].first;
  481. const auto *fst = fst_list[i].second;
  482. nonterminal_hash_[label] = fst_array_.size();
  483. nonterminal_set_.insert(label);
  484. fst_array_.emplace_back(opts.take_ownership ? fst : fst->Copy());
  485. if (i) {
  486. if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
  487. FSTERROR() << "ReplaceFstImpl: Input symbols of FST " << i
  488. << " do not match input symbols of base FST (0th FST)";
  489. SetProperties(kError, kError);
  490. }
  491. if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
  492. FSTERROR() << "ReplaceFstImpl: Output symbols of FST " << i
  493. << " do not match output symbols of base FST (0th FST)";
  494. SetProperties(kError, kError);
  495. }
  496. }
  497. }
  498. const auto nonterminal = nonterminal_hash_[opts.root];
  499. if ((nonterminal == 0) && (fst_array_.size() > 1)) {
  500. FSTERROR() << "ReplaceFstImpl: No FST corresponding to root label "
  501. << opts.root << " in the input tuple vector";
  502. SetProperties(kError, kError);
  503. }
  504. root_ = (nonterminal > 0) ? nonterminal : 1;
  505. bool all_non_empty_and_sorted = false;
  506. SetProperties(ReplaceFstProperties(opts.root, fst_list, call_label_type_,
  507. return_label_type_, call_output_label_,
  508. &all_non_empty_and_sorted));
  509. // Enables optional caching as long as sorted and all non-empty.
  510. always_cache_ = !all_non_empty_and_sorted;
  511. VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
  512. << (always_cache_ ? "true" : "false");
  513. }
  514. ReplaceFstImpl(const ReplaceFstImpl &impl)
  515. : CacheImpl(impl),
  516. call_label_type_(impl.call_label_type_),
  517. return_label_type_(impl.return_label_type_),
  518. call_output_label_(impl.call_output_label_),
  519. return_label_(impl.return_label_),
  520. always_cache_(impl.always_cache_),
  521. state_table_(new StateTable(*(impl.state_table_))),
  522. nonterminal_set_(impl.nonterminal_set_),
  523. nonterminal_hash_(impl.nonterminal_hash_),
  524. root_(impl.root_) {
  525. SetType("replace");
  526. SetProperties(impl.Properties(), kCopyProperties);
  527. SetInputSymbols(impl.InputSymbols());
  528. SetOutputSymbols(impl.OutputSymbols());
  529. fst_array_.reserve(impl.fst_array_.size());
  530. fst_array_.emplace_back(nullptr);
  531. for (Label i = 1; i < impl.fst_array_.size(); ++i) {
  532. fst_array_.emplace_back(impl.fst_array_[i]->Copy(true));
  533. }
  534. }
  535. // Computes the dependency graph of the replace class and returns
  536. // true if the dependencies are cyclic. Cyclic dependencies will result
  537. // in an un-expandable FST.
  538. bool CyclicDependencies() const {
  539. const ReplaceUtilOptions opts(root_);
  540. ReplaceUtil<Arc> replace_util(fst_array_, nonterminal_hash_, opts);
  541. return replace_util.CyclicDependencies();
  542. }
  543. StateId Start() {
  544. if (!HasStart()) {
  545. if (fst_array_.size() == 1) {
  546. SetStart(kNoStateId);
  547. return kNoStateId;
  548. } else {
  549. const auto fst_start = fst_array_[root_]->Start();
  550. if (fst_start == kNoStateId) return kNoStateId;
  551. const auto prefix = GetPrefixId(StackPrefix());
  552. const auto start =
  553. state_table_->FindState(StateTuple(prefix, root_, fst_start));
  554. SetStart(start);
  555. return start;
  556. }
  557. } else {
  558. return CacheImpl::Start();
  559. }
  560. }
  561. Weight Final(StateId s) {
  562. if (HasFinal(s)) return CacheImpl::Final(s);
  563. const auto &tuple = state_table_->Tuple(s);
  564. auto weight = Weight::Zero();
  565. if (tuple.prefix_id == 0) {
  566. const auto fst_state = tuple.fst_state;
  567. weight = fst_array_[tuple.fst_id]->Final(fst_state);
  568. }
  569. if (always_cache_ || HasArcs(s)) SetFinal(s, weight);
  570. return weight;
  571. }
  572. size_t NumArcs(StateId s) {
  573. if (HasArcs(s)) {
  574. return CacheImpl::NumArcs(s);
  575. } else if (always_cache_) { // If always caching, expands and caches state.
  576. Expand(s);
  577. return CacheImpl::NumArcs(s);
  578. } else { // Otherwise computes the number of arcs without expanding.
  579. const auto tuple = state_table_->Tuple(s);
  580. if (tuple.fst_state == kNoStateId) return 0;
  581. auto num_arcs = fst_array_[tuple.fst_id]->NumArcs(tuple.fst_state);
  582. if (ComputeFinalArc(tuple, nullptr)) ++num_arcs;
  583. return num_arcs;
  584. }
  585. }
  586. // Returns whether a given label is a non-terminal.
  587. bool IsNonTerminal(Label label) const {
  588. if (label < *nonterminal_set_.begin() ||
  589. label > *nonterminal_set_.rbegin()) {
  590. return false;
  591. } else {
  592. return nonterminal_hash_.count(label);
  593. }
  594. // TODO(allauzen): be smarter and take advantage of all_dense or
  595. // all_negative. Also use this in ComputeArc. This would require changes to
  596. // Replace so that recursing into an empty FST lead to a non co-accessible
  597. // state instead of deleting the arc as done currently. The current use
  598. // correct, since labels are sorted if all_non_empty is true.
  599. }
  600. size_t NumInputEpsilons(StateId s) {
  601. if (HasArcs(s)) {
  602. return CacheImpl::NumInputEpsilons(s);
  603. } else if (always_cache_ || !Properties(kILabelSorted)) {
  604. // If always caching or if the number of input epsilons is too expensive
  605. // to compute without caching (i.e., not ilabel-sorted), then expands and
  606. // caches state.
  607. Expand(s);
  608. return CacheImpl::NumInputEpsilons(s);
  609. } else {
  610. // Otherwise, computes the number of input epsilons without caching.
  611. const auto tuple = state_table_->Tuple(s);
  612. if (tuple.fst_state == kNoStateId) return 0;
  613. size_t num = 0;
  614. if (!EpsilonOnInput(call_label_type_)) {
  615. // If EpsilonOnInput(c) is false, all input epsilon arcs
  616. // are also input epsilons arcs in the underlying machine.
  617. num = fst_array_[tuple.fst_id]->NumInputEpsilons(tuple.fst_state);
  618. } else {
  619. // Otherwise, one need to consider that all non-terminal arcs
  620. // in the underlying machine also become input epsilon arc.
  621. ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
  622. for (; !aiter.Done() && ((aiter.Value().ilabel == 0) ||
  623. IsNonTerminal(aiter.Value().olabel));
  624. aiter.Next()) {
  625. ++num;
  626. }
  627. }
  628. if (EpsilonOnInput(return_label_type_) &&
  629. ComputeFinalArc(tuple, nullptr)) {
  630. ++num;
  631. }
  632. return num;
  633. }
  634. }
  635. size_t NumOutputEpsilons(StateId s) {
  636. if (HasArcs(s)) {
  637. return CacheImpl::NumOutputEpsilons(s);
  638. } else if (always_cache_ || !Properties(kOLabelSorted)) {
  639. // If always caching or if the number of output epsilons is too expensive
  640. // to compute without caching (i.e., not olabel-sorted), then expands and
  641. // caches state.
  642. Expand(s);
  643. return CacheImpl::NumOutputEpsilons(s);
  644. } else {
  645. // Otherwise, computes the number of output epsilons without caching.
  646. const auto tuple = state_table_->Tuple(s);
  647. if (tuple.fst_state == kNoStateId) return 0;
  648. size_t num = 0;
  649. if (!EpsilonOnOutput(call_label_type_)) {
  650. // If EpsilonOnOutput(c) is false, all output epsilon arcs are also
  651. // output epsilons arcs in the underlying machine.
  652. num = fst_array_[tuple.fst_id]->NumOutputEpsilons(tuple.fst_state);
  653. } else {
  654. // Otherwise, one need to consider that all non-terminal arcs in the
  655. // underlying machine also become output epsilon arc.
  656. ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
  657. for (; !aiter.Done() && ((aiter.Value().olabel == 0) ||
  658. IsNonTerminal(aiter.Value().olabel));
  659. aiter.Next()) {
  660. ++num;
  661. }
  662. }
  663. if (EpsilonOnOutput(return_label_type_) &&
  664. ComputeFinalArc(tuple, nullptr)) {
  665. ++num;
  666. }
  667. return num;
  668. }
  669. }
  670. uint64_t Properties() const override { return Properties(kFstProperties); }
  671. // Sets error if found, and returns other FST impl properties.
  672. uint64_t Properties(uint64_t mask) const override {
  673. if (mask & kError) {
  674. for (Label i = 1; i < fst_array_.size(); ++i) {
  675. if (fst_array_[i]->Properties(kError, false)) {
  676. SetProperties(kError, kError);
  677. }
  678. }
  679. }
  680. return FstImpl<Arc>::Properties(mask);
  681. }
  682. // Returns the base arc iterator, and if arcs have not been computed yet,
  683. // extends and recurses for new arcs.
  684. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
  685. if (!HasArcs(s)) Expand(s);
  686. CacheImpl::InitArcIterator(s, data);
  687. // TODO(allauzen): Set behaviour of generic iterator.
  688. // Warning: ArcIterator<ReplaceFst<A>>::InitCache() relies on current
  689. // behaviour.
  690. }
  691. // Extends current state (walk arcs one level deep).
  692. void Expand(StateId s) {
  693. const auto tuple = state_table_->Tuple(s);
  694. if (tuple.fst_state == kNoStateId) { // Local FST is empty.
  695. SetArcs(s);
  696. return;
  697. }
  698. ArcIterator<Fst<Arc>> aiter(*fst_array_[tuple.fst_id], tuple.fst_state);
  699. Arc arc;
  700. // Creates a final arc when needed.
  701. if (ComputeFinalArc(tuple, &arc)) PushArc(s, std::move(arc));
  702. // Expands all arcs leaving the state.
  703. for (; !aiter.Done(); aiter.Next()) {
  704. if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, std::move(arc));
  705. }
  706. SetArcs(s);
  707. }
  708. void Expand(StateId s, const StateTuple &tuple,
  709. const ArcIteratorData<Arc> &data) {
  710. if (tuple.fst_state == kNoStateId) { // Local FST is empty.
  711. SetArcs(s);
  712. return;
  713. }
  714. ArcIterator<Fst<Arc>> aiter(data);
  715. Arc arc;
  716. // Creates a final arc when needed.
  717. if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc);
  718. // Expands all arcs leaving the state.
  719. for (; !aiter.Done(); aiter.Next()) {
  720. if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc);
  721. }
  722. SetArcs(s);
  723. }
  724. // If acpp is null, only returns true if a final arcp is required, but does
  725. // not actually compute it.
  726. bool ComputeFinalArc(const StateTuple &tuple, Arc *arcp,
  727. uint8_t flags = kArcValueFlags) {
  728. const auto fst_state = tuple.fst_state;
  729. if (fst_state == kNoStateId) return false;
  730. // If state is final, pops the stack.
  731. if (fst_array_[tuple.fst_id]->Final(fst_state) != Weight::Zero() &&
  732. tuple.prefix_id) {
  733. if (arcp) {
  734. arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
  735. arcp->olabel =
  736. (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
  737. if (flags & kArcNextStateValue) {
  738. const auto &stack = state_table_->GetStackPrefix(tuple.prefix_id);
  739. const auto prefix_id = PopPrefix(stack);
  740. const auto &top = stack.Top();
  741. arcp->nextstate = state_table_->FindState(
  742. StateTuple(prefix_id, top.fst_id, top.nextstate));
  743. }
  744. if (flags & kArcWeightValue) {
  745. arcp->weight = fst_array_[tuple.fst_id]->Final(fst_state);
  746. }
  747. }
  748. return true;
  749. } else {
  750. return false;
  751. }
  752. }
  753. // Computes an arc in the FST corresponding to one in the underlying machine.
  754. // Returns false if the underlying arc corresponds to no arc in the resulting
  755. // FST.
  756. bool ComputeArc(const StateTuple &tuple, const Arc &arc, Arc *arcp,
  757. uint8_t flags = kArcValueFlags) {
  758. if (!EpsilonOnInput(call_label_type_) &&
  759. (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
  760. *arcp = arc;
  761. return true;
  762. }
  763. if (arc.olabel == 0 || arc.olabel < *nonterminal_set_.begin() ||
  764. arc.olabel > *nonterminal_set_.rbegin()) { // Expands local FST.
  765. const auto nextstate =
  766. flags & kArcNextStateValue
  767. ? state_table_->FindState(
  768. StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
  769. : kNoStateId;
  770. *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
  771. } else {
  772. // Checks for non-terminal.
  773. if (const auto it = nonterminal_hash_.find(arc.olabel);
  774. it != nonterminal_hash_.end()) { // Recurses into non-terminal.
  775. const auto nonterminal = it->second;
  776. const auto nt_prefix =
  777. PushPrefix(state_table_->GetStackPrefix(tuple.prefix_id),
  778. tuple.fst_id, arc.nextstate);
  779. // If the start state is valid, replace; othewise, the arc is implicitly
  780. // deleted.
  781. const auto nt_start = fst_array_[nonterminal]->Start();
  782. if (nt_start != kNoStateId) {
  783. const auto nt_nextstate = flags & kArcNextStateValue
  784. ? state_table_->FindState(StateTuple(
  785. nt_prefix, nonterminal, nt_start))
  786. : kNoStateId;
  787. const auto ilabel =
  788. (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
  789. const auto olabel =
  790. (EpsilonOnOutput(call_label_type_))
  791. ? 0
  792. : ((call_output_label_ == kNoLabel) ? arc.olabel
  793. : call_output_label_);
  794. *arcp = Arc(ilabel, olabel, arc.weight, nt_nextstate);
  795. } else {
  796. return false;
  797. }
  798. } else {
  799. const auto nextstate =
  800. flags & kArcNextStateValue
  801. ? state_table_->FindState(
  802. StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
  803. : kNoStateId;
  804. *arcp = Arc(arc.ilabel, arc.olabel, arc.weight, nextstate);
  805. }
  806. }
  807. return true;
  808. }
  809. // Returns the arc iterator flags supported by this FST.
  810. uint8_t ArcIteratorFlags() const {
  811. uint8_t flags = kArcValueFlags;
  812. if (!always_cache_) flags |= kArcNoCache;
  813. return flags;
  814. }
  815. StateTable *GetStateTable() const { return state_table_.get(); }
  816. const Fst<Arc> *GetFst(Label fst_id) const {
  817. return fst_array_[fst_id].get();
  818. }
  819. Label GetFstId(Label nonterminal) const {
  820. const auto it = nonterminal_hash_.find(nonterminal);
  821. if (it == nonterminal_hash_.end()) {
  822. FSTERROR() << "ReplaceFstImpl::GetFstId: Nonterminal not found: "
  823. << nonterminal;
  824. }
  825. return it->second;
  826. }
  827. // Returns true if label type on call arc results in epsilon input label.
  828. bool EpsilonOnCallInput() { return EpsilonOnInput(call_label_type_); }
  829. private:
  830. // The unique index into stack prefix table.
  831. PrefixId GetPrefixId(const StackPrefix &prefix) {
  832. return state_table_->FindPrefixId(prefix);
  833. }
  834. // The prefix ID after a stack pop.
  835. PrefixId PopPrefix(StackPrefix prefix) {
  836. prefix.Pop();
  837. return GetPrefixId(prefix);
  838. }
  839. // The prefix ID after a stack push.
  840. PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
  841. prefix.Push(fst_id, nextstate);
  842. return GetPrefixId(prefix);
  843. }
  844. // Runtime options
  845. ReplaceLabelType call_label_type_; // How to label call arc.
  846. ReplaceLabelType return_label_type_; // How to label return arc.
  847. int64_t call_output_label_; // Specifies output label to put on call arc
  848. int64_t return_label_; // Specifies label to put on return arc.
  849. bool always_cache_; // Disable optional caching of arc iterator?
  850. // State table.
  851. std::unique_ptr<StateTable> state_table_;
  852. // Replace components.
  853. std::set<Label> nonterminal_set_;
  854. NonTerminalHash nonterminal_hash_;
  855. std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_;
  856. Label root_;
  857. };
  858. } // namespace internal
  859. //
  860. // ReplaceFst supports dynamic replacement of arcs in one FST with another FST.
  861. // This replacement is recursive. ReplaceFst can be used to support a variety of
  862. // delayed constructions such as recursive
  863. // transition networks, union, or closure. It is constructed with an array of
  864. // FST(s). One FST represents the root (or topology) machine. The root FST
  865. // refers to other FSTs by recursively replacing arcs labeled as non-terminals
  866. // with the matching non-terminal FST. Currently the ReplaceFst uses the output
  867. // symbols of the arcs to determine whether the arc is a non-terminal arc or
  868. // not. A non-terminal can be any label that is not a non-zero terminal label in
  869. // the output alphabet.
  870. //
  871. // Note that the constructor uses a vector of pairs. These correspond to the
  872. // tuple of non-terminal Label and corresponding FST. For example to implement
  873. // the closure operation we need 2 FSTs. The first root FST is a single
  874. // self-loop arc on the start state.
  875. //
  876. // The ReplaceFst class supports an optionally caching arc iterator.
  877. //
  878. // The ReplaceFst needs to be built such that it is known to be ilabel- or
  879. // olabel-sorted (see usage below).
  880. //
  881. // Observe that Matcher<Fst<A>> will use the optionally caching arc iterator
  882. // when available (the FST is ilabel-sorted and matching on the input, or the
  883. // FST is olabel -orted and matching on the output). In order to obtain the
  884. // most efficient behaviour, it is recommended to set call_label_type to
  885. // REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH and return_label_type to
  886. // REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER. This means that the call arc
  887. // does not have epsilon on the input side and the return arc has epsilon on the
  888. // input side) and matching on the input side.
  889. //
  890. // This class attaches interface to implementation and handles reference
  891. // counting, delegating most methods to ImplToFst.
  892. template <class A, class T /* = DefaultReplaceStateTable<A> */,
  893. class CacheStore /* = DefaultCacheStore<A> */>
  894. class ReplaceFst
  895. : public ImplToFst<internal::ReplaceFstImpl<A, T, CacheStore>> {
  896. public:
  897. using Arc = A;
  898. using Label = typename Arc::Label;
  899. using StateId = typename Arc::StateId;
  900. using Weight = typename Arc::Weight;
  901. using StateTable = T;
  902. using Store = CacheStore;
  903. using State = typename CacheStore::State;
  904. using Impl = internal::ReplaceFstImpl<Arc, StateTable, CacheStore>;
  905. using CacheImpl = internal::CacheBaseImpl<State, CacheStore>;
  906. using ImplToFst<Impl>::Properties;
  907. friend class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
  908. friend class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>;
  909. friend class ReplaceFstMatcher<Arc, StateTable, CacheStore>;
  910. ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
  911. Label root)
  912. : ImplToFst<Impl>(std::make_shared<Impl>(
  913. fst_array, ReplaceFstOptions<Arc, StateTable, CacheStore>(root))) {}
  914. ReplaceFst(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_array,
  915. const ReplaceFstOptions<Arc, StateTable, CacheStore> &opts)
  916. : ImplToFst<Impl>(std::make_shared<Impl>(fst_array, opts)) {}
  917. // See Fst<>::Copy() for doc.
  918. ReplaceFst(const ReplaceFst &fst, bool safe = false)
  919. : ImplToFst<Impl>(fst, safe) {}
  920. // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
  921. ReplaceFst *Copy(bool safe = false) const override {
  922. return new ReplaceFst(*this, safe);
  923. }
  924. inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
  925. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  926. GetMutableImpl()->InitArcIterator(s, data);
  927. }
  928. MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
  929. if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
  930. ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
  931. (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
  932. return new ReplaceFstMatcher<Arc, StateTable, CacheStore>(this,
  933. match_type);
  934. } else {
  935. VLOG(2) << "Not using replace matcher";
  936. return nullptr;
  937. }
  938. }
  939. bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); }
  940. const StateTable &GetStateTable() const {
  941. return *GetImpl()->GetStateTable();
  942. }
  943. const Fst<Arc> &GetFst(Label nonterminal) const {
  944. return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
  945. }
  946. private:
  947. using ImplToFst<Impl>::GetImpl;
  948. using ImplToFst<Impl>::GetMutableImpl;
  949. ReplaceFst &operator=(const ReplaceFst &) = delete;
  950. };
  951. // Specialization for ReplaceFst.
  952. template <class Arc, class StateTable, class CacheStore>
  953. class StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>
  954. : public CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
  955. public:
  956. explicit StateIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst)
  957. : CacheStateIterator<ReplaceFst<Arc, StateTable, CacheStore>>(
  958. fst, fst.GetMutableImpl()) {}
  959. };
  960. // Specialization for ReplaceFst, implementing optional caching. It is be used
  961. // as follows:
  962. //
  963. // ReplaceFst<A> replace;
  964. // ArcIterator<ReplaceFst<A>> aiter(replace, s);
  965. // // Note: ArcIterator< Fst<A>> is always a caching arc iterator.
  966. // aiter.SetFlags(kArcNoCache, kArcNoCache);
  967. // // Uses the arc iterator, no arc will be cached, no state will be expanded.
  968. // // Arc flags can be used to decide which component of the arc need to be
  969. // computed.
  970. // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
  971. // // Wants the ilabel for this arc.
  972. // aiter.Value(); // Does not compute the destination state.
  973. // aiter.Next();
  974. // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
  975. // // Wants the ilabel and next state for this arc.
  976. // aiter.Value(); // Does compute the destination state and inserts it
  977. // // in the replace state table.
  978. // // No additional arcs have been cached at this point.
  979. template <class Arc, class StateTable, class CacheStore>
  980. class ArcIterator<ReplaceFst<Arc, StateTable, CacheStore>> {
  981. public:
  982. using StateId = typename Arc::StateId;
  983. using StateTuple = typename StateTable::StateTuple;
  984. ArcIterator(const ReplaceFst<Arc, StateTable, CacheStore> &fst, StateId s)
  985. : fst_(fst),
  986. s_(s),
  987. pos_(0),
  988. offset_(0),
  989. flags_(kArcValueFlags),
  990. arcs_(nullptr),
  991. data_flags_(0),
  992. final_flags_(0) {
  993. cache_data_.ref_count = nullptr;
  994. local_data_.ref_count = nullptr;
  995. // If FST does not support optional caching, forces caching.
  996. if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
  997. !(fst_.GetImpl()->HasArcs(s_))) {
  998. fst_.GetMutableImpl()->Expand(s_);
  999. }
  1000. // If state is already cached, use cached arcs array.
  1001. if (fst_.GetImpl()->HasArcs(s_)) {
  1002. (fst_.GetImpl())
  1003. ->internal::template CacheBaseImpl<
  1004. typename CacheStore::State,
  1005. CacheStore>::InitArcIterator(s_, &cache_data_);
  1006. num_arcs_ = cache_data_.narcs;
  1007. arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
  1008. data_flags_ = kArcValueFlags; // All the arc member values are valid.
  1009. } else { // Otherwise delay decision until Value() is called.
  1010. tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_);
  1011. if (tuple_.fst_state == kNoStateId) {
  1012. num_arcs_ = 0;
  1013. } else {
  1014. // The decision to cache or not to cache has been defered until Value()
  1015. // or
  1016. // SetFlags() is called. However, the arc iterator is set up now to be
  1017. // ready for non-caching in order to keep the Value() method simple and
  1018. // efficient.
  1019. const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id);
  1020. rfst->InitArcIterator(tuple_.fst_state, &local_data_);
  1021. // arcs_ is a pointer to the arcs in the underlying machine.
  1022. arcs_ = local_data_.arcs;
  1023. // Computes the final arc (but not its destination state) if a final arc
  1024. // is required.
  1025. bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc(
  1026. tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue);
  1027. // Sets the arc value flags that hold for final_arc_.
  1028. final_flags_ = kArcValueFlags & ~kArcNextStateValue;
  1029. // Computes the number of arcs.
  1030. num_arcs_ = local_data_.narcs;
  1031. if (has_final_arc) ++num_arcs_;
  1032. // Sets the offset between the underlying arc positions and the
  1033. // positions
  1034. // in the arc iterator.
  1035. offset_ = num_arcs_ - local_data_.narcs;
  1036. // Defers the decision to cache or not until Value() or SetFlags() is
  1037. // called.
  1038. data_flags_ = 0;
  1039. }
  1040. }
  1041. }
  1042. ~ArcIterator() {
  1043. if (cache_data_.ref_count) --(*cache_data_.ref_count);
  1044. if (local_data_.ref_count) --(*local_data_.ref_count);
  1045. }
  1046. void ExpandAndCache() const {
  1047. // TODO(allauzen): revisit this.
  1048. // fst_.GetImpl()->Expand(s_, tuple_, local_data_);
  1049. // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(s_,
  1050. // &cache_data_);
  1051. //
  1052. fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state.
  1053. arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs.
  1054. data_flags_ = kArcValueFlags; // All the arc member values are valid.
  1055. offset_ = 0; // No offset.
  1056. }
  1057. void Init() {
  1058. if (flags_ & kArcNoCache) { // If caching is disabled
  1059. // arcs_ is a pointer to the arcs in the underlying machine.
  1060. arcs_ = local_data_.arcs;
  1061. // Sets the arcs value flags that hold for arcs_.
  1062. data_flags_ = kArcWeightValue;
  1063. if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) {
  1064. data_flags_ |= kArcILabelValue;
  1065. }
  1066. // Sets the offset between the underlying arc positions and the positions
  1067. // in the arc iterator.
  1068. offset_ = num_arcs_ - local_data_.narcs;
  1069. } else {
  1070. ExpandAndCache();
  1071. }
  1072. }
  1073. bool Done() const { return pos_ >= num_arcs_; }
  1074. const Arc &Value() const {
  1075. // If data_flags_ is 0, non-caching was not requested.
  1076. if (!data_flags_) {
  1077. // TODO(allauzen): Revisit this.
  1078. if (flags_ & kArcNoCache) {
  1079. // Should never happen.
  1080. FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags";
  1081. }
  1082. ExpandAndCache();
  1083. }
  1084. if (pos_ - offset_ >= 0) { // The requested arc is not the final arc.
  1085. const auto &arc = arcs_[pos_ - offset_];
  1086. if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
  1087. // If the value flags match the recquired value flags then returns the
  1088. // arc.
  1089. return arc;
  1090. } else {
  1091. // Otherwise, compute the corresponding arc on-the-fly.
  1092. fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_,
  1093. flags_ & kArcValueFlags);
  1094. return arc_;
  1095. }
  1096. } else { // The requested arc is the final arc.
  1097. if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
  1098. // If the arc value flags that hold for the final arc do not match the
  1099. // requested value flags, then
  1100. // final_arc_ needs to be updated.
  1101. fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_,
  1102. flags_ & kArcValueFlags);
  1103. final_flags_ = flags_ & kArcValueFlags;
  1104. }
  1105. return final_arc_;
  1106. }
  1107. }
  1108. void Next() { ++pos_; }
  1109. size_t Position() const { return pos_; }
  1110. void Reset() { pos_ = 0; }
  1111. void Seek(size_t pos) { pos_ = pos; }
  1112. uint8_t Flags() const { return flags_; }
  1113. void SetFlags(uint8_t flags, uint8_t mask) {
  1114. // Updates the flags taking into account what flags are supported
  1115. // by the FST.
  1116. flags_ &= ~mask;
  1117. flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags());
  1118. // If non-caching is not requested (and caching has not already been
  1119. // performed), then flush data_flags_ to request caching during the next
  1120. // call to Value().
  1121. if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
  1122. if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0;
  1123. }
  1124. // If data_flags_ has been flushed but non-caching is requested before
  1125. // calling Value(), then set up the iterator for non-caching.
  1126. if ((flags & kArcNoCache) && (!data_flags_)) Init();
  1127. }
  1128. private:
  1129. const ReplaceFst<Arc, StateTable, CacheStore> &fst_; // Reference to the FST.
  1130. StateId s_; // State in the FST.
  1131. mutable StateTuple tuple_; // Tuple corresponding to state_.
  1132. ssize_t pos_; // Current position.
  1133. mutable ssize_t offset_; // Offset between position in iterator and in arcs_.
  1134. ssize_t num_arcs_; // Number of arcs at state_.
  1135. uint8_t flags_; // Behavorial flags for the arc iterator
  1136. mutable Arc arc_; // Memory to temporarily store computed arcs.
  1137. mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache.
  1138. mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local FST.
  1139. mutable const Arc *arcs_; // Array of arcs.
  1140. mutable uint8_t data_flags_; // Arc value flags valid for data in arcs_.
  1141. mutable Arc final_arc_; // Final arc (when required).
  1142. mutable uint8_t final_flags_; // Arc value flags valid for final_arc_.
  1143. ArcIterator(const ArcIterator &) = delete;
  1144. ArcIterator &operator=(const ArcIterator &) = delete;
  1145. };
  1146. template <class Arc, class StateTable, class CacheStore>
  1147. class ReplaceFstMatcher : public MatcherBase<Arc> {
  1148. public:
  1149. using Label = typename Arc::Label;
  1150. using StateId = typename Arc::StateId;
  1151. using Weight = typename Arc::Weight;
  1152. using FST = ReplaceFst<Arc, StateTable, CacheStore>;
  1153. using LocalMatcher = MultiEpsMatcher<Matcher<Fst<Arc>>>;
  1154. using StateTuple = typename StateTable::StateTuple;
  1155. // This makes a copy of the FST.
  1156. ReplaceFstMatcher(const ReplaceFst<Arc, StateTable, CacheStore> &fst,
  1157. MatchType match_type)
  1158. : owned_fst_(fst.Copy()),
  1159. fst_(*owned_fst_),
  1160. impl_(fst_.GetMutableImpl()),
  1161. s_(fst::kNoStateId),
  1162. match_type_(match_type),
  1163. current_loop_(false),
  1164. final_arc_(false),
  1165. loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
  1166. if (match_type_ == fst::MATCH_OUTPUT) {
  1167. std::swap(loop_.ilabel, loop_.olabel);
  1168. }
  1169. InitMatchers();
  1170. }
  1171. // This doesn't copy the FST.
  1172. ReplaceFstMatcher(const ReplaceFst<Arc, StateTable, CacheStore> *fst,
  1173. MatchType match_type)
  1174. : fst_(*fst),
  1175. impl_(fst_.GetMutableImpl()),
  1176. s_(fst::kNoStateId),
  1177. match_type_(match_type),
  1178. current_loop_(false),
  1179. final_arc_(false),
  1180. loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
  1181. if (match_type_ == fst::MATCH_OUTPUT) {
  1182. std::swap(loop_.ilabel, loop_.olabel);
  1183. }
  1184. InitMatchers();
  1185. }
  1186. // This makes a copy of the FST.
  1187. ReplaceFstMatcher(const ReplaceFstMatcher &matcher, bool safe = false)
  1188. : owned_fst_(matcher.fst_.Copy(safe)),
  1189. fst_(*owned_fst_),
  1190. impl_(fst_.GetMutableImpl()),
  1191. s_(fst::kNoStateId),
  1192. match_type_(matcher.match_type_),
  1193. current_loop_(false),
  1194. final_arc_(false),
  1195. loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) {
  1196. if (match_type_ == fst::MATCH_OUTPUT) {
  1197. std::swap(loop_.ilabel, loop_.olabel);
  1198. }
  1199. InitMatchers();
  1200. }
  1201. // Creates a local matcher for each component FST in the RTN. LocalMatcher is
  1202. // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each
  1203. // non-terminal arc, since these non-terminal
  1204. // turn into epsilons on recursion.
  1205. void InitMatchers() {
  1206. const auto &fst_array = impl_->fst_array_;
  1207. matcher_.resize(fst_array.size());
  1208. for (Label i = 0; i < fst_array.size(); ++i) {
  1209. if (fst_array[i]) {
  1210. matcher_[i] = std::make_unique<LocalMatcher>(*fst_array[i], match_type_,
  1211. kMultiEpsList);
  1212. auto it = impl_->nonterminal_set_.begin();
  1213. for (; it != impl_->nonterminal_set_.end(); ++it) {
  1214. matcher_[i]->AddMultiEpsLabel(*it);
  1215. }
  1216. }
  1217. }
  1218. }
  1219. ReplaceFstMatcher *Copy(bool safe = false) const override {
  1220. return new ReplaceFstMatcher(*this, safe);
  1221. }
  1222. MatchType Type(bool test) const override {
  1223. if (match_type_ == MATCH_NONE) return match_type_;
  1224. const auto true_prop =
  1225. match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
  1226. const auto false_prop =
  1227. match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
  1228. const auto props = fst_.Properties(true_prop | false_prop, test);
  1229. if (props & true_prop) {
  1230. return match_type_;
  1231. } else if (props & false_prop) {
  1232. return MATCH_NONE;
  1233. } else {
  1234. return MATCH_UNKNOWN;
  1235. }
  1236. }
  1237. const Fst<Arc> &GetFst() const override { return fst_; }
  1238. uint64_t Properties(uint64_t props) const override { return props; }
  1239. // Sets the state from which our matching happens.
  1240. void SetState(StateId s) final {
  1241. if (s_ == s) return;
  1242. s_ = s;
  1243. tuple_ = impl_->GetStateTable()->Tuple(s_);
  1244. if (tuple_.fst_state == kNoStateId) {
  1245. done_ = true;
  1246. return;
  1247. }
  1248. // Gets current matcher, used for non-epsilon matching.
  1249. current_matcher_ = matcher_[tuple_.fst_id].get();
  1250. current_matcher_->SetState(tuple_.fst_state);
  1251. loop_.nextstate = s_;
  1252. final_arc_ = false;
  1253. }
  1254. // Searches for label from previous set state. If label == 0, first
  1255. // hallucinate an epsilon loop; otherwise use the underlying matcher to
  1256. // search for the label or epsilons. Note since the ReplaceFst recursion
  1257. // on non-terminal arcs causes epsilon transitions to be created we use
  1258. // MultiEpsilonMatcher to search for possible matches of non-terminals. If the
  1259. // component FST
  1260. // reaches a final state we also need to add the exiting final arc.
  1261. bool Find(Label label) final {
  1262. bool found = false;
  1263. label_ = label;
  1264. if (label_ == 0 || label_ == kNoLabel) {
  1265. // Computes loop directly, avoiding Replace::ComputeArc.
  1266. if (label_ == 0) {
  1267. current_loop_ = true;
  1268. found = true;
  1269. }
  1270. // Searches for matching multi-epsilons.
  1271. final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr);
  1272. found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
  1273. } else {
  1274. // Searches on a sub machine directly using sub machine matcher.
  1275. found = current_matcher_->Find(label_);
  1276. }
  1277. return found;
  1278. }
  1279. bool Done() const final {
  1280. return !current_loop_ && !final_arc_ && current_matcher_->Done();
  1281. }
  1282. const Arc &Value() const final {
  1283. if (current_loop_) return loop_;
  1284. if (final_arc_) {
  1285. impl_->ComputeFinalArc(tuple_, &arc_);
  1286. return arc_;
  1287. }
  1288. const auto &component_arc = current_matcher_->Value();
  1289. impl_->ComputeArc(tuple_, component_arc, &arc_);
  1290. return arc_;
  1291. }
  1292. void Next() final {
  1293. if (current_loop_) {
  1294. current_loop_ = false;
  1295. return;
  1296. }
  1297. if (final_arc_) {
  1298. final_arc_ = false;
  1299. return;
  1300. }
  1301. current_matcher_->Next();
  1302. }
  1303. ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
  1304. private:
  1305. std::unique_ptr<const ReplaceFst<Arc, StateTable, CacheStore>> owned_fst_;
  1306. const ReplaceFst<Arc, StateTable, CacheStore> &fst_;
  1307. internal::ReplaceFstImpl<Arc, StateTable, CacheStore> *impl_;
  1308. LocalMatcher *current_matcher_;
  1309. std::vector<std::unique_ptr<LocalMatcher>> matcher_;
  1310. StateId s_; // Current state.
  1311. Label label_; // Current label.
  1312. MatchType match_type_; // Supplied by caller.
  1313. mutable bool done_;
  1314. mutable bool current_loop_; // Current arc is the implicit loop.
  1315. mutable bool final_arc_; // Current arc for exiting recursion.
  1316. mutable StateTuple tuple_; // Tuple corresponding to state_.
  1317. mutable Arc arc_;
  1318. Arc loop_;
  1319. ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete;
  1320. };
  1321. template <class Arc, class StateTable, class CacheStore>
  1322. inline void ReplaceFst<Arc, StateTable, CacheStore>::InitStateIterator(
  1323. StateIteratorData<Arc> *data) const {
  1324. data->base =
  1325. std::make_unique<StateIterator<ReplaceFst<Arc, StateTable, CacheStore>>>(
  1326. *this);
  1327. }
  1328. using StdReplaceFst = ReplaceFst<StdArc>;
  1329. // Recursively replaces arcs in the root FSTs with other FSTs.
  1330. // This version writes the result of replacement to an output MutableFst.
  1331. //
  1332. // Replace supports replacement of arcs in one Fst with another FST. This
  1333. // replacement is recursive. Replace takes an array of FST(s). One FST
  1334. // represents the root (or topology) machine. The root FST refers to other FSTs
  1335. // by recursively replacing arcs labeled as non-terminals with the matching
  1336. // non-terminal FST. Currently Replace uses the output symbols of the arcs to
  1337. // determine whether the arc is a non-terminal arc or not. A non-terminal can be
  1338. // any label that is not a non-zero terminal label in the output alphabet.
  1339. //
  1340. // Note that input argument is a vector of pairs. These correspond to the tuple
  1341. // of non-terminal Label and corresponding FST.
  1342. template <class Arc>
  1343. void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
  1344. &ifst_array,
  1345. MutableFst<Arc> *ofst,
  1346. ReplaceFstOptions<Arc> opts = ReplaceFstOptions<Arc>()) {
  1347. opts.gc = true;
  1348. opts.gc_limit = 0; // Caches only the last state for fastest copy.
  1349. *ofst = ReplaceFst<Arc>(ifst_array, opts);
  1350. }
  1351. template <class Arc>
  1352. void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
  1353. &ifst_array,
  1354. MutableFst<Arc> *ofst, const ReplaceUtilOptions &opts) {
  1355. Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
  1356. }
  1357. // For backwards compatibility.
  1358. template <class Arc>
  1359. void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
  1360. &ifst_array,
  1361. MutableFst<Arc> *ofst, typename Arc::Label root,
  1362. bool epsilon_on_replace) {
  1363. Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
  1364. }
  1365. template <class Arc>
  1366. void Replace(const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>>
  1367. &ifst_array,
  1368. MutableFst<Arc> *ofst, typename Arc::Label root) {
  1369. Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
  1370. }
  1371. } // namespace fst
  1372. #endif // FST_REPLACE_H_