You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1582 lines
51 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes to allow matching labels leaving FST states.
  19. #ifndef FST_MATCHER_H_
  20. #define FST_MATCHER_H_
  21. #include <sys/types.h>
  22. #include <algorithm>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <map>
  26. #include <memory>
  27. #include <optional>
  28. #include <tuple>
  29. #include <unordered_map>
  30. #include <utility>
  31. #include <fst/log.h>
  32. #include <fst/fst.h>
  33. #include <fst/mutable-fst.h> // for all internal FST accessors.
  34. #include <fst/properties.h>
  35. #include <fst/util.h>
  36. #include <unordered_map>
  37. #include <optional>
  38. namespace fst {
  39. // Matchers find and iterate through requested labels at FST states. In the
  40. // simplest form, these are just some associative map or search keyed on labels.
  41. // More generally, they may implement matching special labels that represent
  42. // sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher
  43. // interface is:
  44. //
  45. // template <class F>
  46. // class Matcher {
  47. // public:
  48. // using FST = F;
  49. // using Arc = typename FST::Arc;
  50. // using Label = typename Arc::Label;
  51. // using StateId = typename Arc::StateId;
  52. // using Weight = typename Arc::Weight;
  53. //
  54. // // Required constructors. Note:
  55. // // -- the constructors that copy the FST arg are useful for
  56. // // letting the matcher manage the FST through copies
  57. // // (esp with 'safe' copies); e.g. ComposeFst depends on this.
  58. // // -- the constructor that does not copy is useful when the
  59. // // the FST is mutated during the lifetime of the matcher
  60. // // (o.w. the matcher would have its own unmutated deep copy).
  61. //
  62. // // This makes a copy of the FST.
  63. // Matcher(const FST &fst, MatchType type);
  64. // // This doesn't copy the FST.
  65. // Matcher(const FST *fst, MatchType type);
  66. // // This makes a copy of the FST.
  67. // // See Copy() below.
  68. // Matcher(const Matcher &matcher, bool safe = false);
  69. //
  70. // // If safe = true, the copy is thread-safe. See Fst<>::Copy() for
  71. // // further doc.
  72. // Matcher *Copy(bool safe = false) const override;
  73. //
  74. // // Returns the match type that can be provided (depending on compatibility
  75. // // of the input FST). It is either the requested match type, MATCH_NONE,
  76. // // or MATCH_UNKNOWN. If test is false, a costly testing is avoided, but
  77. // // MATCH_UNKNOWN may be returned. If test is true, a definite answer is
  78. // // returned, but may involve more costly computation (e.g., visiting
  79. // // the FST).
  80. // // MatchType Type(bool test) const override;
  81. //
  82. // // Specifies the current state.
  83. // void SetState(StateId s) final;
  84. //
  85. // // Finds matches to a label at the current state, returning true if a match
  86. // // found. kNoLabel matches any non-consuming transitions, e.g., epsilon
  87. // // transitions, which do not require a matching symbol.
  88. // bool Find(Label label) final;
  89. //
  90. // // Iterator methods. Note that initially and after SetState() these have
  91. // // undefined behavior until Find() is called.
  92. //
  93. // bool Done() const final;
  94. //
  95. // const Arc &Value() const final;
  96. //
  97. // void Next() final;
  98. //
  99. // // Returns final weight of a state.
  100. // Weight Final(StateId) const final;
  101. //
  102. // // Indicates preference for being the side used for matching in
  103. // // composition. If the value is kRequirePriority, then it is
  104. // // mandatory that it be used. Calling this method without passing the
  105. // // current state of the matcher invalidates the state of the matcher.
  106. // ssize_t Priority(StateId s) final;
  107. //
  108. // // This specifies the known FST properties as viewed from this matcher. It
  109. // // takes as argument the input FST's known properties.
  110. // uint64_t Properties(uint64_t props) const override;
  111. //
  112. // // Returns matcher flags.
  113. // uint32_t Flags() const override;
  114. //
  115. // // Returns matcher FST.
  116. // const FST &GetFst() const override;
  117. // };
  118. // Basic matcher flags.
  119. // Matcher needs to be used as the matching side in composition for
  120. // at least one state (has kRequirePriority).
  121. inline constexpr uint32_t kRequireMatch = 0x00000001;
  122. // Flags used for basic matchers (see also lookahead.h).
  123. inline constexpr uint32_t kMatcherFlags = kRequireMatch;
  124. // Matcher priority that is mandatory.
  125. inline constexpr ssize_t kRequirePriority = -1;
  126. // Matcher interface, templated on the Arc definition; used for matcher
  127. // specializations that are returned by the InitMatcher FST method.
  128. template <class A>
  129. class MatcherBase {
  130. public:
  131. using Arc = A;
  132. using Label = typename Arc::Label;
  133. using StateId = typename Arc::StateId;
  134. using Weight = typename Arc::Weight;
  135. virtual ~MatcherBase() = default;
  136. // Virtual interface.
  137. virtual MatcherBase *Copy(bool safe = false) const = 0;
  138. virtual MatchType Type(bool) const = 0;
  139. virtual void SetState(StateId) = 0;
  140. virtual bool Find(Label) = 0;
  141. virtual bool Done() const = 0;
  142. virtual const Arc &Value() const = 0;
  143. virtual void Next() = 0;
  144. virtual const Fst<Arc> &GetFst() const = 0;
  145. virtual uint64_t Properties(uint64_t) const = 0;
  146. // Trivial implementations that can be used by derived classes. Full
  147. // devirtualization is expected for any derived class marked final.
  148. virtual uint32_t Flags() const { return 0; }
  149. virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); }
  150. virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); }
  151. };
  152. // A matcher that expects sorted labels on the side to be matched.
  153. // If match_type == MATCH_INPUT, epsilons match the implicit self-loop
  154. // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
  155. // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
  156. // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
  157. template <class F>
  158. class SortedMatcher : public MatcherBase<typename F::Arc> {
  159. public:
  160. using FST = F;
  161. using Arc = typename FST::Arc;
  162. using Label = typename Arc::Label;
  163. using StateId = typename Arc::StateId;
  164. using Weight = typename Arc::Weight;
  165. using MatcherBase<Arc>::Flags;
  166. using MatcherBase<Arc>::Properties;
  167. // Labels >= binary_label will be searched for by binary search;
  168. // o.w. linear search is used.
  169. // This makes a copy of the FST.
  170. SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1)
  171. : SortedMatcher(fst.Copy(), match_type, binary_label) {
  172. owned_fst_.reset(&fst_);
  173. }
  174. // Labels >= binary_label will be searched for by binary search;
  175. // o.w. linear search is used.
  176. // This doesn't copy the FST.
  177. SortedMatcher(const FST *fst, MatchType match_type, Label binary_label = 1)
  178. : fst_(*fst),
  179. state_(kNoStateId),
  180. aiter_(std::nullopt),
  181. match_type_(match_type),
  182. binary_label_(binary_label),
  183. match_label_(kNoLabel),
  184. narcs_(0),
  185. loop_(kNoLabel, 0, Weight::One(), kNoStateId),
  186. error_(false) {
  187. switch (match_type_) {
  188. case MATCH_INPUT:
  189. case MATCH_NONE:
  190. break;
  191. case MATCH_OUTPUT:
  192. std::swap(loop_.ilabel, loop_.olabel);
  193. break;
  194. default:
  195. FSTERROR() << "SortedMatcher: Bad match type";
  196. match_type_ = MATCH_NONE;
  197. error_ = true;
  198. }
  199. }
  200. // This makes a copy of the FST.
  201. SortedMatcher(const SortedMatcher &matcher, bool safe = false)
  202. : owned_fst_(matcher.fst_.Copy(safe)),
  203. fst_(*owned_fst_),
  204. state_(kNoStateId),
  205. aiter_(std::nullopt),
  206. match_type_(matcher.match_type_),
  207. binary_label_(matcher.binary_label_),
  208. match_label_(kNoLabel),
  209. narcs_(0),
  210. loop_(matcher.loop_),
  211. error_(matcher.error_) {}
  212. ~SortedMatcher() override = default;
  213. SortedMatcher *Copy(bool safe = false) const override {
  214. return new SortedMatcher(*this, safe);
  215. }
  216. MatchType Type(bool test) const override {
  217. if (match_type_ == MATCH_NONE) return match_type_;
  218. const auto true_prop =
  219. match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
  220. const auto false_prop =
  221. match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
  222. const auto props = fst_.Properties(true_prop | false_prop, test);
  223. if (props & true_prop) {
  224. return match_type_;
  225. } else if (props & false_prop) {
  226. return MATCH_NONE;
  227. } else {
  228. return MATCH_UNKNOWN;
  229. }
  230. }
  231. void SetState(StateId s) final {
  232. if (state_ == s) return;
  233. state_ = s;
  234. if (match_type_ == MATCH_NONE) {
  235. FSTERROR() << "SortedMatcher: Bad match type";
  236. error_ = true;
  237. }
  238. aiter_.emplace(fst_, s);
  239. aiter_->SetFlags(kArcNoCache, kArcNoCache);
  240. narcs_ = internal::NumArcs(fst_, s);
  241. loop_.nextstate = s;
  242. }
  243. bool Find(Label match_label) final {
  244. exact_match_ = true;
  245. if (error_) {
  246. current_loop_ = false;
  247. match_label_ = kNoLabel;
  248. return false;
  249. }
  250. current_loop_ = match_label == 0;
  251. match_label_ = match_label == kNoLabel ? 0 : match_label;
  252. if (Search()) {
  253. return true;
  254. } else {
  255. return current_loop_;
  256. }
  257. }
  258. // Positions matcher to the first position where inserting match_label would
  259. // maintain the sort order.
  260. void LowerBound(Label label) {
  261. exact_match_ = false;
  262. current_loop_ = false;
  263. if (error_) {
  264. match_label_ = kNoLabel;
  265. return;
  266. }
  267. match_label_ = label;
  268. Search();
  269. }
  270. // After Find(), returns false if no more exact matches.
  271. // After LowerBound(), returns false if no more arcs.
  272. bool Done() const final {
  273. if (current_loop_) return false;
  274. if (aiter_->Done()) return true;
  275. if (!exact_match_) return false;
  276. aiter_->SetFlags(
  277. match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
  278. kArcValueFlags);
  279. return GetLabel() != match_label_;
  280. }
  281. const Arc &Value() const final {
  282. if (current_loop_) return loop_;
  283. aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
  284. return aiter_->Value();
  285. }
  286. void Next() final {
  287. if (current_loop_) {
  288. current_loop_ = false;
  289. } else {
  290. aiter_->Next();
  291. }
  292. }
  293. Weight Final(StateId s) const final { return MatcherBase<Arc>::Final(s); }
  294. ssize_t Priority(StateId s) final { return MatcherBase<Arc>::Priority(s); }
  295. const FST &GetFst() const override { return fst_; }
  296. uint64_t Properties(uint64_t inprops) const override {
  297. return inprops | (error_ ? kError : 0);
  298. }
  299. size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
  300. private:
  301. Label GetLabel() const {
  302. const auto &arc = aiter_->Value();
  303. return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
  304. }
  305. bool BinarySearch();
  306. bool LinearSearch();
  307. bool Search();
  308. std::unique_ptr<const FST> owned_fst_; // FST ptr if owned.
  309. const FST &fst_; // FST for matching.
  310. StateId state_; // Matcher state.
  311. mutable std::optional<ArcIterator<FST>>
  312. aiter_; // Iterator for current state.
  313. MatchType match_type_; // Type of match to perform.
  314. Label binary_label_; // Least label for binary search.
  315. Label match_label_; // Current label to be matched.
  316. size_t narcs_; // Current state arc count.
  317. Arc loop_; // For non-consuming symbols.
  318. bool current_loop_; // Current arc is the implicit loop.
  319. bool exact_match_; // Exact match or lower bound?
  320. bool error_; // Error encountered?
  321. };
  322. // Returns true iff match to match_label_. The arc iterator is positioned at the
  323. // lower bound, that is, the first element greater than or equal to
  324. // match_label_, or the end if all elements are less than match_label_.
  325. // If multiple elements are equal to the `match_label_`, returns the rightmost
  326. // one.
  327. template <class FST>
  328. inline bool SortedMatcher<FST>::BinarySearch() {
  329. size_t size = narcs_;
  330. if (size == 0) {
  331. return false;
  332. }
  333. size_t high = size - 1;
  334. while (size > 1) {
  335. const size_t half = size / 2;
  336. const size_t mid = high - half;
  337. aiter_->Seek(mid);
  338. if (GetLabel() >= match_label_) {
  339. high = mid;
  340. }
  341. size -= half;
  342. }
  343. aiter_->Seek(high);
  344. const auto label = GetLabel();
  345. if (label == match_label_) {
  346. return true;
  347. }
  348. if (label < match_label_) {
  349. aiter_->Next();
  350. }
  351. return false;
  352. }
  353. // Returns true iff match to match_label_, positioning arc iterator at lower
  354. // bound.
  355. template <class FST>
  356. inline bool SortedMatcher<FST>::LinearSearch() {
  357. for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
  358. const auto label = GetLabel();
  359. if (label == match_label_) return true;
  360. if (label > match_label_) break;
  361. }
  362. return false;
  363. }
  364. // Returns true iff match to match_label_, positioning arc iterator at lower
  365. // bound.
  366. template <class FST>
  367. inline bool SortedMatcher<FST>::Search() {
  368. aiter_->SetFlags(
  369. match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
  370. kArcValueFlags);
  371. if (match_label_ >= binary_label_) {
  372. return BinarySearch();
  373. } else {
  374. return LinearSearch();
  375. }
  376. }
  377. // A matcher that stores labels in a per-state hash table populated upon the
  378. // first visit to that state. Sorting is not required. Treatment of
  379. // epsilons are the same as with SortedMatcher.
  380. template <class F>
  381. class HashMatcher : public MatcherBase<typename F::Arc> {
  382. public:
  383. using FST = F;
  384. using Arc = typename FST::Arc;
  385. using Label = typename Arc::Label;
  386. using StateId = typename Arc::StateId;
  387. using Weight = typename Arc::Weight;
  388. using MatcherBase<Arc>::Flags;
  389. using MatcherBase<Arc>::Final;
  390. using MatcherBase<Arc>::Priority;
  391. // This makes a copy of the FST.
  392. HashMatcher(const FST &fst, MatchType match_type)
  393. : HashMatcher(fst.Copy(), match_type) {
  394. owned_fst_.reset(&fst_);
  395. }
  396. // This doesn't copy the FST.
  397. HashMatcher(const FST *fst, MatchType match_type)
  398. : fst_(*fst),
  399. state_(kNoStateId),
  400. match_type_(match_type),
  401. loop_(kNoLabel, 0, Weight::One(), kNoStateId),
  402. error_(false),
  403. state_table_(std::make_shared<StateTable>()) {
  404. switch (match_type_) {
  405. case MATCH_INPUT:
  406. case MATCH_NONE:
  407. break;
  408. case MATCH_OUTPUT:
  409. std::swap(loop_.ilabel, loop_.olabel);
  410. break;
  411. default:
  412. FSTERROR() << "HashMatcher: Bad match type";
  413. match_type_ = MATCH_NONE;
  414. error_ = true;
  415. }
  416. }
  417. // This makes a copy of the FST.
  418. HashMatcher(const HashMatcher &matcher, bool safe = false)
  419. : owned_fst_(matcher.fst_.Copy(safe)),
  420. fst_(*owned_fst_),
  421. state_(kNoStateId),
  422. match_type_(matcher.match_type_),
  423. loop_(matcher.loop_),
  424. error_(matcher.error_),
  425. state_table_(safe ? std::make_shared<StateTable>()
  426. : matcher.state_table_) {}
  427. HashMatcher *Copy(bool safe = false) const override {
  428. return new HashMatcher(*this, safe);
  429. }
  430. // The argument is ignored as there are no relevant properties to test.
  431. MatchType Type(bool test) const override { return match_type_; }
  432. void SetState(StateId s) final;
  433. bool Find(Label label) final {
  434. current_loop_ = label == 0;
  435. if (label == 0) {
  436. Search(label);
  437. return true;
  438. }
  439. if (label == kNoLabel) label = 0;
  440. return Search(label);
  441. }
  442. bool Done() const final {
  443. if (current_loop_) return false;
  444. return label_it_ == label_end_;
  445. }
  446. const Arc &Value() const final {
  447. if (current_loop_) return loop_;
  448. aiter_->Seek(label_it_->second);
  449. return aiter_->Value();
  450. }
  451. void Next() final {
  452. if (current_loop_) {
  453. current_loop_ = false;
  454. } else {
  455. ++label_it_;
  456. }
  457. }
  458. const FST &GetFst() const override { return fst_; }
  459. uint64_t Properties(uint64_t inprops) const override {
  460. return inprops | (error_ ? kError : 0);
  461. }
  462. private:
  463. Label GetLabel() const {
  464. const auto &arc = aiter_->Value();
  465. return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
  466. }
  467. bool Search(Label match_label);
  468. using LabelTable = std::unordered_multimap<Label, size_t>;
  469. using StateTable = std::unordered_map<StateId, std::unique_ptr<LabelTable>>;
  470. std::unique_ptr<const FST> owned_fst_; // ptr to FST if owned.
  471. const FST &fst_; // FST for matching.
  472. StateId state_; // Matcher state.
  473. MatchType match_type_;
  474. Arc loop_; // The implicit loop itself.
  475. bool current_loop_; // Is the current arc the implicit loop?
  476. bool error_; // Error encountered?
  477. std::unique_ptr<ArcIterator<FST>> aiter_;
  478. std::shared_ptr<StateTable> state_table_; // Table from state to label table.
  479. LabelTable *label_table_; // Pointer to current state's label table.
  480. typename LabelTable::iterator label_it_; // Position for label.
  481. typename LabelTable::iterator label_end_; // Position for last label + 1.
  482. };
  483. template <class FST>
  484. void HashMatcher<FST>::SetState(typename FST::Arc::StateId s) {
  485. if (state_ == s) return;
  486. // Resets everything for the state.
  487. state_ = s;
  488. loop_.nextstate = state_;
  489. aiter_ = std::make_unique<ArcIterator<FST>>(fst_, state_);
  490. if (match_type_ == MATCH_NONE) {
  491. FSTERROR() << "HashMatcher: Bad match type";
  492. error_ = true;
  493. }
  494. // Attempts to insert a new label table.
  495. const auto &[it, success] =
  496. state_table_->emplace(state_, std::make_unique<LabelTable>());
  497. // Sets instance's pointer to the label table for this state.
  498. label_table_ = it->second.get();
  499. // If it already exists, no additional work is done and we simply return.
  500. if (!success) return;
  501. // Otherwise, populate this new table.
  502. // Populates the label table.
  503. label_table_->reserve(internal::NumArcs(fst_, state_));
  504. const auto aiter_flags =
  505. (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) |
  506. kArcNoCache;
  507. aiter_->SetFlags(aiter_flags, kArcFlags);
  508. for (; !aiter_->Done(); aiter_->Next()) {
  509. label_table_->emplace(GetLabel(), aiter_->Position());
  510. }
  511. aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
  512. }
  513. template <class FST>
  514. inline bool HashMatcher<FST>::Search(typename FST::Arc::Label match_label) {
  515. std::tie(label_it_, label_end_) = label_table_->equal_range(match_label);
  516. if (label_it_ == label_end_) return false;
  517. aiter_->Seek(label_it_->second);
  518. return true;
  519. }
  520. // Specifies whether we rewrite both the input and output sides during matching.
  521. enum MatcherRewriteMode {
  522. MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor.
  523. MATCHER_REWRITE_ALWAYS,
  524. MATCHER_REWRITE_NEVER
  525. };
  526. // For any requested label that doesn't match at a state, this matcher
  527. // considers the *unique* transition that matches the label 'phi_label'
  528. // (phi = 'fail'), and recursively looks for a match at its
  529. // destination. When 'phi_loop' is true, if no match is found but a
  530. // phi self-loop is found, then the phi transition found is returned
  531. // with the phi_label rewritten as the requested label (both sides if
  532. // an acceptor, or if 'rewrite_both' is true and both input and output
  533. // labels of the found transition are 'phi_label'). If 'phi_label' is
  534. // kNoLabel, this special matching is not done. PhiMatcher is
  535. // templated itself on a matcher, which is used to perform the
  536. // underlying matching. By default, the underlying matcher is
  537. // constructed by PhiMatcher. The user can instead pass in this
  538. // object; in that case, PhiMatcher takes its ownership.
  539. // Phi non-determinism not supported. No non-consuming symbols other
  540. // than epsilon supported with the underlying template argument matcher.
  541. template <class M>
  542. class PhiMatcher : public MatcherBase<typename M::Arc> {
  543. public:
  544. using FST = typename M::FST;
  545. using Arc = typename FST::Arc;
  546. using Label = typename Arc::Label;
  547. using StateId = typename Arc::StateId;
  548. using Weight = typename Arc::Weight;
  549. // This makes a copy of the FST (w/o 'matcher' arg).
  550. PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel,
  551. bool phi_loop = true,
  552. MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
  553. M *matcher = nullptr)
  554. : matcher_(matcher ? matcher : new M(fst, match_type)),
  555. match_type_(match_type),
  556. phi_label_(phi_label),
  557. state_(kNoStateId),
  558. phi_loop_(phi_loop),
  559. error_(false) {
  560. if (match_type == MATCH_BOTH) {
  561. FSTERROR() << "PhiMatcher: Bad match type";
  562. match_type_ = MATCH_NONE;
  563. error_ = true;
  564. }
  565. if (rewrite_mode == MATCHER_REWRITE_AUTO) {
  566. rewrite_both_ = fst.Properties(kAcceptor, true);
  567. } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
  568. rewrite_both_ = true;
  569. } else {
  570. rewrite_both_ = false;
  571. }
  572. }
  573. // This doesn't copy the FST.
  574. PhiMatcher(const FST *fst, MatchType match_type, Label phi_label = kNoLabel,
  575. bool phi_loop = true,
  576. MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
  577. M *matcher = nullptr)
  578. : PhiMatcher(*fst, match_type, phi_label, phi_loop, rewrite_mode,
  579. matcher ? matcher : new M(fst, match_type)) {}
  580. // This makes a copy of the FST.
  581. PhiMatcher(const PhiMatcher &matcher, bool safe = false)
  582. : matcher_(new M(*matcher.matcher_, safe)),
  583. match_type_(matcher.match_type_),
  584. phi_label_(matcher.phi_label_),
  585. rewrite_both_(matcher.rewrite_both_),
  586. state_(kNoStateId),
  587. phi_loop_(matcher.phi_loop_),
  588. error_(matcher.error_) {}
  589. PhiMatcher *Copy(bool safe = false) const override {
  590. return new PhiMatcher(*this, safe);
  591. }
  592. MatchType Type(bool test) const override { return matcher_->Type(test); }
  593. void SetState(StateId s) final {
  594. if (state_ == s) return;
  595. matcher_->SetState(s);
  596. state_ = s;
  597. has_phi_ = phi_label_ != kNoLabel;
  598. }
  599. bool Find(Label match_label) final;
  600. bool Done() const final { return matcher_->Done(); }
  601. const Arc &Value() const final {
  602. if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
  603. return matcher_->Value();
  604. } else if (phi_match_ == 0) { // Virtual epsilon loop.
  605. phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
  606. if (match_type_ == MATCH_OUTPUT) {
  607. std::swap(phi_arc_.ilabel, phi_arc_.olabel);
  608. }
  609. return phi_arc_;
  610. } else {
  611. phi_arc_ = matcher_->Value();
  612. phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
  613. if (phi_match_ != kNoLabel) { // Phi loop match.
  614. if (rewrite_both_) {
  615. if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_;
  616. if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_;
  617. } else if (match_type_ == MATCH_INPUT) {
  618. phi_arc_.ilabel = phi_match_;
  619. } else {
  620. phi_arc_.olabel = phi_match_;
  621. }
  622. }
  623. return phi_arc_;
  624. }
  625. }
  626. void Next() final { matcher_->Next(); }
  627. Weight Final(StateId s) const final {
  628. auto weight = matcher_->Final(s);
  629. if (phi_label_ == kNoLabel || weight != Weight::Zero()) {
  630. return weight;
  631. }
  632. weight = Weight::One();
  633. matcher_->SetState(s);
  634. while (matcher_->Final(s) == Weight::Zero()) {
  635. if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break;
  636. weight = Times(weight, matcher_->Value().weight);
  637. if (s == matcher_->Value().nextstate) {
  638. return Weight::Zero(); // Does not follow phi self-loops.
  639. }
  640. s = matcher_->Value().nextstate;
  641. matcher_->SetState(s);
  642. }
  643. weight = Times(weight, matcher_->Final(s));
  644. return weight;
  645. }
  646. ssize_t Priority(StateId s) final {
  647. if (phi_label_ != kNoLabel) {
  648. matcher_->SetState(s);
  649. const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_);
  650. return has_phi ? kRequirePriority : matcher_->Priority(s);
  651. } else {
  652. return matcher_->Priority(s);
  653. }
  654. }
  655. const FST &GetFst() const override { return matcher_->GetFst(); }
  656. uint64_t Properties(uint64_t props) const override;
  657. uint32_t Flags() const override {
  658. if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) {
  659. return matcher_->Flags();
  660. }
  661. return matcher_->Flags() | kRequireMatch;
  662. }
  663. Label PhiLabel() const { return phi_label_; }
  664. private:
  665. mutable std::unique_ptr<M> matcher_;
  666. MatchType match_type_; // Type of match requested.
  667. Label phi_label_; // Label that represents the phi transition.
  668. bool rewrite_both_; // Rewrite both sides when both are phi_label_?
  669. bool has_phi_; // Are there possibly phis at the current state?
  670. Label phi_match_; // Current label that matches phi loop.
  671. mutable Arc phi_arc_; // Arc to return.
  672. StateId state_; // Matcher state.
  673. Weight phi_weight_; // Product of the weights of phi transitions taken.
  674. bool phi_loop_; // When true, phi self-loop are allowed and treated
  675. // as rho (required for Aho-Corasick).
  676. bool error_; // Error encountered?
  677. PhiMatcher &operator=(const PhiMatcher &) = delete;
  678. };
  679. template <class M>
  680. inline bool PhiMatcher<M>::Find(Label label) {
  681. if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
  682. FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
  683. error_ = true;
  684. return false;
  685. }
  686. matcher_->SetState(state_);
  687. phi_match_ = kNoLabel;
  688. phi_weight_ = Weight::One();
  689. // If phi_label_ == 0, there are no more true epsilon arcs.
  690. if (phi_label_ == 0) {
  691. if (label == kNoLabel) {
  692. return false;
  693. }
  694. if (label == 0) { // but a virtual epsilon loop needs to be returned.
  695. if (!matcher_->Find(kNoLabel)) {
  696. return matcher_->Find(0);
  697. } else {
  698. phi_match_ = 0;
  699. return true;
  700. }
  701. }
  702. }
  703. if (!has_phi_ || label == 0 || label == kNoLabel) {
  704. return matcher_->Find(label);
  705. }
  706. auto s = state_;
  707. while (!matcher_->Find(label)) {
  708. // Look for phi transition (if phi_label_ == 0, we need to look
  709. // for -1 to avoid getting the virtual self-loop)
  710. if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false;
  711. if (phi_loop_ && matcher_->Value().nextstate == s) {
  712. phi_match_ = label;
  713. return true;
  714. }
  715. phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
  716. s = matcher_->Value().nextstate;
  717. matcher_->Next();
  718. if (!matcher_->Done()) {
  719. FSTERROR() << "PhiMatcher: Phi non-determinism not supported";
  720. error_ = true;
  721. }
  722. matcher_->SetState(s);
  723. }
  724. return true;
  725. }
  726. template <class M>
  727. inline uint64_t PhiMatcher<M>::Properties(uint64_t inprops) const {
  728. auto outprops = matcher_->Properties(inprops);
  729. if (error_) outprops |= kError;
  730. if (match_type_ == MATCH_NONE) {
  731. return outprops;
  732. } else if (match_type_ == MATCH_INPUT) {
  733. if (phi_label_ == 0) {
  734. outprops &= ~(kEpsilons | kIEpsilons | kOEpsilons);
  735. outprops |= kNoEpsilons | kNoIEpsilons;
  736. }
  737. if (rewrite_both_) {
  738. return outprops &
  739. ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
  740. kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
  741. } else {
  742. return outprops &
  743. ~(kODeterministic | kAcceptor | kString | kILabelSorted |
  744. kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
  745. }
  746. } else if (match_type_ == MATCH_OUTPUT) {
  747. if (phi_label_ == 0) {
  748. outprops &= ~(kEpsilons | kIEpsilons | kOEpsilons);
  749. outprops |= kNoEpsilons | kNoOEpsilons;
  750. }
  751. if (rewrite_both_) {
  752. return outprops &
  753. ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
  754. kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
  755. } else {
  756. return outprops &
  757. ~(kIDeterministic | kAcceptor | kString | kILabelSorted |
  758. kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
  759. }
  760. } else {
  761. // Shouldn't ever get here.
  762. FSTERROR() << "PhiMatcher: Bad match type: " << match_type_;
  763. return 0;
  764. }
  765. }
  766. // For any requested label that doesn't match at a state, this matcher
  767. // considers all transitions that match the label 'rho_label' (rho =
  768. // 'rest'). Each such rho transition found is returned with the
  769. // rho_label rewritten as the requested label (both sides if an
  770. // acceptor, or if 'rewrite_both' is true and both input and output
  771. // labels of the found transition are 'rho_label'). If 'rho_label' is
  772. // kNoLabel, this special matching is not done. RhoMatcher is
  773. // templated itself on a matcher, which is used to perform the
  774. // underlying matching. By default, the underlying matcher is
  775. // constructed by RhoMatcher. The user can instead pass in this
  776. // object; in that case, RhoMatcher takes its ownership.
  777. // No non-consuming symbols other than epsilon supported with
  778. // the underlying template argument matcher.
  779. template <class M>
  780. class RhoMatcher : public MatcherBase<typename M::Arc> {
  781. public:
  782. using FST = typename M::FST;
  783. using Arc = typename FST::Arc;
  784. using Label = typename Arc::Label;
  785. using StateId = typename Arc::StateId;
  786. using Weight = typename Arc::Weight;
  787. // This makes a copy of the FST (w/o 'matcher' arg).
  788. RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel,
  789. MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
  790. M *matcher = nullptr)
  791. : matcher_(matcher ? matcher : new M(fst, match_type)),
  792. match_type_(match_type),
  793. rho_label_(rho_label),
  794. error_(false),
  795. state_(kNoStateId),
  796. has_rho_(false) {
  797. if (match_type == MATCH_BOTH) {
  798. FSTERROR() << "RhoMatcher: Bad match type";
  799. match_type_ = MATCH_NONE;
  800. error_ = true;
  801. }
  802. if (rho_label == 0) {
  803. FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
  804. rho_label_ = kNoLabel;
  805. error_ = true;
  806. }
  807. if (rewrite_mode == MATCHER_REWRITE_AUTO) {
  808. rewrite_both_ = fst.Properties(kAcceptor, true);
  809. } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
  810. rewrite_both_ = true;
  811. } else {
  812. rewrite_both_ = false;
  813. }
  814. }
  815. // This doesn't copy the FST.
  816. RhoMatcher(const FST *fst, MatchType match_type, Label rho_label = kNoLabel,
  817. MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
  818. M *matcher = nullptr)
  819. : RhoMatcher(*fst, match_type, rho_label, rewrite_mode,
  820. matcher ? matcher : new M(fst, match_type)) {}
  821. // This makes a copy of the FST.
  822. RhoMatcher(const RhoMatcher &matcher, bool safe = false)
  823. : matcher_(new M(*matcher.matcher_, safe)),
  824. match_type_(matcher.match_type_),
  825. rho_label_(matcher.rho_label_),
  826. rewrite_both_(matcher.rewrite_both_),
  827. error_(matcher.error_),
  828. state_(kNoStateId),
  829. has_rho_(false) {}
  830. RhoMatcher *Copy(bool safe = false) const override {
  831. return new RhoMatcher(*this, safe);
  832. }
  833. MatchType Type(bool test) const override { return matcher_->Type(test); }
  834. void SetState(StateId s) final {
  835. if (state_ == s) return;
  836. state_ = s;
  837. matcher_->SetState(s);
  838. has_rho_ = rho_label_ != kNoLabel;
  839. }
  840. bool Find(Label label) final {
  841. if (label == rho_label_ && rho_label_ != kNoLabel) {
  842. FSTERROR() << "RhoMatcher::Find: bad label (rho)";
  843. error_ = true;
  844. return false;
  845. }
  846. if (matcher_->Find(label)) {
  847. rho_match_ = kNoLabel;
  848. return true;
  849. } else if (has_rho_ && label != 0 && label != kNoLabel &&
  850. (has_rho_ = matcher_->Find(rho_label_))) {
  851. rho_match_ = label;
  852. return true;
  853. } else {
  854. return false;
  855. }
  856. }
  857. bool Done() const final { return matcher_->Done(); }
  858. const Arc &Value() const final {
  859. if (rho_match_ == kNoLabel) {
  860. return matcher_->Value();
  861. } else {
  862. rho_arc_ = matcher_->Value();
  863. if (rewrite_both_) {
  864. if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_;
  865. if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_;
  866. } else if (match_type_ == MATCH_INPUT) {
  867. rho_arc_.ilabel = rho_match_;
  868. } else {
  869. rho_arc_.olabel = rho_match_;
  870. }
  871. return rho_arc_;
  872. }
  873. }
  874. void Next() final { matcher_->Next(); }
  875. Weight Final(StateId s) const final { return matcher_->Final(s); }
  876. ssize_t Priority(StateId s) final {
  877. state_ = s;
  878. matcher_->SetState(s);
  879. has_rho_ = matcher_->Find(rho_label_);
  880. if (has_rho_) {
  881. return kRequirePriority;
  882. } else {
  883. return matcher_->Priority(s);
  884. }
  885. }
  886. const FST &GetFst() const override { return matcher_->GetFst(); }
  887. uint64_t Properties(uint64_t props) const override;
  888. uint32_t Flags() const override {
  889. if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) {
  890. return matcher_->Flags();
  891. }
  892. return matcher_->Flags() | kRequireMatch;
  893. }
  894. Label RhoLabel() const { return rho_label_; }
  895. private:
  896. std::unique_ptr<M> matcher_;
  897. MatchType match_type_; // Type of match requested.
  898. Label rho_label_; // Label that represents the rho transition
  899. bool rewrite_both_; // Rewrite both sides when both are rho_label_?
  900. Label rho_match_; // Current label that matches rho transition.
  901. mutable Arc rho_arc_; // Arc to return when rho match.
  902. bool error_; // Error encountered?
  903. StateId state_; // Matcher state.
  904. bool has_rho_; // Are there possibly rhos at the current state?
  905. };
  906. template <class M>
  907. inline uint64_t RhoMatcher<M>::Properties(uint64_t inprops) const {
  908. auto outprops = matcher_->Properties(inprops);
  909. if (error_) outprops |= kError;
  910. if (match_type_ == MATCH_NONE) {
  911. return outprops;
  912. } else if (match_type_ == MATCH_INPUT) {
  913. if (rewrite_both_) {
  914. return outprops &
  915. ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
  916. kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
  917. } else {
  918. return outprops & ~(kODeterministic | kAcceptor | kString |
  919. kILabelSorted | kNotILabelSorted);
  920. }
  921. } else if (match_type_ == MATCH_OUTPUT) {
  922. if (rewrite_both_) {
  923. return outprops &
  924. ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
  925. kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
  926. } else {
  927. return outprops & ~(kIDeterministic | kAcceptor | kString |
  928. kOLabelSorted | kNotOLabelSorted);
  929. }
  930. } else {
  931. // Shouldn't ever get here.
  932. FSTERROR() << "RhoMatcher: Bad match type: " << match_type_;
  933. return 0;
  934. }
  935. }
  936. // For any requested label, this matcher considers all transitions
  937. // that match the label 'sigma_label' (sigma = "any"), and this in
  938. // additions to transitions with the requested label. Each such sigma
  939. // transition found is returned with the sigma_label rewritten as the
  940. // requested label (both sides if an acceptor, or if 'rewrite_both' is
  941. // true and both input and output labels of the found transition are
  942. // 'sigma_label'). If 'sigma_label' is kNoLabel, this special
  943. // matching is not done. SigmaMatcher is templated itself on a
  944. // matcher, which is used to perform the underlying matching. By
  945. // default, the underlying matcher is constructed by SigmaMatcher.
  946. // The user can instead pass in this object; in that case,
  947. // SigmaMatcher takes its ownership. No non-consuming symbols other
  948. // than epsilon supported with the underlying template argument matcher.
  949. template <class M>
  950. class SigmaMatcher : public MatcherBase<typename M::Arc> {
  951. public:
  952. using FST = typename M::FST;
  953. using Arc = typename FST::Arc;
  954. using Label = typename Arc::Label;
  955. using StateId = typename Arc::StateId;
  956. using Weight = typename Arc::Weight;
  957. // This makes a copy of the FST (w/o 'matcher' arg).
  958. SigmaMatcher(const FST &fst, MatchType match_type,
  959. Label sigma_label = kNoLabel,
  960. MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
  961. M *matcher = nullptr)
  962. : matcher_(matcher ? matcher : new M(fst, match_type)),
  963. match_type_(match_type),
  964. sigma_label_(sigma_label),
  965. error_(false),
  966. state_(kNoStateId) {
  967. if (match_type == MATCH_BOTH) {
  968. FSTERROR() << "SigmaMatcher: Bad match type";
  969. match_type_ = MATCH_NONE;
  970. error_ = true;
  971. }
  972. if (sigma_label == 0) {
  973. FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
  974. sigma_label_ = kNoLabel;
  975. error_ = true;
  976. }
  977. if (rewrite_mode == MATCHER_REWRITE_AUTO) {
  978. rewrite_both_ = fst.Properties(kAcceptor, true);
  979. } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
  980. rewrite_both_ = true;
  981. } else {
  982. rewrite_both_ = false;
  983. }
  984. }
  985. // This doesn't copy the FST.
  986. SigmaMatcher(const FST *fst, MatchType match_type,
  987. Label sigma_label = kNoLabel,
  988. MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
  989. M *matcher = nullptr)
  990. : SigmaMatcher(*fst, match_type, sigma_label, rewrite_mode,
  991. matcher ? matcher : new M(fst, match_type)) {}
  992. // This makes a copy of the FST.
  993. SigmaMatcher(const SigmaMatcher &matcher, bool safe = false)
  994. : matcher_(new M(*matcher.matcher_, safe)),
  995. match_type_(matcher.match_type_),
  996. sigma_label_(matcher.sigma_label_),
  997. rewrite_both_(matcher.rewrite_both_),
  998. error_(matcher.error_),
  999. state_(kNoStateId) {}
  1000. SigmaMatcher *Copy(bool safe = false) const override {
  1001. return new SigmaMatcher(*this, safe);
  1002. }
  1003. MatchType Type(bool test) const override { return matcher_->Type(test); }
  1004. void SetState(StateId s) final {
  1005. if (state_ == s) return;
  1006. state_ = s;
  1007. matcher_->SetState(s);
  1008. has_sigma_ =
  1009. (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false;
  1010. }
  1011. bool Find(Label match_label) final {
  1012. match_label_ = match_label;
  1013. if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
  1014. FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
  1015. error_ = true;
  1016. return false;
  1017. }
  1018. if (matcher_->Find(match_label)) {
  1019. sigma_match_ = kNoLabel;
  1020. return true;
  1021. } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
  1022. matcher_->Find(sigma_label_)) {
  1023. sigma_match_ = match_label;
  1024. return true;
  1025. } else {
  1026. return false;
  1027. }
  1028. }
  1029. bool Done() const final { return matcher_->Done(); }
  1030. const Arc &Value() const final {
  1031. if (sigma_match_ == kNoLabel) {
  1032. return matcher_->Value();
  1033. } else {
  1034. sigma_arc_ = matcher_->Value();
  1035. if (rewrite_both_) {
  1036. if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_;
  1037. if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_;
  1038. } else if (match_type_ == MATCH_INPUT) {
  1039. sigma_arc_.ilabel = sigma_match_;
  1040. } else {
  1041. sigma_arc_.olabel = sigma_match_;
  1042. }
  1043. return sigma_arc_;
  1044. }
  1045. }
  1046. void Next() final {
  1047. matcher_->Next();
  1048. if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
  1049. (match_label_ > 0)) {
  1050. matcher_->Find(sigma_label_);
  1051. sigma_match_ = match_label_;
  1052. }
  1053. }
  1054. Weight Final(StateId s) const final { return matcher_->Final(s); }
  1055. ssize_t Priority(StateId s) final {
  1056. if (sigma_label_ != kNoLabel) {
  1057. SetState(s);
  1058. return has_sigma_ ? kRequirePriority : matcher_->Priority(s);
  1059. } else {
  1060. return matcher_->Priority(s);
  1061. }
  1062. }
  1063. const FST &GetFst() const override { return matcher_->GetFst(); }
  1064. uint64_t Properties(uint64_t props) const override;
  1065. uint32_t Flags() const override {
  1066. if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) {
  1067. return matcher_->Flags();
  1068. }
  1069. return matcher_->Flags() | kRequireMatch;
  1070. }
  1071. Label SigmaLabel() const { return sigma_label_; }
  1072. private:
  1073. std::unique_ptr<M> matcher_;
  1074. MatchType match_type_; // Type of match requested.
  1075. Label sigma_label_; // Label that represents the sigma transition.
  1076. bool rewrite_both_; // Rewrite both sides when both are sigma_label_?
  1077. bool has_sigma_; // Are there sigmas at the current state?
  1078. Label sigma_match_; // Current label that matches sigma transition.
  1079. mutable Arc sigma_arc_; // Arc to return when sigma match.
  1080. Label match_label_; // Label being matched.
  1081. bool error_; // Error encountered?
  1082. StateId state_; // Matcher state.
  1083. };
  1084. template <class M>
  1085. inline uint64_t SigmaMatcher<M>::Properties(uint64_t inprops) const {
  1086. auto outprops = matcher_->Properties(inprops);
  1087. if (error_) outprops |= kError;
  1088. if (match_type_ == MATCH_NONE) {
  1089. return outprops;
  1090. } else if (rewrite_both_) {
  1091. return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
  1092. kNonODeterministic | kILabelSorted | kNotILabelSorted |
  1093. kOLabelSorted | kNotOLabelSorted | kString);
  1094. } else if (match_type_ == MATCH_INPUT) {
  1095. return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
  1096. kNonODeterministic | kILabelSorted | kNotILabelSorted |
  1097. kString | kAcceptor);
  1098. } else if (match_type_ == MATCH_OUTPUT) {
  1099. return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
  1100. kNonODeterministic | kOLabelSorted | kNotOLabelSorted |
  1101. kString | kAcceptor);
  1102. } else {
  1103. // Shouldn't ever get here.
  1104. FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_;
  1105. return 0;
  1106. }
  1107. }
  1108. // Flags for MultiEpsMatcher.
  1109. // Return multi-epsilon arcs for Find(kNoLabel).
  1110. inline constexpr uint32_t kMultiEpsList = 0x00000001;
  1111. // Return a kNolabel loop for Find(multi_eps).
  1112. inline constexpr uint32_t kMultiEpsLoop = 0x00000002;
  1113. // MultiEpsMatcher: allows treating multiple non-0 labels as
  1114. // non-consuming labels in addition to 0 that is always
  1115. // non-consuming. Precise behavior controlled by 'flags' argument. By
  1116. // default, the underlying matcher is constructed by
  1117. // MultiEpsMatcher. The user can instead pass in this object; in that
  1118. // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
  1119. // true.
  1120. template <class M>
  1121. class MultiEpsMatcher {
  1122. public:
  1123. using FST = typename M::FST;
  1124. using Arc = typename FST::Arc;
  1125. using Label = typename Arc::Label;
  1126. using StateId = typename Arc::StateId;
  1127. using Weight = typename Arc::Weight;
  1128. // This makes a copy of the FST (w/o 'matcher' arg).
  1129. MultiEpsMatcher(const FST &fst, MatchType match_type,
  1130. uint32_t flags = (kMultiEpsLoop | kMultiEpsList),
  1131. M *matcher = nullptr, bool own_matcher = true)
  1132. : matcher_(matcher ? matcher : new M(fst, match_type)),
  1133. flags_(flags),
  1134. own_matcher_(matcher ? own_matcher : true) {
  1135. Init(match_type);
  1136. }
  1137. // This doesn't copy the FST.
  1138. MultiEpsMatcher(const FST *fst, MatchType match_type,
  1139. uint32_t flags = (kMultiEpsLoop | kMultiEpsList),
  1140. M *matcher = nullptr, bool own_matcher = true)
  1141. : matcher_(matcher ? matcher : new M(fst, match_type)),
  1142. flags_(flags),
  1143. own_matcher_(matcher ? own_matcher : true) {
  1144. Init(match_type);
  1145. }
  1146. // This makes a copy of the FST.
  1147. MultiEpsMatcher(const MultiEpsMatcher &matcher, bool safe = false)
  1148. : matcher_(new M(*matcher.matcher_, safe)),
  1149. flags_(matcher.flags_),
  1150. own_matcher_(true),
  1151. multi_eps_labels_(matcher.multi_eps_labels_),
  1152. loop_(matcher.loop_) {
  1153. loop_.nextstate = kNoStateId;
  1154. }
  1155. ~MultiEpsMatcher() {
  1156. if (own_matcher_) delete matcher_;
  1157. }
  1158. MultiEpsMatcher *Copy(bool safe = false) const {
  1159. return new MultiEpsMatcher(*this, safe);
  1160. }
  1161. MatchType Type(bool test) const { return matcher_->Type(test); }
  1162. void SetState(StateId state) {
  1163. matcher_->SetState(state);
  1164. loop_.nextstate = state;
  1165. }
  1166. bool Find(Label label);
  1167. bool Done() const { return done_; }
  1168. const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); }
  1169. void Next() {
  1170. if (!current_loop_) {
  1171. matcher_->Next();
  1172. done_ = matcher_->Done();
  1173. if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
  1174. ++multi_eps_iter_;
  1175. while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
  1176. !matcher_->Find(*multi_eps_iter_)) {
  1177. ++multi_eps_iter_;
  1178. }
  1179. if (multi_eps_iter_ != multi_eps_labels_.End()) {
  1180. done_ = false;
  1181. } else {
  1182. done_ = !matcher_->Find(kNoLabel);
  1183. }
  1184. }
  1185. } else {
  1186. done_ = true;
  1187. }
  1188. }
  1189. const FST &GetFst() const { return matcher_->GetFst(); }
  1190. uint64_t Properties(uint64_t props) const {
  1191. return matcher_->Properties(props);
  1192. }
  1193. const M *GetMatcher() const { return matcher_; }
  1194. Weight Final(StateId s) const { return matcher_->Final(s); }
  1195. uint32_t Flags() const { return matcher_->Flags(); }
  1196. ssize_t Priority(StateId s) { return matcher_->Priority(s); }
  1197. void AddMultiEpsLabel(Label label) {
  1198. if (label == 0) {
  1199. FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
  1200. } else {
  1201. multi_eps_labels_.Insert(label);
  1202. }
  1203. }
  1204. void RemoveMultiEpsLabel(Label label) {
  1205. if (label == 0) {
  1206. FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
  1207. } else {
  1208. multi_eps_labels_.Erase(label);
  1209. }
  1210. }
  1211. void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); }
  1212. private:
  1213. void Init(MatchType match_type) {
  1214. if (match_type == MATCH_INPUT) {
  1215. loop_.ilabel = kNoLabel;
  1216. loop_.olabel = 0;
  1217. } else {
  1218. loop_.ilabel = 0;
  1219. loop_.olabel = kNoLabel;
  1220. }
  1221. loop_.weight = Weight::One();
  1222. loop_.nextstate = kNoStateId;
  1223. }
  1224. M *matcher_;
  1225. uint32_t flags_;
  1226. bool own_matcher_; // Does this class delete the matcher?
  1227. // Multi-eps label set.
  1228. CompactSet<Label, kNoLabel> multi_eps_labels_;
  1229. typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
  1230. bool current_loop_; // Current arc is the implicit loop?
  1231. mutable Arc loop_; // For non-consuming symbols.
  1232. bool done_; // Matching done?
  1233. MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete;
  1234. };
  1235. template <class M>
  1236. inline bool MultiEpsMatcher<M>::Find(Label label) {
  1237. multi_eps_iter_ = multi_eps_labels_.End();
  1238. current_loop_ = false;
  1239. bool ret;
  1240. if (label == 0) {
  1241. ret = matcher_->Find(0);
  1242. } else if (label == kNoLabel) {
  1243. if (flags_ & kMultiEpsList) {
  1244. // Returns all non-consuming arcs (including epsilon).
  1245. multi_eps_iter_ = multi_eps_labels_.Begin();
  1246. while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
  1247. !matcher_->Find(*multi_eps_iter_)) {
  1248. ++multi_eps_iter_;
  1249. }
  1250. if (multi_eps_iter_ != multi_eps_labels_.End()) {
  1251. ret = true;
  1252. } else {
  1253. ret = matcher_->Find(kNoLabel);
  1254. }
  1255. } else {
  1256. // Returns all epsilon arcs.
  1257. ret = matcher_->Find(kNoLabel);
  1258. }
  1259. } else if ((flags_ & kMultiEpsLoop) &&
  1260. multi_eps_labels_.Find(label) != multi_eps_labels_.End()) {
  1261. // Returns implicit loop.
  1262. current_loop_ = true;
  1263. ret = true;
  1264. } else {
  1265. ret = matcher_->Find(label);
  1266. }
  1267. done_ = !ret;
  1268. return ret;
  1269. }
  1270. // This class discards any implicit matches (e.g., the implicit epsilon
  1271. // self-loops in the SortedMatcher). Matchers are most often used in
  1272. // composition/intersection where the implicit matches are needed
  1273. // e.g. for epsilon processing. However, if a matcher is simply being
  1274. // used to look-up explicit label matches, this class saves the user
  1275. // from having to check for and discard the unwanted implicit matches
  1276. // themselves.
  1277. template <class M>
  1278. class ExplicitMatcher : public MatcherBase<typename M::Arc> {
  1279. public:
  1280. using FST = typename M::FST;
  1281. using Arc = typename FST::Arc;
  1282. using Label = typename Arc::Label;
  1283. using StateId = typename Arc::StateId;
  1284. using Weight = typename Arc::Weight;
  1285. // This makes a copy of the FST.
  1286. ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr)
  1287. : matcher_(matcher ? matcher : new M(fst, match_type)),
  1288. match_type_(match_type),
  1289. error_(false) {}
  1290. // This doesn't copy the FST.
  1291. ExplicitMatcher(const FST *fst, MatchType match_type, M *matcher = nullptr)
  1292. : matcher_(matcher ? matcher : new M(fst, match_type)),
  1293. match_type_(match_type),
  1294. error_(false) {}
  1295. // This makes a copy of the FST.
  1296. ExplicitMatcher(const ExplicitMatcher &matcher, bool safe = false)
  1297. : matcher_(new M(*matcher.matcher_, safe)),
  1298. match_type_(matcher.match_type_),
  1299. error_(matcher.error_) {}
  1300. ExplicitMatcher *Copy(bool safe = false) const override {
  1301. return new ExplicitMatcher(*this, safe);
  1302. }
  1303. MatchType Type(bool test) const override { return matcher_->Type(test); }
  1304. void SetState(StateId s) final { matcher_->SetState(s); }
  1305. bool Find(Label label) final {
  1306. matcher_->Find(label);
  1307. CheckArc();
  1308. return !Done();
  1309. }
  1310. bool Done() const final { return matcher_->Done(); }
  1311. const Arc &Value() const final { return matcher_->Value(); }
  1312. void Next() final {
  1313. matcher_->Next();
  1314. CheckArc();
  1315. }
  1316. Weight Final(StateId s) const final { return matcher_->Final(s); }
  1317. ssize_t Priority(StateId s) final { return matcher_->Priority(s); }
  1318. const FST &GetFst() const final { return matcher_->GetFst(); }
  1319. uint64_t Properties(uint64_t inprops) const override {
  1320. return matcher_->Properties(inprops);
  1321. }
  1322. const M *GetMatcher() const { return matcher_.get(); }
  1323. uint32_t Flags() const override { return matcher_->Flags(); }
  1324. private:
  1325. // Checks current arc if available and explicit. If not available, stops. If
  1326. // not explicit, checks next ones.
  1327. void CheckArc() {
  1328. for (; !matcher_->Done(); matcher_->Next()) {
  1329. const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel
  1330. : matcher_->Value().olabel;
  1331. if (label != kNoLabel) return;
  1332. }
  1333. }
  1334. std::unique_ptr<M> matcher_;
  1335. MatchType match_type_; // Type of match requested.
  1336. bool error_; // Error encountered?
  1337. };
  1338. // Generic matcher, templated on the FST definition.
  1339. //
  1340. // Here is a typical use:
  1341. //
  1342. // Matcher<StdFst> matcher(fst, MATCH_INPUT);
  1343. // matcher.SetState(state);
  1344. // if (matcher.Find(label))
  1345. // for (; !matcher.Done(); matcher.Next()) {
  1346. // auto &arc = matcher.Value();
  1347. // ...
  1348. // }
  1349. template <class F>
  1350. class Matcher {
  1351. public:
  1352. using FST = F;
  1353. using Arc = typename F::Arc;
  1354. using Label = typename Arc::Label;
  1355. using StateId = typename Arc::StateId;
  1356. using Weight = typename Arc::Weight;
  1357. // This makes a copy of the FST.
  1358. Matcher(const FST &fst, MatchType match_type)
  1359. : owned_fst_(fst.Copy()), base_(owned_fst_->InitMatcher(match_type)) {
  1360. if (!base_)
  1361. base_ =
  1362. std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
  1363. }
  1364. // This doesn't copy the FST.
  1365. Matcher(const FST *fst, MatchType match_type)
  1366. : base_(fst->InitMatcher(match_type)) {
  1367. if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
  1368. }
  1369. // This makes a copy of the FST.
  1370. Matcher(const Matcher &matcher, bool safe = false)
  1371. : base_(matcher.base_->Copy(safe)) {}
  1372. // Takes ownership of the provided matcher.
  1373. explicit Matcher(MatcherBase<Arc> *base_matcher) : base_(base_matcher) {}
  1374. Matcher *Copy(bool safe = false) const { return new Matcher(*this, safe); }
  1375. MatchType Type(bool test) const { return base_->Type(test); }
  1376. void SetState(StateId s) { base_->SetState(s); }
  1377. bool Find(Label label) { return base_->Find(label); }
  1378. bool Done() const { return base_->Done(); }
  1379. const Arc &Value() const { return base_->Value(); }
  1380. void Next() { base_->Next(); }
  1381. const FST &GetFst() const { return down_cast<const FST &>(base_->GetFst()); }
  1382. uint64_t Properties(uint64_t props) const { return base_->Properties(props); }
  1383. Weight Final(StateId s) const { return base_->Final(s); }
  1384. uint32_t Flags() const { return base_->Flags() & kMatcherFlags; }
  1385. ssize_t Priority(StateId s) { return base_->Priority(s); }
  1386. private:
  1387. std::unique_ptr<const FST> owned_fst_;
  1388. std::unique_ptr<MatcherBase<Arc>> base_;
  1389. };
  1390. } // namespace fst
  1391. #endif // FST_MATCHER_H_