You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

369 lines
14 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes to find shortest distance in an FST.
  19. #ifndef FST_SHORTEST_DISTANCE_H_
  20. #define FST_SHORTEST_DISTANCE_H_
  21. #include <cstddef>
  22. #include <vector>
  23. #include <fst/log.h>
  24. #include <fst/arc.h>
  25. #include <fst/arcfilter.h>
  26. #include <fst/cache.h>
  27. #include <fst/equal.h>
  28. #include <fst/expanded-fst.h>
  29. #include <fst/fst.h>
  30. #include <fst/properties.h>
  31. #include <fst/queue.h>
  32. #include <fst/reverse.h>
  33. #include <fst/util.h>
  34. #include <fst/vector-fst.h>
  35. #include <fst/weight.h>
  36. namespace fst {
  37. // A representable float for shortest distance and shortest path algorithms.
  38. inline constexpr float kShortestDelta = 1e-6;
  39. template <class Arc, class Queue, class ArcFilter>
  40. struct ShortestDistanceOptions {
  41. using StateId = typename Arc::StateId;
  42. Queue *state_queue; // Queue discipline used; owned by caller.
  43. ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
  44. StateId source; // If kNoStateId, use the FST's initial state.
  45. float delta; // Determines the degree of convergence required
  46. bool first_path; // For a semiring with the path property (o.w.
  47. // undefined), compute the shortest-distances along
  48. // along the first path to a final state found
  49. // by the algorithm. That path is the shortest-path
  50. // only if the FST has a unique final state (or all
  51. // the final states have the same final weight), the
  52. // queue discipline is shortest-first and all the
  53. // weights in the FST are between One() and Zero()
  54. // according to NaturalLess.
  55. ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
  56. StateId source = kNoStateId,
  57. float delta = kShortestDelta, bool first_path = false)
  58. : state_queue(state_queue),
  59. arc_filter(arc_filter),
  60. source(source),
  61. delta(delta),
  62. first_path(first_path) {}
  63. };
  64. namespace internal {
  65. // Computation state of the shortest-distance algorithm. Reusable information
  66. // is maintained across calls to member function ShortestDistance(source) when
  67. // retain is true for improved efficiency when calling multiple times from
  68. // different source states (e.g., in epsilon removal). Contrary to the usual
  69. // conventions, fst may not be freed before this class. Vector distance
  70. // should not be modified by the user between these calls. The Error() method
  71. // returns true iff an error was encountered.
  72. template <class Arc, class Queue, class ArcFilter,
  73. class WeightEqual = WeightApproxEqual>
  74. class ShortestDistanceState {
  75. public:
  76. using StateId = typename Arc::StateId;
  77. using Weight = typename Arc::Weight;
  78. ShortestDistanceState(
  79. const Fst<Arc> &fst, std::vector<Weight> *distance,
  80. const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
  81. : fst_(fst),
  82. distance_(distance),
  83. state_queue_(opts.state_queue),
  84. arc_filter_(opts.arc_filter),
  85. weight_equal_(opts.delta),
  86. first_path_(opts.first_path),
  87. retain_(retain),
  88. source_id_(0),
  89. error_(false) {
  90. distance_->clear();
  91. if (std::optional<StateId> num_states = fst.NumStatesIfKnown()) {
  92. distance_->reserve(*num_states);
  93. adder_.reserve(*num_states);
  94. radder_.reserve(*num_states);
  95. enqueued_.reserve(*num_states);
  96. }
  97. }
  98. void ShortestDistance(StateId source);
  99. bool Error() const { return error_; }
  100. private:
  101. void EnsureDistanceIndexIsValid(std::size_t index) {
  102. while (distance_->size() <= index) {
  103. distance_->push_back(Weight::Zero());
  104. adder_.push_back(Adder<Weight>());
  105. radder_.push_back(Adder<Weight>());
  106. enqueued_.push_back(false);
  107. }
  108. DCHECK_LT(index, distance_->size());
  109. }
  110. void EnsureSourcesIndexIsValid(std::size_t index) {
  111. while (sources_.size() <= index) {
  112. sources_.push_back(kNoStateId);
  113. }
  114. DCHECK_LT(index, sources_.size());
  115. }
  116. const Fst<Arc> &fst_;
  117. std::vector<Weight> *distance_;
  118. Queue *state_queue_;
  119. ArcFilter arc_filter_;
  120. WeightEqual weight_equal_; // Determines when relaxation stops.
  121. const bool first_path_;
  122. const bool retain_; // Retain and reuse information across calls.
  123. std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
  124. std::vector<Adder<Weight>> radder_; // Relaxation distance.
  125. std::vector<bool> enqueued_; // Is state enqueued?
  126. std::vector<StateId> sources_; // Source ID for ith state in distance_,
  127. // (r)adder_, and enqueued_ if retained.
  128. StateId source_id_; // Unique ID characterizing each call.
  129. bool error_;
  130. };
  131. // Compute the shortest distance; if source is kNoStateId, uses the initial
  132. // state of the FST.
  133. template <class Arc, class Queue, class ArcFilter, class WeightEqual>
  134. void ShortestDistanceState<Arc, Queue, ArcFilter,
  135. WeightEqual>::ShortestDistance(StateId source) {
  136. if (fst_.Start() == kNoStateId) {
  137. if (fst_.Properties(kError, false)) error_ = true;
  138. return;
  139. }
  140. if (!(Weight::Properties() & kRightSemiring)) {
  141. FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
  142. << Weight::Type();
  143. error_ = true;
  144. return;
  145. }
  146. if (first_path_ && !(Weight::Properties() & kPath)) {
  147. FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
  148. << "Weight does not have the path property: " << Weight::Type();
  149. error_ = true;
  150. return;
  151. }
  152. state_queue_->Clear();
  153. if (!retain_) {
  154. distance_->clear();
  155. adder_.clear();
  156. radder_.clear();
  157. enqueued_.clear();
  158. }
  159. if (source == kNoStateId) source = fst_.Start();
  160. EnsureDistanceIndexIsValid(source);
  161. if (retain_) {
  162. EnsureSourcesIndexIsValid(source);
  163. sources_[source] = source_id_;
  164. }
  165. (*distance_)[source] = Weight::One();
  166. adder_[source].Reset(Weight::One());
  167. radder_[source].Reset(Weight::One());
  168. enqueued_[source] = true;
  169. state_queue_->Enqueue(source);
  170. while (!state_queue_->Empty()) {
  171. const auto state = state_queue_->Head();
  172. state_queue_->Dequeue();
  173. EnsureDistanceIndexIsValid(state);
  174. if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
  175. enqueued_[state] = false;
  176. const auto r = radder_[state].Sum();
  177. radder_[state].Reset();
  178. for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
  179. aiter.Next()) {
  180. const auto &arc = aiter.Value();
  181. const auto nextstate = arc.nextstate;
  182. if (!arc_filter_(arc)) continue;
  183. EnsureDistanceIndexIsValid(nextstate);
  184. if (retain_) {
  185. EnsureSourcesIndexIsValid(nextstate);
  186. if (sources_[nextstate] != source_id_) {
  187. (*distance_)[nextstate] = Weight::Zero();
  188. adder_[nextstate].Reset();
  189. radder_[nextstate].Reset();
  190. enqueued_[nextstate] = false;
  191. sources_[nextstate] = source_id_;
  192. }
  193. }
  194. auto &nd = (*distance_)[nextstate];
  195. auto &na = adder_[nextstate];
  196. auto &nr = radder_[nextstate];
  197. auto weight = Times(r, arc.weight);
  198. if (!weight_equal_(nd, Plus(nd, weight))) {
  199. nd = na.Add(weight);
  200. nr.Add(weight);
  201. if (!nd.Member() || !nr.Sum().Member()) {
  202. error_ = true;
  203. return;
  204. }
  205. if (!enqueued_[nextstate]) {
  206. state_queue_->Enqueue(nextstate);
  207. enqueued_[nextstate] = true;
  208. } else {
  209. state_queue_->Update(nextstate);
  210. }
  211. }
  212. }
  213. }
  214. ++source_id_;
  215. if (fst_.Properties(kError, false)) error_ = true;
  216. }
  217. } // namespace internal
  218. // Shortest-distance algorithm: this version allows fine control
  219. // via the options argument. See below for a simpler interface.
  220. //
  221. // This computes the shortest distance from the opts.source state to each
  222. // visited state S and stores the value in the distance vector. An
  223. // unvisited state S has distance Zero(), which will be stored in the
  224. // distance vector if S is less than the maximum visited state. The state
  225. // queue discipline, arc filter, and convergence delta are taken in the
  226. // options argument. The distance vector will contain a unique element for
  227. // which Member() is false if an error was encountered.
  228. //
  229. // The weights must must be right distributive and k-closed (i.e., 1 +
  230. // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
  231. //
  232. // Complexity:
  233. //
  234. // Depends on properties of the semiring and the queue discipline.
  235. //
  236. // For more information, see:
  237. //
  238. // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
  239. // problems, Journal of Automata, Languages and
  240. // Combinatorics 7(3): 321-350, 2002.
  241. template <class Arc, class Queue, class ArcFilter>
  242. void ShortestDistance(
  243. const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
  244. const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
  245. internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance,
  246. opts, false);
  247. sd_state.ShortestDistance(opts.source);
  248. if (sd_state.Error()) {
  249. distance->assign(1, Arc::Weight::NoWeight());
  250. }
  251. }
  252. // Shortest-distance algorithm: simplified interface. See above for a version
  253. // that permits finer control.
  254. //
  255. // If reverse is false, this computes the shortest distance from the initial
  256. // state to each state S and stores the value in the distance vector. If
  257. // reverse is true, this computes the shortest distance from each state to the
  258. // final states. An unvisited state S has distance Zero(), which will be stored
  259. // in the distance vector if S is less than the maximum visited state. The
  260. // state queue discipline is automatically-selected. The distance vector will
  261. // contain a unique element for which Member() is false if an error was
  262. // encountered.
  263. //
  264. // The weights must must be right (left) distributive if reverse is false (true)
  265. // and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
  266. //
  267. // Arc weights must satisfy the property that the sum of the weights of one or
  268. // more paths from some state S to T is never Zero(). In particular, arc weights
  269. // are never Zero().
  270. //
  271. // Complexity:
  272. //
  273. // Depends on properties of the semiring and the queue discipline.
  274. //
  275. // For more information, see:
  276. //
  277. // Mohri, M. 2002. Semiring framework and algorithms for
  278. // shortest-distance problems, Journal of Automata, Languages and
  279. // Combinatorics 7(3): 321-350, 2002.
  280. template <class Arc>
  281. void ShortestDistance(const Fst<Arc> &fst,
  282. std::vector<typename Arc::Weight> *distance,
  283. bool reverse = false, float delta = kShortestDelta) {
  284. using StateId = typename Arc::StateId;
  285. if (!reverse) {
  286. AnyArcFilter<Arc> arc_filter;
  287. AutoQueue<StateId> state_queue(fst, distance, arc_filter);
  288. const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>>
  289. opts(&state_queue, arc_filter, kNoStateId, delta);
  290. ShortestDistance(fst, distance, opts);
  291. } else {
  292. using ReverseArc = ReverseArc<Arc>;
  293. using ReverseWeight = typename ReverseArc::Weight;
  294. AnyArcFilter<ReverseArc> rarc_filter;
  295. VectorFst<ReverseArc> rfst;
  296. Reverse(fst, &rfst);
  297. std::vector<ReverseWeight> rdistance;
  298. AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
  299. const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>,
  300. AnyArcFilter<ReverseArc>>
  301. ropts(&state_queue, rarc_filter, kNoStateId, delta);
  302. ShortestDistance(rfst, &rdistance, ropts);
  303. distance->clear();
  304. if (rdistance.size() == 1 && !rdistance[0].Member()) {
  305. distance->assign(1, Arc::Weight::NoWeight());
  306. return;
  307. }
  308. DCHECK_GE(rdistance.size(), 1); // reversing added one state
  309. distance->reserve(rdistance.size() - 1);
  310. while (distance->size() < rdistance.size() - 1) {
  311. distance->push_back(rdistance[distance->size() + 1].Reverse());
  312. }
  313. }
  314. }
  315. // Return the sum of the weight of all successful paths in an FST, i.e., the
  316. // shortest-distance from the initial state to the final states. Returns a
  317. // weight such that Member() is false if an error was encountered.
  318. template <class Arc>
  319. typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
  320. float delta = kShortestDelta) {
  321. using StateId = typename Arc::StateId;
  322. using Weight = typename Arc::Weight;
  323. std::vector<Weight> distance;
  324. if (Weight::Properties() & kRightSemiring) {
  325. ShortestDistance(fst, &distance, false, delta);
  326. if (distance.size() == 1 && !distance[0].Member()) {
  327. return Arc::Weight::NoWeight();
  328. }
  329. Adder<Weight> adder; // maintains cumulative sum accurately
  330. for (StateId state = 0; state < distance.size(); ++state) {
  331. adder.Add(Times(distance[state], fst.Final(state)));
  332. }
  333. return adder.Sum();
  334. } else {
  335. ShortestDistance(fst, &distance, true, delta);
  336. const auto state = fst.Start();
  337. if (distance.size() == 1 && !distance[0].Member()) {
  338. return Arc::Weight::NoWeight();
  339. }
  340. return state != kNoStateId && state < distance.size() ? distance[state]
  341. : Weight::Zero();
  342. }
  343. }
  344. } // namespace fst
  345. #endif // FST_SHORTEST_DISTANCE_H_