You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

538 lines
18 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes to factor weights in an FST.
  19. #ifndef FST_FACTOR_WEIGHT_H_
  20. #define FST_FACTOR_WEIGHT_H_
  21. #include <algorithm>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <memory>
  25. #include <string>
  26. #include <utility>
  27. #include <vector>
  28. #include <fst/log.h>
  29. #include <fst/cache.h>
  30. #include <fst/fst.h>
  31. #include <fst/impl-to-fst.h>
  32. #include <fst/properties.h>
  33. #include <fst/string-weight.h>
  34. #include <fst/union-weight.h>
  35. #include <fst/weight.h>
  36. #include <unordered_map>
  37. namespace fst {
  38. inline constexpr uint8_t kFactorFinalWeights = 0x01;
  39. inline constexpr uint8_t kFactorArcWeights = 0x02;
  40. template <class Arc>
  41. struct FactorWeightOptions : CacheOptions {
  42. using Label = typename Arc::Label;
  43. float delta;
  44. uint8_t mode; // Factor arc weights and/or final weights.
  45. Label final_ilabel; // Input label of arc when factoring final weights.
  46. Label final_olabel; // Output label of arc when factoring final weights.
  47. bool increment_final_ilabel; // When factoring final w' results in > 1 arcs
  48. bool increment_final_olabel; // at state, increment labels to make distinct?
  49. explicit FactorWeightOptions(const CacheOptions &opts, float delta = kDelta,
  50. uint8_t mode = kFactorArcWeights |
  51. kFactorFinalWeights,
  52. Label final_ilabel = 0, Label final_olabel = 0,
  53. bool increment_final_ilabel = false,
  54. bool increment_final_olabel = false)
  55. : CacheOptions(opts),
  56. delta(delta),
  57. mode(mode),
  58. final_ilabel(final_ilabel),
  59. final_olabel(final_olabel),
  60. increment_final_ilabel(increment_final_ilabel),
  61. increment_final_olabel(increment_final_olabel) {}
  62. explicit FactorWeightOptions(float delta = kDelta,
  63. uint8_t mode = kFactorArcWeights |
  64. kFactorFinalWeights,
  65. Label final_ilabel = 0, Label final_olabel = 0,
  66. bool increment_final_ilabel = false,
  67. bool increment_final_olabel = false)
  68. : delta(delta),
  69. mode(mode),
  70. final_ilabel(final_ilabel),
  71. final_olabel(final_olabel),
  72. increment_final_ilabel(increment_final_ilabel),
  73. increment_final_olabel(increment_final_olabel) {}
  74. };
  75. // A factor iterator takes as argument a weight w and returns a sequence of
  76. // pairs of weights (xi, yi) such that the sum of the products xi times yi is
  77. // equal to w. If w is fully factored, the iterator should return nothing.
  78. //
  79. // template <class W>
  80. // class FactorIterator {
  81. // public:
  82. // explicit FactorIterator(W w);
  83. //
  84. // bool Done() const;
  85. //
  86. // void Next();
  87. //
  88. // std::pair<W, W> Value() const;
  89. //
  90. // void Reset();
  91. // }
  92. // Factors trivially.
  93. template <class W>
  94. class IdentityFactor {
  95. public:
  96. explicit IdentityFactor(const W &weight) {}
  97. bool Done() const { return true; }
  98. void Next() {}
  99. std::pair<W, W> Value() const { return std::make_pair(W::One(), W::One()); }
  100. void Reset() {}
  101. };
  102. // Factor the Fst to unfold it as needed so that every two paths leading to the
  103. // same state have the same weight. Requires applying only to arc weights
  104. // (FactorWeightOptions::mode == kFactorArcWeights).
  105. template <class W>
  106. class OneFactor {
  107. public:
  108. explicit OneFactor(const W &w) : weight_(w), done_(w == W::One()) {}
  109. bool Done() const { return done_; }
  110. void Next() { done_ = true; }
  111. std::pair<W, W> Value() const { return std::make_pair(W::One(), weight_); }
  112. void Reset() { done_ = weight_ == W::One(); }
  113. private:
  114. W weight_;
  115. bool done_;
  116. };
  117. // Factors a StringWeight w as 'ab' where 'a' is a label.
  118. template <typename Label, StringType S = STRING_LEFT>
  119. class StringFactor {
  120. public:
  121. explicit StringFactor(const StringWeight<Label, S> &weight)
  122. : weight_(weight), done_(weight.Size() <= 1) {}
  123. bool Done() const { return done_; }
  124. void Next() { done_ = true; }
  125. std::pair<StringWeight<Label, S>, StringWeight<Label, S>> Value() const {
  126. using Weight = StringWeight<Label, S>;
  127. typename Weight::Iterator siter(weight_);
  128. Weight w1(siter.Value());
  129. Weight w2;
  130. for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
  131. return std::make_pair(w1, w2);
  132. }
  133. void Reset() { done_ = weight_.Size() <= 1; }
  134. private:
  135. const StringWeight<Label, S> weight_;
  136. bool done_;
  137. };
  138. // Factor a GallicWeight using StringFactor.
  139. template <class Label, class W, GallicType G = GALLIC_LEFT>
  140. class GallicFactor {
  141. public:
  142. using GW = GallicWeight<Label, W, G>;
  143. explicit GallicFactor(const GW &weight)
  144. : weight_(weight), done_(weight.Value1().Size() <= 1) {}
  145. bool Done() const { return done_; }
  146. void Next() { done_ = true; }
  147. std::pair<GW, GW> Value() const {
  148. StringFactor<Label, GallicStringType(G)> siter(weight_.Value1());
  149. GW w1(siter.Value().first, weight_.Value2());
  150. GW w2(siter.Value().second, W::One());
  151. return std::make_pair(w1, w2);
  152. }
  153. void Reset() { done_ = weight_.Value1().Size() <= 1; }
  154. private:
  155. const GW weight_;
  156. bool done_;
  157. };
  158. // Specialization for the (general) GALLIC type GallicWeight.
  159. template <class Label, class W>
  160. class GallicFactor<Label, W, GALLIC> {
  161. public:
  162. using GW = GallicWeight<Label, W, GALLIC>;
  163. using GRW = GallicWeight<Label, W, GALLIC_RESTRICT>;
  164. explicit GallicFactor(const GW &weight)
  165. : iter_(weight),
  166. done_(weight.Size() == 0 ||
  167. (weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
  168. bool Done() const { return done_ || iter_.Done(); }
  169. void Next() { iter_.Next(); }
  170. void Reset() { iter_.Reset(); }
  171. std::pair<GW, GW> Value() const {
  172. const auto weight = iter_.Value();
  173. StringFactor<Label, GallicStringType(GALLIC_RESTRICT)> siter(
  174. weight.Value1());
  175. GRW w1(siter.Value().first, weight.Value2());
  176. GRW w2(siter.Value().second, W::One());
  177. return std::make_pair(GW(w1), GW(w2));
  178. }
  179. private:
  180. UnionWeightIterator<GRW, GallicUnionWeightOptions<Label, W>> iter_;
  181. bool done_;
  182. };
  183. namespace internal {
  184. // Implementation class for FactorWeight
  185. template <class Arc, class FactorIterator>
  186. class FactorWeightFstImpl : public CacheImpl<Arc> {
  187. public:
  188. using Label = typename Arc::Label;
  189. using StateId = typename Arc::StateId;
  190. using Weight = typename Arc::Weight;
  191. using FstImpl<Arc>::SetType;
  192. using FstImpl<Arc>::SetProperties;
  193. using FstImpl<Arc>::SetInputSymbols;
  194. using FstImpl<Arc>::SetOutputSymbols;
  195. using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
  196. using CacheBaseImpl<CacheState<Arc>>::HasArcs;
  197. using CacheBaseImpl<CacheState<Arc>>::HasFinal;
  198. using CacheBaseImpl<CacheState<Arc>>::HasStart;
  199. using CacheBaseImpl<CacheState<Arc>>::SetArcs;
  200. using CacheBaseImpl<CacheState<Arc>>::SetFinal;
  201. using CacheBaseImpl<CacheState<Arc>>::SetStart;
  202. struct Element {
  203. Element() = default;
  204. Element(StateId s, Weight weight_) : state(s), weight(std::move(weight_)) {}
  205. StateId state; // Input state ID.
  206. Weight weight; // Residual weight.
  207. };
  208. FactorWeightFstImpl(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts)
  209. : CacheImpl<Arc>(opts),
  210. fst_(fst.Copy()),
  211. delta_(opts.delta),
  212. mode_(opts.mode),
  213. final_ilabel_(opts.final_ilabel),
  214. final_olabel_(opts.final_olabel),
  215. increment_final_ilabel_(opts.increment_final_ilabel),
  216. increment_final_olabel_(opts.increment_final_olabel) {
  217. SetType("factor_weight");
  218. const auto props = fst.Properties(kFstProperties, false);
  219. SetProperties(FactorWeightProperties(props), kCopyProperties);
  220. SetInputSymbols(fst.InputSymbols());
  221. SetOutputSymbols(fst.OutputSymbols());
  222. if (mode_ == 0) {
  223. LOG(WARNING) << "FactorWeightFst: Factor mode is set to 0; "
  224. << "factoring neither arc weights nor final weights";
  225. }
  226. }
  227. FactorWeightFstImpl(const FactorWeightFstImpl<Arc, FactorIterator> &impl)
  228. : CacheImpl<Arc>(impl),
  229. fst_(impl.fst_->Copy(true)),
  230. delta_(impl.delta_),
  231. mode_(impl.mode_),
  232. final_ilabel_(impl.final_ilabel_),
  233. final_olabel_(impl.final_olabel_),
  234. increment_final_ilabel_(impl.increment_final_ilabel_),
  235. increment_final_olabel_(impl.increment_final_olabel_) {
  236. SetType("factor_weight");
  237. SetProperties(impl.Properties(), kCopyProperties);
  238. SetInputSymbols(impl.InputSymbols());
  239. SetOutputSymbols(impl.OutputSymbols());
  240. }
  241. StateId Start() {
  242. if (!HasStart()) {
  243. const auto s = fst_->Start();
  244. if (s == kNoStateId) return kNoStateId;
  245. SetStart(FindState(Element(fst_->Start(), Weight::One())));
  246. }
  247. return CacheImpl<Arc>::Start();
  248. }
  249. Weight Final(StateId s) {
  250. if (!HasFinal(s)) {
  251. const auto &element = elements_[s];
  252. const auto weight =
  253. element.state == kNoStateId
  254. ? element.weight
  255. : Times(element.weight, fst_->Final(element.state));
  256. FactorIterator siter(weight);
  257. if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
  258. SetFinal(s, weight);
  259. } else {
  260. SetFinal(s, Weight::Zero());
  261. }
  262. }
  263. return CacheImpl<Arc>::Final(s);
  264. }
  265. size_t NumArcs(StateId s) {
  266. if (!HasArcs(s)) Expand(s);
  267. return CacheImpl<Arc>::NumArcs(s);
  268. }
  269. size_t NumInputEpsilons(StateId s) {
  270. if (!HasArcs(s)) Expand(s);
  271. return CacheImpl<Arc>::NumInputEpsilons(s);
  272. }
  273. size_t NumOutputEpsilons(StateId s) {
  274. if (!HasArcs(s)) Expand(s);
  275. return CacheImpl<Arc>::NumOutputEpsilons(s);
  276. }
  277. uint64_t Properties() const override { return Properties(kFstProperties); }
  278. // Sets error if found, and returns other FST impl properties.
  279. uint64_t Properties(uint64_t mask) const override {
  280. if ((mask & kError) && fst_->Properties(kError, false)) {
  281. SetProperties(kError, kError);
  282. }
  283. return FstImpl<Arc>::Properties(mask);
  284. }
  285. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
  286. if (!HasArcs(s)) Expand(s);
  287. CacheImpl<Arc>::InitArcIterator(s, data);
  288. }
  289. // Finds state corresponding to an element, creating new state if element not
  290. // found.
  291. StateId FindState(const Element &element) {
  292. if (!(mode_ & kFactorArcWeights) && element.weight == Weight::One() &&
  293. element.state != kNoStateId) {
  294. while (unfactored_.size() <= element.state)
  295. unfactored_.push_back(kNoStateId);
  296. if (unfactored_[element.state] == kNoStateId) {
  297. unfactored_[element.state] = elements_.size();
  298. elements_.push_back(element);
  299. }
  300. return unfactored_[element.state];
  301. } else {
  302. const auto insert_result =
  303. element_map_.emplace(element, elements_.size());
  304. if (insert_result.second) {
  305. elements_.push_back(element);
  306. }
  307. return insert_result.first->second;
  308. }
  309. }
  310. // Computes the outgoing transitions from a state, creating new destination
  311. // states as needed.
  312. void Expand(StateId s) {
  313. const auto element = elements_[s];
  314. if (element.state != kNoStateId) {
  315. for (ArcIterator<Fst<Arc>> ait(*fst_, element.state); !ait.Done();
  316. ait.Next()) {
  317. const auto &arc = ait.Value();
  318. auto weight = Times(element.weight, arc.weight);
  319. FactorIterator fiter(weight);
  320. if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
  321. const auto dest = FindState(Element(arc.nextstate, Weight::One()));
  322. EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
  323. } else {
  324. for (; !fiter.Done(); fiter.Next()) {
  325. auto pair = fiter.Value();
  326. const auto dest =
  327. FindState(Element(arc.nextstate, pair.second.Quantize(delta_)));
  328. EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
  329. }
  330. }
  331. }
  332. }
  333. if ((mode_ & kFactorFinalWeights) &&
  334. ((element.state == kNoStateId) ||
  335. (fst_->Final(element.state) != Weight::Zero()))) {
  336. const auto weight =
  337. element.state == kNoStateId
  338. ? element.weight
  339. : Times(element.weight, fst_->Final(element.state));
  340. auto ilabel = final_ilabel_;
  341. auto olabel = final_olabel_;
  342. for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
  343. auto pair = fiter.Value();
  344. const auto dest =
  345. FindState(Element(kNoStateId, pair.second.Quantize(delta_)));
  346. EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
  347. if (increment_final_ilabel_) ++ilabel;
  348. if (increment_final_olabel_) ++olabel;
  349. }
  350. }
  351. SetArcs(s);
  352. }
  353. private:
  354. // Equality function for Elements, assume weights have been quantized.
  355. class ElementEqual {
  356. public:
  357. bool operator()(const Element &x, const Element &y) const {
  358. return x.state == y.state && x.weight == y.weight;
  359. }
  360. };
  361. // Hash function for Elements to Fst states.
  362. class ElementKey {
  363. public:
  364. size_t operator()(const Element &x) const {
  365. static constexpr auto prime = 7853;
  366. return static_cast<size_t>(x.state * prime + x.weight.Hash());
  367. }
  368. };
  369. using ElementMap =
  370. std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
  371. std::unique_ptr<const Fst<Arc>> fst_;
  372. float delta_;
  373. uint8_t mode_; // Factoring arc and/or final weights.
  374. Label final_ilabel_; // ilabel of arc created when factoring final weights.
  375. Label final_olabel_; // olabel of arc created when factoring final weights.
  376. bool increment_final_ilabel_; // When factoring final weights results in
  377. bool increment_final_olabel_; // mutiple arcs, increment labels?
  378. std::vector<Element> elements_; // Mapping from FST state to Element.
  379. ElementMap element_map_; // Mapping from Element to FST state.
  380. // Mapping between old/new StateId for states that do not need to be factored
  381. // when mode_ is 0 or kFactorFinalWeights.
  382. std::vector<StateId> unfactored_;
  383. };
  384. } // namespace internal
  385. // FactorWeightFst takes as template parameter a FactorIterator as defined
  386. // above. The result of weight factoring is a transducer equivalent to the
  387. // input whose path weights have been factored according to the FactorIterator.
  388. // States and transitions will be added as necessary. The algorithm is a
  389. // generalization to arbitrary weights of the second step of the input
  390. // epsilon-normalization algorithm.
  391. //
  392. // This class attaches interface to implementation and handles reference
  393. // counting, delegating most methods to ImplToFst.
  394. template <class A, class FactorIterator>
  395. class FactorWeightFst
  396. : public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
  397. public:
  398. using Arc = A;
  399. using StateId = typename Arc::StateId;
  400. using Weight = typename Arc::Weight;
  401. using Store = DefaultCacheStore<Arc>;
  402. using State = typename Store::State;
  403. using Impl = internal::FactorWeightFstImpl<Arc, FactorIterator>;
  404. friend class ArcIterator<FactorWeightFst<Arc, FactorIterator>>;
  405. friend class StateIterator<FactorWeightFst<Arc, FactorIterator>>;
  406. explicit FactorWeightFst(const Fst<Arc> &fst)
  407. : ImplToFst<Impl>(
  408. std::make_shared<Impl>(fst, FactorWeightOptions<Arc>())) {}
  409. FactorWeightFst(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts)
  410. : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
  411. // See Fst<>::Copy() for doc.
  412. FactorWeightFst(const FactorWeightFst &fst, bool copy)
  413. : ImplToFst<Impl>(fst, copy) {}
  414. // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
  415. FactorWeightFst *Copy(bool copy = false) const override {
  416. return new FactorWeightFst(*this, copy);
  417. }
  418. inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
  419. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  420. GetMutableImpl()->InitArcIterator(s, data);
  421. }
  422. private:
  423. using ImplToFst<Impl>::GetImpl;
  424. using ImplToFst<Impl>::GetMutableImpl;
  425. FactorWeightFst &operator=(const FactorWeightFst &) = delete;
  426. };
  427. // Specialization for FactorWeightFst.
  428. template <class Arc, class FactorIterator>
  429. class StateIterator<FactorWeightFst<Arc, FactorIterator>>
  430. : public CacheStateIterator<FactorWeightFst<Arc, FactorIterator>> {
  431. public:
  432. explicit StateIterator(const FactorWeightFst<Arc, FactorIterator> &fst)
  433. : CacheStateIterator<FactorWeightFst<Arc, FactorIterator>>(
  434. fst, fst.GetMutableImpl()) {}
  435. };
  436. // Specialization for FactorWeightFst.
  437. template <class Arc, class FactorIterator>
  438. class ArcIterator<FactorWeightFst<Arc, FactorIterator>>
  439. : public CacheArcIterator<FactorWeightFst<Arc, FactorIterator>> {
  440. public:
  441. using StateId = typename Arc::StateId;
  442. ArcIterator(const FactorWeightFst<Arc, FactorIterator> &fst, StateId s)
  443. : CacheArcIterator<FactorWeightFst<Arc, FactorIterator>>(
  444. fst.GetMutableImpl(), s) {
  445. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  446. }
  447. };
  448. template <class Arc, class FactorIterator>
  449. inline void FactorWeightFst<Arc, FactorIterator>::InitStateIterator(
  450. StateIteratorData<Arc> *data) const {
  451. data->base =
  452. std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>(
  453. *this);
  454. }
  455. } // namespace fst
  456. #endif // FST_FACTOR_WEIGHT_H_