You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

876 lines
30 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes to add lookahead to FST matchers, useful for improving composition
  19. // efficiency with certain inputs.
  20. #ifndef FST_LOOKAHEAD_MATCHER_H_
  21. #define FST_LOOKAHEAD_MATCHER_H_
  22. #include <sys/types.h>
  23. #include <cstdint>
  24. #include <memory>
  25. #include <string>
  26. #include <utility>
  27. #include <vector>
  28. #include <fst/flags.h>
  29. #include <fst/log.h>
  30. #include <fst/accumulator.h>
  31. #include <fst/add-on.h>
  32. #include <fst/const-fst.h>
  33. #include <fst/fst.h>
  34. #include <fst/label-reachable.h>
  35. #include <fst/matcher.h>
  36. #include <fst/mutable-fst.h>
  37. #include <fst/properties.h>
  38. #include <fst/util.h>
  39. #include <fst/vector-fst.h>
  40. #include <string_view>
  41. DECLARE_string(save_relabel_ipairs);
  42. DECLARE_string(save_relabel_opairs);
  43. namespace fst {
  44. // Lookahead matches extend the matcher interface with following additional
  45. // methods:
  46. //
  47. // template <class FST>
  48. // class LookAheadMatcher {
  49. // public:
  50. // using Arc = typename FST::Arc;
  51. // using Label = typename Arc::Label;
  52. // using StateId = typename Arc::StateId;
  53. // using Weight = typename Arc::Weight;
  54. //
  55. // // Required constructors.
  56. // // This makes a copy of the FST.
  57. // LookAheadMatcher(const FST &fst, MatchType match_type);
  58. // // This doesn't copy the FST.
  59. // LookAheadMatcher(const FST *fst, MatchType match_type);
  60. // // This makes a copy of the FST.
  61. // // See Copy() below.
  62. // LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
  63. //
  64. // // If safe = true, the copy is thread-safe (except the lookahead FST is
  65. // // preserved). See Fst<>::Copy() for further doc.
  66. // LookaheadMatcher *Copy(bool safe = false) const override;
  67. // // Below are methods for looking ahead for a match to a label and more
  68. // // generally, to a rational set. Each returns false if there is definitely
  69. // // not a match and returns true if there possibly is a match.
  70. //
  71. // // Optionally pre-specifies the lookahead FST that will be passed to
  72. // // LookAheadFst() for possible precomputation. If copy is true, then the FST
  73. // // argument is a copy of the FST used in the previous call to this method
  74. // // (to avoid unnecessary updates).
  75. // void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override;
  76. //
  77. // // Are there paths from a state in the lookahead FST that can be read from
  78. // // the current matcher state?
  79. // bool LookAheadFst(const Fst<Arc> &fst, StateId s) override;
  80. //
  81. // // Can the label be read from the current matcher state after possibly
  82. // // following epsilon transitions?
  83. // bool LookAheadLabel(Label label) const override;
  84. //
  85. // // The following methods allow looking ahead for an arbitrary rational set
  86. // // of strings, specified by an FST and a state from which to begin the
  87. // // matching. If the lookahead FST is a transducer, this looks on the side
  88. // // different from the matcher's match_type (cf. composition).
  89. // // Is there is a single non-epsilon arc found in the lookahead FST that
  90. // // begins the path (after possibly following any epsilons) in the last call
  91. // // to LookAheadFst? If so, return true and copy it to the arc argument;
  92. // // otherwise, return false. Non-trivial implementations are useful for
  93. // // label-pushing in composition.
  94. // bool LookAheadPrefix(Arc *arc) override;
  95. //
  96. // // Gives an estimate of the combined weight of the paths in the lookahead
  97. // // and matcher FSTs for the last call to LookAheadFst. Non-trivial
  98. // // implementations are useful for weight-pushing in composition.
  99. // Weight LookAheadWeight() const override;
  100. // };
  101. // Look-ahead flags.
  102. // Matcher is a lookahead matcher when match_type is MATCH_INPUT.
  103. inline constexpr uint32_t kInputLookAheadMatcher = 0x00000010;
  104. // Matcher is a lookahead matcher when match_type is MATCH_OUTPUT.
  105. inline constexpr uint32_t kOutputLookAheadMatcher = 0x00000020;
  106. // Is a non-trivial implementation of LookAheadWeight() method defined and
  107. // if so, should it be used?
  108. inline constexpr uint32_t kLookAheadWeight = 0x00000040;
  109. // Is a non-trivial implementation of LookAheadPrefix() method defined and
  110. // if so, should it be used?
  111. inline constexpr uint32_t kLookAheadPrefix = 0x00000080;
  112. // Look-ahead of matcher FST non-epsilon arcs?
  113. inline constexpr uint32_t kLookAheadNonEpsilons = 0x00000100;
  114. // Look-ahead of matcher FST epsilon arcs?
  115. inline constexpr uint32_t kLookAheadEpsilons = 0x00000200;
  116. // Ignore epsilon paths for the lookahead prefix? This gives correct results in
  117. // composition only with an appropriate composition filter since it depends on
  118. // the filter blocking the ignored paths.
  119. inline constexpr uint32_t kLookAheadNonEpsilonPrefix = 0x00000400;
  120. // For LabelLookAheadMatcher, save relabeling data to file?
  121. inline constexpr uint32_t kLookAheadKeepRelabelData = 0x00000800;
  122. // Flags used for lookahead matchers.
  123. inline constexpr uint32_t kLookAheadFlags = 0x00000ff0;
  124. // LookAhead Matcher interface, templated on the Arc definition; used
  125. // for lookahead matcher specializations that are returned by the
  126. // InitMatcher() Fst method.
  127. template <class Arc>
  128. class LookAheadMatcherBase : public MatcherBase<Arc> {
  129. public:
  130. using Label = typename Arc::Label;
  131. using StateId = typename Arc::StateId;
  132. using Weight = typename Arc::Weight;
  133. virtual void InitLookAheadFst(const Fst<Arc> &, bool copy = false) = 0;
  134. virtual bool LookAheadFst(const Fst<Arc> &, StateId) = 0;
  135. virtual bool LookAheadLabel(Label) const = 0;
  136. // Suggested concrete implementation of lookahead methods.
  137. bool LookAheadPrefix(Arc *arc) const {
  138. if (prefix_arc_.nextstate != kNoStateId) {
  139. *arc = prefix_arc_;
  140. return true;
  141. } else {
  142. return false;
  143. }
  144. }
  145. Weight LookAheadWeight() const { return weight_; }
  146. protected:
  147. // Concrete implementations for lookahead helper methods.
  148. void ClearLookAheadWeight() { weight_ = Weight::One(); }
  149. void SetLookAheadWeight(Weight weight) { weight_ = std::move(weight); }
  150. void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
  151. void SetLookAheadPrefix(Arc arc) { prefix_arc_ = std::move(arc); }
  152. private:
  153. Arc prefix_arc_;
  154. Weight weight_;
  155. };
  156. // Doesn't actually lookahead, just declares that the future looks good.
  157. template <class M>
  158. class TrivialLookAheadMatcher
  159. : public LookAheadMatcherBase<typename M::FST::Arc> {
  160. public:
  161. using FST = typename M::FST;
  162. using Arc = typename FST::Arc;
  163. using Label = typename Arc::Label;
  164. using StateId = typename Arc::StateId;
  165. using Weight = typename Arc::Weight;
  166. // This makes a copy of the FST.
  167. TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
  168. : matcher_(fst, match_type) {}
  169. // This doesn't copy the FST.
  170. TrivialLookAheadMatcher(const FST *fst, MatchType match_type)
  171. : matcher_(fst, match_type) {}
  172. // This makes a copy of the FST.
  173. TrivialLookAheadMatcher(const TrivialLookAheadMatcher &lmatcher,
  174. bool safe = false)
  175. : matcher_(lmatcher.matcher_, safe) {}
  176. TrivialLookAheadMatcher *Copy(bool safe = false) const override {
  177. return new TrivialLookAheadMatcher(*this, safe);
  178. }
  179. MatchType Type(bool test) const override { return matcher_.Type(test); }
  180. void SetState(StateId s) final { return matcher_.SetState(s); }
  181. bool Find(Label label) final { return matcher_.Find(label); }
  182. bool Done() const final { return matcher_.Done(); }
  183. const Arc &Value() const final { return matcher_.Value(); }
  184. void Next() final { matcher_.Next(); }
  185. Weight Final(StateId s) const final { return matcher_.Final(s); }
  186. ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
  187. const FST &GetFst() const override { return matcher_.GetFst(); }
  188. uint64_t Properties(uint64_t props) const override {
  189. return matcher_.Properties(props);
  190. }
  191. uint32_t Flags() const override {
  192. return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
  193. }
  194. // Lookahead methods (all trivial).
  195. void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {}
  196. bool LookAheadFst(const Fst<Arc> &, StateId) final { return true; }
  197. bool LookAheadLabel(Label) const final { return true; }
  198. bool LookAheadPrefix(Arc *) const { return false; }
  199. Weight LookAheadWeight() const { return Weight::One(); }
  200. private:
  201. M matcher_;
  202. };
  203. // Look-ahead of one transition. Template argument flags accepts flags to
  204. // control behavior.
  205. template <class M, uint32_t flags = kLookAheadNonEpsilons | kLookAheadEpsilons |
  206. kLookAheadWeight | kLookAheadPrefix>
  207. class ArcLookAheadMatcher : public LookAheadMatcherBase<typename M::FST::Arc> {
  208. public:
  209. using FST = typename M::FST;
  210. using Arc = typename FST::Arc;
  211. using Label = typename Arc::Label;
  212. using StateId = typename Arc::StateId;
  213. using Weight = typename Arc::Weight;
  214. using MatcherData = NullAddOn;
  215. using LookAheadMatcherBase<Arc>::ClearLookAheadWeight;
  216. using LookAheadMatcherBase<Arc>::LookAheadWeight;
  217. using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
  218. using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
  219. using LookAheadMatcherBase<Arc>::LookAheadPrefix;
  220. using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
  221. static constexpr uint32_t kFlags = flags;
  222. // This makes a copy of the FST.
  223. ArcLookAheadMatcher(const FST &fst, MatchType match_type,
  224. std::shared_ptr<MatcherData> data = nullptr)
  225. : matcher_(fst, match_type),
  226. fst_(matcher_.GetFst()),
  227. lfst_(nullptr),
  228. state_(kNoStateId) {}
  229. // This doesn't copy the FST.
  230. ArcLookAheadMatcher(const FST *fst, MatchType match_type,
  231. std::shared_ptr<MatcherData> data = nullptr)
  232. : matcher_(fst, match_type),
  233. fst_(matcher_.GetFst()),
  234. lfst_(nullptr),
  235. state_(kNoStateId) {}
  236. // This makes a copy of the FST.
  237. ArcLookAheadMatcher(const ArcLookAheadMatcher &lmatcher, bool safe = false)
  238. : matcher_(lmatcher.matcher_, safe),
  239. fst_(matcher_.GetFst()),
  240. lfst_(lmatcher.lfst_),
  241. state_(kNoStateId) {}
  242. // General matcher methods.
  243. ArcLookAheadMatcher *Copy(bool safe = false) const override {
  244. return new ArcLookAheadMatcher(*this, safe);
  245. }
  246. MatchType Type(bool test) const override { return matcher_.Type(test); }
  247. void SetState(StateId s) final {
  248. state_ = s;
  249. matcher_.SetState(s);
  250. }
  251. bool Find(Label label) final { return matcher_.Find(label); }
  252. bool Done() const final { return matcher_.Done(); }
  253. const Arc &Value() const final { return matcher_.Value(); }
  254. void Next() final { matcher_.Next(); }
  255. Weight Final(StateId s) const final { return matcher_.Final(s); }
  256. ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
  257. const FST &GetFst() const override { return fst_; }
  258. uint64_t Properties(uint64_t props) const override {
  259. return matcher_.Properties(props);
  260. }
  261. uint32_t Flags() const override {
  262. return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher |
  263. kFlags;
  264. }
  265. const MatcherData *GetData() const { return nullptr; }
  266. std::shared_ptr<MatcherData> GetSharedData() const { return nullptr; }
  267. // Look-ahead methods.
  268. void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
  269. lfst_ = &fst;
  270. }
  271. // Checks if there is a matching (possibly super-final) transition
  272. // at (state_, s).
  273. bool LookAheadFst(const Fst<Arc> &, StateId) final;
  274. bool LookAheadLabel(Label label) const final { return matcher_.Find(label); }
  275. private:
  276. mutable M matcher_;
  277. const FST &fst_; // Matcher FST.
  278. const Fst<Arc> *lfst_; // Look-ahead FST.
  279. StateId state_; // Matcher state.
  280. };
  281. template <class M, uint32_t flags>
  282. bool ArcLookAheadMatcher<M, flags>::LookAheadFst(const Fst<Arc> &fst,
  283. StateId s) {
  284. if (&fst != lfst_) InitLookAheadFst(fst);
  285. bool result = false;
  286. ssize_t nprefix = 0;
  287. if (kFlags & kLookAheadWeight) ClearLookAheadWeight();
  288. if (kFlags & kLookAheadPrefix) ClearLookAheadPrefix();
  289. if (fst_.Final(state_) != Weight::Zero() &&
  290. lfst_->Final(s) != Weight::Zero()) {
  291. if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
  292. ++nprefix;
  293. if (kFlags & kLookAheadWeight) {
  294. SetLookAheadWeight(
  295. Plus(LookAheadWeight(), Times(fst_.Final(state_), lfst_->Final(s))));
  296. }
  297. result = true;
  298. }
  299. if (matcher_.Find(kNoLabel)) {
  300. if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
  301. ++nprefix;
  302. if (kFlags & kLookAheadWeight) {
  303. for (; !matcher_.Done(); matcher_.Next()) {
  304. SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
  305. }
  306. }
  307. result = true;
  308. }
  309. for (ArcIterator<Fst<Arc>> aiter(*lfst_, s); !aiter.Done(); aiter.Next()) {
  310. const auto &arc = aiter.Value();
  311. Label label = kNoLabel;
  312. switch (matcher_.Type(false)) {
  313. case MATCH_INPUT:
  314. label = arc.olabel;
  315. break;
  316. case MATCH_OUTPUT:
  317. label = arc.ilabel;
  318. break;
  319. default:
  320. FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: Bad match type";
  321. return true;
  322. }
  323. if (label == 0) {
  324. if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
  325. if (!(kFlags & kLookAheadNonEpsilonPrefix)) ++nprefix;
  326. if (kFlags & kLookAheadWeight) {
  327. SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
  328. }
  329. result = true;
  330. } else if (matcher_.Find(label)) {
  331. if (!(kFlags & (kLookAheadWeight | kLookAheadPrefix))) return true;
  332. for (; !matcher_.Done(); matcher_.Next()) {
  333. ++nprefix;
  334. if (kFlags & kLookAheadWeight) {
  335. SetLookAheadWeight(Plus(LookAheadWeight(),
  336. Times(arc.weight, matcher_.Value().weight)));
  337. }
  338. if ((kFlags & kLookAheadPrefix) && nprefix == 1)
  339. SetLookAheadPrefix(arc);
  340. }
  341. result = true;
  342. }
  343. }
  344. if (kFlags & kLookAheadPrefix) {
  345. if (nprefix == 1) {
  346. ClearLookAheadWeight(); // Avoids double counting.
  347. } else {
  348. ClearLookAheadPrefix();
  349. }
  350. }
  351. return result;
  352. }
  353. // Template argument flags accepts flags to control behavior. It must include
  354. // precisely one of kInputLookAheadMatcher or kOutputLookAheadMatcher.
  355. template <class M,
  356. uint32_t flags = kLookAheadEpsilons | kLookAheadWeight |
  357. kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
  358. kLookAheadKeepRelabelData,
  359. class Accum = DefaultAccumulator<typename M::Arc>,
  360. class R = LabelReachable<typename M::Arc, Accum>>
  361. class LabelLookAheadMatcher
  362. : public LookAheadMatcherBase<typename M::FST::Arc> {
  363. public:
  364. using Matcher = M;
  365. using Accumulator = Accum;
  366. using Reachable = R;
  367. using FST = typename M::FST;
  368. using Arc = typename FST::Arc;
  369. using Label = typename Arc::Label;
  370. using StateId = typename Arc::StateId;
  371. using Weight = typename Arc::Weight;
  372. using MatcherData = typename Reachable::Data;
  373. using LookAheadMatcherBase<Arc>::ClearLookAheadWeight;
  374. using LookAheadMatcherBase<Arc>::LookAheadWeight;
  375. using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
  376. using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
  377. using LookAheadMatcherBase<Arc>::LookAheadPrefix;
  378. using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
  379. static_assert(!(flags & kInputLookAheadMatcher) !=
  380. !(flags & kOutputLookAheadMatcher),
  381. "Must include precisely one of kInputLookAheadMatcher and "
  382. "kOutputLookAheadMatcher");
  383. static constexpr uint32_t kFlags = flags;
  384. // This makes a copy of the FST.
  385. LabelLookAheadMatcher(const FST &fst, MatchType match_type,
  386. std::shared_ptr<MatcherData> data = nullptr,
  387. std::unique_ptr<Accumulator> accumulator = nullptr)
  388. : matcher_(fst, match_type),
  389. lfst_(nullptr),
  390. state_(kNoStateId),
  391. error_(false) {
  392. Init(fst, match_type, data, std::move(accumulator));
  393. }
  394. // This doesn't copy the FST.
  395. LabelLookAheadMatcher(const FST *fst, MatchType match_type,
  396. std::shared_ptr<MatcherData> data = nullptr,
  397. std::unique_ptr<Accumulator> accumulator = nullptr)
  398. : matcher_(fst, match_type),
  399. lfst_(nullptr),
  400. state_(kNoStateId),
  401. error_(false) {
  402. Init(*fst, match_type, data, std::move(accumulator));
  403. }
  404. // This makes a copy of the FST.
  405. LabelLookAheadMatcher(const LabelLookAheadMatcher &lmatcher,
  406. bool safe = false)
  407. : matcher_(lmatcher.matcher_, safe),
  408. lfst_(lmatcher.lfst_),
  409. label_reachable_(lmatcher.label_reachable_
  410. ? new Reachable(*lmatcher.label_reachable_, safe)
  411. : nullptr),
  412. state_(kNoStateId),
  413. error_(lmatcher.error_) {}
  414. LabelLookAheadMatcher *Copy(bool safe = false) const override {
  415. return new LabelLookAheadMatcher(*this, safe);
  416. }
  417. MatchType Type(bool test) const override { return matcher_.Type(test); }
  418. void SetState(StateId s) final {
  419. if (state_ == s) return;
  420. state_ = s;
  421. match_set_state_ = false;
  422. reach_set_state_ = false;
  423. }
  424. bool Find(Label label) final {
  425. if (!match_set_state_) {
  426. matcher_.SetState(state_);
  427. match_set_state_ = true;
  428. }
  429. return matcher_.Find(label);
  430. }
  431. bool Done() const final { return matcher_.Done(); }
  432. const Arc &Value() const final { return matcher_.Value(); }
  433. void Next() final { matcher_.Next(); }
  434. Weight Final(StateId s) const final { return matcher_.Final(s); }
  435. ssize_t Priority(StateId s) final { return matcher_.Priority(s); }
  436. const FST &GetFst() const override { return matcher_.GetFst(); }
  437. uint64_t Properties(uint64_t inprops) const override {
  438. auto outprops = matcher_.Properties(inprops);
  439. if (error_ || (label_reachable_ && label_reachable_->Error())) {
  440. outprops |= kError;
  441. }
  442. return outprops;
  443. }
  444. uint32_t Flags() const override {
  445. if (label_reachable_ && label_reachable_->GetData()->ReachInput()) {
  446. return matcher_.Flags() | kFlags | kInputLookAheadMatcher;
  447. } else if (label_reachable_ && !label_reachable_->GetData()->ReachInput()) {
  448. return matcher_.Flags() | kFlags | kOutputLookAheadMatcher;
  449. } else {
  450. return matcher_.Flags();
  451. }
  452. }
  453. const MatcherData *GetData() const {
  454. return label_reachable_ ? label_reachable_->GetData() : nullptr;
  455. }
  456. std::shared_ptr<MatcherData> GetSharedData() const {
  457. return label_reachable_ ? label_reachable_->GetSharedData() : nullptr;
  458. }
  459. // Checks if there is a matching (possibly super-final) transition at
  460. // (state_, s).
  461. template <class LFST>
  462. bool LookAheadFst(const LFST &fst, StateId s);
  463. // Required to make class concrete.
  464. bool LookAheadFst(const Fst<Arc> &fst, StateId s) final {
  465. return LookAheadFst<Fst<Arc>>(fst, s);
  466. }
  467. void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) override {
  468. lfst_ = &fst;
  469. if (label_reachable_) {
  470. const bool reach_input = Type(false) == MATCH_OUTPUT;
  471. label_reachable_->ReachInit(fst, reach_input, copy);
  472. }
  473. }
  474. template <class LFST>
  475. void InitLookAheadFst(const LFST &fst, bool copy = false) {
  476. lfst_ = &fst;
  477. if (label_reachable_) {
  478. const bool reach_input = Type(false) == MATCH_OUTPUT;
  479. label_reachable_->ReachInit(fst, reach_input, copy);
  480. }
  481. }
  482. bool LookAheadLabel(Label label) const final {
  483. if (label == 0) return true;
  484. if (label_reachable_) {
  485. if (!reach_set_state_) {
  486. label_reachable_->SetState(state_);
  487. reach_set_state_ = true;
  488. }
  489. return label_reachable_->Reach(label);
  490. } else {
  491. return true;
  492. }
  493. }
  494. private:
  495. void Init(const FST &fst, MatchType match_type,
  496. std::shared_ptr<MatcherData> data,
  497. std::unique_ptr<Accumulator> accumulator) {
  498. const bool reach_input = match_type == MATCH_INPUT;
  499. if (data) {
  500. if (reach_input == data->ReachInput()) {
  501. label_reachable_ =
  502. std::make_unique<Reachable>(data, std::move(accumulator));
  503. }
  504. } else if ((reach_input && (kFlags & kInputLookAheadMatcher)) ||
  505. (!reach_input && (kFlags & kOutputLookAheadMatcher))) {
  506. label_reachable_ =
  507. std::make_unique<Reachable>(fst, reach_input, std::move(accumulator),
  508. kFlags & kLookAheadKeepRelabelData);
  509. }
  510. }
  511. mutable M matcher_;
  512. const Fst<Arc> *lfst_; // Look-ahead FST.
  513. std::unique_ptr<Reachable> label_reachable_; // Label reachability info.
  514. StateId state_; // Matcher state.
  515. bool match_set_state_; // matcher_.SetState called?
  516. mutable bool reach_set_state_; // reachable_.SetState called?
  517. bool error_; // Error encountered?
  518. };
  519. template <class M, uint32_t flags, class Accumulator, class Reachable>
  520. template <class LFST>
  521. inline bool LabelLookAheadMatcher<M, flags, Accumulator,
  522. Reachable>::LookAheadFst(const LFST &fst,
  523. StateId s) {
  524. if (&fst != lfst_) InitLookAheadFst(fst);
  525. ClearLookAheadWeight();
  526. ClearLookAheadPrefix();
  527. if (!label_reachable_) return true;
  528. label_reachable_->SetState(state_, s);
  529. reach_set_state_ = true;
  530. bool compute_weight = kFlags & kLookAheadWeight;
  531. constexpr bool kComputePrefix = kFlags & kLookAheadPrefix;
  532. ArcIterator<LFST> aiter(fst, s);
  533. aiter.SetFlags(kArcNoCache, kArcNoCache); // Makes caching optional.
  534. const bool reach_arc = label_reachable_->Reach(
  535. &aiter, 0, internal::NumArcs(*lfst_, s), compute_weight);
  536. const auto lfinal = internal::Final(*lfst_, s);
  537. const bool reach_final =
  538. lfinal != Weight::Zero() && label_reachable_->ReachFinal();
  539. if (reach_arc) {
  540. const auto begin = label_reachable_->ReachBegin();
  541. const auto end = label_reachable_->ReachEnd();
  542. if (kComputePrefix && end - begin == 1 && !reach_final) {
  543. aiter.Seek(begin);
  544. SetLookAheadPrefix(aiter.Value());
  545. compute_weight = false;
  546. } else if (compute_weight) {
  547. SetLookAheadWeight(label_reachable_->ReachWeight());
  548. }
  549. }
  550. if (reach_final && compute_weight) {
  551. SetLookAheadWeight(reach_arc ? Plus(LookAheadWeight(), lfinal) : lfinal);
  552. }
  553. return reach_arc || reach_final;
  554. }
  555. // Relabels the fst with Reachable::Reachable. Relabels input
  556. // if data.First() is non-null, otherwise relabels output.
  557. // Optionally saves the input/output label pairs to a file
  558. // if save_relabel_ipairs/opairs is non-empty.
  559. template <class Reachable, class FST, class Data>
  560. void RelabelForReachable(FST *fst, const Data &data,
  561. std::string_view save_relabel_ipairs,
  562. std::string_view save_relabel_opairs) {
  563. using Label = typename FST::Arc::Label;
  564. if (data.First() != nullptr) { // reach_input.
  565. Reachable reachable(data.SharedFirst());
  566. reachable.Relabel(fst, /*relabel_input=*/true);
  567. if (!save_relabel_ipairs.empty()) {
  568. std::vector<std::pair<Label, Label>> pairs;
  569. reachable.RelabelPairs(&pairs, /*avoid_collisions=*/true);
  570. WriteLabelPairs(save_relabel_ipairs, pairs);
  571. }
  572. } else {
  573. Reachable reachable(data.SharedSecond());
  574. reachable.Relabel(fst, /*relabel_input=*/false);
  575. if (!save_relabel_opairs.empty()) {
  576. std::vector<std::pair<Label, Label>> pairs;
  577. reachable.RelabelPairs(&pairs, /*avoid_collisions=*/true);
  578. WriteLabelPairs(save_relabel_opairs, pairs);
  579. }
  580. }
  581. }
  582. // Label-lookahead relabeling class.
  583. template <class Arc, class Data = LabelReachableData<typename Arc::Label>>
  584. class LabelLookAheadRelabeler {
  585. public:
  586. using Label = typename Arc::Label;
  587. using Reachable = LabelReachable<Arc, DefaultAccumulator<Arc>, Data>;
  588. // Relabels matcher FST (initialization function object).
  589. template <typename Impl>
  590. explicit LabelLookAheadRelabeler(std::shared_ptr<Impl> *impl);
  591. // Relabels arbitrary FST. Class LFST should be a label-lookahead FST.
  592. template <class LFST>
  593. static void Relabel(MutableFst<Arc> *fst, const LFST &mfst,
  594. bool relabel_input) {
  595. const auto *data = mfst.GetAddOn();
  596. Reachable reachable(data->First() ? data->SharedFirst()
  597. : data->SharedSecond());
  598. reachable.Relabel(fst, relabel_input);
  599. }
  600. // Returns relabeling pairs (cf. relabel.h::Relabel()). Class LFST should be a
  601. // label-lookahead FST. If avoid_collisions is true, extra pairs are added to
  602. // ensure no collisions when relabeling automata that have labels unseen here.
  603. template <class LFST>
  604. static void RelabelPairs(const LFST &mfst,
  605. std::vector<std::pair<Label, Label>> *pairs,
  606. bool avoid_collisions = false) {
  607. const auto *data = mfst.GetAddOn();
  608. Reachable reachable(data->First() ? data->SharedFirst()
  609. : data->SharedSecond());
  610. reachable.RelabelPairs(pairs, avoid_collisions);
  611. }
  612. };
  613. template <class Arc, class Data>
  614. template <typename Impl>
  615. inline LabelLookAheadRelabeler<Arc, Data>::LabelLookAheadRelabeler(
  616. std::shared_ptr<Impl> *impl) {
  617. Fst<Arc> &fst = (*impl)->GetFst();
  618. auto data = (*impl)->GetSharedAddOn();
  619. const auto name = (*impl)->Type();
  620. const bool is_mutable = fst.Properties(kMutable, false);
  621. std::unique_ptr<MutableFst<Arc>> mfst;
  622. if (is_mutable) {
  623. // Borrow pointer from fst without increasing ref count; it will
  624. // be released below. We do not want to call Copy() since that would
  625. // do a deep copy when the Fst is modified.
  626. mfst.reset(down_cast<MutableFst<Arc> *>(&fst));
  627. } else {
  628. mfst = std::make_unique<VectorFst<Arc>>(fst);
  629. }
  630. RelabelForReachable<Reachable>(mfst.get(), *data,
  631. FST_FLAGS_save_relabel_ipairs,
  632. FST_FLAGS_save_relabel_opairs);
  633. if (is_mutable) {
  634. // Pointer was just borrowed, don't delete it.
  635. mfst.release();
  636. } else {
  637. *impl = std::make_shared<Impl>(*mfst, name);
  638. (*impl)->SetAddOn(data);
  639. }
  640. }
  641. // Generic lookahead matcher, templated on the FST definition (a wrapper around
  642. // a pointer to specific one).
  643. template <class F>
  644. class LookAheadMatcher {
  645. public:
  646. using FST = F;
  647. using Arc = typename FST::Arc;
  648. using Label = typename Arc::Label;
  649. using StateId = typename Arc::StateId;
  650. using Weight = typename Arc::Weight;
  651. using LBase = LookAheadMatcherBase<Arc>;
  652. // This makes a copy of the FST.
  653. LookAheadMatcher(const FST &fst, MatchType match_type)
  654. : owned_fst_(fst.Copy()),
  655. base_(owned_fst_->InitMatcher(match_type)),
  656. lookahead_(false) {
  657. if (!base_)
  658. base_ =
  659. std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
  660. }
  661. // This doesn't copy the FST.
  662. LookAheadMatcher(const FST *fst, MatchType match_type)
  663. : base_(fst->InitMatcher(match_type)), lookahead_(false) {
  664. if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
  665. }
  666. // This makes a copy of the FST.
  667. LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false)
  668. : base_(matcher.base_->Copy(safe)), lookahead_(matcher.lookahead_) {}
  669. // Takes ownership of base.
  670. explicit LookAheadMatcher(MatcherBase<Arc> *base)
  671. : base_(base), lookahead_(false) {}
  672. LookAheadMatcher *Copy(bool safe = false) const {
  673. return new LookAheadMatcher(*this, safe);
  674. }
  675. MatchType Type(bool test) const { return base_->Type(test); }
  676. void SetState(StateId s) { base_->SetState(s); }
  677. bool Find(Label label) { return base_->Find(label); }
  678. bool Done() const { return base_->Done(); }
  679. const Arc &Value() const { return base_->Value(); }
  680. void Next() { base_->Next(); }
  681. Weight Final(StateId s) const { return base_->Final(s); }
  682. ssize_t Priority(StateId s) { return base_->Priority(s); }
  683. const FST &GetFst() const { return down_cast<const FST &>(base_->GetFst()); }
  684. uint64_t Properties(uint64_t props) const { return base_->Properties(props); }
  685. uint32_t Flags() const { return base_->Flags(); }
  686. bool LookAheadLabel(Label label) const {
  687. if (LookAheadCheck()) {
  688. return down_cast<LBase *>(base_.get())->LookAheadLabel(label);
  689. } else {
  690. return true;
  691. }
  692. }
  693. bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
  694. if (LookAheadCheck()) {
  695. return down_cast<LBase *>(base_.get())->LookAheadFst(fst, s);
  696. } else {
  697. return true;
  698. }
  699. }
  700. Weight LookAheadWeight() const {
  701. if (LookAheadCheck()) {
  702. return down_cast<LBase *>(base_.get())->LookAheadWeight();
  703. } else {
  704. return Weight::One();
  705. }
  706. }
  707. bool LookAheadPrefix(Arc *arc) const {
  708. if (LookAheadCheck()) {
  709. return down_cast<LBase *>(base_.get())->LookAheadPrefix(arc);
  710. } else {
  711. return false;
  712. }
  713. }
  714. void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
  715. if (LookAheadCheck()) {
  716. down_cast<LBase *>(base_.get())->InitLookAheadFst(fst, copy);
  717. }
  718. }
  719. private:
  720. bool LookAheadCheck() const {
  721. if (!lookahead_) {
  722. lookahead_ =
  723. base_->Flags() & (kInputLookAheadMatcher | kOutputLookAheadMatcher);
  724. if (!lookahead_) {
  725. FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
  726. }
  727. }
  728. return lookahead_;
  729. }
  730. std::unique_ptr<const FST> owned_fst_;
  731. std::unique_ptr<MatcherBase<Arc>> base_;
  732. mutable bool lookahead_;
  733. LookAheadMatcher &operator=(const LookAheadMatcher &) = delete;
  734. };
  735. } // namespace fst
  736. #endif // FST_LOOKAHEAD_MATCHER_H_