You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1059 lines
39 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to compute the composition of two FSTs.
  19. #ifndef FST_COMPOSE_H_
  20. #define FST_COMPOSE_H_
  21. #include <sys/types.h>
  22. #include <algorithm>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <memory>
  26. #include <utility>
  27. #include <fst/log.h>
  28. #include <fst/arc.h>
  29. #include <fst/cache.h>
  30. #include <fst/compose-filter.h>
  31. #include <fst/connect.h>
  32. #include <fst/float-weight.h>
  33. #include <fst/fst-decl.h> // For optional argument declarations
  34. #include <fst/fst.h>
  35. #include <fst/impl-to-fst.h>
  36. #include <fst/lookahead-filter.h>
  37. #include <fst/matcher.h>
  38. #include <fst/mutable-fst.h>
  39. #include <fst/properties.h>
  40. #include <fst/state-table.h>
  41. #include <fst/symbol-table.h>
  42. #include <fst/util.h>
  43. #include <fst/weight.h>
  44. namespace fst {
  45. // Delayed composition options templated on the arc type, the matcher,
  46. // the composition filter, and the composition state table. By
  47. // default, the matchers, filter, and state table are constructed by
  48. // composition. If set below, the user can instead pass in these
  49. // objects; in that case, ComposeFst takes their ownership. This
  50. // version controls composition implemented between generic Fst<Arc>
  51. // types and a shared matcher type M for Fst<Arc>. This should be
  52. // adequate for most applications, giving a reasonable tradeoff
  53. // between efficiency and code sharing (but see ComposeFstImplOptions).
  54. template <class Arc, class M = Matcher<Fst<Arc>>,
  55. class Filter = SequenceComposeFilter<M>,
  56. class StateTable =
  57. GenericComposeStateTable<Arc, typename Filter::FilterState>>
  58. struct ComposeFstOptions : public CacheOptions {
  59. M *matcher1; // FST1 matcher.
  60. M *matcher2; // FST2 matcher.
  61. Filter *filter; // Composition filter.
  62. StateTable *state_table; // Composition state table.
  63. explicit ComposeFstOptions(const CacheOptions &opts = CacheOptions(),
  64. M *matcher1 = nullptr, M *matcher2 = nullptr,
  65. Filter *filter = nullptr,
  66. StateTable *state_table = nullptr)
  67. : CacheOptions(opts),
  68. matcher1(matcher1),
  69. matcher2(matcher2),
  70. filter(filter),
  71. state_table(state_table) {}
  72. };
  73. // Forward declaration of ComposeFstMatcher.
  74. template <class C, class F, class T>
  75. class ComposeFstMatcher;
  76. // Delayed composition options templated on the two matcher types, the
  77. // composition filter, the composition state table and the cache store. By
  78. // default, the matchers, filter, state table and cache store are constructed
  79. // by composition. If set below, the user can instead pass in these objects; in
  80. // that case, ComposeFst takes their ownership. This version controls
  81. // composition implemented using arbitrary matchers (of the same arc type but
  82. // otherwise arbitrary FST type). The user must ensure the matchers are
  83. // compatible. These options permit the most efficient use, but shares the
  84. // least code. This is for advanced use only in the most demanding or
  85. // specialized applications that can benefit from it; otherwise, prefer
  86. // ComposeFstOptions).
  87. template <class M1, class M2, class Filter = SequenceComposeFilter<M1, M2>,
  88. class StateTable = GenericComposeStateTable<
  89. typename M1::Arc, typename Filter::FilterState>,
  90. class CacheStore = DefaultCacheStore<typename M1::Arc>>
  91. struct ComposeFstImplOptions : public CacheImplOptions<CacheStore> {
  92. M1 *matcher1; // FST1 matcher (see matcher.h)....
  93. M2 *matcher2; // FST2 matcher.
  94. Filter *filter; // Composition filter (see compose-filter.h).
  95. StateTable
  96. *state_table; // Composition state table (see compose-state-table.h).
  97. bool own_state_table; // ComposeFstImpl takes ownership of 'state_table'?
  98. bool allow_noncommute; // Allow non-commutative weights
  99. explicit ComposeFstImplOptions(const CacheOptions &opts,
  100. M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
  101. Filter *filter = nullptr,
  102. StateTable *state_table = nullptr)
  103. : CacheImplOptions<CacheStore>(opts),
  104. matcher1(matcher1),
  105. matcher2(matcher2),
  106. filter(filter),
  107. state_table(state_table),
  108. own_state_table(true),
  109. allow_noncommute(false) {}
  110. explicit ComposeFstImplOptions(const CacheImplOptions<CacheStore> &opts,
  111. M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
  112. Filter *filter = nullptr,
  113. StateTable *state_table = nullptr)
  114. : CacheImplOptions<CacheStore>(opts),
  115. matcher1(matcher1),
  116. matcher2(matcher2),
  117. filter(filter),
  118. state_table(state_table),
  119. own_state_table(true),
  120. allow_noncommute(false) {}
  121. ComposeFstImplOptions()
  122. : matcher1(nullptr),
  123. matcher2(nullptr),
  124. filter(nullptr),
  125. state_table(nullptr),
  126. own_state_table(true),
  127. allow_noncommute(false) {}
  128. };
  129. namespace internal {
  130. // Implementation of delayed composition. This base class is common to the
  131. // variants with different matchers, composition filters and state tables.
  132. template <class Arc, class CacheStore = DefaultCacheStore<Arc>,
  133. class F = ComposeFst<Arc, CacheStore>>
  134. class ComposeFstImplBase
  135. : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
  136. public:
  137. using FST = F;
  138. using Label = typename Arc::Label;
  139. using StateId = typename Arc::StateId;
  140. using Weight = typename Arc::Weight;
  141. using State = typename CacheStore::State;
  142. using CacheImpl = CacheBaseImpl<State, CacheStore>;
  143. using FstImpl<Arc>::SetType;
  144. using FstImpl<Arc>::SetProperties;
  145. using FstImpl<Arc>::Properties;
  146. using FstImpl<Arc>::SetInputSymbols;
  147. using FstImpl<Arc>::SetOutputSymbols;
  148. using CacheImpl::HasArcs;
  149. using CacheImpl::HasFinal;
  150. using CacheImpl::HasStart;
  151. using CacheImpl::SetFinal;
  152. using CacheImpl::SetStart;
  153. explicit ComposeFstImplBase(const CacheImplOptions<CacheStore> &opts)
  154. : CacheImpl(opts) {}
  155. explicit ComposeFstImplBase(const CacheOptions &opts) : CacheImpl(opts) {}
  156. ComposeFstImplBase(const ComposeFstImplBase &impl) : CacheImpl(impl, true) {
  157. SetType(impl.Type());
  158. SetProperties(impl.Properties(), kCopyProperties);
  159. SetInputSymbols(impl.InputSymbols());
  160. SetOutputSymbols(impl.OutputSymbols());
  161. }
  162. virtual ComposeFstImplBase *Copy() const = 0;
  163. ~ComposeFstImplBase() override = default;
  164. StateId Start() {
  165. if (!HasStart()) {
  166. const auto start = ComputeStart();
  167. if (start != kNoStateId) SetStart(start);
  168. }
  169. return CacheImpl::Start();
  170. }
  171. Weight Final(StateId s) {
  172. if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
  173. return CacheImpl::Final(s);
  174. }
  175. virtual void Expand(StateId s) = 0;
  176. size_t NumArcs(StateId s) {
  177. if (!HasArcs(s)) Expand(s);
  178. return CacheImpl::NumArcs(s);
  179. }
  180. size_t NumInputEpsilons(StateId s) {
  181. if (!HasArcs(s)) Expand(s);
  182. return CacheImpl::NumInputEpsilons(s);
  183. }
  184. size_t NumOutputEpsilons(StateId s) {
  185. if (!HasArcs(s)) Expand(s);
  186. return CacheImpl::NumOutputEpsilons(s);
  187. }
  188. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
  189. if (!HasArcs(s)) Expand(s);
  190. CacheImpl::InitArcIterator(s, data);
  191. }
  192. virtual MatcherBase<Arc> *InitMatcher(const F &fst,
  193. MatchType match_type) const {
  194. // Use the default matcher if no override is provided.
  195. return nullptr;
  196. }
  197. protected:
  198. virtual StateId ComputeStart() = 0;
  199. virtual Weight ComputeFinal(StateId s) = 0;
  200. };
  201. // Implementation of delayed composition templated on the matchers (see
  202. // matcher.h), composition filter (see compose-filter.h) and the composition
  203. // state table (see compose-state-table.h).
  204. template <class CacheStore, class Filter, class StateTable>
  205. class ComposeFstImpl
  206. : public ComposeFstImplBase<typename CacheStore::Arc, CacheStore> {
  207. public:
  208. using Matcher1 = typename Filter::Matcher1;
  209. using Matcher2 = typename Filter::Matcher2;
  210. using FST1 = typename Matcher1::FST;
  211. using FST2 = typename Matcher2::FST;
  212. using Arc = typename CacheStore::Arc;
  213. using Label = typename Arc::Label;
  214. using StateId = typename Arc::StateId;
  215. using Weight = typename Arc::Weight;
  216. using FilterState = typename Filter::FilterState;
  217. using State = typename CacheStore::State;
  218. using CacheImpl = CacheBaseImpl<State, CacheStore>;
  219. using StateTuple = typename StateTable::StateTuple;
  220. friend class ComposeFstMatcher<CacheStore, Filter, StateTable>;
  221. using FstImpl<Arc>::SetInputSymbols;
  222. using FstImpl<Arc>::SetOutputSymbols;
  223. using FstImpl<Arc>::SetType;
  224. using FstImpl<Arc>::SetProperties;
  225. template <class M1, class M2>
  226. ComposeFstImpl(const FST1 &fst1, const FST2 &fst2,
  227. const ComposeFstImplOptions<M1, M2, Filter, StateTable,
  228. CacheStore> &opts);
  229. ComposeFstImpl(const ComposeFstImpl &impl)
  230. : ComposeFstImplBase<Arc, CacheStore>(impl),
  231. filter_(new Filter(*impl.filter_, true)),
  232. matcher1_(filter_->GetMatcher1()),
  233. matcher2_(filter_->GetMatcher2()),
  234. fst1_(matcher1_->GetFst()),
  235. fst2_(matcher2_->GetFst()),
  236. state_table_(new StateTable(*impl.state_table_)),
  237. own_state_table_(true),
  238. match_type_(impl.match_type_) {}
  239. ~ComposeFstImpl() override {
  240. if (own_state_table_) delete state_table_;
  241. }
  242. ComposeFstImpl *Copy() const override { return new ComposeFstImpl(*this); }
  243. uint64_t Properties() const override { return Properties(kFstProperties); }
  244. // Sets error if found, and returns other FST impl properties.
  245. uint64_t Properties(uint64_t mask) const override {
  246. if ((mask & kError) &&
  247. (fst1_.Properties(kError, false) || fst2_.Properties(kError, false) ||
  248. (matcher1_->Properties(0) & kError) ||
  249. (matcher2_->Properties(0) & kError) |
  250. (filter_->Properties(0) & kError) ||
  251. state_table_->Error())) {
  252. SetProperties(kError, kError);
  253. }
  254. return FstImpl<Arc>::Properties(mask);
  255. }
  256. // Arranges it so that the first arg to OrderedExpand is the Fst
  257. // that will be matched on.
  258. void Expand(StateId s) override {
  259. const auto &tuple = state_table_->Tuple(s);
  260. const auto s1 = tuple.StateId1();
  261. const auto s2 = tuple.StateId2();
  262. filter_->SetState(s1, s2, tuple.GetFilterState());
  263. if (MatchInput(s1, s2)) {
  264. OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true);
  265. } else {
  266. OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false);
  267. }
  268. }
  269. const FST1 &GetFst1() const { return fst1_; }
  270. const FST2 &GetFst2() const { return fst2_; }
  271. const Matcher1 *GetMatcher1() const { return matcher1_; }
  272. Matcher1 *GetMatcher1() { return matcher1_; }
  273. const Matcher2 *GetMatcher2() const { return matcher2_; }
  274. Matcher2 *GetMatcher2() { return matcher2_; }
  275. const Filter *GetFilter() const { return filter_.get(); }
  276. Filter *GetFilter() { return filter_.get(); }
  277. const StateTable *GetStateTable() const { return state_table_; }
  278. StateTable *GetStateTable() { return state_table_; }
  279. MatcherBase<Arc> *InitMatcher(const ComposeFst<Arc, CacheStore> &fst,
  280. MatchType match_type) const override {
  281. const auto test_props = match_type == MATCH_INPUT
  282. ? kFstProperties & ~kILabelInvariantProperties
  283. : kFstProperties & ~kOLabelInvariantProperties;
  284. // If both matchers support 'match_type' and we have a guarantee that a
  285. // call to 'filter_->FilterArc(arc1, arc2)' will not modify the ilabel of
  286. // arc1 when MATCH_INPUT or the olabel or arc2 when MATCH_OUTPUT, then
  287. // ComposeFstMatcher can be used.
  288. if ((matcher1_->Type(false) == match_type) &&
  289. (matcher2_->Type(false) == match_type) &&
  290. (filter_->Properties(test_props) == test_props)) {
  291. return new ComposeFstMatcher<CacheStore, Filter, StateTable>(&fst,
  292. match_type);
  293. }
  294. return nullptr;
  295. }
  296. private:
  297. // This does that actual matching of labels in the composition. The
  298. // arguments are ordered so matching is called on state 'sa' of
  299. // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg
  300. // determines whether the input or output label of arcs at 'sb' is
  301. // the one to match on.
  302. template <class FST, class Matcher>
  303. void OrderedExpand(StateId s, const Fst<Arc> &, StateId sa, const FST &fstb,
  304. StateId sb, Matcher *matchera, bool match_input) {
  305. matchera->SetState(sa);
  306. // First processes non-consuming symbols (e.g., epsilons) on FSTA.
  307. const Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0,
  308. Weight::One(), sb);
  309. MatchArc(s, matchera, loop, match_input);
  310. // Then processes matches on FSTB.
  311. for (ArcIterator<FST> iterb(fstb, sb); !iterb.Done(); iterb.Next()) {
  312. MatchArc(s, matchera, iterb.Value(), match_input);
  313. }
  314. CacheImpl::SetArcs(s);
  315. }
  316. // Matches a single transition from 'fstb' against 'fata' at 's'.
  317. template <class Matcher>
  318. void MatchArc(StateId s, Matcher *matchera, const Arc &arc,
  319. bool match_input) {
  320. if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) {
  321. for (; !matchera->Done(); matchera->Next()) {
  322. auto arca = matchera->Value();
  323. auto arcb = arc;
  324. if (match_input) {
  325. const auto &fs = filter_->FilterArc(&arcb, &arca);
  326. if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs);
  327. } else {
  328. const auto &fs = filter_->FilterArc(&arca, &arcb);
  329. if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs);
  330. }
  331. }
  332. }
  333. }
  334. // Add a matching transition at 's'.
  335. void AddArc(StateId s, const Arc &arc1, const Arc &arc2,
  336. const FilterState &f) {
  337. const StateTuple tuple(arc1.nextstate, arc2.nextstate, f);
  338. CacheImpl::EmplaceArc(s, arc1.ilabel, arc2.olabel,
  339. Times(arc1.weight, arc2.weight),
  340. state_table_->FindState(tuple));
  341. }
  342. StateId ComputeStart() override {
  343. const auto s1 = fst1_.Start();
  344. if (s1 == kNoStateId) return kNoStateId;
  345. const auto s2 = fst2_.Start();
  346. if (s2 == kNoStateId) return kNoStateId;
  347. const auto &fs = filter_->Start();
  348. const StateTuple tuple(s1, s2, fs);
  349. return state_table_->FindState(tuple);
  350. }
  351. Weight ComputeFinal(StateId s) override {
  352. const auto &tuple = state_table_->Tuple(s);
  353. const auto s1 = tuple.StateId1();
  354. auto final1 = matcher1_->Final(s1);
  355. if (final1 == Weight::Zero()) return final1;
  356. const auto s2 = tuple.StateId2();
  357. auto final2 = matcher2_->Final(s2);
  358. if (final2 == Weight::Zero()) return final2;
  359. filter_->SetState(s1, s2, tuple.GetFilterState());
  360. filter_->FilterFinal(&final1, &final2);
  361. return Times(final1, final2);
  362. }
  363. // Determines which side to match on per composition state.
  364. bool MatchInput(StateId s1, StateId s2) {
  365. switch (match_type_) {
  366. case MATCH_INPUT:
  367. return true;
  368. case MATCH_OUTPUT:
  369. return false;
  370. default: // MATCH_BOTH
  371. const auto priority1 = matcher1_->Priority(s1);
  372. const auto priority2 = matcher2_->Priority(s2);
  373. if (priority1 == kRequirePriority && priority2 == kRequirePriority) {
  374. FSTERROR() << "ComposeFst: Both sides can't require match";
  375. SetProperties(kError, kError);
  376. return true;
  377. }
  378. if (priority1 == kRequirePriority) return false;
  379. if (priority2 == kRequirePriority) {
  380. return true;
  381. }
  382. return priority1 <= priority2;
  383. }
  384. }
  385. // Identifies and verifies the capabilities of the matcher to be used for
  386. // composition.
  387. void SetMatchType();
  388. std::unique_ptr<Filter> filter_;
  389. Matcher1 *matcher1_; // Borrowed reference.
  390. Matcher2 *matcher2_; // Borrowed reference.
  391. const FST1 &fst1_;
  392. const FST2 &fst2_;
  393. StateTable *state_table_;
  394. bool own_state_table_;
  395. MatchType match_type_;
  396. };
  397. template <class CacheStore, class Filter, class StateTable>
  398. template <class M1, class M2>
  399. ComposeFstImpl<CacheStore, Filter, StateTable>::ComposeFstImpl(
  400. const FST1 &fst1, const FST2 &fst2,
  401. const ComposeFstImplOptions<M1, M2, Filter, StateTable, CacheStore> &opts)
  402. : ComposeFstImplBase<Arc, CacheStore>(opts),
  403. filter_(opts.filter
  404. ? opts.filter
  405. : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)),
  406. matcher1_(filter_->GetMatcher1()),
  407. matcher2_(filter_->GetMatcher2()),
  408. fst1_(matcher1_->GetFst()),
  409. fst2_(matcher2_->GetFst()),
  410. state_table_(opts.state_table ? opts.state_table
  411. : new StateTable(fst1_, fst2_)),
  412. own_state_table_(opts.state_table ? opts.own_state_table : true) {
  413. SetType("compose");
  414. if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) {
  415. FSTERROR() << "ComposeFst: Output symbol table of 1st argument "
  416. << "does not match input symbol table of 2nd argument";
  417. SetProperties(kError, kError);
  418. }
  419. SetInputSymbols(fst1_.InputSymbols());
  420. SetOutputSymbols(fst2_.OutputSymbols());
  421. SetMatchType();
  422. VLOG(2) << "ComposeFstImpl: Match type: " << match_type_;
  423. if (match_type_ == MATCH_NONE) SetProperties(kError, kError);
  424. const auto fprops1 = fst1.Properties(kFstProperties, false);
  425. const auto fprops2 = fst2.Properties(kFstProperties, false);
  426. const auto mprops1 = matcher1_->Properties(fprops1);
  427. const auto mprops2 = matcher2_->Properties(fprops2);
  428. const auto cprops = ComposeProperties(mprops1, mprops2);
  429. SetProperties(filter_->Properties(cprops), kCopyProperties);
  430. if (state_table_->Error()) SetProperties(kError, kError);
  431. }
  432. template <class CacheStore, class Filter, class StateTable>
  433. void ComposeFstImpl<CacheStore, Filter, StateTable>::SetMatchType() {
  434. // Ensures any required matching is possible and known.
  435. if ((matcher1_->Flags() & kRequireMatch) &&
  436. matcher1_->Type(true) != MATCH_OUTPUT) {
  437. FSTERROR() << "ComposeFst: 1st argument cannot perform required matching "
  438. << "(sort?).";
  439. match_type_ = MATCH_NONE;
  440. return;
  441. }
  442. if ((matcher2_->Flags() & kRequireMatch) &&
  443. matcher2_->Type(true) != MATCH_INPUT) {
  444. FSTERROR() << "ComposeFst: 2nd argument cannot perform required matching "
  445. << "(sort?).";
  446. match_type_ = MATCH_NONE;
  447. return;
  448. }
  449. // Finds which sides to match on (favoring minimal testing of capabilities).
  450. const auto type1 = matcher1_->Type(false);
  451. const auto type2 = matcher2_->Type(false);
  452. if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) {
  453. match_type_ = MATCH_BOTH;
  454. } else if (type1 == MATCH_OUTPUT) {
  455. match_type_ = MATCH_OUTPUT;
  456. } else if (type2 == MATCH_INPUT) {
  457. match_type_ = MATCH_INPUT;
  458. } else if (matcher1_->Type(true) == MATCH_OUTPUT) {
  459. match_type_ = MATCH_OUTPUT;
  460. } else if (matcher2_->Type(true) == MATCH_INPUT) {
  461. match_type_ = MATCH_INPUT;
  462. } else {
  463. FSTERROR() << "ComposeFst: 1st argument cannot match on output labels "
  464. << "and 2nd argument cannot match on input labels (sort?).";
  465. match_type_ = MATCH_NONE;
  466. }
  467. }
  468. } // namespace internal
  469. // Computes the composition of two transducers. This version is a delayed FST.
  470. // If FST1 transduces string x to y with weight a and FST2 transduces y to z
  471. // with weight b, then their composition transduces string x to z with weight
  472. // Times(x, z).
  473. //
  474. // The output labels of the first transducer or the input labels of the second
  475. // transducer must be sorted (with the default matcher). The weights need to
  476. // form a commutative semiring (valid for TropicalWeight and LogWeight).
  477. //
  478. // Complexity:
  479. //
  480. // Assuming the first FST is unsorted and the second is sorted,
  481. //
  482. // Time: O(v1 v2 d1 (log d2 + m2)),
  483. // Space: O(v1 v2)
  484. //
  485. // where vi = # of states visited, di = maximum out-degree, and mi the
  486. // maximum multiplicity of the states visited, for the ith FST. Constant time
  487. // and space to visit an input state or arc is assumed and exclusive of caching.
  488. //
  489. // Caveats:
  490. // - ComposeFst does not trim its output (since it is a delayed operation).
  491. // - The efficiency of composition can be strongly affected by several factors:
  492. // - the choice of which transducer is sorted - prefer sorting the FST
  493. // that has the greater average out-degree.
  494. // - the amount of non-determinism
  495. // - the presence and location of epsilon transitions - avoid epsilon
  496. // transitions on the output side of the first transducer or
  497. // the input side of the second transducer or prefer placing
  498. // them later in a path since they delay matching and can
  499. // introduce non-coaccessible states and transitions.
  500. //
  501. // This class attaches interface to implementation and handles reference
  502. // counting, delegating most methods to ImplToFst. The CacheStore specifies the
  503. // cache store (default declared in fst-decl.h).
  504. template <class A, class CacheStore /* = DefaultCacheStore<A> */>
  505. class ComposeFst
  506. : public ImplToFst<internal::ComposeFstImplBase<A, CacheStore>> {
  507. public:
  508. using Arc = A;
  509. using StateId = typename Arc::StateId;
  510. using Weight = typename Arc::Weight;
  511. using Store = CacheStore;
  512. using State = typename CacheStore::State;
  513. using Impl = internal::ComposeFstImplBase<A, CacheStore>;
  514. friend class ArcIterator<ComposeFst<Arc, CacheStore>>;
  515. friend class StateIterator<ComposeFst<Arc, CacheStore>>;
  516. template <class, class, class>
  517. friend class ComposeFstMatcher;
  518. // Compose specifying only caching options.
  519. ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
  520. const CacheOptions &opts = CacheOptions())
  521. : ImplToFst<Impl>(CreateBase(fst1, fst2, opts)) {}
  522. // Compose specifying one shared matcher type M. Requires that the input FSTs
  523. // and matcher FST types be Fst<Arc>. Recommended for best code-sharing and
  524. // matcher compatibility.
  525. template <class Matcher, class Filter, class StateTuple>
  526. ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
  527. const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts)
  528. : ImplToFst<Impl>(CreateBase1(fst1, fst2, opts)) {}
  529. // Compose specifying two matcher types Matcher1 and Matcher2. Requires input
  530. // FST (of the same Arc type, but o.w. arbitrary) match the corresponding
  531. // matcher FST types). Recommended only for advanced use in demanding or
  532. // specialized applications due to potential code bloat and matcher
  533. // incompatibilities.
  534. template <class Matcher1, class Matcher2, class Filter, class StateTuple>
  535. ComposeFst(const typename Matcher1::FST &fst1,
  536. const typename Matcher2::FST &fst2,
  537. const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
  538. CacheStore> &opts)
  539. : ImplToFst<Impl>(CreateBase2(fst1, fst2, opts)) {}
  540. // See Fst<>::Copy() for doc.
  541. ComposeFst(const ComposeFst &fst, bool safe = false)
  542. : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
  543. : fst.GetSharedImpl()) {}
  544. // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc.
  545. ComposeFst *Copy(bool safe = false) const override {
  546. return new ComposeFst(*this, safe);
  547. }
  548. inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
  549. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  550. GetMutableImpl()->InitArcIterator(s, data);
  551. }
  552. MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
  553. return GetImpl()->InitMatcher(*this, match_type);
  554. }
  555. protected:
  556. using ImplToFst<Impl>::GetImpl;
  557. using ImplToFst<Impl>::GetMutableImpl;
  558. explicit ComposeFst(std::shared_ptr<Impl> impl) : ImplToFst<Impl>(impl) {}
  559. // Create compose implementation specifying two matcher types.
  560. template <class Matcher1, class Matcher2, class Filter, class StateTuple>
  561. static std::shared_ptr<Impl> CreateBase2(
  562. const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2,
  563. const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
  564. CacheStore> &opts) {
  565. auto impl = std::make_shared<
  566. internal::ComposeFstImpl<CacheStore, Filter, StateTuple>>(fst1, fst2,
  567. opts);
  568. if (!(Weight::Properties() & kCommutative) && !opts.allow_noncommute) {
  569. const auto props1 = fst1.Properties(kUnweighted, true);
  570. const auto props2 = fst2.Properties(kUnweighted, true);
  571. if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) {
  572. FSTERROR() << "ComposeFst: Weights must be a commutative semiring: "
  573. << Weight::Type();
  574. impl->SetProperties(kError, kError);
  575. }
  576. }
  577. return impl;
  578. }
  579. // Create compose implementation specifying one matcher type; requires that
  580. // input and matcher FST types be Fst<Arc>.
  581. template <class Matcher, class Filter, class StateTuple>
  582. static std::shared_ptr<Impl> CreateBase1(
  583. const Fst<Arc> &fst1, const Fst<Arc> &fst2,
  584. const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts) {
  585. ComposeFstImplOptions<Matcher, Matcher, Filter, StateTuple, CacheStore>
  586. nopts(opts, opts.matcher1, opts.matcher2, opts.filter,
  587. opts.state_table);
  588. return CreateBase2(fst1, fst2, nopts);
  589. }
  590. // Create compose implementation specifying no matcher type.
  591. static std::shared_ptr<Impl> CreateBase(const Fst<Arc> &fst1,
  592. const Fst<Arc> &fst2,
  593. const CacheOptions &opts) {
  594. switch (LookAheadMatchType(fst1, fst2)) { // Check for lookahead matchers
  595. default:
  596. case MATCH_NONE: { // Default composition (no look-ahead).
  597. ComposeFstOptions<Arc> nopts(opts);
  598. return CreateBase1(fst1, fst2, nopts);
  599. }
  600. case MATCH_OUTPUT: { // Lookahead on fst1.
  601. using M = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::FstMatcher;
  602. using F = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::ComposeFilter;
  603. ComposeFstOptions<Arc, M, F> nopts(opts);
  604. return CreateBase1(fst1, fst2, nopts);
  605. }
  606. case MATCH_INPUT: { // Lookahead on fst2
  607. using M = typename DefaultLookAhead<Arc, MATCH_INPUT>::FstMatcher;
  608. using F = typename DefaultLookAhead<Arc, MATCH_INPUT>::ComposeFilter;
  609. ComposeFstOptions<Arc, M, F> nopts(opts);
  610. return CreateBase1(fst1, fst2, nopts);
  611. }
  612. }
  613. }
  614. private:
  615. ComposeFst &operator=(const ComposeFst &fst) = delete;
  616. };
  617. // Specialization for ComposeFst.
  618. template <class Arc, class CacheStore>
  619. class StateIterator<ComposeFst<Arc, CacheStore>>
  620. : public CacheStateIterator<ComposeFst<Arc, CacheStore>> {
  621. public:
  622. explicit StateIterator(const ComposeFst<Arc, CacheStore> &fst)
  623. : CacheStateIterator<ComposeFst<Arc, CacheStore>>(fst,
  624. fst.GetMutableImpl()) {}
  625. };
  626. // Specialization for ComposeFst.
  627. template <class Arc, class CacheStore>
  628. class ArcIterator<ComposeFst<Arc, CacheStore>>
  629. : public CacheArcIterator<ComposeFst<Arc, CacheStore>> {
  630. public:
  631. using StateId = typename Arc::StateId;
  632. ArcIterator(const ComposeFst<Arc, CacheStore> &fst, StateId s)
  633. : CacheArcIterator<ComposeFst<Arc, CacheStore>>(fst.GetMutableImpl(), s) {
  634. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  635. }
  636. };
  637. template <class Arc, class CacheStore>
  638. inline void ComposeFst<Arc, CacheStore>::InitStateIterator(
  639. StateIteratorData<Arc> *data) const {
  640. data->base =
  641. std::make_unique<StateIterator<ComposeFst<Arc, CacheStore>>>(*this);
  642. }
  643. // Specialized matcher for ComposeFst. Supports MATCH_INPUT or MATCH_OUTPUT,
  644. // iff the underlying matchers for the two FSTS being composed support
  645. // MATCH_INPUT or MATCH_OUTPUT, respectively.
  646. template <class CacheStore, class Filter, class StateTable>
  647. class ComposeFstMatcher : public MatcherBase<typename CacheStore::Arc> {
  648. public:
  649. using Arc = typename CacheStore::Arc;
  650. using Label = typename Arc::Label;
  651. using StateId = typename Arc::StateId;
  652. using Weight = typename Arc::Weight;
  653. using Matcher1 = typename Filter::Matcher1;
  654. using Matcher2 = typename Filter::Matcher2;
  655. using FilterState = typename Filter::FilterState;
  656. using StateTuple = typename StateTable::StateTuple;
  657. using Impl = internal::ComposeFstImpl<CacheStore, Filter, StateTable>;
  658. // The compose FST arg must match the filter and state table types.
  659. // This makes a copy of the FST.
  660. ComposeFstMatcher(const ComposeFst<Arc, CacheStore> &fst,
  661. MatchType match_type)
  662. : owned_fst_(fst.Copy()),
  663. fst_(*owned_fst_),
  664. impl_(down_cast<const Impl *>(fst_.GetImpl())),
  665. s_(kNoStateId),
  666. match_type_(match_type),
  667. matcher1_(impl_->matcher1_->Copy()),
  668. matcher2_(impl_->matcher2_->Copy()),
  669. current_loop_(false),
  670. loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
  671. if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
  672. }
  673. // The compose FST arg must match the filter and state table types.
  674. // This doesn't copy the FST (although it may copy components).
  675. ComposeFstMatcher(const ComposeFst<Arc, CacheStore> *fst,
  676. MatchType match_type)
  677. : fst_(*fst),
  678. impl_(down_cast<const Impl *>(fst_.GetImpl())),
  679. s_(kNoStateId),
  680. match_type_(match_type),
  681. matcher1_(impl_->matcher1_->Copy()),
  682. matcher2_(impl_->matcher2_->Copy()),
  683. current_loop_(false),
  684. loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
  685. if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
  686. }
  687. // This makes a copy of the FST.
  688. ComposeFstMatcher(
  689. const ComposeFstMatcher<CacheStore, Filter, StateTable> &matcher,
  690. bool safe = false)
  691. : owned_fst_(matcher.fst_.Copy(safe)),
  692. fst_(*owned_fst_),
  693. impl_(down_cast<const Impl *>(fst_.GetImpl())),
  694. s_(kNoStateId),
  695. match_type_(matcher.match_type_),
  696. matcher1_(matcher.matcher1_->Copy(safe)),
  697. matcher2_(matcher.matcher2_->Copy(safe)),
  698. current_loop_(false),
  699. loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
  700. if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
  701. }
  702. ComposeFstMatcher *Copy(bool safe = false) const override {
  703. return new ComposeFstMatcher(*this, safe);
  704. }
  705. MatchType Type(bool test) const override {
  706. if ((matcher1_->Type(test) == MATCH_NONE) ||
  707. (matcher2_->Type(test) == MATCH_NONE)) {
  708. return MATCH_NONE;
  709. }
  710. if (((matcher1_->Type(test) == MATCH_UNKNOWN) &&
  711. (matcher2_->Type(test) == MATCH_UNKNOWN)) ||
  712. ((matcher1_->Type(test) == MATCH_UNKNOWN) &&
  713. (matcher2_->Type(test) == match_type_)) ||
  714. ((matcher1_->Type(test) == match_type_) &&
  715. (matcher2_->Type(test) == MATCH_UNKNOWN))) {
  716. return MATCH_UNKNOWN;
  717. }
  718. if ((matcher1_->Type(test) == match_type_) &&
  719. (matcher2_->Type(test) == match_type_)) {
  720. return match_type_;
  721. }
  722. return MATCH_NONE;
  723. }
  724. const Fst<Arc> &GetFst() const override { return fst_; }
  725. uint64_t Properties(uint64_t inprops) const override { return inprops; }
  726. void SetState(StateId s) final {
  727. if (s_ == s) return;
  728. s_ = s;
  729. const auto &tuple = impl_->state_table_->Tuple(s);
  730. matcher1_->SetState(tuple.StateId1());
  731. matcher2_->SetState(tuple.StateId2());
  732. loop_.nextstate = s_;
  733. }
  734. bool Find(Label label) final {
  735. bool found = false;
  736. current_loop_ = false;
  737. if (label == 0) {
  738. current_loop_ = true;
  739. found = true;
  740. }
  741. if (match_type_ == MATCH_INPUT) {
  742. found = found || FindLabel(label, matcher1_.get(), matcher2_.get());
  743. } else { // match_type_ == MATCH_OUTPUT
  744. found = found || FindLabel(label, matcher2_.get(), matcher1_.get());
  745. }
  746. return found;
  747. }
  748. bool Done() const final {
  749. return !current_loop_ && matcher1_->Done() && matcher2_->Done();
  750. }
  751. const Arc &Value() const final { return current_loop_ ? loop_ : arc_; }
  752. void Next() final {
  753. if (current_loop_) {
  754. current_loop_ = false;
  755. } else if (match_type_ == MATCH_INPUT) {
  756. FindNext(matcher1_.get(), matcher2_.get());
  757. } else { // match_type_ == MATCH_OUTPUT
  758. FindNext(matcher2_.get(), matcher1_.get());
  759. }
  760. }
  761. ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
  762. private:
  763. // Processes a match with the filter and creates resulting arc.
  764. bool MatchArc(StateId s, Arc *arc1, Arc *arc2) {
  765. const auto &fs = impl_->filter_->FilterArc(arc1, arc2);
  766. if (fs == FilterState::NoState()) return false;
  767. const StateTuple tuple(arc1->nextstate, arc2->nextstate, fs);
  768. arc_.ilabel = arc1->ilabel;
  769. arc_.olabel = arc2->olabel;
  770. arc_.weight = Times(arc1->weight, arc2->weight);
  771. arc_.nextstate = impl_->state_table_->FindState(tuple);
  772. return true;
  773. }
  774. // Finds the first match allowed by the filter.
  775. template <class MatcherA, class MatcherB>
  776. bool FindLabel(Label label, MatcherA *matchera, MatcherB *matcherb) {
  777. if (matchera->Find(label)) {
  778. matcherb->Find(match_type_ == MATCH_INPUT ? matchera->Value().olabel
  779. : matchera->Value().ilabel);
  780. return FindNext(matchera, matcherb);
  781. }
  782. return false;
  783. }
  784. // Finds the next match allowed by the filter, returning true iff such a
  785. // match is found.
  786. template <class MatcherA, class MatcherB>
  787. bool FindNext(MatcherA *matchera, MatcherB *matcherb) {
  788. // State when entering this function:
  789. // 'matchera' is pointed to a match x, y for label x, and a match for y was
  790. // requested on 'matcherb'.
  791. while (!matchera->Done() || !matcherb->Done()) {
  792. if (matcherb->Done()) {
  793. // If no more matches for y on 'matcherb', moves forward on 'matchera'
  794. // until a match x, y' is found such that there is a match for y' on
  795. // 'matcherb'.
  796. matchera->Next();
  797. while (!matchera->Done() &&
  798. !matcherb->Find(match_type_ == MATCH_INPUT
  799. ? matchera->Value().olabel
  800. : matchera->Value().ilabel)) {
  801. matchera->Next();
  802. }
  803. }
  804. while (!matcherb->Done()) {
  805. // 'matchera' is pointing to a match x, y' ('arca') and 'matcherb' is
  806. // pointing to a match y', z' ('arcb'). If combining these two arcs is
  807. // allowed by the filter (hence resulting in an arc x, z') return true.
  808. // Position 'matcherb' on the next potential match for y' before
  809. // returning.
  810. auto arca = matchera->Value();
  811. auto arcb = matcherb->Value();
  812. // Position 'matcherb' on the next potential match for y'.
  813. matcherb->Next();
  814. // Returns true If combining these two arcs is allowed by the filter
  815. // (hence resulting in an arc x, z'); otherwise consider next match
  816. // for y' on 'matcherb'.
  817. if (match_type_ == MATCH_INPUT) {
  818. if (MatchArc(s_, &arca, &arcb)) return true;
  819. } else {
  820. if (MatchArc(s_, &arcb, &arca)) return true;
  821. }
  822. }
  823. }
  824. // Both 'matchera' and 'matcherb' are done, no more match to analyse.
  825. return false;
  826. }
  827. std::unique_ptr<const ComposeFst<Arc, CacheStore>> owned_fst_;
  828. const ComposeFst<Arc, CacheStore> &fst_;
  829. const Impl *impl_;
  830. StateId s_;
  831. MatchType match_type_;
  832. std::unique_ptr<Matcher1> matcher1_;
  833. std::unique_ptr<Matcher2> matcher2_;
  834. bool current_loop_;
  835. Arc loop_;
  836. Arc arc_;
  837. };
  838. // Useful alias when using StdArc.
  839. using StdComposeFst = ComposeFst<StdArc>;
  840. enum ComposeFilter {
  841. AUTO_FILTER,
  842. NULL_FILTER,
  843. TRIVIAL_FILTER,
  844. SEQUENCE_FILTER,
  845. ALT_SEQUENCE_FILTER,
  846. MATCH_FILTER,
  847. NO_MATCH_FILTER
  848. };
  849. struct ComposeOptions {
  850. bool connect; // Connect output?
  851. ComposeFilter filter_type; // Pre-defined filter to use.
  852. explicit ComposeOptions(bool connect = true,
  853. ComposeFilter filter_type = AUTO_FILTER)
  854. : connect(connect), filter_type(filter_type) {}
  855. };
  856. // Computes the composition of two transducers. This version writes
  857. // the composed FST into a MutableFst. If FST1 transduces string x to
  858. // y with weight a and FST2 transduces y to z with weight b, then
  859. // their composition transduces string x to z with weight
  860. // Times(a, b).
  861. //
  862. // The output labels of the first transducer or the input labels of
  863. // the second transducer must be sorted. The weights need to form a
  864. // commutative semiring (valid for TropicalWeight and LogWeight).
  865. //
  866. // Complexity:
  867. //
  868. // Assuming the first FST is unsorted and the second is sorted:
  869. //
  870. // Time: O(V1 V2 D1 (log D2 + M2)),
  871. // Space: O(V1 V2 D1 M2)
  872. //
  873. // where Vi = # of states, Di = maximum out-degree, and Mi is the maximum
  874. // multiplicity, for the ith FST.
  875. //
  876. // Caveats:
  877. //
  878. // - Compose trims its output.
  879. // - The efficiency of composition can be strongly affected by several factors:
  880. // - the choice of which transducer is sorted - prefer sorting the FST
  881. // that has the greater average out-degree.
  882. // - the amount of non-determinism
  883. // - the presence and location of epsilon transitions - avoid epsilon
  884. // transitions on the output side of the first transducer or
  885. // the input side of the second transducer or prefer placing
  886. // them later in a path since they delay matching and can
  887. // introduce non-coaccessible states and transitions.
  888. template <class Arc>
  889. void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
  890. MutableFst<Arc> *ofst,
  891. const ComposeOptions &opts = ComposeOptions()) {
  892. using M = Matcher<Fst<Arc>>;
  893. // In each case, we cache only the last state for fastest copy.
  894. switch (opts.filter_type) {
  895. case AUTO_FILTER: {
  896. CacheOptions nopts;
  897. nopts.gc_limit = 0;
  898. *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
  899. break;
  900. }
  901. case NULL_FILTER: {
  902. ComposeFstOptions<Arc, M, NullComposeFilter<M>> copts;
  903. copts.gc_limit = 0;
  904. *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  905. break;
  906. }
  907. case SEQUENCE_FILTER: {
  908. ComposeFstOptions<Arc, M, SequenceComposeFilter<M>> copts;
  909. copts.gc_limit = 0;
  910. *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  911. break;
  912. }
  913. case ALT_SEQUENCE_FILTER: {
  914. ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>> copts;
  915. copts.gc_limit = 0;
  916. *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  917. break;
  918. }
  919. case MATCH_FILTER: {
  920. ComposeFstOptions<Arc, M, MatchComposeFilter<M>> copts;
  921. copts.gc_limit = 0;
  922. *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  923. break;
  924. }
  925. case NO_MATCH_FILTER: {
  926. ComposeFstOptions<Arc, M, NoMatchComposeFilter<M>> copts;
  927. copts.gc_limit = 0;
  928. *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  929. break;
  930. }
  931. case TRIVIAL_FILTER: {
  932. ComposeFstOptions<Arc, M, TrivialComposeFilter<M>> copts;
  933. copts.gc_limit = 0;
  934. *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  935. break;
  936. }
  937. }
  938. if (opts.connect) Connect(ofst);
  939. }
  940. } // namespace fst
  941. #endif // FST_COMPOSE_H_