You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1017 lines
32 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Functions and classes for various FST state queues with a unified interface.
  19. #ifndef FST_QUEUE_H_
  20. #define FST_QUEUE_H_
  21. #include <sys/types.h>
  22. #include <algorithm>
  23. #include <cstdint>
  24. #include <memory>
  25. #include <queue>
  26. #include <type_traits>
  27. #include <utility>
  28. #include <vector>
  29. #include <fst/log.h>
  30. #include <fst/arcfilter.h>
  31. #include <fst/cc-visitors.h>
  32. #include <fst/dfs-visit.h>
  33. #include <fst/fst.h>
  34. #include <fst/heap.h>
  35. #include <fst/properties.h>
  36. #include <fst/topsort.h>
  37. #include <fst/util.h>
  38. #include <fst/weight.h>
  39. namespace fst {
  40. // The Queue interface is:
  41. //
  42. // template <class S>
  43. // class Queue {
  44. // public:
  45. // using StateId = S;
  46. //
  47. // // Constructor: may need args (e.g., FST, comparator) for some queues.
  48. // Queue(...) override;
  49. //
  50. // // Returns the head of the queue.
  51. // StateId Head() const override;
  52. //
  53. // // Inserts a state.
  54. // void Enqueue(StateId s) override;
  55. //
  56. // // Removes the head of the queue.
  57. // void Dequeue() override;
  58. //
  59. // // Updates ordering of state s when weight changes, if necessary.
  60. // void Update(StateId s) override;
  61. //
  62. // // Is the queue empty?
  63. // bool Empty() const override;
  64. //
  65. // // Removes all states from the queue.
  66. // void Clear() override;
  67. // };
  68. // State queue types.
  69. enum QueueType {
  70. TRIVIAL_QUEUE = 0, // Single state queue.
  71. FIFO_QUEUE = 1, // First-in, first-out queue.
  72. LIFO_QUEUE = 2, // Last-in, first-out queue.
  73. SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue.
  74. TOP_ORDER_QUEUE = 4, // Topologically-ordered queue.
  75. STATE_ORDER_QUEUE = 5, // State ID-ordered queue.
  76. SCC_QUEUE = 6, // Component graph top-ordered meta-queue.
  77. AUTO_QUEUE = 7, // Auto-selected queue.
  78. OTHER_QUEUE = 8
  79. };
  80. // QueueBase, templated on the StateId, is a virtual base class shared by all
  81. // queues considered by AutoQueue.
  82. template <class S>
  83. class QueueBase {
  84. public:
  85. using StateId = S;
  86. virtual ~QueueBase() = default;
  87. // Concrete implementation.
  88. explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
  89. void SetError(bool error) { error_ = error; }
  90. bool Error() const { return error_; }
  91. QueueType Type() const { return queue_type_; }
  92. // Virtual interface.
  93. virtual StateId Head() const = 0;
  94. virtual void Enqueue(StateId) = 0;
  95. virtual void Dequeue() = 0;
  96. virtual void Update(StateId) = 0;
  97. virtual bool Empty() const = 0;
  98. virtual void Clear() = 0;
  99. private:
  100. QueueType queue_type_;
  101. bool error_;
  102. };
  103. // Trivial queue discipline; one may enqueue at most one state at a time. It
  104. // can be used for strongly connected components with only one state and no
  105. // self-loops.
  106. template <class S>
  107. class TrivialQueue : public QueueBase<S> {
  108. public:
  109. using StateId = S;
  110. TrivialQueue() : QueueBase<StateId>(TRIVIAL_QUEUE), front_(kNoStateId) {}
  111. ~TrivialQueue() override = default;
  112. StateId Head() const final { return front_; }
  113. void Enqueue(StateId s) final { front_ = s; }
  114. void Dequeue() final { front_ = kNoStateId; }
  115. void Update(StateId) final {}
  116. bool Empty() const final { return front_ == kNoStateId; }
  117. void Clear() final { front_ = kNoStateId; }
  118. private:
  119. StateId front_;
  120. };
  121. // First-in, first-out queue discipline.
  122. //
  123. // This is not a final class.
  124. template <class S>
  125. class FifoQueue : public QueueBase<S> {
  126. public:
  127. using StateId = S;
  128. FifoQueue() : QueueBase<StateId>(FIFO_QUEUE) {}
  129. ~FifoQueue() override = default;
  130. StateId Head() const override { return queue_.front(); }
  131. void Enqueue(StateId s) override { queue_.push(s); }
  132. void Dequeue() override { queue_.pop(); }
  133. void Update(StateId) override {}
  134. bool Empty() const override { return queue_.empty(); }
  135. void Clear() override { queue_ = std::queue<StateId>(); }
  136. private:
  137. std::queue<StateId> queue_;
  138. };
  139. // Last-in, first-out queue discipline.
  140. template <class S>
  141. class LifoQueue : public QueueBase<S> {
  142. public:
  143. using StateId = S;
  144. LifoQueue() : QueueBase<StateId>(LIFO_QUEUE) {}
  145. ~LifoQueue() override = default;
  146. StateId Head() const final { return stack_.back(); }
  147. void Enqueue(StateId s) final { stack_.push_back(s); }
  148. void Dequeue() final { stack_.pop_back(); }
  149. void Update(StateId) final {}
  150. bool Empty() const final { return stack_.empty(); }
  151. void Clear() final { stack_.clear(); }
  152. private:
  153. std::vector<StateId> stack_;
  154. };
  155. // Shortest-first queue discipline, templated on the StateId and as well as a
  156. // comparison functor used to compare two StateIds. If a (single) state's order
  157. // changes, it can be reordered in the queue with a call to Update(). If update
  158. // is false, call to Update() does not reorder the queue.
  159. //
  160. // This is not a final class.
  161. template <typename S, typename Compare, bool update = true>
  162. class ShortestFirstQueue : public QueueBase<S> {
  163. public:
  164. using StateId = S;
  165. explicit ShortestFirstQueue(Compare comp)
  166. : QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
  167. ~ShortestFirstQueue() override = default;
  168. StateId Head() const override { return heap_.Top(); }
  169. void Enqueue(StateId s) override {
  170. if (update) {
  171. for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
  172. key_[s] = heap_.Insert(s);
  173. } else {
  174. heap_.Insert(s);
  175. }
  176. }
  177. void Dequeue() override {
  178. if (update) {
  179. key_[heap_.Pop()] = kNoStateId;
  180. } else {
  181. heap_.Pop();
  182. }
  183. }
  184. void Update(StateId s) override {
  185. if (!update) return;
  186. if (s >= key_.size() || key_[s] == kNoStateId) {
  187. Enqueue(s);
  188. } else {
  189. heap_.Update(key_[s], s);
  190. }
  191. }
  192. bool Empty() const override { return heap_.Empty(); }
  193. void Clear() override {
  194. heap_.Clear();
  195. if (update) key_.clear();
  196. }
  197. ssize_t Size() const { return heap_.Size(); }
  198. const Compare &GetCompare() const { return heap_.GetCompare(); }
  199. private:
  200. Heap<StateId, Compare> heap_;
  201. std::vector<ssize_t> key_;
  202. };
  203. namespace internal {
  204. // Given a vector that maps from states to weights, and a comparison functor
  205. // for weights, this class defines a comparison function object between states.
  206. template <typename StateId, typename Less>
  207. class StateWeightCompare {
  208. public:
  209. using Weight = typename Less::Weight;
  210. StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
  211. : weights_(weights), less_(less) {}
  212. bool operator()(const StateId s1, const StateId s2) const {
  213. return less_(weights_[s1], weights_[s2]);
  214. }
  215. private:
  216. // Borrowed references.
  217. const std::vector<Weight> &weights_;
  218. const Less &less_;
  219. };
  220. // Comparison that can never be instantiated. Useful only to pass a pointer to
  221. // this to a function that needs a comparison when it is known that the pointer
  222. // will always be null.
  223. template <class W>
  224. struct ErrorLess {
  225. using Weight = W;
  226. ErrorLess() {
  227. FSTERROR() << "ErrorLess: instantiated for Weight " << Weight::Type();
  228. }
  229. bool operator()(const Weight &, const Weight &) const { return false; }
  230. };
  231. } // namespace internal
  232. // Shortest-first queue discipline, templated on the StateId and Weight, is
  233. // specialized to use the weight's natural order for the comparison function.
  234. // Requires Weight is idempotent (due to use of NaturalLess).
  235. template <typename S, typename Weight>
  236. class NaturalShortestFirstQueue
  237. : public ShortestFirstQueue<
  238. S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
  239. public:
  240. using StateId = S;
  241. using Less = NaturalLess<Weight>;
  242. using Compare = internal::StateWeightCompare<StateId, Less>;
  243. explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
  244. : ShortestFirstQueue<StateId, Compare>(Compare(distance, Less())) {}
  245. ~NaturalShortestFirstQueue() override = default;
  246. };
  247. // In a shortest path computation on a lattice-like FST, we may keep many old
  248. // nonviable paths as a part of the search. Since the search process always
  249. // expands the lowest cost path next, that lowest cost path may be a very old
  250. // nonviable path instead of one we expect to lead to a shortest path.
  251. //
  252. // For instance, suppose that the current best path in an alignment has
  253. // traversed 500 arcs with a cost of 10. We may also have a bad path in
  254. // the queue that has traversed only 40 arcs but also has a cost of 10.
  255. // This path is very unlikely to lead to a reasonable alignment, so this queue
  256. // can prune it from the search space.
  257. //
  258. // This queue relies on the caller using a shortest-first exploration order
  259. // like this:
  260. // while (true) {
  261. // StateId head = queue.Head();
  262. // queue.Dequeue();
  263. // for (const auto& arc : GetArcs(fst, head)) {
  264. // queue.Enqueue(arc.nextstate);
  265. // }
  266. // }
  267. // We use this assumption to guess that there is an arc between Head and the
  268. // Enqueued state; this is how the number of path steps is measured.
  269. template <typename S, typename Weight>
  270. class PruneNaturalShortestFirstQueue
  271. : public NaturalShortestFirstQueue<S, Weight> {
  272. public:
  273. using StateId = S;
  274. using Base = NaturalShortestFirstQueue<StateId, Weight>;
  275. PruneNaturalShortestFirstQueue(const std::vector<Weight> &distance,
  276. ssize_t arc_threshold, ssize_t state_limit = 0)
  277. : Base(distance),
  278. arc_threshold_(arc_threshold),
  279. state_limit_(state_limit),
  280. head_steps_(0),
  281. max_head_steps_(0) {}
  282. ~PruneNaturalShortestFirstQueue() override = default;
  283. StateId Head() const override {
  284. const auto head = Base::Head();
  285. // Stores the number of steps from the start of the graph to this state
  286. // along the shortest-weight path.
  287. if (head < steps_.size()) {
  288. max_head_steps_ = std::max(steps_[head], max_head_steps_);
  289. head_steps_ = steps_[head];
  290. }
  291. return head;
  292. }
  293. void Enqueue(StateId s) override {
  294. // We assume that there is an arc between the Head() state and this
  295. // Enqueued state.
  296. const ssize_t state_steps = head_steps_ + 1;
  297. if (s >= steps_.size()) {
  298. steps_.resize(s + 1, state_steps);
  299. }
  300. // This is the number of arcs in the minimum cost path from Start to s.
  301. steps_[s] = state_steps;
  302. // Adjust the threshold in cases where path step thresholding wasn't
  303. // enough to keep the queue small.
  304. ssize_t adjusted_threshold = arc_threshold_;
  305. if (Base::Size() > state_limit_ && state_limit_ > 0) {
  306. adjusted_threshold = std::max<ssize_t>(
  307. 0, arc_threshold_ - (Base::Size() / state_limit_) - 1);
  308. }
  309. if (state_steps > (max_head_steps_ - adjusted_threshold) ||
  310. arc_threshold_ < 0) {
  311. if (adjusted_threshold == 0 && state_limit_ > 0) {
  312. // If the queue is continuing to grow without bound, we follow any
  313. // path that makes progress and clear the rest.
  314. Base::Clear();
  315. }
  316. Base::Enqueue(s);
  317. }
  318. }
  319. private:
  320. // A dense map from StateId to the number of arcs in the minimum weight
  321. // path from Start to this state.
  322. std::vector<ssize_t> steps_;
  323. // We only keep paths that are within this number of arcs (not weight!)
  324. // of the longest path.
  325. const ssize_t arc_threshold_;
  326. // If the size of the queue climbs above this number, we increase the
  327. // threshold to reduce the amount of work we have to do.
  328. const ssize_t state_limit_;
  329. // The following are mutable because Head() is const.
  330. // The number of arcs traversed in the minimum cost path from the start
  331. // state to the current Head() state.
  332. mutable ssize_t head_steps_;
  333. // The maximum number of arcs traversed by any low-cost path so far.
  334. mutable ssize_t max_head_steps_;
  335. };
  336. // Topological-order queue discipline, templated on the StateId. States are
  337. // ordered in the queue topologically. The FST must be acyclic.
  338. template <class S>
  339. class TopOrderQueue : public QueueBase<S> {
  340. public:
  341. using StateId = S;
  342. // This constructor computes the topological order. It accepts an arc filter
  343. // to limit the transitions considered in that computation (e.g., only the
  344. // epsilon graph).
  345. template <class Arc, class ArcFilter>
  346. TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
  347. : QueueBase<StateId>(TOP_ORDER_QUEUE),
  348. front_(0),
  349. back_(kNoStateId),
  350. order_(0),
  351. state_(0) {
  352. bool acyclic;
  353. TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
  354. DfsVisit(fst, &top_order_visitor, filter);
  355. if (!acyclic) {
  356. FSTERROR() << "TopOrderQueue: FST is not acyclic";
  357. QueueBase<S>::SetError(true);
  358. }
  359. state_.resize(order_.size(), kNoStateId);
  360. }
  361. // This constructor is passed the pre-computed topological order.
  362. explicit TopOrderQueue(const std::vector<StateId> &order)
  363. : QueueBase<StateId>(TOP_ORDER_QUEUE),
  364. front_(0),
  365. back_(kNoStateId),
  366. order_(order),
  367. state_(order.size(), kNoStateId) {}
  368. ~TopOrderQueue() override = default;
  369. StateId Head() const final { return state_[front_]; }
  370. void Enqueue(StateId s) final {
  371. if (front_ > back_) {
  372. front_ = back_ = order_[s];
  373. } else if (order_[s] > back_) {
  374. back_ = order_[s];
  375. } else if (order_[s] < front_) {
  376. front_ = order_[s];
  377. }
  378. state_[order_[s]] = s;
  379. }
  380. void Dequeue() final {
  381. state_[front_] = kNoStateId;
  382. while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
  383. }
  384. void Update(StateId) final {}
  385. bool Empty() const final { return front_ > back_; }
  386. void Clear() final {
  387. for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
  388. back_ = kNoStateId;
  389. front_ = 0;
  390. }
  391. private:
  392. StateId front_;
  393. StateId back_;
  394. std::vector<StateId> order_;
  395. std::vector<StateId> state_;
  396. };
  397. // State order queue discipline, templated on the StateId. States are ordered in
  398. // the queue by state ID.
  399. template <class S>
  400. class StateOrderQueue : public QueueBase<S> {
  401. public:
  402. using StateId = S;
  403. StateOrderQueue()
  404. : QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
  405. ~StateOrderQueue() override = default;
  406. StateId Head() const final { return front_; }
  407. void Enqueue(StateId s) final {
  408. if (front_ > back_) {
  409. front_ = back_ = s;
  410. } else if (s > back_) {
  411. back_ = s;
  412. } else if (s < front_) {
  413. front_ = s;
  414. }
  415. while (enqueued_.size() <= s) enqueued_.push_back(false);
  416. enqueued_[s] = true;
  417. }
  418. void Dequeue() final {
  419. enqueued_[front_] = false;
  420. while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
  421. }
  422. void Update(StateId) final {}
  423. bool Empty() const final { return front_ > back_; }
  424. void Clear() final {
  425. for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
  426. front_ = 0;
  427. back_ = kNoStateId;
  428. }
  429. private:
  430. StateId front_;
  431. StateId back_;
  432. std::vector<bool> enqueued_;
  433. };
  434. // SCC topological-order meta-queue discipline, templated on the StateId and a
  435. // queue used inside each SCC. It visits the SCCs of an FST in topological
  436. // order. Its constructor is passed the queues to to use within an SCC.
  437. template <class S, class Queue>
  438. class SccQueue : public QueueBase<S> {
  439. public:
  440. using StateId = S;
  441. // Constructor takes a vector specifying the SCC number per state and a
  442. // vector giving the queue to use per SCC number.
  443. SccQueue(const std::vector<StateId> &scc,
  444. std::vector<std::unique_ptr<Queue>> *queue)
  445. : QueueBase<StateId>(SCC_QUEUE),
  446. queue_(queue),
  447. scc_(scc),
  448. front_(0),
  449. back_(kNoStateId) {}
  450. ~SccQueue() override = default;
  451. StateId Head() const final {
  452. while ((front_ <= back_) &&
  453. (((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
  454. (((*queue_)[front_] == nullptr) &&
  455. ((front_ >= trivial_queue_.size()) ||
  456. (trivial_queue_[front_] == kNoStateId))))) {
  457. ++front_;
  458. }
  459. if ((*queue_)[front_]) {
  460. return (*queue_)[front_]->Head();
  461. } else {
  462. return trivial_queue_[front_];
  463. }
  464. }
  465. void Enqueue(StateId s) final {
  466. if (front_ > back_) {
  467. front_ = back_ = scc_[s];
  468. } else if (scc_[s] > back_) {
  469. back_ = scc_[s];
  470. } else if (scc_[s] < front_) {
  471. front_ = scc_[s];
  472. }
  473. if ((*queue_)[scc_[s]]) {
  474. (*queue_)[scc_[s]]->Enqueue(s);
  475. } else {
  476. while (trivial_queue_.size() <= scc_[s]) {
  477. trivial_queue_.push_back(kNoStateId);
  478. }
  479. trivial_queue_[scc_[s]] = s;
  480. }
  481. }
  482. void Dequeue() final {
  483. if ((*queue_)[front_]) {
  484. (*queue_)[front_]->Dequeue();
  485. } else if (front_ < trivial_queue_.size()) {
  486. trivial_queue_[front_] = kNoStateId;
  487. }
  488. }
  489. void Update(StateId s) final {
  490. if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
  491. }
  492. bool Empty() const final {
  493. // Queues SCC number back_ is not empty unless back_ == front_.
  494. if (front_ < back_) {
  495. return false;
  496. } else if (front_ > back_) {
  497. return true;
  498. } else if ((*queue_)[front_]) {
  499. return (*queue_)[front_]->Empty();
  500. } else {
  501. return (front_ >= trivial_queue_.size()) ||
  502. (trivial_queue_[front_] == kNoStateId);
  503. }
  504. }
  505. void Clear() final {
  506. for (StateId i = front_; i <= back_; ++i) {
  507. if ((*queue_)[i]) {
  508. (*queue_)[i]->Clear();
  509. } else if (i < trivial_queue_.size()) {
  510. trivial_queue_[i] = kNoStateId;
  511. }
  512. }
  513. front_ = 0;
  514. back_ = kNoStateId;
  515. }
  516. private:
  517. std::vector<std::unique_ptr<Queue>> *queue_;
  518. const std::vector<StateId> &scc_;
  519. mutable StateId front_;
  520. StateId back_;
  521. std::vector<StateId> trivial_queue_;
  522. };
  523. // Automatic queue discipline. It selects a queue discipline for a given FST
  524. // based on its properties.
  525. template <class S>
  526. class AutoQueue : public QueueBase<S> {
  527. public:
  528. using StateId = S;
  529. // This constructor takes a state distance vector that, if non-null and if
  530. // the Weight type has the path property, will entertain the shortest-first
  531. // queue using the natural order w.r.t to the distance.
  532. template <class Arc, class ArcFilter>
  533. AutoQueue(const Fst<Arc> &fst,
  534. const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
  535. : QueueBase<StateId>(AUTO_QUEUE) {
  536. using Weight = typename Arc::Weight;
  537. // We need to have variables of type Less and Compare, so we use
  538. // ErrorLess if the type NaturalLess<Weight> cannot be instantiated due
  539. // to lack of path property.
  540. using Less = std::conditional_t<IsPath<Weight>::value, NaturalLess<Weight>,
  541. internal::ErrorLess<Weight>>;
  542. using Compare = internal::StateWeightCompare<StateId, Less>;
  543. // First checks if the FST is known to have these properties.
  544. const auto props =
  545. fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false);
  546. if ((props & kTopSorted) || fst.Start() == kNoStateId) {
  547. queue_ = std::make_unique<StateOrderQueue<StateId>>();
  548. VLOG(2) << "AutoQueue: using state-order discipline";
  549. } else if (props & kAcyclic) {
  550. queue_ = std::make_unique<TopOrderQueue<StateId>>(fst, filter);
  551. VLOG(2) << "AutoQueue: using top-order discipline";
  552. } else if ((props & kUnweighted) && IsIdempotent<Weight>::value) {
  553. queue_ = std::make_unique<LifoQueue<StateId>>();
  554. VLOG(2) << "AutoQueue: using LIFO discipline";
  555. } else {
  556. uint64_t properties;
  557. // Decomposes into strongly-connected components.
  558. SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
  559. DfsVisit(fst, &scc_visitor, filter);
  560. auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
  561. std::vector<QueueType> queue_types(nscc);
  562. std::unique_ptr<Less> less;
  563. std::unique_ptr<Compare> comp;
  564. if constexpr (IsPath<Weight>::value) {
  565. if (distance) {
  566. less = std::make_unique<Less>();
  567. comp = std::make_unique<Compare>(*distance, *less);
  568. }
  569. }
  570. // Finds the queue type to use per SCC.
  571. bool unweighted;
  572. bool all_trivial;
  573. SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
  574. &unweighted);
  575. // If unweighted and semiring is idempotent, uses LIFO queue.
  576. if (unweighted) {
  577. queue_ = std::make_unique<LifoQueue<StateId>>();
  578. VLOG(2) << "AutoQueue: using LIFO discipline";
  579. return;
  580. }
  581. // If all the SCC are trivial, the FST is acyclic and the scc number gives
  582. // the topological order.
  583. if (all_trivial) {
  584. queue_ = std::make_unique<TopOrderQueue<StateId>>(scc_);
  585. VLOG(2) << "AutoQueue: using top-order discipline";
  586. return;
  587. }
  588. VLOG(2) << "AutoQueue: using SCC meta-discipline";
  589. queues_.resize(nscc);
  590. for (StateId i = 0; i < nscc; ++i) {
  591. switch (queue_types[i]) {
  592. case TRIVIAL_QUEUE:
  593. queues_[i].reset();
  594. VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
  595. break;
  596. case SHORTEST_FIRST_QUEUE:
  597. // The IsPath test is not needed for correctness. It just saves
  598. // instantiating a ShortestFirstQueue that can never be called.
  599. if constexpr (IsPath<Weight>::value) {
  600. queues_[i] =
  601. std::make_unique<ShortestFirstQueue<StateId, Compare, false>>(
  602. *comp);
  603. VLOG(3) << "AutoQueue: SCC #" << i
  604. << ": using shortest-first discipline";
  605. } else {
  606. // SccQueueType should ensure this can never happen.
  607. FSTERROR() << "Got SHORTEST_FIRST_QUEUE for non-Path Weight "
  608. << Weight::Type();
  609. queues_[i].reset();
  610. }
  611. break;
  612. case LIFO_QUEUE:
  613. queues_[i] = std::make_unique<LifoQueue<StateId>>();
  614. VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
  615. break;
  616. case FIFO_QUEUE:
  617. default:
  618. queues_[i] = std::make_unique<FifoQueue<StateId>>();
  619. VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
  620. break;
  621. }
  622. }
  623. queue_ = std::make_unique<SccQueue<StateId, QueueBase<StateId>>>(
  624. scc_, &queues_);
  625. }
  626. }
  627. ~AutoQueue() override = default;
  628. StateId Head() const final { return queue_->Head(); }
  629. void Enqueue(StateId s) final { queue_->Enqueue(s); }
  630. void Dequeue() final { queue_->Dequeue(); }
  631. void Update(StateId s) final { queue_->Update(s); }
  632. bool Empty() const final { return queue_->Empty(); }
  633. void Clear() final { queue_->Clear(); }
  634. private:
  635. template <class Arc, class ArcFilter, class Less>
  636. static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
  637. std::vector<QueueType> *queue_types,
  638. ArcFilter filter, Less *less, bool *all_trivial,
  639. bool *unweighted);
  640. std::unique_ptr<QueueBase<StateId>> queue_;
  641. std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
  642. std::vector<StateId> scc_;
  643. };
  644. // Examines the states in an FST's strongly connected components and determines
  645. // which type of queue to use per SCC. Stores result as a vector of QueueTypes
  646. // which is assumed to have length equal to the number of SCCs. An arc filter
  647. // is used to limit the transitions considered (e.g., only the epsilon graph).
  648. // The argument all_trivial is set to true if every queue is the trivial queue.
  649. // The argument unweighted is set to true if the semiring is idempotent and all
  650. // the arc weights are equal to Zero() or One().
  651. template <class StateId>
  652. template <class Arc, class ArcFilter, class Less>
  653. void AutoQueue<StateId>::SccQueueType(const Fst<Arc> &fst,
  654. const std::vector<StateId> &scc,
  655. std::vector<QueueType> *queue_type,
  656. ArcFilter filter, Less *less,
  657. bool *all_trivial, bool *unweighted) {
  658. using StateId = typename Arc::StateId;
  659. using Weight = typename Arc::Weight;
  660. *all_trivial = true;
  661. *unweighted = true;
  662. for (StateId i = 0; i < queue_type->size(); ++i) {
  663. (*queue_type)[i] = TRIVIAL_QUEUE;
  664. }
  665. for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
  666. const auto state = sit.Value();
  667. for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
  668. const auto &arc = ait.Value();
  669. if (!filter(arc)) continue;
  670. if (scc[state] == scc[arc.nextstate]) {
  671. auto &type = (*queue_type)[scc[state]];
  672. if constexpr (!IsPath<Weight>::value) {
  673. type = FIFO_QUEUE;
  674. } else if (!less || (*less)(arc.weight, Weight::One())) {
  675. type = FIFO_QUEUE;
  676. } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
  677. if (!IsIdempotent<Weight>::value ||
  678. (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
  679. type = SHORTEST_FIRST_QUEUE;
  680. } else {
  681. type = LIFO_QUEUE;
  682. }
  683. }
  684. if (type != TRIVIAL_QUEUE) *all_trivial = false;
  685. }
  686. if (!IsIdempotent<Weight>::value ||
  687. (arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
  688. *unweighted = false;
  689. }
  690. }
  691. }
  692. }
  693. // An A* estimate is a function object that maps from a state ID to an
  694. // estimate of the shortest distance to the final states.
  695. // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
  696. // algorithm.
  697. template <typename StateId, typename Weight>
  698. struct TrivialAStarEstimate {
  699. constexpr Weight operator()(StateId) const { return Weight::One(); }
  700. };
  701. // A non-trivial A* estimate using a vector of the estimated future costs.
  702. template <typename StateId, typename Weight>
  703. class NaturalAStarEstimate {
  704. public:
  705. NaturalAStarEstimate(const std::vector<Weight> &beta) : beta_(beta) {}
  706. const Weight &operator()(StateId s) const {
  707. return (s < beta_.size()) ? beta_[s] : kZero;
  708. }
  709. private:
  710. static constexpr Weight kZero = Weight::Zero();
  711. const std::vector<Weight> &beta_;
  712. };
  713. // Given a vector that maps from states to weights representing the shortest
  714. // distance from the initial state, a comparison function object between
  715. // weights, and an estimate of the shortest distance to the final states, this
  716. // class defines a comparison function object between states.
  717. template <typename S, typename Less, typename Estimate>
  718. class AStarWeightCompare {
  719. public:
  720. using StateId = S;
  721. using Weight = typename Less::Weight;
  722. AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
  723. const Estimate &estimate)
  724. : weights_(weights), less_(less), estimate_(estimate) {}
  725. bool operator()(StateId s1, StateId s2) const {
  726. const auto w1 = Times(weights_[s1], estimate_(s1));
  727. const auto w2 = Times(weights_[s2], estimate_(s2));
  728. return less_(w1, w2);
  729. }
  730. const Estimate &GetEstimate() const { return estimate_; }
  731. private:
  732. const std::vector<Weight> &weights_;
  733. const Less &less_;
  734. const Estimate &estimate_;
  735. };
  736. // A* queue discipline templated on StateId, Weight, and Estimate.
  737. template <typename S, typename Weight, typename Estimate>
  738. class NaturalAStarQueue
  739. : public ShortestFirstQueue<
  740. S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
  741. public:
  742. using StateId = S;
  743. using Compare = AStarWeightCompare<StateId, NaturalLess<Weight>, Estimate>;
  744. NaturalAStarQueue(const std::vector<Weight> &distance,
  745. const Estimate &estimate)
  746. : ShortestFirstQueue<StateId, Compare>(
  747. Compare(distance, less_, estimate)) {}
  748. ~NaturalAStarQueue() override = default;
  749. private:
  750. // This is non-static because the constructor for non-idempotent weights will
  751. // result in an error.
  752. const NaturalLess<Weight> less_{};
  753. };
  754. // A state equivalence class is a function object that maps from a state ID to
  755. // an equivalence class (state) ID. The trivial equivalence class maps a state
  756. // ID to itself.
  757. template <typename StateId>
  758. struct TrivialStateEquivClass {
  759. StateId operator()(StateId s) const { return s; }
  760. };
  761. // Distance-based pruning queue discipline: Enqueues a state only when its
  762. // shortest distance (so far), as specified by distance, is less than (as
  763. // specified by comp) the shortest distance Times() the threshold to any state
  764. // in the same equivalence class, as specified by the functor class_func. The
  765. // underlying queue discipline is specified by queue.
  766. //
  767. // This is not a final class.
  768. template <typename Queue, typename Less, typename ClassFnc>
  769. class PruneQueue : public QueueBase<typename Queue::StateId> {
  770. public:
  771. using StateId = typename Queue::StateId;
  772. using Weight = typename Less::Weight;
  773. PruneQueue(const std::vector<Weight> &distance, std::unique_ptr<Queue> queue,
  774. const Less &less, const ClassFnc &class_fnc, Weight threshold)
  775. : QueueBase<StateId>(OTHER_QUEUE),
  776. distance_(distance),
  777. queue_(std::move(queue)),
  778. less_(less),
  779. class_fnc_(class_fnc),
  780. threshold_(std::move(threshold)) {}
  781. ~PruneQueue() override = default;
  782. StateId Head() const override { return queue_->Head(); }
  783. void Enqueue(StateId s) override {
  784. const auto c = class_fnc_(s);
  785. if (c >= class_distance_.size()) {
  786. class_distance_.resize(c + 1, Weight::Zero());
  787. }
  788. if (less_(distance_[s], class_distance_[c])) {
  789. class_distance_[c] = distance_[s];
  790. }
  791. // Enqueues only if below threshold limit.
  792. const auto limit = Times(class_distance_[c], threshold_);
  793. if (less_(distance_[s], limit)) queue_->Enqueue(s);
  794. }
  795. void Dequeue() override { queue_->Dequeue(); }
  796. void Update(StateId s) override {
  797. const auto c = class_fnc_(s);
  798. if (less_(distance_[s], class_distance_[c])) {
  799. class_distance_[c] = distance_[s];
  800. }
  801. queue_->Update(s);
  802. }
  803. bool Empty() const override { return queue_->Empty(); }
  804. void Clear() override { queue_->Clear(); }
  805. private:
  806. const std::vector<Weight> &distance_; // Shortest distance to state.
  807. std::unique_ptr<Queue> queue_;
  808. const Less &less_; // Borrowed reference.
  809. const ClassFnc &class_fnc_; // Equivalence class functor.
  810. Weight threshold_; // Pruning weight threshold.
  811. std::vector<Weight> class_distance_; // Shortest distance to class.
  812. };
  813. // Pruning queue discipline (see above) using the weight's natural order for the
  814. // comparison function. The ownership of the queue argument is given to this
  815. // class.
  816. template <typename Queue, typename Weight, typename ClassFnc>
  817. class NaturalPruneQueue final
  818. : public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
  819. public:
  820. using StateId = typename Queue::StateId;
  821. NaturalPruneQueue(const std::vector<Weight> &distance,
  822. std::unique_ptr<Queue> queue, const ClassFnc &class_fnc,
  823. Weight threshold)
  824. : PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
  825. distance, std::move(queue), NaturalLess<Weight>(), class_fnc,
  826. threshold) {}
  827. ~NaturalPruneQueue() override = default;
  828. };
  829. // Filter-based pruning queue discipline: enqueues a state only if allowed by
  830. // the filter, specified by the state filter functor argument. The underlying
  831. // queue discipline is specified by the queue argument.
  832. template <typename Queue, typename Filter>
  833. class FilterQueue : public QueueBase<typename Queue::StateId> {
  834. public:
  835. using StateId = typename Queue::StateId;
  836. FilterQueue(std::unique_ptr<Queue> queue, const Filter &filter)
  837. : QueueBase<StateId>(OTHER_QUEUE),
  838. queue_(std::move(queue)),
  839. filter_(filter) {}
  840. ~FilterQueue() override = default;
  841. StateId Head() const final { return queue_->Head(); }
  842. // Enqueues only if allowed by state filter.
  843. void Enqueue(StateId s) final {
  844. if (filter_(s)) queue_->Enqueue(s);
  845. }
  846. void Dequeue() final { queue_->Dequeue(); }
  847. void Update(StateId s) final {}
  848. bool Empty() const final { return queue_->Empty(); }
  849. void Clear() final { queue_->Clear(); }
  850. private:
  851. std::unique_ptr<Queue> queue_;
  852. const Filter &filter_;
  853. };
  854. } // namespace fst
  855. #endif // FST_QUEUE_H_