You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1322 lines
40 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to map over/transform arcs e.g., change semirings or
  19. // implement project/invert. Consider using when operation does
  20. // not change the number of arcs (except possibly superfinal arcs).
  21. #ifndef FST_ARC_MAP_H_
  22. #define FST_ARC_MAP_H_
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <memory>
  26. #include <string>
  27. #include <type_traits>
  28. #include <utility>
  29. #include <fst/log.h>
  30. #include <fst/arc.h>
  31. #include <fst/cache.h>
  32. #include <fst/expanded-fst.h>
  33. #include <fst/float-weight.h>
  34. #include <fst/fst.h>
  35. #include <fst/impl-to-fst.h>
  36. #include <fst/mutable-fst.h>
  37. #include <fst/properties.h>
  38. #include <fst/string-weight.h>
  39. #include <fst/symbol-table.h>
  40. #include <fst/util.h>
  41. #include <fst/weight.h>
  42. #include <unordered_map>
  43. namespace fst {
  44. // Determines how final weights are mapped.
  45. enum MapFinalAction {
  46. // A final weight is mapped into a final weight. An error is raised if this
  47. // is not possible.
  48. MAP_NO_SUPERFINAL,
  49. // A final weight is mapped to an arc to the superfinal state when the result
  50. // cannot be represented as a final weight. The superfinal state will be
  51. // added only if it is needed.
  52. MAP_ALLOW_SUPERFINAL,
  53. // A final weight is mapped to an arc to the superfinal state unless the
  54. // result can be represented as a final weight of weight Zero(). The
  55. // superfinal state is always added (if the input is not the empty FST).
  56. MAP_REQUIRE_SUPERFINAL
  57. };
  58. // Determines how symbol tables are mapped.
  59. enum MapSymbolsAction {
  60. // Symbols should be cleared in the result by the map.
  61. MAP_CLEAR_SYMBOLS,
  62. // Symbols should be copied from the input FST by the map.
  63. MAP_COPY_SYMBOLS,
  64. // Symbols should not be modified in the result by the map itself.
  65. // (They may set by the mapper).
  66. MAP_NOOP_SYMBOLS
  67. };
  68. // The ArcMapper interfaces defines how arcs and final weights are mapped.
  69. // This is useful for implementing operations that apply to each arc separately
  70. // and do not change the number of arcs (except possibly superfinal arcs).
  71. //
  72. // template <class A, class B>
  73. // class ArcMapper {
  74. // public:
  75. // using FromArc = A;
  76. // using ToArc = B;
  77. //
  78. // // Maps an arc type FromArc to arc type ToArc.
  79. // ToArc operator()(const FromArc &arc);
  80. //
  81. // // Specifies final action the mapper requires (see above).
  82. // // The mapper will be passed final weights as arcs of the form
  83. // // Arc(0, 0, weight, kNoStateId).
  84. // MapFinalAction FinalAction() const;
  85. //
  86. // // Specifies input symbol table action the mapper requires (see above).
  87. // MapSymbolsAction InputSymbolsAction() const;
  88. //
  89. // // Specifies output symbol table action the mapper requires (see above).
  90. // MapSymbolsAction OutputSymbolsAction() const;
  91. //
  92. // // This specifies the known properties of an FST mapped by this mapper. It
  93. // takes as argument the input FSTs's known properties.
  94. // uint64_t Properties(uint64_t props) const;
  95. // };
  96. //
  97. // The ArcMap functions and classes below will use the FinalAction()
  98. // method of the mapper to determine how to treat final weights, e.g., whether
  99. // to add a superfinal state. They will use the Properties() method to set the
  100. // result FST properties.
  101. //
  102. // We include a various map versions below. One dimension of variation is
  103. // whether the mapping mutates its input, writes to a new result FST, or is an
  104. // on-the-fly FST. Another dimension is how we pass the mapper. We allow passing
  105. // the mapper by pointer for cases that we need to change the state of the
  106. // user's mapper. This is the case with the EncodeMapper, which is reused
  107. // during decoding. We also include map versions that pass the mapper by value
  108. // or const reference when this suffices.
  109. // Maps an arc type A using a mapper function object C, passed
  110. // by pointer. This version modifies its Fst input.
  111. template <class A, class C>
  112. void ArcMap(MutableFst<A> *fst, C *mapper) {
  113. using FromArc = A;
  114. using ToArc = A;
  115. using Weight = typename FromArc::Weight;
  116. if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  117. fst->SetInputSymbols(nullptr);
  118. }
  119. if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  120. fst->SetOutputSymbols(nullptr);
  121. }
  122. if (fst->Start() == kNoStateId) return;
  123. const auto props = fst->Properties(kFstProperties, false);
  124. const auto final_action = mapper->FinalAction();
  125. auto superfinal = kNoStateId;
  126. if (final_action == MAP_REQUIRE_SUPERFINAL) {
  127. superfinal = fst->AddState();
  128. fst->SetFinal(superfinal);
  129. }
  130. for (StateIterator<MutableFst<FromArc>> siter(*fst); !siter.Done();
  131. siter.Next()) {
  132. const auto state = siter.Value();
  133. for (MutableArcIterator<MutableFst<FromArc>> aiter(fst, state);
  134. !aiter.Done(); aiter.Next()) {
  135. const auto &arc = aiter.Value();
  136. aiter.SetValue((*mapper)(arc));
  137. }
  138. switch (final_action) {
  139. case MAP_NO_SUPERFINAL:
  140. default: {
  141. const FromArc arc(0, 0, fst->Final(state), kNoStateId);
  142. const auto final_arc = (*mapper)(arc);
  143. if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
  144. FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc";
  145. fst->SetProperties(kError, kError);
  146. }
  147. fst->SetFinal(state, final_arc.weight);
  148. break;
  149. }
  150. case MAP_ALLOW_SUPERFINAL: {
  151. if (state != superfinal) {
  152. const FromArc arc(0, 0, fst->Final(state), kNoStateId);
  153. auto final_arc = (*mapper)(arc);
  154. if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
  155. // Add a superfinal state if not already done.
  156. if (superfinal == kNoStateId) {
  157. superfinal = fst->AddState();
  158. fst->SetFinal(superfinal);
  159. }
  160. final_arc.nextstate = superfinal;
  161. fst->AddArc(state, std::move(final_arc));
  162. fst->SetFinal(state, Weight::Zero());
  163. } else {
  164. fst->SetFinal(state, final_arc.weight);
  165. }
  166. }
  167. break;
  168. }
  169. case MAP_REQUIRE_SUPERFINAL: {
  170. if (state != superfinal) {
  171. const FromArc arc(0, 0, fst->Final(state), kNoStateId);
  172. const auto final_arc = (*mapper)(arc);
  173. if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
  174. final_arc.weight != Weight::Zero()) {
  175. fst->AddArc(state, ToArc(final_arc.ilabel, final_arc.olabel,
  176. final_arc.weight, superfinal));
  177. }
  178. fst->SetFinal(state, Weight::Zero());
  179. }
  180. break;
  181. }
  182. }
  183. }
  184. fst->SetProperties(mapper->Properties(props), kFstProperties);
  185. }
  186. // Maps an arc type A using a mapper function object C, passed by value. This
  187. // version modifies its FST input.
  188. template <class A, class C>
  189. void ArcMap(MutableFst<A> *fst, C mapper) {
  190. ArcMap(fst, &mapper);
  191. }
  192. // Maps an arc type A to an arc type B using mapper function object C,
  193. // passed by pointer. This version writes the mapped input FST to an
  194. // output MutableFst.
  195. template <class A, class B, class C>
  196. void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C *mapper) {
  197. using FromArc = A;
  198. using StateId = typename FromArc::StateId;
  199. ofst->DeleteStates();
  200. if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
  201. ofst->SetInputSymbols(ifst.InputSymbols());
  202. } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  203. ofst->SetInputSymbols(nullptr);
  204. }
  205. if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
  206. ofst->SetOutputSymbols(ifst.OutputSymbols());
  207. } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  208. ofst->SetOutputSymbols(nullptr);
  209. }
  210. const auto iprops = ifst.Properties(kCopyProperties, false);
  211. if (ifst.Start() == kNoStateId) {
  212. if (iprops & kError) ofst->SetProperties(kError, kError);
  213. return;
  214. }
  215. const auto final_action = mapper->FinalAction();
  216. if (std::optional<StateId> num_states = ifst.NumStatesIfKnown()) {
  217. ofst->ReserveStates(*num_states +
  218. (final_action == MAP_NO_SUPERFINAL ? 0 : 1));
  219. }
  220. // Adds all states.
  221. for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
  222. ofst->AddState();
  223. }
  224. StateId superfinal = kNoStateId;
  225. if (final_action == MAP_REQUIRE_SUPERFINAL) {
  226. superfinal = ofst->AddState();
  227. ofst->SetFinal(superfinal);
  228. }
  229. for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
  230. StateId s = siter.Value();
  231. if (s == ifst.Start()) ofst->SetStart(s);
  232. ofst->ReserveArcs(
  233. s, ifst.NumArcs(s) + (final_action != MAP_NO_SUPERFINAL ? 1 : 0));
  234. for (ArcIterator<Fst<A>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
  235. ofst->AddArc(s, (*mapper)(aiter.Value()));
  236. }
  237. switch (final_action) {
  238. case MAP_NO_SUPERFINAL:
  239. default: {
  240. B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
  241. if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
  242. FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc";
  243. ofst->SetProperties(kError, kError);
  244. }
  245. ofst->SetFinal(s, final_arc.weight);
  246. break;
  247. }
  248. case MAP_ALLOW_SUPERFINAL: {
  249. B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
  250. if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
  251. // Add a superfinal state if not already done.
  252. if (superfinal == kNoStateId) {
  253. superfinal = ofst->AddState();
  254. ofst->SetFinal(superfinal);
  255. }
  256. final_arc.nextstate = superfinal;
  257. ofst->AddArc(s, std::move(final_arc));
  258. ofst->SetFinal(s, B::Weight::Zero());
  259. } else {
  260. ofst->SetFinal(s, final_arc.weight);
  261. }
  262. break;
  263. }
  264. case MAP_REQUIRE_SUPERFINAL: {
  265. B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
  266. if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
  267. final_arc.weight != B::Weight::Zero()) {
  268. ofst->AddArc(s, B(final_arc.ilabel, final_arc.olabel,
  269. final_arc.weight, superfinal));
  270. }
  271. ofst->SetFinal(s, B::Weight::Zero());
  272. break;
  273. }
  274. }
  275. }
  276. const auto oprops = ofst->Properties(kFstProperties, false);
  277. ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
  278. }
  279. // Maps an arc type A to an arc type B using mapper function
  280. // object C, passed by value. This version writes the mapped input
  281. // Fst to an output MutableFst.
  282. template <class A, class B, class C>
  283. void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
  284. ArcMap(ifst, ofst, &mapper);
  285. }
  286. struct ArcMapFstOptions : public CacheOptions {
  287. // ArcMapFst default caching behaviour is to do no caching. Most mappers are
  288. // cheap and therefore we save memory by not doing caching.
  289. ArcMapFstOptions() : CacheOptions(true, 0) {}
  290. explicit ArcMapFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
  291. };
  292. template <class A, class B, class C>
  293. class ArcMapFst;
  294. namespace internal {
  295. // Implementation of delayed ArcMapFst.
  296. template <class A, class B, class C>
  297. class ArcMapFstImpl : public CacheImpl<B> {
  298. public:
  299. using Arc = B;
  300. using StateId = typename Arc::StateId;
  301. using Weight = typename Arc::Weight;
  302. using FstImpl<B>::SetType;
  303. using FstImpl<B>::SetProperties;
  304. using FstImpl<B>::SetInputSymbols;
  305. using FstImpl<B>::SetOutputSymbols;
  306. using CacheImpl<B>::EmplaceArc;
  307. using CacheImpl<B>::HasArcs;
  308. using CacheImpl<B>::HasFinal;
  309. using CacheImpl<B>::HasStart;
  310. using CacheImpl<B>::PushArc;
  311. using CacheImpl<B>::SetArcs;
  312. using CacheImpl<B>::SetFinal;
  313. using CacheImpl<B>::SetStart;
  314. friend class StateIterator<ArcMapFst<A, B, C>>;
  315. ArcMapFstImpl(const Fst<A> &fst, const C &mapper,
  316. const ArcMapFstOptions &opts)
  317. : CacheImpl<B>(opts),
  318. fst_(fst.Copy()),
  319. mapper_(new C(mapper)),
  320. own_mapper_(true),
  321. superfinal_(kNoStateId),
  322. nstates_(0) {
  323. Init();
  324. }
  325. ArcMapFstImpl(const Fst<A> &fst, C *mapper, const ArcMapFstOptions &opts)
  326. : CacheImpl<B>(opts),
  327. fst_(fst.Copy()),
  328. mapper_(mapper),
  329. own_mapper_(false),
  330. superfinal_(kNoStateId),
  331. nstates_(0) {
  332. Init();
  333. }
  334. ArcMapFstImpl(const ArcMapFstImpl<A, B, C> &impl)
  335. : CacheImpl<B>(impl),
  336. fst_(impl.fst_->Copy(true)),
  337. mapper_(new C(*impl.mapper_)),
  338. own_mapper_(true),
  339. superfinal_(kNoStateId),
  340. nstates_(0) {
  341. Init();
  342. }
  343. ~ArcMapFstImpl() override {
  344. if (own_mapper_) delete mapper_;
  345. }
  346. StateId Start() {
  347. if (!HasStart()) SetStart(FindOState(fst_->Start()));
  348. return CacheImpl<B>::Start();
  349. }
  350. Weight Final(StateId s) {
  351. if (!HasFinal(s)) {
  352. switch (final_action_) {
  353. case MAP_NO_SUPERFINAL:
  354. default: {
  355. const auto final_arc =
  356. (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
  357. if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
  358. FSTERROR() << "ArcMapFst: Non-zero arc labels for superfinal arc";
  359. SetProperties(kError, kError);
  360. }
  361. SetFinal(s, final_arc.weight);
  362. break;
  363. }
  364. case MAP_ALLOW_SUPERFINAL: {
  365. if (s == superfinal_) {
  366. SetFinal(s);
  367. } else {
  368. const auto final_arc =
  369. (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
  370. if (final_arc.ilabel == 0 && final_arc.olabel == 0) {
  371. SetFinal(s, final_arc.weight);
  372. } else {
  373. SetFinal(s, Weight::Zero());
  374. }
  375. }
  376. break;
  377. }
  378. case MAP_REQUIRE_SUPERFINAL: {
  379. SetFinal(s, s == superfinal_ ? Weight::One() : Weight::Zero());
  380. break;
  381. }
  382. }
  383. }
  384. return CacheImpl<B>::Final(s);
  385. }
  386. size_t NumArcs(StateId s) {
  387. if (!HasArcs(s)) Expand(s);
  388. return CacheImpl<B>::NumArcs(s);
  389. }
  390. size_t NumInputEpsilons(StateId s) {
  391. if (!HasArcs(s)) Expand(s);
  392. return CacheImpl<B>::NumInputEpsilons(s);
  393. }
  394. size_t NumOutputEpsilons(StateId s) {
  395. if (!HasArcs(s)) Expand(s);
  396. return CacheImpl<B>::NumOutputEpsilons(s);
  397. }
  398. uint64_t Properties() const override { return Properties(kFstProperties); }
  399. // Sets error if found, and returns other FST impl properties.
  400. uint64_t Properties(uint64_t mask) const override {
  401. if ((mask & kError) && (fst_->Properties(kError, false) ||
  402. (mapper_->Properties(0) & kError))) {
  403. SetProperties(kError, kError);
  404. }
  405. return FstImpl<Arc>::Properties(mask);
  406. }
  407. void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
  408. if (!HasArcs(s)) Expand(s);
  409. CacheImpl<B>::InitArcIterator(s, data);
  410. }
  411. void Expand(StateId s) {
  412. // Add exiting arcs.
  413. if (s == superfinal_) {
  414. SetArcs(s);
  415. return;
  416. }
  417. for (ArcIterator<Fst<A>> aiter(*fst_, FindIState(s)); !aiter.Done();
  418. aiter.Next()) {
  419. auto aarc = aiter.Value();
  420. aarc.nextstate = FindOState(aarc.nextstate);
  421. PushArc(s, (*mapper_)(aarc));
  422. }
  423. // Check for superfinal arcs.
  424. if (!HasFinal(s) || Final(s) == Weight::Zero()) {
  425. switch (final_action_) {
  426. case MAP_NO_SUPERFINAL:
  427. default:
  428. break;
  429. case MAP_ALLOW_SUPERFINAL: {
  430. auto final_arc =
  431. (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
  432. if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
  433. if (superfinal_ == kNoStateId) superfinal_ = nstates_++;
  434. final_arc.nextstate = superfinal_;
  435. PushArc(s, std::move(final_arc));
  436. }
  437. break;
  438. }
  439. case MAP_REQUIRE_SUPERFINAL: {
  440. const auto final_arc =
  441. (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
  442. if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
  443. final_arc.weight != B::Weight::Zero()) {
  444. EmplaceArc(s, final_arc.ilabel, final_arc.olabel, final_arc.weight,
  445. superfinal_);
  446. }
  447. break;
  448. }
  449. }
  450. }
  451. SetArcs(s);
  452. }
  453. private:
  454. void Init() {
  455. SetType("map");
  456. if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
  457. SetInputSymbols(fst_->InputSymbols());
  458. } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  459. SetInputSymbols(nullptr);
  460. }
  461. if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
  462. SetOutputSymbols(fst_->OutputSymbols());
  463. } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
  464. SetOutputSymbols(nullptr);
  465. }
  466. if (fst_->Start() == kNoStateId) {
  467. final_action_ = MAP_NO_SUPERFINAL;
  468. SetProperties(kNullProperties);
  469. } else {
  470. final_action_ = mapper_->FinalAction();
  471. uint64_t props = fst_->Properties(kCopyProperties, false);
  472. SetProperties(mapper_->Properties(props));
  473. if (final_action_ == MAP_REQUIRE_SUPERFINAL) superfinal_ = 0;
  474. }
  475. }
  476. // Maps from output state to input state.
  477. StateId FindIState(StateId s) {
  478. if (superfinal_ == kNoStateId || s < superfinal_) {
  479. return s;
  480. } else {
  481. return s - 1;
  482. }
  483. }
  484. // Maps from input state to output state.
  485. StateId FindOState(StateId is) {
  486. auto os = is;
  487. if (!(superfinal_ == kNoStateId || is < superfinal_)) ++os;
  488. if (os >= nstates_) nstates_ = os + 1;
  489. return os;
  490. }
  491. std::unique_ptr<const Fst<A>> fst_;
  492. C *mapper_;
  493. const bool own_mapper_;
  494. MapFinalAction final_action_;
  495. StateId superfinal_;
  496. StateId nstates_;
  497. };
  498. } // namespace internal
  499. // Maps an arc type A to an arc type B using Mapper function object
  500. // C. This version is a delayed FST.
  501. template <class A, class B, class C>
  502. class ArcMapFst : public ImplToFst<internal::ArcMapFstImpl<A, B, C>> {
  503. public:
  504. using Arc = B;
  505. using StateId = typename Arc::StateId;
  506. using Weight = typename Arc::Weight;
  507. using Store = DefaultCacheStore<B>;
  508. using State = typename Store::State;
  509. using Impl = internal::ArcMapFstImpl<A, B, C>;
  510. friend class ArcIterator<ArcMapFst<A, B, C>>;
  511. friend class StateIterator<ArcMapFst<A, B, C>>;
  512. explicit ArcMapFst(const Fst<A> &fst, const C &mapper = C(),
  513. const ArcMapFstOptions &opts = ArcMapFstOptions())
  514. : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
  515. ArcMapFst(const Fst<A> &fst, C *mapper,
  516. const ArcMapFstOptions &opts = ArcMapFstOptions())
  517. : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
  518. // See Fst<>::Copy() for doc.
  519. ArcMapFst(const ArcMapFst &fst, bool safe = false)
  520. : ImplToFst<Impl>(fst, safe) {}
  521. // Get a copy of this ArcMapFst. See Fst<>::Copy() for further doc.
  522. ArcMapFst *Copy(bool safe = false) const override {
  523. return new ArcMapFst(*this, safe);
  524. }
  525. inline void InitStateIterator(StateIteratorData<B> *data) const override;
  526. void InitArcIterator(StateId s, ArcIteratorData<B> *data) const override {
  527. GetMutableImpl()->InitArcIterator(s, data);
  528. }
  529. protected:
  530. using ImplToFst<Impl>::GetImpl;
  531. using ImplToFst<Impl>::GetMutableImpl;
  532. private:
  533. ArcMapFst &operator=(const ArcMapFst &) = delete;
  534. };
  535. // Specialization for ArcMapFst.
  536. //
  537. // This may be derived from.
  538. template <class A, class B, class C>
  539. class StateIterator<ArcMapFst<A, B, C>> : public StateIteratorBase<B> {
  540. public:
  541. using StateId = typename B::StateId;
  542. explicit StateIterator(const ArcMapFst<A, B, C> &fst)
  543. : impl_(fst.GetImpl()),
  544. siter_(*impl_->fst_),
  545. s_(0),
  546. superfinal_(impl_->final_action_ == MAP_REQUIRE_SUPERFINAL) {
  547. CheckSuperfinal();
  548. }
  549. bool Done() const final { return siter_.Done() && !superfinal_; }
  550. StateId Value() const final { return s_; }
  551. void Next() final {
  552. ++s_;
  553. if (!siter_.Done()) {
  554. siter_.Next();
  555. CheckSuperfinal();
  556. } else if (superfinal_) {
  557. superfinal_ = false;
  558. }
  559. }
  560. void Reset() final {
  561. s_ = 0;
  562. siter_.Reset();
  563. superfinal_ = impl_->final_action_ == MAP_REQUIRE_SUPERFINAL;
  564. CheckSuperfinal();
  565. }
  566. private:
  567. void CheckSuperfinal() {
  568. if (impl_->final_action_ != MAP_ALLOW_SUPERFINAL || superfinal_) return;
  569. if (!siter_.Done()) {
  570. const auto final_arc =
  571. (*impl_->mapper_)(A(0, 0, impl_->fst_->Final(s_), kNoStateId));
  572. if (final_arc.ilabel != 0 || final_arc.olabel != 0) superfinal_ = true;
  573. }
  574. }
  575. const internal::ArcMapFstImpl<A, B, C> *impl_;
  576. StateIterator<Fst<A>> siter_;
  577. StateId s_;
  578. bool superfinal_; // True if there is a superfinal state and not done.
  579. };
  580. // Specialization for ArcMapFst.
  581. template <class A, class B, class C>
  582. class ArcIterator<ArcMapFst<A, B, C>>
  583. : public CacheArcIterator<ArcMapFst<A, B, C>> {
  584. public:
  585. using StateId = typename A::StateId;
  586. ArcIterator(const ArcMapFst<A, B, C> &fst, StateId s)
  587. : CacheArcIterator<ArcMapFst<A, B, C>>(fst.GetMutableImpl(), s) {
  588. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  589. }
  590. };
  591. template <class A, class B, class C>
  592. inline void ArcMapFst<A, B, C>::InitStateIterator(
  593. StateIteratorData<B> *data) const {
  594. data->base = std::make_unique<StateIterator<ArcMapFst<A, B, C>>>(*this);
  595. }
  596. // CTAD deduction guides
  597. // This allows constructing ArcMapFsts without specifying all the types.
  598. template <class ArcMapper>
  599. ArcMapFst(const Fst<typename ArcMapper::FromArc> &, const ArcMapper &)
  600. -> ArcMapFst<typename ArcMapper::FromArc, typename ArcMapper::ToArc,
  601. ArcMapper>;
  602. // As above, but using the ArcMapFst(..., ArcMapper *) constructor.
  603. template <class ArcMapper>
  604. ArcMapFst(const Fst<typename ArcMapper::FromArc> &, ArcMapper *)
  605. -> ArcMapFst<typename ArcMapper::FromArc, typename ArcMapper::ToArc,
  606. ArcMapper>;
  607. // Utility Mappers.
  608. // Mapper that returns its input.
  609. template <class A>
  610. class IdentityArcMapper {
  611. public:
  612. using FromArc = A;
  613. using ToArc = A;
  614. constexpr ToArc operator()(const FromArc &arc) const { return arc; }
  615. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  616. constexpr MapSymbolsAction InputSymbolsAction() const {
  617. return MAP_COPY_SYMBOLS;
  618. }
  619. constexpr MapSymbolsAction OutputSymbolsAction() const {
  620. return MAP_COPY_SYMBOLS;
  621. }
  622. constexpr uint64_t Properties(uint64_t props) const { return props; }
  623. };
  624. // Mapper that converts all input symbols to epsilon.
  625. template <class A>
  626. class InputEpsilonMapper {
  627. public:
  628. using FromArc = A;
  629. using ToArc = A;
  630. constexpr ToArc operator()(const FromArc &arc) const {
  631. return ToArc(0, arc.olabel, arc.weight, arc.nextstate);
  632. }
  633. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  634. constexpr MapSymbolsAction InputSymbolsAction() const {
  635. return MAP_CLEAR_SYMBOLS;
  636. }
  637. constexpr MapSymbolsAction OutputSymbolsAction() const {
  638. return MAP_COPY_SYMBOLS;
  639. }
  640. constexpr uint64_t Properties(uint64_t props) const {
  641. return (props & kSetArcProperties) | kIEpsilons | kILabelSorted;
  642. }
  643. };
  644. // Mapper that converts all output symbols to epsilon.
  645. template <class A>
  646. class OutputEpsilonMapper {
  647. public:
  648. using FromArc = A;
  649. using ToArc = A;
  650. constexpr ToArc operator()(const FromArc &arc) const {
  651. return ToArc(arc.ilabel, 0, arc.weight, arc.nextstate);
  652. }
  653. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  654. constexpr MapSymbolsAction InputSymbolsAction() const {
  655. return MAP_COPY_SYMBOLS;
  656. }
  657. constexpr MapSymbolsAction OutputSymbolsAction() const {
  658. return MAP_CLEAR_SYMBOLS;
  659. }
  660. constexpr uint64_t Properties(uint64_t props) const {
  661. return (props & kSetArcProperties) | kOEpsilons | kOLabelSorted;
  662. }
  663. };
  664. // Mapper that returns its input with final states redirected to a single
  665. // super-final state.
  666. template <class A>
  667. class SuperFinalMapper {
  668. public:
  669. using FromArc = A;
  670. using ToArc = A;
  671. using Label = typename FromArc::Label;
  672. using Weight = typename FromArc::Weight;
  673. // Arg allows setting super-final label.
  674. constexpr explicit SuperFinalMapper(Label final_label = 0)
  675. : final_label_(final_label) {}
  676. ToArc operator()(const FromArc &arc) const {
  677. // Super-final arc.
  678. if (arc.nextstate == kNoStateId && arc.weight != Weight::Zero()) {
  679. return ToArc(final_label_, final_label_, arc.weight, kNoStateId);
  680. } else {
  681. return arc;
  682. }
  683. }
  684. constexpr MapFinalAction FinalAction() const {
  685. return MAP_REQUIRE_SUPERFINAL;
  686. }
  687. constexpr MapSymbolsAction InputSymbolsAction() const {
  688. return MAP_COPY_SYMBOLS;
  689. }
  690. constexpr MapSymbolsAction OutputSymbolsAction() const {
  691. return MAP_COPY_SYMBOLS;
  692. }
  693. uint64_t Properties(uint64_t props) const {
  694. if (final_label_ == 0) {
  695. return props & kAddSuperFinalProperties;
  696. } else {
  697. return props & kAddSuperFinalProperties & kILabelInvariantProperties &
  698. kOLabelInvariantProperties;
  699. }
  700. }
  701. private:
  702. const Label final_label_;
  703. };
  704. // Mapper that leaves labels and nextstate unchanged and constructs a new weight
  705. // from the underlying value of the arc weight. If no weight converter is
  706. // explictly specified, requires that there is a WeightConvert class
  707. // specialization that converts the weights.
  708. template <class A, class B,
  709. class C = WeightConvert<typename A::Weight, typename B::Weight>>
  710. class WeightConvertMapper {
  711. public:
  712. using FromArc = A;
  713. using ToArc = B;
  714. using Converter = C;
  715. using FromWeight = typename FromArc::Weight;
  716. using ToWeight = typename ToArc::Weight;
  717. constexpr explicit WeightConvertMapper(const Converter &c = Converter())
  718. : convert_weight_(c) {}
  719. constexpr ToArc operator()(const FromArc &arc) const {
  720. return ToArc(arc.ilabel, arc.olabel, convert_weight_(arc.weight),
  721. arc.nextstate);
  722. }
  723. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  724. constexpr MapSymbolsAction InputSymbolsAction() const {
  725. return MAP_COPY_SYMBOLS;
  726. }
  727. constexpr MapSymbolsAction OutputSymbolsAction() const {
  728. return MAP_COPY_SYMBOLS;
  729. }
  730. constexpr uint64_t Properties(uint64_t props) const { return props; }
  731. private:
  732. const Converter convert_weight_;
  733. };
  734. // Non-precision-changing weight conversions; consider using more efficient
  735. // Cast method instead.
  736. using StdToLogMapper = WeightConvertMapper<StdArc, LogArc>;
  737. using LogToStdMapper = WeightConvertMapper<LogArc, StdArc>;
  738. // Precision-changing weight conversions.
  739. using StdToLog64Mapper = WeightConvertMapper<StdArc, Log64Arc>;
  740. using LogToLog64Mapper = WeightConvertMapper<LogArc, Log64Arc>;
  741. using Log64ToStdMapper = WeightConvertMapper<Log64Arc, StdArc>;
  742. using Log64ToLogMapper = WeightConvertMapper<Log64Arc, LogArc>;
  743. // Mapper from A to GallicArc<A>.
  744. template <class A, GallicType G = GALLIC_LEFT>
  745. class ToGallicMapper {
  746. public:
  747. using FromArc = A;
  748. using ToArc = GallicArc<A, G>;
  749. using SW = StringWeight<typename A::Label, GallicStringType(G)>;
  750. using AW = typename FromArc::Weight;
  751. using GW = typename ToArc::Weight;
  752. ToArc operator()(const FromArc &arc) const {
  753. // Super-final arc.
  754. if (arc.nextstate == kNoStateId && arc.weight != AW::Zero()) {
  755. return ToArc(0, 0, GW(SW::One(), arc.weight), kNoStateId);
  756. // Super-non-final arc.
  757. } else if (arc.nextstate == kNoStateId) {
  758. return ToArc(0, 0, GW::Zero(), kNoStateId);
  759. // Epsilon label.
  760. } else if (arc.olabel == 0) {
  761. return ToArc(arc.ilabel, arc.ilabel, GW(SW::One(), arc.weight),
  762. arc.nextstate);
  763. // Regular label.
  764. } else {
  765. return ToArc(arc.ilabel, arc.ilabel, GW(SW(arc.olabel), arc.weight),
  766. arc.nextstate);
  767. }
  768. }
  769. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  770. constexpr MapSymbolsAction InputSymbolsAction() const {
  771. return MAP_COPY_SYMBOLS;
  772. }
  773. constexpr MapSymbolsAction OutputSymbolsAction() const {
  774. return MAP_CLEAR_SYMBOLS;
  775. }
  776. uint64_t Properties(uint64_t props) const {
  777. return ProjectProperties(props, true) & kWeightInvariantProperties;
  778. }
  779. };
  780. // Mapper from GallicArc<A> to A.
  781. template <class A, GallicType G = GALLIC_LEFT>
  782. class FromGallicMapper {
  783. public:
  784. using FromArc = GallicArc<A, G>;
  785. using ToArc = A;
  786. using Label = typename ToArc::Label;
  787. using AW = typename ToArc::Weight;
  788. using GW = typename FromArc::Weight;
  789. explicit FromGallicMapper(Label superfinal_label = 0)
  790. : superfinal_label_(superfinal_label), error_(false) {}
  791. ToArc operator()(const FromArc &arc) const {
  792. // 'Super-non-final' arc.
  793. if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) {
  794. return A(arc.ilabel, 0, AW::Zero(), kNoStateId);
  795. }
  796. Label l = kNoLabel;
  797. AW weight = AW::Zero();
  798. if (!Extract(arc.weight, &weight, &l) || arc.ilabel != arc.olabel) {
  799. FSTERROR() << "FromGallicMapper: Unrepresentable weight: " << arc.weight
  800. << " for arc with ilabel = " << arc.ilabel
  801. << ", olabel = " << arc.olabel
  802. << ", nextstate = " << arc.nextstate;
  803. error_ = true;
  804. }
  805. if (arc.ilabel == 0 && l != 0 && arc.nextstate == kNoStateId) {
  806. return ToArc(superfinal_label_, l, weight, arc.nextstate);
  807. } else {
  808. return ToArc(arc.ilabel, l, weight, arc.nextstate);
  809. }
  810. }
  811. constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
  812. constexpr MapSymbolsAction InputSymbolsAction() const {
  813. return MAP_COPY_SYMBOLS;
  814. }
  815. constexpr MapSymbolsAction OutputSymbolsAction() const {
  816. return MAP_CLEAR_SYMBOLS;
  817. }
  818. uint64_t Properties(uint64_t inprops) const {
  819. uint64_t outprops = inprops & kOLabelInvariantProperties &
  820. kWeightInvariantProperties & kAddSuperFinalProperties;
  821. if (error_) outprops |= kError;
  822. return outprops;
  823. }
  824. private:
  825. template <GallicType GT>
  826. static bool Extract(const GallicWeight<Label, AW, GT> &gallic_weight,
  827. typename A::Weight *weight, typename A::Label *label) {
  828. using GW = StringWeight<Label, GallicStringType(GT)>;
  829. const GW &w1 = gallic_weight.Value1();
  830. const AW &w2 = gallic_weight.Value2();
  831. typename GW::Iterator iter1(w1);
  832. const Label l = w1.Size() == 1 ? iter1.Value() : 0;
  833. if (l == kStringInfinity || l == kStringBad || w1.Size() > 1) return false;
  834. *label = l;
  835. *weight = w2;
  836. return true;
  837. }
  838. static bool Extract(const GallicWeight<Label, AW, GALLIC> &gallic_weight,
  839. typename A::Weight *weight, typename A::Label *label) {
  840. if (gallic_weight.Size() > 1) return false;
  841. if (gallic_weight.Size() == 0) {
  842. *label = 0;
  843. *weight = A::Weight::Zero();
  844. return true;
  845. }
  846. return Extract<GALLIC_RESTRICT>(gallic_weight.Back(), weight, label);
  847. }
  848. const Label superfinal_label_;
  849. mutable bool error_;
  850. };
  851. // Mapper from GallicArc<A> to A.
  852. template <class A, GallicType G = GALLIC_LEFT>
  853. class GallicToNewSymbolsMapper {
  854. public:
  855. using FromArc = GallicArc<A, G>;
  856. using ToArc = A;
  857. using Label = typename ToArc::Label;
  858. using StateId = typename ToArc::StateId;
  859. using AW = typename ToArc::Weight;
  860. using GW = typename FromArc::Weight;
  861. using SW = StringWeight<Label, GallicStringType(G)>;
  862. explicit GallicToNewSymbolsMapper(MutableFst<ToArc> *fst)
  863. : fst_(fst),
  864. lmax_(0),
  865. osymbols_(fst->OutputSymbols()),
  866. isymbols_(nullptr),
  867. error_(false) {
  868. fst_->DeleteStates();
  869. state_ = fst_->AddState();
  870. fst_->SetStart(state_);
  871. fst_->SetFinal(state_);
  872. if (osymbols_) {
  873. std::string name = osymbols_->Name() + "_from_gallic";
  874. fst_->SetInputSymbols(new SymbolTable(name));
  875. isymbols_ = fst_->MutableInputSymbols();
  876. const int64_t zero = 0;
  877. isymbols_->AddSymbol(osymbols_->Find(zero), 0);
  878. } else {
  879. fst_->SetInputSymbols(nullptr);
  880. }
  881. }
  882. ToArc operator()(const FromArc &arc) {
  883. // Super-non-final arc.
  884. if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) {
  885. return ToArc(arc.ilabel, 0, AW::Zero(), kNoStateId);
  886. }
  887. SW w1 = arc.weight.Value1();
  888. AW w2 = arc.weight.Value2();
  889. Label l;
  890. if (w1.Size() == 0) {
  891. l = 0;
  892. } else if (auto [it, inserted] = map_.emplace(w1, kNoLabel); !inserted) {
  893. l = it->second;
  894. } else {
  895. l = ++lmax_;
  896. it->second = l;
  897. StringWeightIterator<SW> iter1(w1);
  898. StateId n;
  899. std::string s;
  900. for (size_t i = 0, p = state_; i < w1.Size(); ++i, iter1.Next(), p = n) {
  901. n = i == w1.Size() - 1 ? state_ : fst_->AddState();
  902. fst_->AddArc(p, ToArc(i ? 0 : l, iter1.Value(), n));
  903. if (isymbols_) {
  904. if (i) s = s + "_";
  905. s = s + osymbols_->Find(iter1.Value());
  906. }
  907. }
  908. if (isymbols_) isymbols_->AddSymbol(s, l);
  909. }
  910. if (l == kStringInfinity || l == kStringBad || arc.ilabel != arc.olabel) {
  911. FSTERROR() << "GallicToNewSymbolMapper: Unrepresentable weight: " << l;
  912. error_ = true;
  913. }
  914. return ToArc(arc.ilabel, l, w2, arc.nextstate);
  915. }
  916. constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
  917. constexpr MapSymbolsAction InputSymbolsAction() const {
  918. return MAP_COPY_SYMBOLS;
  919. }
  920. constexpr MapSymbolsAction OutputSymbolsAction() const {
  921. return MAP_CLEAR_SYMBOLS;
  922. }
  923. uint64_t Properties(uint64_t inprops) const {
  924. uint64_t outprops = inprops & kOLabelInvariantProperties &
  925. kWeightInvariantProperties & kAddSuperFinalProperties;
  926. if (error_) outprops |= kError;
  927. return outprops;
  928. }
  929. private:
  930. class StringKey {
  931. public:
  932. size_t operator()(const SW &x) const { return x.Hash(); }
  933. };
  934. using Map = std::unordered_map<SW, Label, StringKey>;
  935. MutableFst<ToArc> *fst_;
  936. Map map_;
  937. Label lmax_;
  938. StateId state_;
  939. const SymbolTable *osymbols_;
  940. SymbolTable *isymbols_;
  941. mutable bool error_;
  942. };
  943. // TODO(kbg): Add common base class for those mappers which do nothing except
  944. // mutate their weights.
  945. // Mapper to add a constant to all weights.
  946. template <class A>
  947. class PlusMapper {
  948. public:
  949. using FromArc = A;
  950. using ToArc = A;
  951. using Weight = typename FromArc::Weight;
  952. constexpr explicit PlusMapper(Weight weight) : weight_(std::move(weight)) {}
  953. ToArc operator()(const FromArc &arc) const {
  954. if (arc.weight == Weight::Zero()) return arc;
  955. return ToArc(arc.ilabel, arc.olabel, Plus(arc.weight, weight_),
  956. arc.nextstate);
  957. }
  958. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  959. constexpr MapSymbolsAction InputSymbolsAction() const {
  960. return MAP_COPY_SYMBOLS;
  961. }
  962. constexpr MapSymbolsAction OutputSymbolsAction() const {
  963. return MAP_COPY_SYMBOLS;
  964. }
  965. constexpr uint64_t Properties(uint64_t props) const {
  966. return props & kWeightInvariantProperties;
  967. }
  968. private:
  969. const Weight weight_;
  970. };
  971. // Mapper to (right) multiply a constant to all weights.
  972. template <class A>
  973. class TimesMapper {
  974. public:
  975. using FromArc = A;
  976. using ToArc = A;
  977. using Weight = typename FromArc::Weight;
  978. constexpr explicit TimesMapper(Weight weight) : weight_(std::move(weight)) {}
  979. ToArc operator()(const FromArc &arc) const {
  980. if (arc.weight == Weight::Zero()) return arc;
  981. return ToArc(arc.ilabel, arc.olabel, Times(arc.weight, weight_),
  982. arc.nextstate);
  983. }
  984. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  985. constexpr MapSymbolsAction InputSymbolsAction() const {
  986. return MAP_COPY_SYMBOLS;
  987. }
  988. constexpr MapSymbolsAction OutputSymbolsAction() const {
  989. return MAP_COPY_SYMBOLS;
  990. }
  991. constexpr uint64_t Properties(uint64_t props) const {
  992. return props & kWeightInvariantProperties;
  993. }
  994. private:
  995. const Weight weight_;
  996. };
  997. // Mapper to take all weights to a constant power. The power argument is stored
  998. // as a double, so if there is a floating-point power implementation for this
  999. // weight type, it will take precedence. Otherwise, the power argument's 53 bits
  1000. // of integer precision will be implicitly converted to a size_t and the default
  1001. // power implementation (iterated multiplication) will be used instead.
  1002. template <class A>
  1003. class PowerMapper {
  1004. public:
  1005. using FromArc = A;
  1006. using ToArc = A;
  1007. using Weight = typename FromArc::Weight;
  1008. explicit PowerMapper(double power) : power_(power) {}
  1009. ToArc operator()(const FromArc &arc) const {
  1010. return ToArc(arc.ilabel, arc.olabel, Power(arc.weight, power_),
  1011. arc.nextstate);
  1012. }
  1013. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  1014. constexpr MapSymbolsAction InputSymbolsAction() const {
  1015. return MAP_COPY_SYMBOLS;
  1016. }
  1017. constexpr MapSymbolsAction OutputSymbolsAction() const {
  1018. return MAP_COPY_SYMBOLS;
  1019. }
  1020. constexpr uint64_t Properties(uint64_t props) const {
  1021. return props & kWeightInvariantProperties;
  1022. }
  1023. private:
  1024. const double power_;
  1025. };
  1026. // Mapper to reciprocate all non-Zero() weights.
  1027. template <class A>
  1028. class InvertWeightMapper {
  1029. public:
  1030. using FromArc = A;
  1031. using ToArc = A;
  1032. using Weight = typename FromArc::Weight;
  1033. ToArc operator()(const FromArc &arc) const {
  1034. if (arc.weight == Weight::Zero()) return arc;
  1035. return ToArc(arc.ilabel, arc.olabel, Divide(Weight::One(), arc.weight),
  1036. arc.nextstate);
  1037. }
  1038. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  1039. constexpr MapSymbolsAction InputSymbolsAction() const {
  1040. return MAP_COPY_SYMBOLS;
  1041. }
  1042. constexpr MapSymbolsAction OutputSymbolsAction() const {
  1043. return MAP_COPY_SYMBOLS;
  1044. }
  1045. constexpr uint64_t Properties(uint64_t props) const {
  1046. return props & kWeightInvariantProperties;
  1047. }
  1048. };
  1049. // Mapper to map all non-Zero() weights to One().
  1050. template <class A, class B = A>
  1051. class RmWeightMapper {
  1052. public:
  1053. using FromArc = A;
  1054. using ToArc = B;
  1055. using FromWeight = typename FromArc::Weight;
  1056. using ToWeight = typename ToArc::Weight;
  1057. ToArc operator()(const FromArc &arc) const {
  1058. return ToArc(
  1059. arc.ilabel, arc.olabel,
  1060. arc.weight != FromWeight::Zero() ? ToWeight::One() : ToWeight::Zero(),
  1061. arc.nextstate);
  1062. }
  1063. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  1064. constexpr MapSymbolsAction InputSymbolsAction() const {
  1065. return MAP_COPY_SYMBOLS;
  1066. }
  1067. constexpr MapSymbolsAction OutputSymbolsAction() const {
  1068. return MAP_COPY_SYMBOLS;
  1069. }
  1070. constexpr uint64_t Properties(uint64_t props) const {
  1071. return (props & kWeightInvariantProperties) | kUnweighted;
  1072. }
  1073. };
  1074. // Mapper to quantize all weights.
  1075. template <class A, class B = A>
  1076. class QuantizeMapper {
  1077. public:
  1078. using FromArc = A;
  1079. using ToArc = B;
  1080. using FromWeight = typename FromArc::Weight;
  1081. using ToWeight = typename ToArc::Weight;
  1082. QuantizeMapper() : delta_(kDelta) {}
  1083. explicit QuantizeMapper(float d) : delta_(d) {}
  1084. ToArc operator()(const FromArc &arc) const {
  1085. return ToArc(arc.ilabel, arc.olabel, arc.weight.Quantize(delta_),
  1086. arc.nextstate);
  1087. }
  1088. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  1089. constexpr MapSymbolsAction InputSymbolsAction() const {
  1090. return MAP_COPY_SYMBOLS;
  1091. }
  1092. constexpr MapSymbolsAction OutputSymbolsAction() const {
  1093. return MAP_COPY_SYMBOLS;
  1094. }
  1095. constexpr uint64_t Properties(uint64_t props) const {
  1096. return props & kWeightInvariantProperties;
  1097. }
  1098. private:
  1099. const float delta_;
  1100. };
  1101. // Mapper from A to B under the assumption:
  1102. //
  1103. // B::Weight = A::Weight::ReverseWeight
  1104. // B::Label == A::Label
  1105. // B::StateId == A::StateId
  1106. //
  1107. // The weight is reversed, while the label and nextstate are preserved.
  1108. template <class A, class B>
  1109. class ReverseWeightMapper {
  1110. public:
  1111. using FromArc = A;
  1112. using ToArc = B;
  1113. static_assert(std::is_same_v<typename ToArc::Weight,
  1114. typename FromArc::Weight::ReverseWeight>,
  1115. "ToArc::Weight must be FromArc::Weight::ReverseWeight");
  1116. static_assert(std::is_same_v<typename ToArc::Label, typename FromArc::Label>,
  1117. "ToArc::Label must be FromArc::Label");
  1118. static_assert(
  1119. std::is_same_v<typename ToArc::StateId, typename FromArc::StateId>,
  1120. "ToArc::StateId must be FromArc::StateId");
  1121. constexpr ToArc operator()(const FromArc &arc) const {
  1122. return ToArc(arc.ilabel, arc.olabel, arc.weight.Reverse(), arc.nextstate);
  1123. }
  1124. constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  1125. constexpr MapSymbolsAction InputSymbolsAction() const {
  1126. return MAP_COPY_SYMBOLS;
  1127. }
  1128. constexpr MapSymbolsAction OutputSymbolsAction() const {
  1129. return MAP_COPY_SYMBOLS;
  1130. }
  1131. constexpr uint64_t Properties(uint64_t props) const { return props; }
  1132. };
  1133. } // namespace fst
  1134. #endif // FST_ARC_MAP_H_