You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

338 lines
13 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions implementing pruning.
  19. #ifndef FST_PRUNE_H_
  20. #define FST_PRUNE_H_
  21. #include <cstddef>
  22. #include <cstdlib>
  23. #include <type_traits>
  24. #include <utility>
  25. #include <vector>
  26. #include <fst/log.h>
  27. #include <fst/arcfilter.h>
  28. #include <fst/fst.h>
  29. #include <fst/heap.h>
  30. #include <fst/mutable-fst.h>
  31. #include <fst/shortest-distance.h>
  32. #include <fst/weight.h>
  33. namespace fst {
  34. namespace internal {
  35. template <class StateId, class Weight>
  36. class PruneCompare {
  37. public:
  38. PruneCompare(const std::vector<Weight> &idistance,
  39. const std::vector<Weight> &fdistance)
  40. : idistance_(idistance), fdistance_(fdistance) {}
  41. bool operator()(const StateId x, const StateId y) const {
  42. const auto wx = Times(IDistance(x), FDistance(x));
  43. const auto wy = Times(IDistance(y), FDistance(y));
  44. return less_(wx, wy);
  45. }
  46. private:
  47. Weight IDistance(const StateId s) const {
  48. return s < idistance_.size() ? idistance_[s] : Weight::Zero();
  49. }
  50. Weight FDistance(const StateId s) const {
  51. return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
  52. }
  53. const std::vector<Weight> &idistance_;
  54. const std::vector<Weight> &fdistance_;
  55. NaturalLess<Weight> less_;
  56. };
  57. } // namespace internal
  58. template <class Arc, class ArcFilter>
  59. struct PruneOptions {
  60. using StateId = typename Arc::StateId;
  61. using Weight = typename Arc::Weight;
  62. explicit PruneOptions(const Weight &weight_threshold = Weight::Zero(),
  63. StateId state_threshold = kNoStateId,
  64. ArcFilter filter = ArcFilter(),
  65. std::vector<Weight> *distance = nullptr,
  66. float delta = kDelta, bool threshold_initial = false)
  67. : weight_threshold(std::move(weight_threshold)),
  68. state_threshold(state_threshold),
  69. filter(std::move(filter)),
  70. distance(distance),
  71. delta(delta),
  72. threshold_initial(threshold_initial) {}
  73. // Pruning weight threshold.
  74. Weight weight_threshold;
  75. // Pruning state threshold.
  76. StateId state_threshold;
  77. // Arc filter.
  78. ArcFilter filter;
  79. // If non-zero, passes in pre-computed shortest distance to final states.
  80. const std::vector<Weight> *distance;
  81. // Determines the degree of convergence required when computing shortest
  82. // distances.
  83. float delta;
  84. // Determines if the shortest path weight is left (true) or right
  85. // (false) multiplied by the threshold to get the limit for
  86. // keeping a state or arc (matters if the semiring is not
  87. // commutative).
  88. bool threshold_initial;
  89. };
  90. // Pruning algorithm: this version modifies its input and it takes an options
  91. // class as an argument. After pruning the FST contains states and arcs that
  92. // belong to a successful path in the FST whose weight is no more than the
  93. // weight of the shortest path Times() the provided weight threshold. When the
  94. // state threshold is not kNoStateId, the output FST is further restricted to
  95. // have no more than the number of states in opts.state_threshold. Weights must
  96. // have the path property. The weight of any cycle needs to be bounded; i.e.,
  97. //
  98. // Plus(weight, Weight::One()) == Weight::One()
  99. template <class Arc, class ArcFilter>
  100. void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts =
  101. PruneOptions<Arc, ArcFilter>()) {
  102. using StateId = typename Arc::StateId;
  103. using Weight = typename Arc::Weight;
  104. static_assert(IsPath<Weight>::value, "Weight must have path property.");
  105. using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
  106. auto ns = fst->NumStates();
  107. if (ns < 1) return;
  108. std::vector<Weight> idistance(ns, Weight::Zero());
  109. std::vector<Weight> tmp;
  110. if (!opts.distance) {
  111. tmp.reserve(ns);
  112. ShortestDistance(*fst, &tmp, true, opts.delta);
  113. }
  114. const auto *fdistance = opts.distance ? opts.distance : &tmp;
  115. if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
  116. ((*fdistance)[fst->Start()] == Weight::Zero())) {
  117. fst->DeleteStates();
  118. return;
  119. }
  120. internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
  121. StateHeap heap(compare);
  122. std::vector<bool> visited(ns, false);
  123. std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
  124. std::vector<StateId> dead;
  125. dead.push_back(fst->AddState());
  126. NaturalLess<Weight> less;
  127. auto s = fst->Start();
  128. const auto limit = opts.threshold_initial
  129. ? Times(opts.weight_threshold, (*fdistance)[s])
  130. : Times((*fdistance)[s], opts.weight_threshold);
  131. StateId num_visited = 0;
  132. if (!less(limit, (*fdistance)[s])) {
  133. idistance[s] = Weight::One();
  134. enqueued[s] = heap.Insert(s);
  135. ++num_visited;
  136. }
  137. while (!heap.Empty()) {
  138. s = heap.Top();
  139. heap.Pop();
  140. enqueued[s] = StateHeap::kNoKey;
  141. visited[s] = true;
  142. if (less(limit, Times(idistance[s], fst->Final(s)))) {
  143. fst->SetFinal(s, Weight::Zero());
  144. }
  145. for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
  146. aiter.Next()) {
  147. auto arc = aiter.Value(); // Copy intended.
  148. if (!opts.filter(arc)) continue;
  149. const auto weight =
  150. Times(Times(idistance[s], arc.weight),
  151. arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
  152. : Weight::Zero());
  153. if (less(limit, weight)) {
  154. arc.nextstate = dead[0];
  155. aiter.SetValue(arc);
  156. continue;
  157. }
  158. if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
  159. idistance[arc.nextstate] = Times(idistance[s], arc.weight);
  160. }
  161. if (visited[arc.nextstate]) continue;
  162. if ((opts.state_threshold != kNoStateId) &&
  163. (num_visited >= opts.state_threshold)) {
  164. continue;
  165. }
  166. if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
  167. enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
  168. ++num_visited;
  169. } else {
  170. heap.Update(enqueued[arc.nextstate], arc.nextstate);
  171. }
  172. }
  173. }
  174. for (StateId i = 0; i < visited.size(); ++i) {
  175. if (!visited[i]) dead.push_back(i);
  176. }
  177. fst->DeleteStates(dead);
  178. }
  179. // Pruning algorithm: this version modifies its input and takes the
  180. // pruning threshold as an argument. It deletes states and arcs in the
  181. // FST that do not belong to a successful path whose weight is more
  182. // than the weight of the shortest path Times() the provided weight
  183. // threshold. When the state threshold is not kNoStateId, the output
  184. // FST is further restricted to have no more than the number of states
  185. // in opts.state_threshold. Weights must have the path property. The
  186. // weight of any cycle needs to be bounded; i.e.,
  187. //
  188. // Plus(weight, Weight::One()) == Weight::One()
  189. template <class Arc>
  190. void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
  191. typename Arc::StateId state_threshold = kNoStateId,
  192. float delta = kDelta) {
  193. const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
  194. weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
  195. Prune(fst, opts);
  196. }
  197. // Pruning algorithm: this version writes the pruned input FST to an
  198. // output MutableFst and it takes an options class as an argument. The
  199. // output FST contains states and arcs that belong to a successful
  200. // path in the input FST whose weight is more than the weight of the
  201. // shortest path Times() the provided weight threshold. When the state
  202. // threshold is not kNoStateId, the output FST is further restricted
  203. // to have no more than the number of states in
  204. // opts.state_threshold. Weights have the path property. The weight
  205. // of any cycle needs to be bounded; i.e.,
  206. //
  207. // Plus(weight, Weight::One()) == Weight::One()
  208. template <class Arc, class ArcFilter>
  209. void Prune(
  210. const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
  211. const PruneOptions<Arc, ArcFilter> &opts = PruneOptions<Arc, ArcFilter>()) {
  212. using StateId = typename Arc::StateId;
  213. using Weight = typename Arc::Weight;
  214. static_assert(IsPath<Weight>::value, "Weight must have path property.");
  215. using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
  216. ofst->DeleteStates();
  217. ofst->SetInputSymbols(ifst.InputSymbols());
  218. ofst->SetOutputSymbols(ifst.OutputSymbols());
  219. if (ifst.Start() == kNoStateId) return;
  220. NaturalLess<Weight> less;
  221. if (less(opts.weight_threshold, Weight::One()) ||
  222. (opts.state_threshold == 0)) {
  223. return;
  224. }
  225. std::vector<Weight> idistance;
  226. std::vector<Weight> tmp;
  227. if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
  228. const auto *fdistance = opts.distance ? opts.distance : &tmp;
  229. if ((fdistance->size() <= ifst.Start()) ||
  230. ((*fdistance)[ifst.Start()] == Weight::Zero())) {
  231. return;
  232. }
  233. internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
  234. StateHeap heap(compare);
  235. std::vector<StateId> copy;
  236. std::vector<size_t> enqueued;
  237. std::vector<bool> visited;
  238. auto s = ifst.Start();
  239. const auto limit = opts.threshold_initial
  240. ? Times(opts.weight_threshold, (*fdistance)[s])
  241. : Times((*fdistance)[s], opts.weight_threshold);
  242. while (copy.size() <= s) copy.push_back(kNoStateId);
  243. copy[s] = ofst->AddState();
  244. ofst->SetStart(copy[s]);
  245. while (idistance.size() <= s) idistance.push_back(Weight::Zero());
  246. idistance[s] = Weight::One();
  247. while (enqueued.size() <= s) {
  248. enqueued.push_back(StateHeap::kNoKey);
  249. visited.push_back(false);
  250. }
  251. enqueued[s] = heap.Insert(s);
  252. while (!heap.Empty()) {
  253. s = heap.Top();
  254. heap.Pop();
  255. enqueued[s] = StateHeap::kNoKey;
  256. visited[s] = true;
  257. if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
  258. ofst->SetFinal(copy[s], ifst.Final(s));
  259. }
  260. for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
  261. const auto &arc = aiter.Value();
  262. if (!opts.filter(arc)) continue;
  263. const auto weight =
  264. Times(Times(idistance[s], arc.weight),
  265. arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
  266. : Weight::Zero());
  267. if (less(limit, weight)) continue;
  268. if ((opts.state_threshold != kNoStateId) &&
  269. (ofst->NumStates() >= opts.state_threshold)) {
  270. continue;
  271. }
  272. while (idistance.size() <= arc.nextstate) {
  273. idistance.push_back(Weight::Zero());
  274. }
  275. if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
  276. idistance[arc.nextstate] = Times(idistance[s], arc.weight);
  277. }
  278. while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
  279. if (copy[arc.nextstate] == kNoStateId) {
  280. copy[arc.nextstate] = ofst->AddState();
  281. }
  282. ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
  283. copy[arc.nextstate]));
  284. while (enqueued.size() <= arc.nextstate) {
  285. enqueued.push_back(StateHeap::kNoKey);
  286. visited.push_back(false);
  287. }
  288. if (visited[arc.nextstate]) continue;
  289. if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
  290. enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
  291. } else {
  292. heap.Update(enqueued[arc.nextstate], arc.nextstate);
  293. }
  294. }
  295. }
  296. }
  297. // Pruning algorithm: this version writes the pruned input FST to an
  298. // output MutableFst and simply takes the pruning threshold as an
  299. // argument. The output FST contains states and arcs that belong to a
  300. // successful path in the input FST whose weight is no more than the
  301. // weight of the shortest path Times() the provided weight
  302. // threshold. When the state threshold is not kNoStateId, the output
  303. // FST is further restricted to have no more than the number of states
  304. // in opts.state_threshold. Weights must have the path property. The
  305. // weight of any cycle needs to be bounded; i.e.,
  306. //
  307. // Plus(weight, Weight::One()) = Weight::One();
  308. template <class Arc>
  309. void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
  310. typename Arc::Weight weight_threshold,
  311. typename Arc::StateId state_threshold = kNoStateId,
  312. float delta = kDelta) {
  313. const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
  314. weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
  315. Prune(ifst, ofst, opts);
  316. }
  317. } // namespace fst
  318. #endif // FST_PRUNE_H_