You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

580 lines
19 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes that implemement epsilon-removal.
  19. #ifndef FST_RMEPSILON_H_
  20. #define FST_RMEPSILON_H_
  21. #include <cstddef>
  22. #include <cstdint>
  23. #include <memory>
  24. #include <stack>
  25. #include <string>
  26. #include <utility>
  27. #include <vector>
  28. #include <fst/log.h>
  29. #include <fst/arc.h>
  30. #include <fst/arcfilter.h>
  31. #include <fst/cache.h>
  32. #include <fst/cc-visitors.h>
  33. #include <fst/connect.h>
  34. #include <fst/dfs-visit.h>
  35. #include <fst/factor-weight.h>
  36. #include <fst/float-weight.h>
  37. #include <fst/fst.h>
  38. #include <fst/impl-to-fst.h>
  39. #include <fst/invert.h>
  40. #include <fst/mutable-fst.h>
  41. #include <fst/properties.h>
  42. #include <fst/prune.h>
  43. #include <fst/queue.h>
  44. #include <fst/shortest-distance.h>
  45. #include <fst/topsort.h>
  46. #include <fst/util.h>
  47. #include <fst/weight.h>
  48. #include <unordered_map>
  49. namespace fst {
  50. template <class Arc, class Queue>
  51. struct RmEpsilonOptions
  52. : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>> {
  53. using StateId = typename Arc::StateId;
  54. using Weight = typename Arc::Weight;
  55. bool connect; // Connect output
  56. Weight weight_threshold; // Pruning weight threshold.
  57. StateId state_threshold; // Pruning state threshold.
  58. explicit RmEpsilonOptions(Queue *queue, float delta = kShortestDelta,
  59. bool connect = true,
  60. Weight weight_threshold = Weight::Zero(),
  61. StateId state_threshold = kNoStateId)
  62. : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>>(
  63. queue, EpsilonArcFilter<Arc>(), kNoStateId, delta),
  64. connect(connect),
  65. weight_threshold(std::move(weight_threshold)),
  66. state_threshold(state_threshold) {}
  67. };
  68. namespace internal {
  69. // Computation state of the epsilon-removal algorithm.
  70. template <class Arc, class Queue>
  71. class RmEpsilonState {
  72. public:
  73. using Label = typename Arc::Label;
  74. using StateId = typename Arc::StateId;
  75. using Weight = typename Arc::Weight;
  76. RmEpsilonState(const Fst<Arc> &fst, std::vector<Weight> *distance,
  77. const RmEpsilonOptions<Arc, Queue> &opts)
  78. : fst_(fst),
  79. distance_(distance),
  80. sd_state_(fst_, distance, opts, true),
  81. expand_id_(0) {}
  82. void Expand(StateId s);
  83. std::vector<Arc> &Arcs() { return arcs_; }
  84. const Weight &Final() const { return final_weight_; }
  85. bool Error() const { return sd_state_.Error(); }
  86. private:
  87. struct Element {
  88. Label ilabel;
  89. Label olabel;
  90. StateId nextstate;
  91. Element() = default;
  92. Element(Label ilabel, Label olabel, StateId nexstate)
  93. : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
  94. };
  95. struct ElementHash {
  96. public:
  97. size_t operator()(const Element &element) const {
  98. static constexpr size_t prime0 = 7853;
  99. static constexpr size_t prime1 = 7867;
  100. return static_cast<size_t>(element.nextstate) +
  101. static_cast<size_t>(element.ilabel) * prime0 +
  102. static_cast<size_t>(element.olabel) * prime1;
  103. }
  104. };
  105. class ElementEqual {
  106. public:
  107. bool operator()(const Element &e1, const Element &e2) const {
  108. return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
  109. (e1.nextstate == e2.nextstate);
  110. }
  111. };
  112. using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
  113. ElementHash, ElementEqual>;
  114. const Fst<Arc> &fst_;
  115. // Distance from state being expanded in epsilon-closure.
  116. std::vector<Weight> *distance_;
  117. // Shortest distance algorithm computation state.
  118. internal::ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc>> sd_state_;
  119. // Maps an element to a pair corresponding to a position in the arcs vector
  120. // of the state being expanded. The element corresopnds to the position in
  121. // the arcs_ vector if p.first is equal to the state being expanded.
  122. ElementMap element_map_;
  123. EpsilonArcFilter<Arc> eps_filter_;
  124. std::stack<StateId, std::vector<StateId>>
  125. eps_queue_; // Queue used to visit the epsilon-closure.
  126. std::vector<bool> visited_; // True if the state has been visited.
  127. std::vector<StateId> visited_states_; // List of visited states.
  128. std::vector<Arc> arcs_; // Arcs of state being expanded.
  129. Weight final_weight_; // Final weight of state being expanded.
  130. StateId expand_id_; // Unique ID for each call to Expand
  131. RmEpsilonState(const RmEpsilonState &) = delete;
  132. RmEpsilonState &operator=(const RmEpsilonState &) = delete;
  133. };
  134. template <class Arc, class Queue>
  135. void RmEpsilonState<Arc, Queue>::Expand(typename Arc::StateId source) {
  136. final_weight_ = Weight::Zero();
  137. arcs_.clear();
  138. sd_state_.ShortestDistance(source);
  139. if (sd_state_.Error()) return;
  140. eps_queue_.push(source);
  141. while (!eps_queue_.empty()) {
  142. const auto state = eps_queue_.top();
  143. eps_queue_.pop();
  144. if (static_cast<decltype(state)>(visited_.size()) <= state) {
  145. visited_.resize(state + 1, false);
  146. }
  147. if (visited_[state]) continue;
  148. visited_[state] = true;
  149. visited_states_.push_back(state);
  150. for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
  151. aiter.Next()) {
  152. auto arc = aiter.Value();
  153. arc.weight = Times((*distance_)[state], arc.weight);
  154. if (eps_filter_(arc)) {
  155. if (static_cast<decltype(arc.nextstate)>(visited_.size()) <=
  156. arc.nextstate) {
  157. visited_.resize(arc.nextstate + 1, false);
  158. }
  159. if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
  160. } else if (auto [insert_it, success] = element_map_.emplace(
  161. Element(arc.ilabel, arc.olabel, arc.nextstate),
  162. std::make_pair(expand_id_, arcs_.size()));
  163. success) {
  164. arcs_.push_back(std::move(arc));
  165. } else if (auto &[xid, arc_idx] = insert_it->second; xid == expand_id_) {
  166. auto &weight = arcs_[arc_idx].weight;
  167. weight = Plus(weight, arc.weight);
  168. } else {
  169. xid = expand_id_;
  170. arc_idx = arcs_.size();
  171. arcs_.push_back(std::move(arc));
  172. }
  173. }
  174. final_weight_ =
  175. Plus(final_weight_, Times((*distance_)[state], fst_.Final(state)));
  176. }
  177. for (const auto state_id : visited_states_) visited_[state_id] = false;
  178. visited_states_.clear();
  179. ++expand_id_;
  180. }
  181. } // namespace internal
  182. // Removes epsilon-transitions (when both the input and output label are an
  183. // epsilon) from a transducer. The result will be an equivalent FST that has no
  184. // such epsilon transitions. This version modifies its input. It allows fine
  185. // control via the options argument; see below for a simpler interface.
  186. //
  187. // The distance vector will be used to hold the shortest distances during the
  188. // epsilon-closure computation. The state queue discipline and convergence delta
  189. // are taken in the options argument.
  190. template <class Arc, class Queue>
  191. void RmEpsilon(MutableFst<Arc> *fst,
  192. std::vector<typename Arc::Weight> *distance,
  193. const RmEpsilonOptions<Arc, Queue> &opts) {
  194. using StateId = typename Arc::StateId;
  195. using Weight = typename Arc::Weight;
  196. if (fst->Start() == kNoStateId) return;
  197. // noneps_in[s] will be set to true iff s admits a non-epsilon incoming
  198. // transition or is the start state.
  199. std::vector<bool> noneps_in(fst->NumStates(), false);
  200. noneps_in[fst->Start()] = true;
  201. for (size_t i = 0; i < fst->NumStates(); ++i) {
  202. for (ArcIterator<Fst<Arc>> aiter(*fst, i); !aiter.Done(); aiter.Next()) {
  203. const auto &arc = aiter.Value();
  204. if (arc.ilabel != 0 || arc.olabel != 0) {
  205. noneps_in[arc.nextstate] = true;
  206. }
  207. }
  208. }
  209. // States sorted in topological order when (acyclic) or generic topological
  210. // order (cyclic).
  211. std::vector<StateId> states;
  212. states.reserve(fst->NumStates());
  213. if (fst->Properties(kTopSorted, false) & kTopSorted) {
  214. for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i);
  215. } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
  216. std::vector<StateId> order;
  217. bool acyclic;
  218. TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
  219. DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
  220. // Sanity check: should be acyclic if property bit is set.
  221. if (!acyclic) {
  222. FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit";
  223. fst->SetProperties(kError, kError);
  224. return;
  225. }
  226. states.resize(order.size());
  227. for (StateId i = 0; i < order.size(); i++) states[order[i]] = i;
  228. } else {
  229. uint64_t props;
  230. std::vector<StateId> scc;
  231. SccVisitor<Arc> scc_visitor(&scc, nullptr, nullptr, &props);
  232. DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
  233. std::vector<StateId> first(scc.size(), kNoStateId);
  234. std::vector<StateId> next(scc.size(), kNoStateId);
  235. for (StateId i = 0; i < scc.size(); i++) {
  236. if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]];
  237. first[scc[i]] = i;
  238. }
  239. for (StateId i = 0; i < first.size(); i++) {
  240. for (auto j = first[i]; j != kNoStateId; j = next[j]) {
  241. states.push_back(j);
  242. }
  243. }
  244. }
  245. internal::RmEpsilonState<Arc, Queue> rmeps_state(*fst, distance, opts);
  246. while (!states.empty()) {
  247. const auto state = states.back();
  248. states.pop_back();
  249. if (!noneps_in[state] &&
  250. (opts.connect || opts.weight_threshold != Weight::Zero() ||
  251. opts.state_threshold != kNoStateId)) {
  252. continue;
  253. }
  254. rmeps_state.Expand(state);
  255. fst->SetFinal(state, rmeps_state.Final());
  256. fst->DeleteArcs(state);
  257. auto &arcs = rmeps_state.Arcs();
  258. fst->ReserveArcs(state, arcs.size());
  259. while (!arcs.empty()) {
  260. fst->AddArc(state, arcs.back());
  261. arcs.pop_back();
  262. }
  263. }
  264. if (opts.connect || opts.weight_threshold != Weight::Zero() ||
  265. opts.state_threshold != kNoStateId) {
  266. for (size_t s = 0; s < fst->NumStates(); ++s) {
  267. if (!noneps_in[s]) fst->DeleteArcs(s);
  268. }
  269. }
  270. if (rmeps_state.Error()) fst->SetProperties(kError, kError);
  271. fst->SetProperties(
  272. RmEpsilonProperties(fst->Properties(kFstProperties, false)),
  273. kFstProperties);
  274. if (opts.weight_threshold != Weight::Zero() ||
  275. opts.state_threshold != kNoStateId) {
  276. if constexpr (IsPath<Weight>::value) {
  277. Prune(fst, opts.weight_threshold, opts.state_threshold);
  278. } else {
  279. FSTERROR() << "RmEpsilon: Weight must have path property: "
  280. << Weight::Type();
  281. fst->SetProperties(kError, kError);
  282. return;
  283. }
  284. }
  285. if (opts.connect && opts.weight_threshold == Weight::Zero() &&
  286. opts.state_threshold == kNoStateId) {
  287. Connect(fst);
  288. }
  289. }
  290. // Removes epsilon-transitions (when both the input and output label
  291. // are an epsilon) from a transducer. The result will be an equivalent
  292. // FST that has no such epsilon transitions. This version modifies its
  293. // input. It has a simplified interface; see above for a version that
  294. // allows finer control.
  295. //
  296. // Complexity:
  297. //
  298. // - Time:
  299. //
  300. // Unweighted: O(v^2 + ve).
  301. // Acyclic: O(v^2 + V e).
  302. // Tropical semiring: O(v^2 log V + ve).
  303. // General: exponential.
  304. //
  305. // - Space: O(vE)
  306. //
  307. // where v is the number of states visited and e is the number of arcs visited.
  308. //
  309. // For more information, see:
  310. //
  311. // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
  312. // algorithms for weighted transducers. International Journal of Computer
  313. // Science 13(1): 129-143.
  314. template <class Arc>
  315. void RmEpsilon(MutableFst<Arc> *fst, bool connect = true,
  316. typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
  317. typename Arc::StateId state_threshold = kNoStateId,
  318. float delta = kShortestDelta) {
  319. using StateId = typename Arc::StateId;
  320. using Weight = typename Arc::Weight;
  321. std::vector<Weight> distance;
  322. AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
  323. RmEpsilonOptions<Arc, AutoQueue<StateId>> opts(
  324. &state_queue, delta, connect, weight_threshold, state_threshold);
  325. RmEpsilon(fst, &distance, opts);
  326. }
  327. struct RmEpsilonFstOptions : CacheOptions {
  328. float delta;
  329. explicit RmEpsilonFstOptions(const CacheOptions &opts,
  330. float delta = kShortestDelta)
  331. : CacheOptions(opts), delta(delta) {}
  332. explicit RmEpsilonFstOptions(float delta = kShortestDelta) : delta(delta) {}
  333. };
  334. namespace internal {
  335. // Implementation of delayed RmEpsilonFst.
  336. template <class Arc>
  337. class RmEpsilonFstImpl : public CacheImpl<Arc> {
  338. public:
  339. using StateId = typename Arc::StateId;
  340. using Weight = typename Arc::Weight;
  341. using Store = DefaultCacheStore<Arc>;
  342. using State = typename Store::State;
  343. using FstImpl<Arc>::Properties;
  344. using FstImpl<Arc>::SetType;
  345. using FstImpl<Arc>::SetProperties;
  346. using FstImpl<Arc>::SetInputSymbols;
  347. using FstImpl<Arc>::SetOutputSymbols;
  348. using CacheBaseImpl<CacheState<Arc>>::HasArcs;
  349. using CacheBaseImpl<CacheState<Arc>>::HasFinal;
  350. using CacheBaseImpl<CacheState<Arc>>::HasStart;
  351. using CacheBaseImpl<CacheState<Arc>>::PushArc;
  352. using CacheBaseImpl<CacheState<Arc>>::SetArcs;
  353. using CacheBaseImpl<CacheState<Arc>>::SetFinal;
  354. using CacheBaseImpl<CacheState<Arc>>::SetStart;
  355. RmEpsilonFstImpl(const Fst<Arc> &fst, const RmEpsilonFstOptions &opts)
  356. : CacheImpl<Arc>(opts),
  357. fst_(fst.Copy()),
  358. delta_(opts.delta),
  359. rmeps_state_(
  360. *fst_, &distance_,
  361. RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
  362. SetType("rmepsilon");
  363. SetProperties(
  364. RmEpsilonProperties(fst.Properties(kFstProperties, false), true),
  365. kCopyProperties);
  366. SetInputSymbols(fst.InputSymbols());
  367. SetOutputSymbols(fst.OutputSymbols());
  368. }
  369. RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
  370. : CacheImpl<Arc>(impl),
  371. fst_(impl.fst_->Copy(true)),
  372. delta_(impl.delta_),
  373. rmeps_state_(
  374. *fst_, &distance_,
  375. RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
  376. SetType("rmepsilon");
  377. SetProperties(impl.Properties(), kCopyProperties);
  378. SetInputSymbols(impl.InputSymbols());
  379. SetOutputSymbols(impl.OutputSymbols());
  380. }
  381. StateId Start() {
  382. if (!HasStart()) SetStart(fst_->Start());
  383. return CacheImpl<Arc>::Start();
  384. }
  385. Weight Final(StateId s) {
  386. if (!HasFinal(s)) Expand(s);
  387. return CacheImpl<Arc>::Final(s);
  388. }
  389. size_t NumArcs(StateId s) {
  390. if (!HasArcs(s)) Expand(s);
  391. return CacheImpl<Arc>::NumArcs(s);
  392. }
  393. size_t NumInputEpsilons(StateId s) {
  394. if (!HasArcs(s)) Expand(s);
  395. return CacheImpl<Arc>::NumInputEpsilons(s);
  396. }
  397. size_t NumOutputEpsilons(StateId s) {
  398. if (!HasArcs(s)) Expand(s);
  399. return CacheImpl<Arc>::NumOutputEpsilons(s);
  400. }
  401. uint64_t Properties() const override { return Properties(kFstProperties); }
  402. // Sets error if found and returns other FST impl properties.
  403. uint64_t Properties(uint64_t mask) const override {
  404. if ((mask & kError) &&
  405. (fst_->Properties(kError, false) || rmeps_state_.Error())) {
  406. SetProperties(kError, kError);
  407. }
  408. return FstImpl<Arc>::Properties(mask);
  409. }
  410. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
  411. if (!HasArcs(s)) Expand(s);
  412. CacheImpl<Arc>::InitArcIterator(s, data);
  413. }
  414. void Expand(StateId s) {
  415. rmeps_state_.Expand(s);
  416. SetFinal(s, rmeps_state_.Final());
  417. auto &arcs = rmeps_state_.Arcs();
  418. while (!arcs.empty()) {
  419. PushArc(s, std::move(arcs.back()));
  420. arcs.pop_back();
  421. }
  422. SetArcs(s);
  423. }
  424. private:
  425. std::unique_ptr<const Fst<Arc>> fst_;
  426. float delta_;
  427. std::vector<Weight> distance_;
  428. FifoQueue<StateId> queue_;
  429. internal::RmEpsilonState<Arc, FifoQueue<StateId>> rmeps_state_;
  430. };
  431. } // namespace internal
  432. // Removes epsilon-transitions (when both the input and output label are an
  433. // epsilon) from a transducer. The result will be an equivalent FST that has no
  434. // such epsilon transitions. This version is a
  435. // delayed FST.
  436. //
  437. // Complexity:
  438. //
  439. // - Time:
  440. // Unweighted: O(v^2 + ve).
  441. // General: exponential.
  442. //
  443. // - Space: O(vE)
  444. //
  445. // where v is the number of states visited and e is the number of arcs visited.
  446. // Constant time to visit an input state or arc is assumed and exclusive of
  447. // caching.
  448. //
  449. // For more information, see:
  450. //
  451. // Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
  452. // algorithms for weighted transducers. International Journal of Computer
  453. // Science 13(1): 129-143.
  454. //
  455. // This class attaches interface to implementation and handles
  456. // reference counting, delegating most methods to ImplToFst.
  457. template <class A>
  458. class RmEpsilonFst : public ImplToFst<internal::RmEpsilonFstImpl<A>> {
  459. public:
  460. using Arc = A;
  461. using StateId = typename Arc::StateId;
  462. using Store = DefaultCacheStore<Arc>;
  463. using State = typename Store::State;
  464. using Impl = internal::RmEpsilonFstImpl<Arc>;
  465. friend class ArcIterator<RmEpsilonFst<Arc>>;
  466. friend class StateIterator<RmEpsilonFst<Arc>>;
  467. explicit RmEpsilonFst(const Fst<Arc> &fst)
  468. : ImplToFst<Impl>(std::make_shared<Impl>(fst, RmEpsilonFstOptions())) {}
  469. RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
  470. : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
  471. // See Fst<>::Copy() for doc.
  472. RmEpsilonFst(const RmEpsilonFst &fst, bool safe = false)
  473. : ImplToFst<Impl>(fst, safe) {}
  474. // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
  475. RmEpsilonFst *Copy(bool safe = false) const override {
  476. return new RmEpsilonFst(*this, safe);
  477. }
  478. inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
  479. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  480. GetMutableImpl()->InitArcIterator(s, data);
  481. }
  482. private:
  483. using ImplToFst<Impl>::GetImpl;
  484. using ImplToFst<Impl>::GetMutableImpl;
  485. RmEpsilonFst &operator=(const RmEpsilonFst &) = delete;
  486. };
  487. // Specialization for RmEpsilonFst.
  488. template <class Arc>
  489. class StateIterator<RmEpsilonFst<Arc>>
  490. : public CacheStateIterator<RmEpsilonFst<Arc>> {
  491. public:
  492. explicit StateIterator(const RmEpsilonFst<Arc> &fst)
  493. : CacheStateIterator<RmEpsilonFst<Arc>>(fst, fst.GetMutableImpl()) {}
  494. };
  495. // Specialization for RmEpsilonFst.
  496. template <class Arc>
  497. class ArcIterator<RmEpsilonFst<Arc>>
  498. : public CacheArcIterator<RmEpsilonFst<Arc>> {
  499. public:
  500. using StateId = typename Arc::StateId;
  501. ArcIterator(const RmEpsilonFst<Arc> &fst, StateId s)
  502. : CacheArcIterator<RmEpsilonFst<Arc>>(fst.GetMutableImpl(), s) {
  503. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  504. }
  505. };
  506. template <class Arc>
  507. inline void RmEpsilonFst<Arc>::InitStateIterator(
  508. StateIteratorData<Arc> *data) const {
  509. data->base = std::make_unique<StateIterator<RmEpsilonFst<Arc>>>(*this);
  510. }
  511. // Useful alias when using StdArc.
  512. using StdRmEpsilonFst = RmEpsilonFst<StdArc>;
  513. } // namespace fst
  514. #endif // FST_RMEPSILON_H_