You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

640 lines
19 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to map over/transform states e.g., sort transitions.
  19. //
  20. // Consider using when operation does not change the number of states.
  21. #ifndef FST_STATE_MAP_H_
  22. #define FST_STATE_MAP_H_
  23. #include <sys/types.h>
  24. #include <algorithm>
  25. #include <cstddef>
  26. #include <cstdint>
  27. #include <memory>
  28. #include <string>
  29. #include <utility>
  30. #include <vector>
  31. #include <fst/log.h>
  32. #include <fst/arc-map.h>
  33. #include <fst/arc.h>
  34. #include <fst/cache.h>
  35. #include <fst/expanded-fst.h>
  36. #include <fst/float-weight.h>
  37. #include <fst/fst.h>
  38. #include <fst/impl-to-fst.h>
  39. #include <fst/mutable-fst.h>
  40. #include <fst/properties.h>
  41. namespace fst {
  42. // StateMapper Interface. The class determines how states are mapped; useful for
  43. // implementing operations that do not change the number of states.
  44. //
  45. // class StateMapper {
  46. // public:
  47. // using FromArc = ...;
  48. // using ToArc = ...;
  49. //
  50. // // Typical constructor.
  51. // StateMapper(const Fst<FromArc> &fst);
  52. //
  53. // // Required copy constructor that allows updating FST argument;
  54. // // pass only if relevant and changed.
  55. // StateMapper(const StateMapper &mapper, const Fst<FromArc> *fst = 0);
  56. //
  57. // // Specifies initial state of result.
  58. // ToArc::StateId Start() const;
  59. // // Specifies state's final weight in result.
  60. // ToArc::Weight Final(ToArc::StateId state) const;
  61. //
  62. // // These methods iterate through a state's arcs in result.
  63. //
  64. // // Specifies state to iterate over.
  65. // void SetState(ToArc::StateId state);
  66. //
  67. // // End of arcs?
  68. // bool Done() const;
  69. //
  70. // // Current arc.
  71. // const ToArc &Value() const;
  72. //
  73. // // Advances to next arc (when !Done)
  74. // void Next();
  75. //
  76. // // Specifies input symbol table action the mapper requires (see above).
  77. // MapSymbolsAction InputSymbolsAction() const;
  78. //
  79. // // Specifies output symbol table action the mapper requires (see above).
  80. // MapSymbolsAction OutputSymbolsAction() const;
  81. //
  82. // // This specifies the known properties of an FST mapped by this
  83. // // mapper. It takes as argument the input FST's known properties.
  84. // uint64_t Properties(uint64_t props) const;
  85. // };
  86. //
  87. // We include a various state map versions below. One dimension of variation is
  88. // whether the mapping mutates its input, writes to a new result FST, or is an
  89. // on-the-fly Fst. Another dimension is how we pass the mapper. We allow passing
  90. // the mapper by pointer for cases that we need to change the state of the
  91. // user's mapper. We also include map versions that pass the mapper by value or
  92. // const reference when this suffices.
  93. // Maps an arc type A using a mapper function object C, passed by pointer. This
  94. // version modifies the input FST.
  95. template <class A, class C>
  96. void StateMap(MutableFst<A> *fst, C *mapper) {
  97. if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  98. fst->SetInputSymbols(nullptr);
  99. }
  100. if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  101. fst->SetOutputSymbols(nullptr);
  102. }
  103. if (fst->Start() == kNoStateId) return;
  104. const auto props = fst->Properties(kFstProperties, false);
  105. fst->SetStart(mapper->Start());
  106. for (StateIterator<Fst<A>> siter(*fst); !siter.Done(); siter.Next()) {
  107. const auto state = siter.Value();
  108. mapper->SetState(state);
  109. fst->DeleteArcs(state);
  110. for (; !mapper->Done(); mapper->Next()) {
  111. fst->AddArc(state, mapper->Value());
  112. }
  113. fst->SetFinal(state, mapper->Final(state));
  114. }
  115. fst->SetProperties(mapper->Properties(props), kFstProperties);
  116. }
  117. // Maps an arc type A using a mapper function object C, passed by value.
  118. // This version modifies the input FST.
  119. template <class A, class C>
  120. void StateMap(MutableFst<A> *fst, C mapper) {
  121. StateMap(fst, &mapper);
  122. }
  123. // Maps an arc type A to an arc type B using mapper functor C, passed by
  124. // pointer. This version writes to an output FST.
  125. template <class A, class B, class C>
  126. void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C *mapper) {
  127. ofst->DeleteStates();
  128. if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
  129. ofst->SetInputSymbols(ifst.InputSymbols());
  130. } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  131. ofst->SetInputSymbols(nullptr);
  132. }
  133. if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
  134. ofst->SetOutputSymbols(ifst.OutputSymbols());
  135. } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  136. ofst->SetOutputSymbols(nullptr);
  137. }
  138. const auto iprops = ifst.Properties(kCopyProperties, false);
  139. if (ifst.Start() == kNoStateId) {
  140. if (iprops & kError) ofst->SetProperties(kError, kError);
  141. return;
  142. }
  143. // Adds all states.
  144. if (std::optional<typename A::StateId> num_states = ifst.NumStatesIfKnown()) {
  145. ofst->ReserveStates(*num_states);
  146. }
  147. for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
  148. ofst->AddState();
  149. }
  150. ofst->SetStart(mapper->Start());
  151. for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
  152. const auto state = siter.Value();
  153. mapper->SetState(state);
  154. for (; !mapper->Done(); mapper->Next()) {
  155. ofst->AddArc(state, mapper->Value());
  156. }
  157. ofst->SetFinal(state, mapper->Final(state));
  158. }
  159. const auto oprops = ofst->Properties(kFstProperties, false);
  160. ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
  161. }
  162. // Maps an arc type A to an arc type B using mapper functor object C, passed by
  163. // value. This version writes to an output FST.
  164. template <class A, class B, class C>
  165. void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
  166. StateMap(ifst, ofst, &mapper);
  167. }
  168. using StateMapFstOptions = CacheOptions;
  169. template <class A, class B, class C>
  170. class StateMapFst;
  171. // Facade around StateIteratorBase<A> inheriting from StateIteratorBase<B>.
  172. template <class A, class B>
  173. class StateMapStateIteratorBase : public StateIteratorBase<B> {
  174. public:
  175. using Arc = B;
  176. using StateId = typename Arc::StateId;
  177. explicit StateMapStateIteratorBase(std::unique_ptr<StateIteratorBase<A>> base)
  178. : base_(std::move(base)) {}
  179. bool Done() const final { return base_->Done(); }
  180. StateId Value() const final { return base_->Value(); }
  181. void Next() final { base_->Next(); }
  182. void Reset() final { base_->Reset(); }
  183. private:
  184. std::unique_ptr<StateIteratorBase<A>> base_;
  185. StateMapStateIteratorBase() = delete;
  186. };
  187. namespace internal {
  188. // Implementation of delayed StateMapFst.
  189. template <class A, class B, class C>
  190. class StateMapFstImpl : public CacheImpl<B> {
  191. public:
  192. using Arc = B;
  193. using StateId = typename Arc::StateId;
  194. using Weight = typename Arc::Weight;
  195. using FstImpl<B>::SetType;
  196. using FstImpl<B>::SetProperties;
  197. using FstImpl<B>::SetInputSymbols;
  198. using FstImpl<B>::SetOutputSymbols;
  199. using CacheImpl<B>::PushArc;
  200. using CacheImpl<B>::HasArcs;
  201. using CacheImpl<B>::HasFinal;
  202. using CacheImpl<B>::HasStart;
  203. using CacheImpl<B>::SetArcs;
  204. using CacheImpl<B>::SetFinal;
  205. using CacheImpl<B>::SetStart;
  206. friend class StateIterator<StateMapFst<A, B, C>>;
  207. StateMapFstImpl(const Fst<A> &fst, const C &mapper,
  208. const StateMapFstOptions &opts)
  209. : CacheImpl<B>(opts),
  210. fst_(fst.Copy()),
  211. mapper_(new C(mapper, fst_.get())),
  212. own_mapper_(true) {
  213. Init();
  214. }
  215. StateMapFstImpl(const Fst<A> &fst, C *mapper, const StateMapFstOptions &opts)
  216. : CacheImpl<B>(opts),
  217. fst_(fst.Copy()),
  218. mapper_(mapper),
  219. own_mapper_(false) {
  220. Init();
  221. }
  222. StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl)
  223. : CacheImpl<B>(impl),
  224. fst_(impl.fst_->Copy(true)),
  225. mapper_(new C(*impl.mapper_, fst_.get())),
  226. own_mapper_(true) {
  227. Init();
  228. }
  229. ~StateMapFstImpl() override {
  230. if (own_mapper_) delete mapper_;
  231. }
  232. StateId Start() {
  233. if (!HasStart()) SetStart(mapper_->Start());
  234. return CacheImpl<B>::Start();
  235. }
  236. Weight Final(StateId state) {
  237. if (!HasFinal(state)) SetFinal(state, mapper_->Final(state));
  238. return CacheImpl<B>::Final(state);
  239. }
  240. size_t NumArcs(StateId state) {
  241. if (!HasArcs(state)) Expand(state);
  242. return CacheImpl<B>::NumArcs(state);
  243. }
  244. size_t NumInputEpsilons(StateId state) {
  245. if (!HasArcs(state)) Expand(state);
  246. return CacheImpl<B>::NumInputEpsilons(state);
  247. }
  248. size_t NumOutputEpsilons(StateId state) {
  249. if (!HasArcs(state)) Expand(state);
  250. return CacheImpl<B>::NumOutputEpsilons(state);
  251. }
  252. void InitStateIterator(StateIteratorData<B> *datb) const {
  253. StateIteratorData<A> data;
  254. fst_->InitStateIterator(&data);
  255. datb->base = data.base ? std::make_unique<StateMapStateIteratorBase<A, B>>(
  256. std::move(data.base))
  257. : nullptr;
  258. datb->nstates = data.nstates;
  259. }
  260. void InitArcIterator(StateId state, ArcIteratorData<B> *data) {
  261. if (!HasArcs(state)) Expand(state);
  262. CacheImpl<B>::InitArcIterator(state, data);
  263. }
  264. uint64_t Properties() const override { return Properties(kFstProperties); }
  265. uint64_t Properties(uint64_t mask) const override {
  266. if ((mask & kError) && (fst_->Properties(kError, false) ||
  267. (mapper_->Properties(0) & kError))) {
  268. SetProperties(kError, kError);
  269. }
  270. return FstImpl<Arc>::Properties(mask);
  271. }
  272. void Expand(StateId state) {
  273. // Adds exiting arcs.
  274. for (mapper_->SetState(state); !mapper_->Done(); mapper_->Next()) {
  275. PushArc(state, mapper_->Value());
  276. }
  277. SetArcs(state);
  278. }
  279. const Fst<A> *GetFst() const { return fst_.get(); }
  280. private:
  281. void Init() {
  282. SetType("statemap");
  283. if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
  284. SetInputSymbols(fst_->InputSymbols());
  285. } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  286. SetInputSymbols(nullptr);
  287. }
  288. if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
  289. SetOutputSymbols(fst_->OutputSymbols());
  290. } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  291. SetOutputSymbols(nullptr);
  292. }
  293. const auto props = fst_->Properties(kCopyProperties, false);
  294. SetProperties(mapper_->Properties(props));
  295. }
  296. std::unique_ptr<const Fst<A>> fst_;
  297. C *mapper_;
  298. bool own_mapper_;
  299. };
  300. } // namespace internal
  301. // Maps an arc type A to an arc type B using Mapper function object
  302. // C. This version is a delayed FST.
  303. template <class A, class B, class C>
  304. class StateMapFst : public ImplToFst<internal::StateMapFstImpl<A, B, C>> {
  305. public:
  306. friend class ArcIterator<StateMapFst<A, B, C>>;
  307. using Arc = B;
  308. using StateId = typename Arc::StateId;
  309. using Weight = typename Arc::Weight;
  310. using Store = DefaultCacheStore<Arc>;
  311. using State = typename Store::State;
  312. using Impl = internal::StateMapFstImpl<A, B, C>;
  313. StateMapFst(const Fst<A> &fst, const C &mapper,
  314. const StateMapFstOptions &opts)
  315. : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
  316. StateMapFst(const Fst<A> &fst, C *mapper, const StateMapFstOptions &opts)
  317. : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
  318. StateMapFst(const Fst<A> &fst, const C &mapper)
  319. : ImplToFst<Impl>(
  320. std::make_shared<Impl>(fst, mapper, StateMapFstOptions())) {}
  321. StateMapFst(const Fst<A> &fst, C *mapper)
  322. : ImplToFst<Impl>(
  323. std::make_shared<Impl>(fst, mapper, StateMapFstOptions())) {}
  324. // See Fst<>::Copy() for doc.
  325. StateMapFst(const StateMapFst &fst, bool safe = false)
  326. : ImplToFst<Impl>(fst, safe) {}
  327. // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc.
  328. StateMapFst *Copy(bool safe = false) const override {
  329. return new StateMapFst(*this, safe);
  330. }
  331. void InitStateIterator(StateIteratorData<B> *data) const override {
  332. GetImpl()->InitStateIterator(data);
  333. }
  334. void InitArcIterator(StateId state, ArcIteratorData<B> *data) const override {
  335. GetMutableImpl()->InitArcIterator(state, data);
  336. }
  337. protected:
  338. using ImplToFst<Impl>::GetImpl;
  339. using ImplToFst<Impl>::GetMutableImpl;
  340. private:
  341. StateMapFst &operator=(const StateMapFst &) = delete;
  342. };
  343. // Specialization for StateMapFst.
  344. template <class A, class B, class C>
  345. class ArcIterator<StateMapFst<A, B, C>>
  346. : public CacheArcIterator<StateMapFst<A, B, C>> {
  347. public:
  348. using StateId = typename A::StateId;
  349. ArcIterator(const StateMapFst<A, B, C> &fst, StateId state)
  350. : CacheArcIterator<StateMapFst<A, B, C>>(fst.GetMutableImpl(), state) {
  351. if (!fst.GetImpl()->HasArcs(state)) fst.GetMutableImpl()->Expand(state);
  352. }
  353. };
  354. // Utility mappers.
  355. // Mapper that returns its input.
  356. template <class Arc>
  357. class IdentityStateMapper {
  358. public:
  359. using FromArc = Arc;
  360. using ToArc = Arc;
  361. using StateId = typename Arc::StateId;
  362. using Weight = typename Arc::Weight;
  363. explicit IdentityStateMapper(const Fst<Arc> &fst) : fst_(fst) {}
  364. // Allows updating FST argument; pass only if changed.
  365. IdentityStateMapper(const IdentityStateMapper<Arc> &mapper,
  366. const Fst<Arc> *fst = nullptr)
  367. : fst_(fst ? *fst : mapper.fst_) {}
  368. StateId Start() const { return fst_.Start(); }
  369. Weight Final(StateId state) const { return fst_.Final(state); }
  370. void SetState(StateId state) {
  371. aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst_, state);
  372. }
  373. bool Done() const { return aiter_->Done(); }
  374. const Arc &Value() const { return aiter_->Value(); }
  375. void Next() { aiter_->Next(); }
  376. constexpr MapSymbolsAction InputSymbolsAction() const {
  377. return MAP_COPY_SYMBOLS;
  378. }
  379. constexpr MapSymbolsAction OutputSymbolsAction() const {
  380. return MAP_COPY_SYMBOLS;
  381. }
  382. uint64_t Properties(uint64_t props) const { return props; }
  383. private:
  384. const Fst<Arc> &fst_;
  385. std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
  386. };
  387. template <class Arc>
  388. class ArcSumMapper {
  389. public:
  390. using FromArc = Arc;
  391. using ToArc = Arc;
  392. using StateId = typename Arc::StateId;
  393. using Weight = typename Arc::Weight;
  394. explicit ArcSumMapper(const Fst<Arc> &fst) : fst_(fst), i_(0) {}
  395. // Allows updating FST argument; pass only if changed.
  396. ArcSumMapper(const ArcSumMapper<Arc> &mapper, const Fst<Arc> *fst = nullptr)
  397. : fst_(fst ? *fst : mapper.fst_), i_(0) {}
  398. StateId Start() const { return fst_.Start(); }
  399. Weight Final(StateId state) const { return fst_.Final(state); }
  400. void SetState(StateId state) {
  401. i_ = 0;
  402. arcs_.clear();
  403. arcs_.reserve(fst_.NumArcs(state));
  404. for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
  405. aiter.Next()) {
  406. arcs_.push_back(aiter.Value());
  407. }
  408. // First sorts the exiting arcs by input label, output label and destination
  409. // state and then sums weights of arcs with the same input label, output
  410. // label, and destination state.
  411. std::sort(arcs_.begin(), arcs_.end(), comp_);
  412. size_t narcs = 0;
  413. for (const auto &arc : arcs_) {
  414. if (narcs > 0 && equal_(arc, arcs_[narcs - 1])) {
  415. arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, arc.weight);
  416. } else {
  417. arcs_[narcs] = arc;
  418. ++narcs;
  419. }
  420. }
  421. arcs_.resize(narcs);
  422. }
  423. bool Done() const { return i_ >= arcs_.size(); }
  424. const Arc &Value() const { return arcs_[i_]; }
  425. void Next() { ++i_; }
  426. constexpr MapSymbolsAction InputSymbolsAction() const {
  427. return MAP_COPY_SYMBOLS;
  428. }
  429. constexpr MapSymbolsAction OutputSymbolsAction() const {
  430. return MAP_COPY_SYMBOLS;
  431. }
  432. uint64_t Properties(uint64_t props) const {
  433. return props & kArcSortProperties & kDeleteArcsProperties &
  434. kWeightInvariantProperties;
  435. }
  436. private:
  437. struct Compare {
  438. bool operator()(const Arc &x, const Arc &y) const {
  439. if (x.ilabel < y.ilabel) return true;
  440. if (x.ilabel > y.ilabel) return false;
  441. if (x.olabel < y.olabel) return true;
  442. if (x.olabel > y.olabel) return false;
  443. if (x.nextstate < y.nextstate) return true;
  444. if (x.nextstate > y.nextstate) return false;
  445. return false;
  446. }
  447. };
  448. struct Equal {
  449. bool operator()(const Arc &x, const Arc &y) const {
  450. return (x.ilabel == y.ilabel && x.olabel == y.olabel &&
  451. x.nextstate == y.nextstate);
  452. }
  453. };
  454. const Fst<Arc> &fst_;
  455. Compare comp_;
  456. Equal equal_;
  457. std::vector<Arc> arcs_;
  458. ssize_t i_; // Current arc position.
  459. ArcSumMapper &operator=(const ArcSumMapper &) = delete;
  460. };
  461. template <class Arc>
  462. class ArcUniqueMapper {
  463. public:
  464. using FromArc = Arc;
  465. using ToArc = Arc;
  466. using StateId = typename Arc::StateId;
  467. using Weight = typename Arc::Weight;
  468. explicit ArcUniqueMapper(const Fst<Arc> &fst) : fst_(fst), i_(0) {}
  469. // Allows updating FST argument; pass only if changed.
  470. ArcUniqueMapper(const ArcUniqueMapper<Arc> &mapper,
  471. const Fst<Arc> *fst = nullptr)
  472. : fst_(fst ? *fst : mapper.fst_), i_(0) {}
  473. StateId Start() const { return fst_.Start(); }
  474. Weight Final(StateId state) const { return fst_.Final(state); }
  475. void SetState(StateId state) {
  476. i_ = 0;
  477. arcs_.clear();
  478. arcs_.reserve(fst_.NumArcs(state));
  479. for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
  480. aiter.Next()) {
  481. arcs_.push_back(aiter.Value());
  482. }
  483. // First sorts the exiting arcs by input label, output label and destination
  484. // state and then uniques identical arcs.
  485. std::sort(arcs_.begin(), arcs_.end(), comp_);
  486. arcs_.erase(std::unique(arcs_.begin(), arcs_.end(), equal_), arcs_.end());
  487. }
  488. bool Done() const { return i_ >= arcs_.size(); }
  489. const Arc &Value() const { return arcs_[i_]; }
  490. void Next() { ++i_; }
  491. constexpr MapSymbolsAction InputSymbolsAction() const {
  492. return MAP_COPY_SYMBOLS;
  493. }
  494. constexpr MapSymbolsAction OutputSymbolsAction() const {
  495. return MAP_COPY_SYMBOLS;
  496. }
  497. uint64_t Properties(uint64_t props) const {
  498. return props & kArcSortProperties & kDeleteArcsProperties;
  499. }
  500. private:
  501. struct Compare {
  502. bool operator()(const Arc &x, const Arc &y) const {
  503. if (x.ilabel < y.ilabel) return true;
  504. if (x.ilabel > y.ilabel) return false;
  505. if (x.olabel < y.olabel) return true;
  506. if (x.olabel > y.olabel) return false;
  507. if (x.nextstate < y.nextstate) return true;
  508. if (x.nextstate > y.nextstate) return false;
  509. return false;
  510. }
  511. };
  512. struct Equal {
  513. bool operator()(const Arc &x, const Arc &y) const {
  514. return (x.ilabel == y.ilabel && x.olabel == y.olabel &&
  515. x.nextstate == y.nextstate && x.weight == y.weight);
  516. }
  517. };
  518. const Fst<Arc> &fst_;
  519. Compare comp_;
  520. Equal equal_;
  521. std::vector<Arc> arcs_;
  522. size_t i_; // Current arc position.
  523. ArcUniqueMapper &operator=(const ArcUniqueMapper &) = delete;
  524. };
  525. // Useful aliases when using StdArc.
  526. using StdArcSumMapper = ArcSumMapper<StdArc>;
  527. using StdArcUniqueMapper = ArcUniqueMapper<StdArc>;
  528. } // namespace fst
  529. #endif // FST_STATE_MAP_H_