You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

530 lines
22 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions to find shortest paths in an FST.
  19. #ifndef FST_SHORTEST_PATH_H_
  20. #define FST_SHORTEST_PATH_H_
  21. #include <algorithm>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <functional>
  25. #include <type_traits>
  26. #include <utility>
  27. #include <vector>
  28. #include <fst/log.h>
  29. #include <fst/arc.h>
  30. #include <fst/arcfilter.h>
  31. #include <fst/cache.h>
  32. #include <fst/connect.h>
  33. #include <fst/determinize.h>
  34. #include <fst/fst.h>
  35. #include <fst/mutable-fst.h>
  36. #include <fst/properties.h>
  37. #include <fst/queue.h>
  38. #include <fst/reverse.h>
  39. #include <fst/shortest-distance.h>
  40. #include <fst/vector-fst.h>
  41. #include <fst/weight.h>
  42. namespace fst {
  43. template <class Arc, class Queue, class ArcFilter>
  44. struct ShortestPathOptions
  45. : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
  46. using StateId = typename Arc::StateId;
  47. using Weight = typename Arc::Weight;
  48. int32_t nshortest; // Returns n-shortest paths.
  49. bool unique; // Only returns paths with distinct input strings.
  50. bool has_distance; // Distance vector already contains the
  51. // shortest distance from the initial state.
  52. bool first_path; // Single shortest path stops after finding the first
  53. // path to a final state; that path is the shortest path
  54. // only when:
  55. // (1) using the ShortestFirstQueue with all the weights
  56. // in the FST being between One() and Zero() according to
  57. // NaturalLess or when
  58. // (2) using the NaturalAStarQueue with an admissible
  59. // and consistent estimate.
  60. Weight weight_threshold; // Pruning weight threshold.
  61. StateId state_threshold; // Pruning state threshold.
  62. ShortestPathOptions(Queue *queue, ArcFilter filter, int32_t nshortest = 1,
  63. bool unique = false, bool has_distance = false,
  64. float delta = kShortestDelta, bool first_path = false,
  65. Weight weight_threshold = Weight::Zero(),
  66. StateId state_threshold = kNoStateId)
  67. : ShortestDistanceOptions<Arc, Queue, ArcFilter>(queue, filter,
  68. kNoStateId, delta),
  69. nshortest(nshortest),
  70. unique(unique),
  71. has_distance(has_distance),
  72. first_path(first_path),
  73. weight_threshold(std::move(weight_threshold)),
  74. state_threshold(state_threshold) {}
  75. };
  76. namespace internal {
  77. inline constexpr size_t kNoArc = -1;
  78. // Helper function for SingleShortestPath building the shortest path as a left-
  79. // to-right machine backwards from the best final state. It takes the input
  80. // FST passed to SingleShortestPath and the parent vector and f_parent returned
  81. // by that function, and builds the result into the provided output mutable FS
  82. // This is not normally called by users; see ShortestPath instead.
  83. template <class Arc>
  84. void SingleShortestPathBacktrace(
  85. const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
  86. const std::vector<std::pair<typename Arc::StateId, size_t>> &parent,
  87. typename Arc::StateId f_parent) {
  88. using StateId = typename Arc::StateId;
  89. ofst->DeleteStates();
  90. ofst->SetInputSymbols(ifst.InputSymbols());
  91. ofst->SetOutputSymbols(ifst.OutputSymbols());
  92. StateId s_p = kNoStateId;
  93. StateId d_p = kNoStateId;
  94. for (StateId state = f_parent, d = kNoStateId; state != kNoStateId;
  95. d = state, state = parent[state].first) {
  96. d_p = s_p;
  97. s_p = ofst->AddState();
  98. if (d == kNoStateId) {
  99. ofst->SetFinal(s_p, ifst.Final(f_parent));
  100. } else {
  101. ArcIterator<Fst<Arc>> aiter(ifst, state);
  102. aiter.Seek(parent[d].second);
  103. auto arc = aiter.Value();
  104. arc.nextstate = d_p;
  105. ofst->AddArc(s_p, std::move(arc));
  106. }
  107. }
  108. ofst->SetStart(s_p);
  109. if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
  110. ofst->SetProperties(
  111. ShortestPathProperties(ofst->Properties(kFstProperties, false), true),
  112. kFstProperties);
  113. }
  114. // Implements the stopping criterion when ShortestPathOptions::first_path
  115. // is set to true:
  116. // operator()(s, d, f) == true
  117. // iff every successful path through state 's' has a cost greater or equal
  118. // to 'f' under the assumption that 'd' is the shortest distance to state 's'.
  119. // Correct when using the ShortestFirstQueue with all the weights in the FST
  120. // being between One() and Zero() according to NaturalLess
  121. template <typename S, typename W, typename Queue>
  122. struct FirstPathSelect {
  123. FirstPathSelect(const Queue &) {}
  124. bool operator()(S s, W d, W f) const { return f == Plus(d, f); }
  125. };
  126. // Specialisation for A*.
  127. // Correct when the estimate is admissible and consistent.
  128. template <typename S, typename W, typename Estimate>
  129. class FirstPathSelect<S, W, NaturalAStarQueue<S, W, Estimate>> {
  130. public:
  131. using Queue = NaturalAStarQueue<S, W, Estimate>;
  132. FirstPathSelect(const Queue &state_queue)
  133. : estimate_(state_queue.GetCompare().GetEstimate()) {}
  134. bool operator()(S s, W d, W f) const {
  135. return f == Plus(Times(d, estimate_(s)), f);
  136. }
  137. private:
  138. const Estimate &estimate_;
  139. };
  140. // Shortest-path algorithm. It builds the output mutable FST so that it contains
  141. // the shortest path in the input FST; distance returns the shortest distances
  142. // from the source state to each state in the input FST, and the options struct
  143. // is
  144. // used to specify options such as the queue discipline, the arc filter and
  145. // delta. The super_final option is an output parameter indicating the final
  146. // state, and the parent argument is used for the storage of the backtrace path
  147. // for each state 1 to n, (i.e., the best previous state and the arc that
  148. // transition to state n.) The shortest path is the lowest weight path w.r.t.
  149. // the natural semiring order. The weights need to be right distributive and
  150. // have the path (kPath) property. False is returned if an error is encountered.
  151. //
  152. // This is not normally called by users; see ShortestPath instead (with n = 1).
  153. template <class Arc, class Queue, class ArcFilter>
  154. bool SingleShortestPath(
  155. const Fst<Arc> &ifst, std::vector<typename Arc::Weight> *distance,
  156. const ShortestPathOptions<Arc, Queue, ArcFilter> &opts,
  157. typename Arc::StateId *f_parent,
  158. std::vector<std::pair<typename Arc::StateId, size_t>> *parent) {
  159. using StateId = typename Arc::StateId;
  160. using Weight = typename Arc::Weight;
  161. static_assert(IsPath<Weight>::value, "Weight must have path property.");
  162. static_assert((Weight::Properties() & kRightSemiring) == kRightSemiring,
  163. "Weight must be right distributive.");
  164. parent->clear();
  165. *f_parent = kNoStateId;
  166. if (ifst.Start() == kNoStateId) return true;
  167. std::vector<bool> enqueued;
  168. auto state_queue = opts.state_queue;
  169. const auto source = (opts.source == kNoStateId) ? ifst.Start() : opts.source;
  170. bool final_seen = false;
  171. auto f_distance = Weight::Zero();
  172. distance->clear();
  173. state_queue->Clear();
  174. while (distance->size() < source) {
  175. distance->push_back(Weight::Zero());
  176. enqueued.push_back(false);
  177. parent->emplace_back(kNoStateId, kNoArc);
  178. }
  179. distance->push_back(Weight::One());
  180. parent->emplace_back(kNoStateId, kNoArc);
  181. state_queue->Enqueue(source);
  182. enqueued.push_back(true);
  183. while (!state_queue->Empty()) {
  184. const auto s = state_queue->Head();
  185. state_queue->Dequeue();
  186. enqueued[s] = false;
  187. const auto sd = (*distance)[s];
  188. // If we are using a shortest queue, no other path is going to be shorter
  189. // than f_distance at this point.
  190. using FirstPath = FirstPathSelect<StateId, Weight, Queue>;
  191. if (opts.first_path && final_seen &&
  192. FirstPath(*state_queue)(s, sd, f_distance)) {
  193. break;
  194. }
  195. if (ifst.Final(s) != Weight::Zero()) {
  196. const auto plus = Plus(f_distance, Times(sd, ifst.Final(s)));
  197. if (f_distance != plus) {
  198. f_distance = plus;
  199. *f_parent = s;
  200. }
  201. if (!f_distance.Member()) return false;
  202. final_seen = true;
  203. }
  204. for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
  205. const auto &arc = aiter.Value();
  206. while (distance->size() <= arc.nextstate) {
  207. distance->push_back(Weight::Zero());
  208. enqueued.push_back(false);
  209. parent->emplace_back(kNoStateId, kNoArc);
  210. }
  211. auto &nd = (*distance)[arc.nextstate];
  212. const auto weight = Times(sd, arc.weight);
  213. if (nd != Plus(nd, weight)) {
  214. nd = Plus(nd, weight);
  215. if (!nd.Member()) return false;
  216. (*parent)[arc.nextstate] = std::make_pair(s, aiter.Position());
  217. if (!enqueued[arc.nextstate]) {
  218. state_queue->Enqueue(arc.nextstate);
  219. enqueued[arc.nextstate] = true;
  220. } else {
  221. state_queue->Update(arc.nextstate);
  222. }
  223. }
  224. }
  225. }
  226. return true;
  227. }
  228. template <class StateId, class Weight>
  229. class ShortestPathCompare {
  230. public:
  231. ShortestPathCompare(const std::vector<std::pair<StateId, Weight>> &pairs,
  232. const std::vector<Weight> &distance, StateId superfinal,
  233. float delta)
  234. : pairs_(pairs),
  235. distance_(distance),
  236. superfinal_(superfinal),
  237. delta_(delta) {}
  238. bool operator()(const StateId x, const StateId y) const {
  239. const auto &px = pairs_[x];
  240. const auto &py = pairs_[y];
  241. const auto wx = Times(PWeight(px.first), px.second);
  242. const auto wy = Times(PWeight(py.first), py.second);
  243. // Penalize complete paths to ensure correct results with inexact weights.
  244. // This forms a strict weak order so long as ApproxEqual(a, b) =>
  245. // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
  246. if (px.first == superfinal_ && py.first != superfinal_) {
  247. return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
  248. } else if (py.first == superfinal_ && px.first != superfinal_) {
  249. return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
  250. } else {
  251. return less_(wy, wx);
  252. }
  253. }
  254. private:
  255. Weight PWeight(StateId state) const {
  256. return (state == superfinal_) ? Weight::One()
  257. : (state < distance_.size()) ? distance_[state]
  258. : Weight::Zero();
  259. }
  260. const std::vector<std::pair<StateId, Weight>> &pairs_;
  261. const std::vector<Weight> &distance_;
  262. const StateId superfinal_;
  263. const float delta_;
  264. NaturalLess<Weight> less_;
  265. };
  266. // N-Shortest-path algorithm: implements the core n-shortest path algorithm.
  267. // The output is built reversed. See below for versions with more options and
  268. // *not reversed*.
  269. //
  270. // The output mutable FST contains the REVERSE of n'shortest paths in the input
  271. // FST; distance must contain the shortest distance from each state to a final
  272. // state in the input FST; delta is the convergence delta.
  273. //
  274. // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
  275. // semiring order. The single path that can be read from the ith of at most n
  276. // transitions leaving the initial state of the input FST is the ith shortest
  277. // path. Disregarding the initial state and initial transitions, the
  278. // n-shortest paths, in fact, form a tree rooted at the single final state.
  279. //
  280. // The weights need to be left and right distributive (kSemiring) and have the
  281. // path (kPath) property.
  282. //
  283. // Arc weights must satisfy the property that the sum of the weights of one or
  284. // more paths from some state S to T is never Zero(). In particular, arc weights
  285. // are never Zero().
  286. //
  287. // For more information, see:
  288. //
  289. // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
  290. // problem. In Proc. ICSLP.
  291. //
  292. // The algorithm relies on the shortest-distance algorithm. There are some
  293. // issues with the pseudo-code as written in the paper (viz., line 11).
  294. //
  295. // IMPLEMENTATION NOTE: The input FST can be a delayed FST and at any state in
  296. // its expansion the values of distance vector need only be defined at that time
  297. // for the states that are known to exist.
  298. template <class Arc, class RevArc>
  299. void NShortestPath(const Fst<RevArc> &ifst, MutableFst<Arc> *ofst,
  300. const std::vector<typename Arc::Weight> &distance,
  301. int32_t nshortest, float delta = kShortestDelta,
  302. typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
  303. typename Arc::StateId state_threshold = kNoStateId) {
  304. using StateId = typename Arc::StateId;
  305. using Weight = typename Arc::Weight;
  306. using Pair = std::pair<StateId, Weight>;
  307. static_assert((Weight::Properties() & kPath) == kPath,
  308. "Weight must have path property.");
  309. static_assert((Weight::Properties() & kSemiring) == kSemiring,
  310. "Weight must be distributive.");
  311. if (nshortest <= 0) return;
  312. ofst->DeleteStates();
  313. ofst->SetInputSymbols(ifst.InputSymbols());
  314. ofst->SetOutputSymbols(ifst.OutputSymbols());
  315. // Each state in ofst corresponds to a path with weight w from the initial
  316. // state of ifst to a state s in ifst, that can be characterized by a pair
  317. // (s, w). The vector pairs maps each state in ofst to the corresponding
  318. // pair maps states in ofst to the corresponding pair (s, w).
  319. std::vector<Pair> pairs;
  320. // The superfinal state is denoted by kNoStateId. The distance from the
  321. // superfinal state to the final state is semiring One, so
  322. // `distance[kNoStateId]` is not needed.
  323. const ShortestPathCompare<StateId, Weight> compare(pairs, distance,
  324. kNoStateId, delta);
  325. const NaturalLess<Weight> less;
  326. if (ifst.Start() == kNoStateId || distance.size() <= ifst.Start() ||
  327. distance[ifst.Start()] == Weight::Zero() ||
  328. less(weight_threshold, Weight::One()) || state_threshold == 0) {
  329. if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
  330. return;
  331. }
  332. ofst->SetStart(ofst->AddState());
  333. const auto final_state = ofst->AddState();
  334. ofst->SetFinal(final_state);
  335. while (pairs.size() <= final_state) {
  336. pairs.emplace_back(kNoStateId, Weight::Zero());
  337. }
  338. pairs[final_state] = std::make_pair(ifst.Start(), Weight::One());
  339. std::vector<StateId> heap;
  340. heap.push_back(final_state);
  341. const auto limit = Times(distance[ifst.Start()], weight_threshold);
  342. // r[s + 1], s state in fst, is the number of states in ofst which
  343. // corresponding pair contains s, i.e., it is number of paths computed so far
  344. // to s. Valid for s == kNoStateId (the superfinal state).
  345. std::vector<int> r;
  346. while (!heap.empty()) {
  347. std::pop_heap(heap.begin(), heap.end(), compare);
  348. const auto state = heap.back();
  349. const auto p = pairs[state];
  350. heap.pop_back();
  351. const auto d = (p.first == kNoStateId) ? Weight::One()
  352. : (p.first < distance.size()) ? distance[p.first]
  353. : Weight::Zero();
  354. if (less(limit, Times(d, p.second)) ||
  355. (state_threshold != kNoStateId &&
  356. ofst->NumStates() >= state_threshold)) {
  357. continue;
  358. }
  359. while (r.size() <= p.first + 1) r.push_back(0);
  360. ++r[p.first + 1];
  361. if (p.first == kNoStateId) ofst->AddArc(ofst->Start(), Arc(0, 0, state));
  362. if ((p.first == kNoStateId) && (r[p.first + 1] == nshortest)) break;
  363. if (r[p.first + 1] > nshortest) continue;
  364. if (p.first == kNoStateId) continue;
  365. for (ArcIterator<Fst<RevArc>> aiter(ifst, p.first); !aiter.Done();
  366. aiter.Next()) {
  367. const auto &rarc = aiter.Value();
  368. Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
  369. const auto weight = Times(p.second, arc.weight);
  370. const auto next = ofst->AddState();
  371. pairs.emplace_back(arc.nextstate, weight);
  372. arc.nextstate = state;
  373. ofst->AddArc(next, std::move(arc));
  374. heap.push_back(next);
  375. std::push_heap(heap.begin(), heap.end(), compare);
  376. }
  377. const auto final_weight = ifst.Final(p.first).Reverse();
  378. if (final_weight != Weight::Zero()) {
  379. const auto weight = Times(p.second, final_weight);
  380. const auto next = ofst->AddState();
  381. pairs.emplace_back(kNoStateId, weight);
  382. ofst->AddArc(next, Arc(0, 0, final_weight, state));
  383. heap.push_back(next);
  384. std::push_heap(heap.begin(), heap.end(), compare);
  385. }
  386. }
  387. Connect(ofst);
  388. if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
  389. ofst->SetProperties(
  390. ShortestPathProperties(ofst->Properties(kFstProperties, false)),
  391. kFstProperties);
  392. }
  393. } // namespace internal
  394. // N-Shortest-path algorithm: this version allows finer control via the options
  395. // argument. See below for a simpler interface. The output mutable FST contains
  396. // the n-shortest paths in the input FST; the distance argument is used to
  397. // return the shortest distances from the source state to each state in the
  398. // input FST, and the options struct is used to specify the number of paths to
  399. // return, whether they need to have distinct input strings, the queue
  400. // discipline, the arc filter and the convergence delta.
  401. //
  402. // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
  403. // semiring order. The single path that can be read from the ith of at most n
  404. // transitions leaving the initial state of the output FST is the ith shortest
  405. // path.
  406. // Disregarding the initial state and initial transitions, The n-shortest paths,
  407. // in fact, form a tree rooted at the single final state.
  408. //
  409. // The weights need to be right distributive and have the path (kPath) property.
  410. // They need to be left distributive as well for nshortest > 1.
  411. //
  412. // For more information, see:
  413. //
  414. // Mohri, M, and Riley, M. 2002. An efficient algorithm for the n-best-strings
  415. // problem. In Proc. ICSLP.
  416. //
  417. // The algorithm relies on the shortest-distance algorithm. There are some
  418. // issues with the pseudo-code as written in the paper (viz., line 11).
  419. template <class Arc, class Queue, class ArcFilter>
  420. void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
  421. std::vector<typename Arc::Weight> *distance,
  422. const ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
  423. using StateId = typename Arc::StateId;
  424. using Weight = typename Arc::Weight;
  425. using RevArc = ReverseArc<Arc>;
  426. static_assert(IsPath<Weight>::value,
  427. "ShortestPath: Weight needs to have the path property and "
  428. "be distributive");
  429. if (opts.nshortest == 1) {
  430. std::vector<std::pair<StateId, size_t>> parent;
  431. StateId f_parent;
  432. if (internal::SingleShortestPath(ifst, distance, opts, &f_parent,
  433. &parent)) {
  434. internal::SingleShortestPathBacktrace(ifst, ofst, parent, f_parent);
  435. } else {
  436. ofst->SetProperties(kError, kError);
  437. }
  438. return;
  439. }
  440. if (opts.nshortest <= 0) return;
  441. if (!opts.has_distance) {
  442. ShortestDistance(ifst, distance, opts);
  443. if (distance->size() == 1 && !(*distance)[0].Member()) {
  444. ofst->SetProperties(kError, kError);
  445. return;
  446. }
  447. }
  448. // Algorithm works on the reverse of 'fst'; 'distance' is the distance to the
  449. // final state in 'rfst', 'ofst' is built as the reverse of the tree of
  450. // n-shortest path in 'rfst'.
  451. VectorFst<RevArc> rfst;
  452. Reverse(ifst, &rfst);
  453. auto d = Weight::Zero();
  454. for (ArcIterator<VectorFst<RevArc>> aiter(rfst, 0); !aiter.Done();
  455. aiter.Next()) {
  456. const auto &arc = aiter.Value();
  457. const auto state = arc.nextstate - 1;
  458. if (state < distance->size()) {
  459. d = Plus(d, Times(arc.weight.Reverse(), (*distance)[state]));
  460. }
  461. }
  462. // TODO(kbg): Avoid this expensive vector operation.
  463. distance->insert(distance->begin(), d);
  464. if (!opts.unique) {
  465. internal::NShortestPath(rfst, ofst, *distance, opts.nshortest, opts.delta,
  466. opts.weight_threshold, opts.state_threshold);
  467. } else {
  468. std::vector<Weight> ddistance;
  469. const DeterminizeFstOptions<RevArc> dopts(opts.delta);
  470. const DeterminizeFst<RevArc> dfst(rfst, distance, &ddistance, dopts);
  471. internal::NShortestPath(dfst, ofst, ddistance, opts.nshortest, opts.delta,
  472. opts.weight_threshold, opts.state_threshold);
  473. }
  474. // TODO(kbg): Avoid this expensive vector operation.
  475. distance->erase(distance->begin());
  476. }
  477. // Shortest-path algorithm: simplified interface. See above for a version that
  478. // allows finer control. The output mutable FST contains the n-shortest paths
  479. // in the input FST. The queue discipline is automatically selected. When unique
  480. // is true, only paths with distinct input label sequences are returned.
  481. //
  482. // The n-shortest paths are the n-lowest weight paths w.r.t. the natural
  483. // semiring order. The single path that can be read from the ith of at most n
  484. // transitions leaving the initial state of the output FST is the ith best path.
  485. // The weights need to be right distributive and have the path (kPath) property.
  486. template <class Arc>
  487. void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
  488. int32_t nshortest = 1, bool unique = false,
  489. bool first_path = false,
  490. typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
  491. typename Arc::StateId state_threshold = kNoStateId,
  492. float delta = kShortestDelta) {
  493. using StateId = typename Arc::StateId;
  494. std::vector<typename Arc::Weight> distance;
  495. AnyArcFilter<Arc> arc_filter;
  496. AutoQueue<StateId> state_queue(ifst, &distance, arc_filter);
  497. const ShortestPathOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>> opts(
  498. &state_queue, arc_filter, nshortest, unique, false, delta, first_path,
  499. weight_threshold, state_threshold);
  500. ShortestPath(ifst, ofst, &distance, opts);
  501. }
  502. } // namespace fst
  503. #endif // FST_SHORTEST_PATH_H_