You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1133 lines
41 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes to determinize an FST.
  19. #ifndef FST_DETERMINIZE_H_
  20. #define FST_DETERMINIZE_H_
  21. #include <algorithm>
  22. #include <climits>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <forward_list>
  26. #include <map>
  27. #include <memory>
  28. #include <string>
  29. #include <utility>
  30. #include <vector>
  31. #include <fst/log.h>
  32. #include <fst/arc-map.h>
  33. #include <fst/arc.h>
  34. #include <fst/arcfilter.h>
  35. #include <fst/bi-table.h>
  36. #include <fst/cache.h>
  37. #include <fst/const-fst.h>
  38. #include <fst/factor-weight.h>
  39. #include <fst/filter-state.h>
  40. #include <fst/float-weight.h>
  41. #include <fst/fst.h>
  42. #include <fst/impl-to-fst.h>
  43. #include <fst/lexicographic-weight.h>
  44. #include <fst/mutable-fst.h>
  45. #include <fst/pair-weight.h>
  46. #include <fst/power-weight.h>
  47. #include <fst/product-weight.h>
  48. #include <fst/properties.h>
  49. #include <fst/prune.h>
  50. #include <fst/shortest-distance.h>
  51. #include <fst/string-weight.h>
  52. #include <fst/tuple-weight.h>
  53. #include <fst/union-weight.h>
  54. #include <fst/util.h>
  55. #include <fst/weight.h>
  56. namespace fst {
  57. // Common divisors are used in determinization to compute transition weights.
  58. // In the simplest case, it is the same as semiring Plus, but other choices
  59. // permit more efficient determinization when the output contains strings.
  60. // The default common divisor uses the semiring Plus.
  61. namespace internal {
  62. template <class Arc, class Relation>
  63. class RelationDeterminizeFilter;
  64. } // namespace internal
  65. struct PairArc;
  66. template <class W>
  67. struct DefaultCommonDivisor {
  68. public:
  69. using Weight = W;
  70. Weight operator()(const Weight &w1, const Weight &w2) const {
  71. return Plus(w1, w2);
  72. }
  73. };
  74. // The label common divisor for a (left) string semiring selects a single
  75. // letter common prefix or the empty string. This is used in the
  76. // determinization of output strings so that at most a single letter will
  77. // appear in the output of a transtion.
  78. template <typename Label, StringType S>
  79. struct LabelCommonDivisor {
  80. public:
  81. using Weight = StringWeight<Label, S>;
  82. Weight operator()(const Weight &w1, const Weight &w2) const {
  83. typename Weight::Iterator iter1(w1);
  84. typename Weight::Iterator iter2(w2);
  85. if (!(StringWeight<Label, S>::Properties() & kLeftSemiring)) {
  86. FSTERROR() << "LabelCommonDivisor: Weight needs to be left semiring";
  87. return Weight::NoWeight();
  88. } else if (w1.Size() == 0 || w2.Size() == 0) {
  89. return Weight::One();
  90. } else if (w1 == Weight::Zero()) {
  91. return Weight(iter2.Value());
  92. } else if (w2 == Weight::Zero()) {
  93. return Weight(iter1.Value());
  94. } else if (iter1.Value() == iter2.Value()) {
  95. return Weight(iter1.Value());
  96. } else {
  97. return Weight::One();
  98. }
  99. }
  100. };
  101. // The gallic common divisor uses the label common divisor on the string
  102. // component and the common divisor on the weight component, which defaults to
  103. // the default common divisor.
  104. template <class Label, class W, GallicType G,
  105. class CommonDivisor = DefaultCommonDivisor<W>>
  106. class GallicCommonDivisor {
  107. public:
  108. using Weight = GallicWeight<Label, W, G>;
  109. Weight operator()(const Weight &w1, const Weight &w2) const {
  110. return Weight(label_common_divisor_(w1.Value1(), w2.Value1()),
  111. weight_common_divisor_(w1.Value2(), w2.Value2()));
  112. }
  113. private:
  114. LabelCommonDivisor<Label, GallicStringType(G)> label_common_divisor_;
  115. CommonDivisor weight_common_divisor_;
  116. };
  117. // Specialization for general GALLIC weight.
  118. template <class Label, class W, class CommonDivisor>
  119. class GallicCommonDivisor<Label, W, GALLIC, CommonDivisor> {
  120. public:
  121. using Weight = GallicWeight<Label, W, GALLIC>;
  122. using GRWeight = GallicWeight<Label, W, GALLIC_RESTRICT>;
  123. using Iterator =
  124. UnionWeightIterator<GRWeight, GallicUnionWeightOptions<Label, W>>;
  125. Weight operator()(const Weight &w1, const Weight &w2) const {
  126. auto weight = GRWeight::Zero();
  127. for (Iterator iter(w1); !iter.Done(); iter.Next()) {
  128. weight = common_divisor_(weight, iter.Value());
  129. }
  130. for (Iterator iter(w2); !iter.Done(); iter.Next()) {
  131. weight = common_divisor_(weight, iter.Value());
  132. }
  133. return weight == GRWeight::Zero() ? Weight::Zero() : Weight(weight);
  134. }
  135. private:
  136. GallicCommonDivisor<Label, W, GALLIC_RESTRICT, CommonDivisor> common_divisor_;
  137. };
  138. namespace internal {
  139. // Represents an element in a subset
  140. template <class Arc>
  141. struct DeterminizeElement {
  142. using StateId = typename Arc::StateId;
  143. using Weight = typename Arc::Weight;
  144. DeterminizeElement(StateId s, Weight weight)
  145. : state_id(s), weight(std::move(weight)) {}
  146. inline bool operator==(const DeterminizeElement &element) const {
  147. return state_id == element.state_id && weight == element.weight;
  148. }
  149. inline bool operator!=(const DeterminizeElement &element) const {
  150. return !(*this == element);
  151. }
  152. inline bool operator<(const DeterminizeElement<Arc> &element) const {
  153. return state_id < element.state_id;
  154. }
  155. StateId state_id; // Input state ID.
  156. Weight weight; // Residual weight.
  157. };
  158. // Represents a weighted subset and determinization filter state
  159. template <typename A, typename FilterState>
  160. struct DeterminizeStateTuple {
  161. using Arc = A;
  162. using Element = DeterminizeElement<Arc>;
  163. using Subset = std::forward_list<Element>;
  164. DeterminizeStateTuple() : filter_state(FilterState::NoState()) {}
  165. inline bool operator==(const DeterminizeStateTuple &tuple) const {
  166. return (tuple.filter_state == filter_state) && (tuple.subset == subset);
  167. }
  168. inline bool operator!=(const DeterminizeStateTuple &tuple) const {
  169. return (tuple.filter_state != filter_state) || (tuple.subset != subset);
  170. }
  171. Subset subset;
  172. FilterState filter_state;
  173. };
  174. // Proto-transition for determinization.
  175. template <class StateTuple>
  176. struct DeterminizeArc {
  177. using Arc = typename StateTuple::Arc;
  178. using Label = typename Arc::Label;
  179. using Weight = typename Arc::Weight;
  180. DeterminizeArc() = default;
  181. explicit DeterminizeArc(const Arc &arc)
  182. : label(arc.ilabel),
  183. dest_tuple(fst::make_unique_for_overwrite<StateTuple>()) {}
  184. Label label = kNoLabel; // Arc label.
  185. Weight weight = Weight::Zero(); // Arc weight.
  186. std::unique_ptr<StateTuple>
  187. dest_tuple; // Destination subset and filter state.
  188. };
  189. } // namespace internal
  190. // Determinization filters are used to compute destination state tuples based
  191. // on the source tuple, transition, and destination element or on similar
  192. // super-final transition information. The filter operates on a map between a
  193. // label and the corresponding destination state tuples. It must define the map
  194. // type LabelMap. The default filter is used for weighted determinization.
  195. // A determinize filter for implementing weighted determinization.
  196. template <class Arc>
  197. class DefaultDeterminizeFilter {
  198. public:
  199. using Label = typename Arc::Label;
  200. using StateId = typename Arc::StateId;
  201. using Weight = typename Arc::Weight;
  202. using FilterState = CharFilterState;
  203. using Element = internal::DeterminizeElement<Arc>;
  204. using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
  205. using LabelMap = std::map<Label, internal::DeterminizeArc<StateTuple>>;
  206. // This is needed e.g. to go into the gallic domain for transducers.
  207. template <class A>
  208. struct rebind {
  209. using Other = DefaultDeterminizeFilter<A>;
  210. };
  211. explicit DefaultDeterminizeFilter(const Fst<Arc> &fst) : fst_(fst.Copy()) {}
  212. // This is needed (e.g.) to go into the gallic domain for transducers.
  213. template <class Filter>
  214. DefaultDeterminizeFilter(const Fst<Arc> &fst, std::unique_ptr<Filter> filter)
  215. : fst_(fst.Copy()) {}
  216. // Copy constructor; the FST can be passed if it has been deep-copied.
  217. DefaultDeterminizeFilter(const DefaultDeterminizeFilter &filter,
  218. const Fst<Arc> *fst = nullptr)
  219. : fst_(fst ? fst->Copy() : filter.fst_->Copy()) {}
  220. FilterState Start() const { return FilterState(0); }
  221. // Does no work.
  222. void SetState(StateId s, const StateTuple &tuple) {}
  223. // Filters transition, possibly modifying label map. Returns true if arc is
  224. // added to the label map.
  225. bool FilterArc(const Arc &arc, const Element &src_element,
  226. Element &&dest_element, LabelMap *label_map) const {
  227. // Adds element to unique state tuple for arc label.
  228. auto &det_arc = (*label_map)[arc.ilabel];
  229. if (det_arc.label == kNoLabel) {
  230. det_arc = internal::DeterminizeArc<StateTuple>(arc);
  231. det_arc.dest_tuple->filter_state = FilterState(0);
  232. }
  233. det_arc.dest_tuple->subset.push_front(std::move(dest_element));
  234. return true;
  235. }
  236. // Filters super-final transition, returning new final weight.
  237. Weight FilterFinal(Weight weight, const Element &element) { return weight; }
  238. static uint64_t Properties(uint64_t props) { return props; }
  239. private:
  240. std::unique_ptr<Fst<Arc>> fst_;
  241. };
  242. // Determinization state table interface:
  243. //
  244. // template <class Arc, class FilterState>
  245. // class DeterminizeStateTable {
  246. // public:
  247. // using StateId = typename Arc::StateId;
  248. // using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
  249. //
  250. // // Required sub-class. This is needed (e.g.) to go into the gallic domain.
  251. // template <class B, class G>
  252. // struct rebind {
  253. // using Other = DeterminizeStateTable<B, G>;
  254. // }
  255. //
  256. // // Required constuctor.
  257. // DeterminizeStateTable();
  258. //
  259. // // Required copy constructor that does not copy state.
  260. // DeterminizeStateTable(const DeterminizeStateTable<Arc, FilterState>
  261. // &table);
  262. //
  263. // // Looks up state ID by state tuple; if it doesn't exist, then adds it.
  264. // // FindState takes ownership of the state tuple argument so that it
  265. // // doesn't have to copy it if it creates a new state.
  266. // StateId FindState(std::unique_ptr<StateTuple> tuple);
  267. //
  268. // // Looks up state tuple by ID.
  269. // const StateTuple *Tuple(StateId id) const;
  270. // };
  271. // The default determinization state table based on the compact hash bi-table.
  272. template <class Arc, class FilterState>
  273. class DefaultDeterminizeStateTable {
  274. public:
  275. using Label = typename Arc::Label;
  276. using StateId = typename Arc::StateId;
  277. using Weight = typename Arc::Weight;
  278. using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
  279. using Element = typename StateTuple::Element;
  280. using Subset = typename StateTuple::Subset;
  281. template <class B, class G>
  282. struct rebind {
  283. using Other = DefaultDeterminizeStateTable<B, G>;
  284. };
  285. explicit DefaultDeterminizeStateTable(size_t table_size = 0)
  286. : table_size_(table_size), tuples_(table_size_) {}
  287. DefaultDeterminizeStateTable(const DefaultDeterminizeStateTable &table)
  288. : table_size_(table.table_size_), tuples_(table_size_) {}
  289. ~DefaultDeterminizeStateTable() {
  290. for (StateId s = 0; s < tuples_.Size(); ++s) delete tuples_.FindEntry(s);
  291. }
  292. // Finds the state corresponding to a state tuple. Only creates a new state if
  293. // the tuple is not found. FindState takes ownership of the tuple argument so
  294. // that it doesn't have to copy it if it creates a new state.
  295. StateId FindState(std::unique_ptr<StateTuple> tuple) {
  296. StateTuple *raw_tuple = tuple.release();
  297. const StateId ns = tuples_.Size();
  298. // TODO(wolfsonkin): Make CompactHashBiTable support move semantics so we
  299. // can store a `std::unique_ptr` in `tuples_`.
  300. const auto s = tuples_.FindId(raw_tuple);
  301. if (s != ns) delete raw_tuple; // Tuple found.
  302. return s;
  303. }
  304. const StateTuple *Tuple(StateId s) { return tuples_.FindEntry(s); }
  305. private:
  306. // Comparison object for StateTuples.
  307. class StateTupleEqual {
  308. public:
  309. bool operator()(const StateTuple *tuple1, const StateTuple *tuple2) const {
  310. return *tuple1 == *tuple2;
  311. }
  312. };
  313. // Hash function for StateTuples.
  314. class StateTupleKey {
  315. public:
  316. size_t operator()(const StateTuple *tuple) const {
  317. size_t h = tuple->filter_state.Hash();
  318. for (auto &element : tuple->subset) {
  319. const size_t h1 = element.state_id;
  320. static constexpr auto lshift = 5;
  321. static constexpr auto rshift = CHAR_BIT * sizeof(size_t) - 5;
  322. h ^= h << 1 ^ h1 << lshift ^ h1 >> rshift ^ element.weight.Hash();
  323. }
  324. return h;
  325. }
  326. };
  327. size_t table_size_;
  328. CompactHashBiTable<StateId, StateTuple *, StateTupleKey, StateTupleEqual,
  329. HS_STL>
  330. tuples_;
  331. DefaultDeterminizeStateTable &operator=(
  332. const DefaultDeterminizeStateTable &) = delete;
  333. };
  334. // Determinization type.
  335. enum DeterminizeType {
  336. // Input transducer is known to be functional (or error).
  337. DETERMINIZE_FUNCTIONAL, // Input transducer is functional (error if not).
  338. // Input transducer is not known to be functional.
  339. DETERMINIZE_NONFUNCTIONAL,
  340. // Input transducer is not known to be functional but only keep the min of
  341. // of ambiguous outputs.
  342. DETERMINIZE_DISAMBIGUATE
  343. };
  344. // Options for finite-state transducer determinization templated on the arc
  345. // type, common divisor, the determinization filter and the state table.
  346. // DeterminizeFst takes ownership of the determinization filter and state table,
  347. // if provided.
  348. template <class Arc,
  349. class CommonDivisor = DefaultCommonDivisor<typename Arc::Weight>,
  350. class Filter = DefaultDeterminizeFilter<Arc>,
  351. class StateTable =
  352. DefaultDeterminizeStateTable<Arc, typename Filter::FilterState>>
  353. struct DeterminizeFstOptions : public CacheOptions {
  354. using Label = typename Arc::Label;
  355. float delta; // Quantization delta for subset weights.
  356. Label subsequential_label; // Label used for residual final output
  357. // when producing subsequential transducers.
  358. DeterminizeType type; // Determinization type.
  359. bool increment_subsequential_label; // When creating several subsequential
  360. // arcs at a given state, make their
  361. // label distinct by incrementing.
  362. Filter *filter; // Determinization filter;
  363. // DeterminizeFst takes ownership.
  364. StateTable *state_table; // Determinization state table;
  365. // DeterminizeFst takes ownership.
  366. explicit DeterminizeFstOptions(const CacheOptions &opts, float delta = kDelta,
  367. Label subsequential_label = 0,
  368. DeterminizeType type = DETERMINIZE_FUNCTIONAL,
  369. bool increment_subsequential_label = false,
  370. Filter *filter = nullptr,
  371. StateTable *state_table = nullptr)
  372. : CacheOptions(opts),
  373. delta(delta),
  374. subsequential_label(subsequential_label),
  375. type(type),
  376. increment_subsequential_label(increment_subsequential_label),
  377. filter(filter),
  378. state_table(state_table) {}
  379. explicit DeterminizeFstOptions(float delta = kDelta,
  380. Label subsequential_label = 0,
  381. DeterminizeType type = DETERMINIZE_FUNCTIONAL,
  382. bool increment_subsequential_label = false,
  383. Filter *filter = nullptr,
  384. StateTable *state_table = nullptr)
  385. : delta(delta),
  386. subsequential_label(subsequential_label),
  387. type(type),
  388. increment_subsequential_label(increment_subsequential_label),
  389. filter(filter),
  390. state_table(state_table) {}
  391. };
  392. namespace internal {
  393. // Implementation of delayed DeterminizeFst. This base class is
  394. // common to the variants that implement acceptor and transducer
  395. // determinization.
  396. template <class Arc>
  397. class DeterminizeFstImplBase : public CacheImpl<Arc> {
  398. public:
  399. using Label = typename Arc::Label;
  400. using StateId = typename Arc::StateId;
  401. using Weight = typename Arc::Weight;
  402. using Store = DefaultCacheStore<Arc>;
  403. using State = typename Store::State;
  404. using FstImpl<Arc>::SetType;
  405. using FstImpl<Arc>::SetProperties;
  406. using FstImpl<Arc>::Properties;
  407. using FstImpl<Arc>::SetInputSymbols;
  408. using FstImpl<Arc>::SetOutputSymbols;
  409. using CacheBaseImpl<CacheState<Arc>>::HasStart;
  410. using CacheBaseImpl<CacheState<Arc>>::HasFinal;
  411. using CacheBaseImpl<CacheState<Arc>>::HasArcs;
  412. using CacheBaseImpl<CacheState<Arc>>::SetFinal;
  413. using CacheBaseImpl<CacheState<Arc>>::SetStart;
  414. template <class CommonDivisor, class Filter, class StateTable>
  415. DeterminizeFstImplBase(
  416. const Fst<Arc> &fst,
  417. const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
  418. : CacheImpl<Arc>(opts), fst_(fst.Copy()) {
  419. SetType("determinize");
  420. const auto iprops = fst.Properties(kFstProperties, false);
  421. const auto dprops =
  422. DeterminizeProperties(iprops, opts.subsequential_label != 0,
  423. opts.type == DETERMINIZE_NONFUNCTIONAL
  424. ? opts.increment_subsequential_label
  425. : true);
  426. SetProperties(Filter::Properties(dprops), kCopyProperties);
  427. SetInputSymbols(fst.InputSymbols());
  428. SetOutputSymbols(fst.OutputSymbols());
  429. }
  430. DeterminizeFstImplBase(const DeterminizeFstImplBase &impl)
  431. : CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)) {
  432. SetType("determinize");
  433. SetProperties(impl.Properties(), kCopyProperties);
  434. SetInputSymbols(impl.InputSymbols());
  435. SetOutputSymbols(impl.OutputSymbols());
  436. }
  437. virtual DeterminizeFstImplBase *Copy() const = 0;
  438. StateId Start() {
  439. if (!HasStart()) {
  440. const auto start = ComputeStart();
  441. if (start != kNoStateId) SetStart(start);
  442. }
  443. return CacheImpl<Arc>::Start();
  444. }
  445. Weight Final(StateId s) {
  446. if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
  447. return CacheImpl<Arc>::Final(s);
  448. }
  449. virtual void Expand(StateId s) = 0;
  450. size_t NumArcs(StateId s) {
  451. if (!HasArcs(s)) Expand(s);
  452. return CacheImpl<Arc>::NumArcs(s);
  453. }
  454. size_t NumInputEpsilons(StateId s) {
  455. if (!HasArcs(s)) Expand(s);
  456. return CacheImpl<Arc>::NumInputEpsilons(s);
  457. }
  458. size_t NumOutputEpsilons(StateId s) {
  459. if (!HasArcs(s)) Expand(s);
  460. return CacheImpl<Arc>::NumOutputEpsilons(s);
  461. }
  462. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
  463. if (!HasArcs(s)) Expand(s);
  464. CacheImpl<Arc>::InitArcIterator(s, data);
  465. }
  466. virtual StateId ComputeStart() = 0;
  467. virtual Weight ComputeFinal(StateId s) = 0;
  468. const Fst<Arc> &GetFst() const { return *fst_; }
  469. private:
  470. std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
  471. };
  472. // Implementation of delayed determinization for weighted acceptors.
  473. template <class Arc, class CommonDivisor, class Filter, class StateTable>
  474. class DeterminizeFsaImpl : public DeterminizeFstImplBase<Arc> {
  475. public:
  476. using Label = typename Arc::Label;
  477. using StateId = typename Arc::StateId;
  478. using Weight = typename Arc::Weight;
  479. using FilterState = typename Filter::FilterState;
  480. using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
  481. using Element = typename StateTuple::Element;
  482. using Subset = typename StateTuple::Subset;
  483. using LabelMap = typename Filter::LabelMap;
  484. using FstImpl<Arc>::SetProperties;
  485. using DeterminizeFstImplBase<Arc>::GetFst;
  486. using DeterminizeFstImplBase<Arc>::SetArcs;
  487. DeterminizeFsaImpl(
  488. const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
  489. std::vector<Weight> *out_dist,
  490. const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
  491. : DeterminizeFstImplBase<Arc>(fst, opts),
  492. delta_(opts.delta),
  493. in_dist_(in_dist),
  494. out_dist_(out_dist),
  495. filter_(opts.filter ? opts.filter : new Filter(fst)),
  496. state_table_(opts.state_table ? opts.state_table : new StateTable()) {
  497. if (!fst.Properties(kAcceptor, true)) {
  498. FSTERROR() << "DeterminizeFst: Argument not an acceptor";
  499. SetProperties(kError, kError);
  500. }
  501. if (!(Weight::Properties() & kLeftSemiring)) {
  502. FSTERROR() << "DeterminizeFst: Weight must be left distributive: "
  503. << Weight::Type();
  504. SetProperties(kError, kError);
  505. }
  506. if (out_dist_) out_dist_->clear();
  507. }
  508. DeterminizeFsaImpl(const DeterminizeFsaImpl &impl)
  509. : DeterminizeFstImplBase<Arc>(impl),
  510. delta_(impl.delta_),
  511. in_dist_(nullptr),
  512. out_dist_(nullptr),
  513. filter_(new Filter(*impl.filter_, &GetFst())),
  514. state_table_(new StateTable(*impl.state_table_)) {
  515. if (impl.out_dist_) {
  516. FSTERROR() << "DeterminizeFsaImpl: Cannot copy with out_dist vector";
  517. SetProperties(kError, kError);
  518. }
  519. }
  520. DeterminizeFsaImpl *Copy() const override {
  521. return new DeterminizeFsaImpl(*this);
  522. }
  523. uint64_t Properties() const override { return Properties(kFstProperties); }
  524. // Sets error if found, and returns other FST impl properties.
  525. uint64_t Properties(uint64_t mask) const override {
  526. if ((mask & kError) && (GetFst().Properties(kError, false))) {
  527. SetProperties(kError, kError);
  528. }
  529. return FstImpl<Arc>::Properties(mask);
  530. }
  531. StateId ComputeStart() override {
  532. const auto s = GetFst().Start();
  533. if (s == kNoStateId) return kNoStateId;
  534. auto tuple = fst::make_unique_for_overwrite<StateTuple>();
  535. tuple->subset.emplace_front(s, Weight::One());
  536. tuple->filter_state = filter_->Start();
  537. return FindState(std::move(tuple));
  538. }
  539. Weight ComputeFinal(StateId s) override {
  540. const auto *tuple = state_table_->Tuple(s);
  541. filter_->SetState(s, *tuple);
  542. auto final_weight = Weight::Zero();
  543. for (const auto &element : tuple->subset) {
  544. final_weight =
  545. Plus(final_weight,
  546. Times(element.weight, GetFst().Final(element.state_id)));
  547. final_weight = filter_->FilterFinal(final_weight, element);
  548. if (!final_weight.Member()) SetProperties(kError, kError);
  549. }
  550. return final_weight;
  551. }
  552. StateId FindState(std::unique_ptr<StateTuple> tuple) {
  553. const auto &subset = tuple->subset;
  554. const auto s = state_table_->FindState(std::move(tuple));
  555. if (in_dist_ && out_dist_->size() <= s) {
  556. out_dist_->push_back(ComputeDistance(subset));
  557. }
  558. return s;
  559. }
  560. // Computes distance from a state to the final states in the DFA given the
  561. // distances in the NFA.
  562. Weight ComputeDistance(const Subset &subset) {
  563. auto outd = Weight::Zero();
  564. for (const auto &element : subset) {
  565. const auto ind =
  566. (element.state_id < in_dist_->size() ? (*in_dist_)[element.state_id]
  567. : Weight::Zero());
  568. outd = Plus(outd, Times(element.weight, ind));
  569. }
  570. return outd;
  571. }
  572. // Computes the outgoing transitions from a state, creating new destination
  573. // states as needed.
  574. void Expand(StateId s) override {
  575. LabelMap label_map;
  576. GetLabelMap(s, &label_map);
  577. for (auto &[unused_label, arc] : label_map) {
  578. AddArc(s, std::move(arc));
  579. }
  580. SetArcs(s);
  581. }
  582. private:
  583. using DetArc = internal::DeterminizeArc<StateTuple>;
  584. // Constructs proto-determinization transition, including destination subset,
  585. // per label.
  586. void GetLabelMap(StateId s, LabelMap *label_map) {
  587. const auto *src_tuple = state_table_->Tuple(s);
  588. filter_->SetState(s, *src_tuple);
  589. for (const auto &src_element : src_tuple->subset) {
  590. for (ArcIterator<Fst<Arc>> aiter(GetFst(), src_element.state_id);
  591. !aiter.Done(); aiter.Next()) {
  592. const auto &arc = aiter.Value();
  593. Element dest_element(arc.nextstate,
  594. Times(src_element.weight, arc.weight));
  595. filter_->FilterArc(arc, src_element, std::move(dest_element),
  596. label_map);
  597. }
  598. }
  599. for (auto &[unused_label, arc] : *label_map) {
  600. NormArc(&arc);
  601. }
  602. }
  603. // Sorts subsets and removes duplicate elements, normalizing transition and
  604. // subset weights.
  605. void NormArc(DetArc *det_arc) {
  606. auto &dest_subset = det_arc->dest_tuple->subset;
  607. dest_subset.sort();
  608. auto piter = dest_subset.begin();
  609. for (auto diter = dest_subset.begin(); diter != dest_subset.end(); ) {
  610. auto &dest_element = *diter;
  611. auto &prev_element = *piter;
  612. // Computes arc weight.
  613. det_arc->weight = common_divisor_(det_arc->weight, dest_element.weight);
  614. if (piter != diter && dest_element.state_id == prev_element.state_id) {
  615. // Found duplicate state: sums state weight and deletes duplicate.
  616. prev_element.weight = Plus(prev_element.weight, dest_element.weight);
  617. if (!prev_element.weight.Member()) SetProperties(kError, kError);
  618. ++diter;
  619. dest_subset.erase_after(piter);
  620. } else {
  621. piter = diter;
  622. ++diter;
  623. }
  624. }
  625. // Divides out label weight from destination subset elements, quantizing to
  626. // ensure comparisons are effective.
  627. for (auto &dest_element : dest_subset) {
  628. dest_element.weight =
  629. Divide(dest_element.weight, det_arc->weight, DIVIDE_LEFT);
  630. dest_element.weight = dest_element.weight.Quantize(delta_);
  631. }
  632. }
  633. // Adds an arc from state S to the destination state associated with state
  634. // tuple in det_arc as created by GetLabelMap.
  635. void AddArc(StateId s, DetArc &&det_arc) {
  636. CacheImpl<Arc>::EmplaceArc(s, det_arc.label, det_arc.label,
  637. std::move(det_arc.weight),
  638. FindState(std::move(det_arc.dest_tuple)));
  639. }
  640. float delta_; // Quantization delta for weights.
  641. const std::vector<Weight> *in_dist_; // Distance to final NFA states.
  642. std::vector<Weight> *out_dist_; // Distance to final DFA states.
  643. static const CommonDivisor common_divisor_;
  644. std::unique_ptr<Filter> filter_;
  645. std::unique_ptr<StateTable> state_table_;
  646. };
  647. template <class Arc, class CommonDivisor, class Filter, class StateTable>
  648. const CommonDivisor DeterminizeFsaImpl<Arc, CommonDivisor, Filter,
  649. StateTable>::common_divisor_{};
  650. // Implementation of delayed determinization for transducers. Transducer
  651. // determinization is implemented by mapping the input to the Gallic semiring as
  652. // an acceptor whose weights contain the output strings and using acceptor
  653. // determinization above to determinize that acceptor.
  654. template <class Arc, GallicType G, class CommonDivisor, class Filter,
  655. class StateTable>
  656. class DeterminizeFstImpl : public DeterminizeFstImplBase<Arc> {
  657. public:
  658. using Label = typename Arc::Label;
  659. using StateId = typename Arc::StateId;
  660. using Weight = typename Arc::Weight;
  661. using ToMapper = ToGallicMapper<Arc, G>;
  662. using ToArc = typename ToMapper::ToArc;
  663. using ToFst = ArcMapFst<Arc, ToArc, ToMapper>;
  664. using FromMapper = FromGallicMapper<Arc, G>;
  665. using FromFst = ArcMapFst<ToArc, Arc, FromMapper>;
  666. using ToCommonDivisor = GallicCommonDivisor<Label, Weight, G, CommonDivisor>;
  667. using ToFilter = typename Filter::template rebind<ToArc>::Other;
  668. using ToFilterState = typename ToFilter::FilterState;
  669. using ToStateTable =
  670. typename StateTable::template rebind<ToArc, ToFilterState>::Other;
  671. using FactorIterator = GallicFactor<Label, Weight, G>;
  672. using FstImpl<Arc>::SetProperties;
  673. using DeterminizeFstImplBase<Arc>::GetFst;
  674. using CacheBaseImpl<CacheState<Arc>>::GetCacheGc;
  675. using CacheBaseImpl<CacheState<Arc>>::GetCacheLimit;
  676. DeterminizeFstImpl(
  677. const Fst<Arc> &fst,
  678. const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
  679. : DeterminizeFstImplBase<Arc>(fst, opts),
  680. delta_(opts.delta),
  681. subsequential_label_(opts.subsequential_label),
  682. increment_subsequential_label_(opts.increment_subsequential_label) {
  683. if (opts.state_table) {
  684. FSTERROR() << "DeterminizeFst: "
  685. << "A state table can not be passed with transducer input";
  686. SetProperties(kError, kError);
  687. return;
  688. }
  689. // Takes ownership of filter.
  690. Init(GetFst(), fst::WrapUnique(opts.filter));
  691. }
  692. DeterminizeFstImpl(const DeterminizeFstImpl &impl)
  693. : DeterminizeFstImplBase<Arc>(impl),
  694. delta_(impl.delta_),
  695. subsequential_label_(impl.subsequential_label_),
  696. increment_subsequential_label_(impl.increment_subsequential_label_) {
  697. Init(GetFst(), nullptr);
  698. }
  699. DeterminizeFstImpl *Copy() const override {
  700. return new DeterminizeFstImpl(*this);
  701. }
  702. uint64_t Properties() const override { return Properties(kFstProperties); }
  703. // Sets error if found, and returns other FST impl properties.
  704. uint64_t Properties(uint64_t mask) const override {
  705. if ((mask & kError) && (GetFst().Properties(kError, false) ||
  706. from_fst_->Properties(kError, false))) {
  707. SetProperties(kError, kError);
  708. }
  709. return FstImpl<Arc>::Properties(mask);
  710. }
  711. StateId ComputeStart() override { return from_fst_->Start(); }
  712. Weight ComputeFinal(StateId s) override { return from_fst_->Final(s); }
  713. void Expand(StateId s) override {
  714. for (ArcIterator<FromFst> aiter(*from_fst_, s); !aiter.Done();
  715. aiter.Next()) {
  716. CacheImpl<Arc>::PushArc(s, aiter.Value());
  717. }
  718. CacheImpl<Arc>::SetArcs(s);
  719. }
  720. private:
  721. // Initialization of transducer determinization implementation, which is
  722. // defined after DeterminizeFst since it calls it.
  723. void Init(const Fst<Arc> &fst, std::unique_ptr<Filter> filter);
  724. float delta_;
  725. Label subsequential_label_;
  726. bool increment_subsequential_label_;
  727. std::unique_ptr<FromFst> from_fst_;
  728. };
  729. } // namespace internal
  730. // Determinizes a weighted transducer. This version is a delayed
  731. // FST. The result will be an equivalent FST that has the property
  732. // that no state has two transitions with the same input label.
  733. // For this algorithm, epsilon transitions are treated as regular
  734. // symbols (cf. RmEpsilon).
  735. //
  736. // The transducer must be functional. The weights must be (weakly) left
  737. // divisible (valid for TropicalWeight and LogWeight for instance) and be
  738. // zero-sum-free if for all a, b: (Plus(a, b) == 0) => a = b = 0.
  739. //
  740. // Complexity:
  741. //
  742. // Determinizable: exponential (polynomial in the size of the output).
  743. // Non-determinizable: does not terminate.
  744. //
  745. // The determinizable automata include all unweighted and all acyclic input.
  746. //
  747. // For more information, see:
  748. //
  749. // Mohri, M. 1997. Finite-state transducers in language and speech processing.
  750. // Computational Linguistics 23(2): 269-311.
  751. //
  752. // This class attaches interface to implementation and handles reference
  753. // counting, delegating most methods to ImplToFst.
  754. template <class A>
  755. class DeterminizeFst : public ImplToFst<internal::DeterminizeFstImplBase<A>> {
  756. public:
  757. using Arc = A;
  758. using Label = typename Arc::Label;
  759. using StateId = typename Arc::StateId;
  760. using Weight = typename Arc::Weight;
  761. using Store = DefaultCacheStore<Arc>;
  762. using State = typename Store::State;
  763. using Impl = internal::DeterminizeFstImplBase<Arc>;
  764. friend class ArcIterator<DeterminizeFst<Arc>>;
  765. friend class StateIterator<DeterminizeFst<Arc>>;
  766. template <class B, GallicType G, class CommonDivisor, class Filter,
  767. class StateTable>
  768. friend class DeterminizeFstImpl;
  769. explicit DeterminizeFst(const Fst<A> &fst)
  770. : ImplToFst<Impl>(CreateImpl(fst)) {}
  771. template <class CommonDivisor, class Filter, class StateTable>
  772. explicit DeterminizeFst(
  773. const Fst<Arc> &fst,
  774. const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
  775. &opts =
  776. DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
  777. : ImplToFst<Impl>(CreateImpl(fst, opts)) {}
  778. // This acceptor-only version additionally computes the distance to final
  779. // states in the output if provided with those distances for the input; this
  780. // is useful for e.g., computing the k-shortest unique paths.
  781. template <class CommonDivisor, class Filter, class StateTable>
  782. DeterminizeFst(
  783. const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
  784. std::vector<Weight> *out_dist,
  785. const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
  786. &opts =
  787. DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
  788. : ImplToFst<Impl>(
  789. std::make_shared<internal::DeterminizeFsaImpl<Arc, CommonDivisor,
  790. Filter, StateTable>>(
  791. fst, in_dist, out_dist, opts)) {
  792. if (!fst.Properties(kAcceptor, true)) {
  793. FSTERROR() << "DeterminizeFst: "
  794. << "Distance to final states computed for acceptors only";
  795. GetMutableImpl()->SetProperties(kError, kError);
  796. }
  797. }
  798. // See Fst<>::Copy() for doc.
  799. DeterminizeFst(const DeterminizeFst &fst, bool safe = false)
  800. : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
  801. : fst.GetSharedImpl()) {}
  802. // Get a copy of this DeterminizeFst. See Fst<>::Copy() for further doc.
  803. DeterminizeFst *Copy(bool safe = false) const override {
  804. return new DeterminizeFst(*this, safe);
  805. }
  806. inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
  807. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  808. GetMutableImpl()->InitArcIterator(s, data);
  809. }
  810. private:
  811. using ImplToFst<Impl>::GetImpl;
  812. using ImplToFst<Impl>::GetMutableImpl;
  813. static std::shared_ptr<Impl> CreateImpl(const Fst<Arc> &fst) {
  814. using D = DefaultCommonDivisor<Weight>;
  815. using F = DefaultDeterminizeFilter<Arc>;
  816. using T = DefaultDeterminizeStateTable<Arc, typename F::FilterState>;
  817. const DeterminizeFstOptions<Arc, D, F, T> opts;
  818. return CreateImpl(fst, opts);
  819. }
  820. template <class CommonDivisor, class Filter, class StateTable>
  821. static std::shared_ptr<Impl> CreateImpl(
  822. const Fst<Arc> &fst,
  823. const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
  824. &opts) {
  825. if (fst.Properties(kAcceptor, true)) {
  826. // Calls implementation for acceptors.
  827. return std::make_shared<
  828. internal::DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable>>(
  829. fst, nullptr, nullptr, opts);
  830. } else if (opts.type == DETERMINIZE_DISAMBIGUATE) {
  831. if constexpr (IsPath<Weight>::value) {
  832. // Calls disambiguating implementation for non-functional transducers.
  833. return std::make_shared<internal::DeterminizeFstImpl<
  834. Arc, GALLIC_MIN, CommonDivisor, Filter, StateTable>>(fst, opts);
  835. } else {
  836. FSTERROR() << "DeterminizeFst: Weight needs to have the path "
  837. << "property to disambiguate output: " << Weight::Type();
  838. // Return an error Impl.
  839. const ConstFst<Arc> empty_fst;
  840. auto rv = std::make_shared<internal::DeterminizeFstImpl<
  841. Arc, GALLIC, CommonDivisor, Filter, StateTable>>(empty_fst, opts);
  842. rv->SetProperties(kError, kError);
  843. return rv;
  844. }
  845. } else if (opts.type == DETERMINIZE_FUNCTIONAL) {
  846. // Calls implementation for functional transducers.
  847. return std::make_shared<internal::DeterminizeFstImpl<
  848. Arc, GALLIC_RESTRICT, CommonDivisor, Filter, StateTable>>(fst, opts);
  849. } else { // opts.type == DETERMINIZE_NONFUNCTIONAL
  850. // Calls implementation for non functional transducers;
  851. return std::make_shared<internal::DeterminizeFstImpl<
  852. Arc, GALLIC, CommonDivisor, Filter, StateTable>>(fst, opts);
  853. }
  854. }
  855. DeterminizeFst &operator=(const DeterminizeFst &) = delete;
  856. };
  857. namespace internal {
  858. // Initialization of transducer determinization implementation, which is defined
  859. // after DeterminizeFst since it calls it.
  860. template <class A, GallicType G, class D, class F, class T>
  861. void DeterminizeFstImpl<A, G, D, F, T>::Init(const Fst<A> &fst,
  862. std::unique_ptr<F> filter) {
  863. // Mapper to an acceptor.
  864. const ToFst to_fst(fst);
  865. auto *to_filter = filter ? new ToFilter(to_fst, std::move(filter)) : nullptr;
  866. // This recursive call terminates since it is to a (non-recursive)
  867. // different constructor.
  868. const CacheOptions copts(GetCacheGc(), GetCacheLimit());
  869. const DeterminizeFstOptions<ToArc, ToCommonDivisor, ToFilter, ToStateTable>
  870. dopts(copts, delta_, 0, DETERMINIZE_FUNCTIONAL, false, to_filter);
  871. // Uses acceptor-only constructor to avoid template recursion.
  872. const DeterminizeFst<ToArc> det_fsa(to_fst, nullptr, nullptr, dopts);
  873. // Mapper back to transducer.
  874. const FactorWeightOptions<ToArc> fopts(
  875. CacheOptions(true, 0), delta_, kFactorFinalWeights, subsequential_label_,
  876. subsequential_label_, increment_subsequential_label_,
  877. increment_subsequential_label_);
  878. const FactorWeightFst<ToArc, FactorIterator> factored_fst(det_fsa, fopts);
  879. from_fst_ =
  880. std::make_unique<FromFst>(factored_fst, FromMapper(subsequential_label_));
  881. }
  882. } // namespace internal
  883. // Specialization for DeterminizeFst.
  884. template <class Arc>
  885. class StateIterator<DeterminizeFst<Arc>>
  886. : public CacheStateIterator<DeterminizeFst<Arc>> {
  887. public:
  888. explicit StateIterator(const DeterminizeFst<Arc> &fst)
  889. : CacheStateIterator<DeterminizeFst<Arc>>(fst, fst.GetMutableImpl()) {}
  890. };
  891. // Specialization for DeterminizeFst.
  892. template <class Arc>
  893. class ArcIterator<DeterminizeFst<Arc>>
  894. : public CacheArcIterator<DeterminizeFst<Arc>> {
  895. public:
  896. using StateId = typename Arc::StateId;
  897. ArcIterator(const DeterminizeFst<Arc> &fst, StateId s)
  898. : CacheArcIterator<DeterminizeFst<Arc>>(fst.GetMutableImpl(), s) {
  899. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  900. }
  901. };
  902. template <class Arc>
  903. inline void DeterminizeFst<Arc>::InitStateIterator(
  904. StateIteratorData<Arc> *data) const {
  905. data->base = std::make_unique<StateIterator<DeterminizeFst<Arc>>>(*this);
  906. }
  907. // Useful aliases when using StdArc.
  908. using StdDeterminizeFst = DeterminizeFst<StdArc>;
  909. template <class Arc>
  910. struct DeterminizeOptions {
  911. using Label = typename Arc::Label;
  912. using StateId = typename Arc::StateId;
  913. using Weight = typename Arc::Weight;
  914. float delta; // Quantization delta for subset weights.
  915. Weight weight_threshold; // Pruning weight threshold.
  916. StateId state_threshold; // Pruning state threshold.
  917. Label subsequential_label; // Label used for residual final output.
  918. DeterminizeType type;
  919. bool increment_subsequential_label; // When creating several subsequential
  920. // arcs at a given state, make their
  921. // label distinct by incrementation?
  922. explicit DeterminizeOptions(float delta = kDelta,
  923. Weight weight_threshold = Weight::Zero(),
  924. StateId state_threshold = kNoStateId,
  925. Label subsequential_label = 0,
  926. DeterminizeType type = DETERMINIZE_FUNCTIONAL,
  927. bool increment_subsequential_label = false)
  928. : delta(delta),
  929. weight_threshold(std::move(weight_threshold)),
  930. state_threshold(state_threshold),
  931. subsequential_label(subsequential_label),
  932. type(type),
  933. increment_subsequential_label(increment_subsequential_label) {}
  934. };
  935. // Determinizes a weighted transducer. This version writes the
  936. // determinized Fst to an output MutableFst. The result will be an
  937. // equivalent FST that has the property that no state has two
  938. // transitions with the same input label. For this algorithm, epsilon
  939. // transitions are treated as regular symbols (cf. RmEpsilon).
  940. //
  941. // The transducer must be functional. The weights must be (weakly)
  942. // left divisible (valid for TropicalWeight and LogWeight).
  943. //
  944. // Complexity:
  945. //
  946. // Determinizable: exponential (polynomial in the size of the output)
  947. // Non-determinizable: does not terminate
  948. //
  949. // The determinizable automata include all unweighted and all acyclic input.
  950. template <class Arc>
  951. void Determinize(
  952. const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
  953. const DeterminizeOptions<Arc> &opts = DeterminizeOptions<Arc>()) {
  954. using Weight = typename Arc::Weight;
  955. DeterminizeFstOptions<Arc> nopts;
  956. nopts.delta = opts.delta;
  957. nopts.subsequential_label = opts.subsequential_label;
  958. nopts.type = opts.type;
  959. nopts.increment_subsequential_label = opts.increment_subsequential_label;
  960. nopts.gc_limit = 0; // Caches only the last state for fastest copy.
  961. if (opts.weight_threshold != Weight::Zero() ||
  962. opts.state_threshold != kNoStateId) {
  963. if constexpr (IsPath<Weight>::value) {
  964. if (ifst.Properties(kAcceptor, false)) {
  965. std::vector<Weight> idistance;
  966. std::vector<Weight> odistance;
  967. ShortestDistance(ifst, &idistance, true);
  968. DeterminizeFst<Arc> dfst(ifst, &idistance, &odistance, nopts);
  969. PruneOptions<Arc, AnyArcFilter<Arc>> popts(
  970. opts.weight_threshold, opts.state_threshold, AnyArcFilter<Arc>(),
  971. &odistance);
  972. Prune(dfst, ofst, popts);
  973. } else {
  974. *ofst = DeterminizeFst<Arc>(ifst, nopts);
  975. Prune(ofst, opts.weight_threshold, opts.state_threshold);
  976. }
  977. } else {
  978. FSTERROR() << "Determinize: Weight needs to have the path "
  979. << "property to use pruning options: " << Weight::Type();
  980. ofst->SetProperties(kError, kError);
  981. }
  982. } else {
  983. *ofst = DeterminizeFst<Arc>(ifst, nopts);
  984. }
  985. }
  986. } // namespace fst
  987. #endif // FST_DETERMINIZE_H_