// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Functions and classes for various FST state queues with a unified interface. #ifndef FST_QUEUE_H_ #define FST_QUEUE_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { // The Queue interface is: // // template // class Queue { // public: // using StateId = S; // // // Constructor: may need args (e.g., FST, comparator) for some queues. // Queue(...) override; // // // Returns the head of the queue. // StateId Head() const override; // // // Inserts a state. // void Enqueue(StateId s) override; // // // Removes the head of the queue. // void Dequeue() override; // // // Updates ordering of state s when weight changes, if necessary. // void Update(StateId s) override; // // // Is the queue empty? // bool Empty() const override; // // // Removes all states from the queue. // void Clear() override; // }; // State queue types. enum QueueType { TRIVIAL_QUEUE = 0, // Single state queue. FIFO_QUEUE = 1, // First-in, first-out queue. LIFO_QUEUE = 2, // Last-in, first-out queue. SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue. TOP_ORDER_QUEUE = 4, // Topologically-ordered queue. STATE_ORDER_QUEUE = 5, // State ID-ordered queue. SCC_QUEUE = 6, // Component graph top-ordered meta-queue. AUTO_QUEUE = 7, // Auto-selected queue. OTHER_QUEUE = 8 }; // QueueBase, templated on the StateId, is a virtual base class shared by all // queues considered by AutoQueue. template class QueueBase { public: using StateId = S; virtual ~QueueBase() = default; // Concrete implementation. explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {} void SetError(bool error) { error_ = error; } bool Error() const { return error_; } QueueType Type() const { return queue_type_; } // Virtual interface. virtual StateId Head() const = 0; virtual void Enqueue(StateId) = 0; virtual void Dequeue() = 0; virtual void Update(StateId) = 0; virtual bool Empty() const = 0; virtual void Clear() = 0; private: QueueType queue_type_; bool error_; }; // Trivial queue discipline; one may enqueue at most one state at a time. It // can be used for strongly connected components with only one state and no // self-loops. template class TrivialQueue : public QueueBase { public: using StateId = S; TrivialQueue() : QueueBase(TRIVIAL_QUEUE), front_(kNoStateId) {} ~TrivialQueue() override = default; StateId Head() const final { return front_; } void Enqueue(StateId s) final { front_ = s; } void Dequeue() final { front_ = kNoStateId; } void Update(StateId) final {} bool Empty() const final { return front_ == kNoStateId; } void Clear() final { front_ = kNoStateId; } private: StateId front_; }; // First-in, first-out queue discipline. // // This is not a final class. template class FifoQueue : public QueueBase { public: using StateId = S; FifoQueue() : QueueBase(FIFO_QUEUE) {} ~FifoQueue() override = default; StateId Head() const override { return queue_.front(); } void Enqueue(StateId s) override { queue_.push(s); } void Dequeue() override { queue_.pop(); } void Update(StateId) override {} bool Empty() const override { return queue_.empty(); } void Clear() override { queue_ = std::queue(); } private: std::queue queue_; }; // Last-in, first-out queue discipline. template class LifoQueue : public QueueBase { public: using StateId = S; LifoQueue() : QueueBase(LIFO_QUEUE) {} ~LifoQueue() override = default; StateId Head() const final { return stack_.back(); } void Enqueue(StateId s) final { stack_.push_back(s); } void Dequeue() final { stack_.pop_back(); } void Update(StateId) final {} bool Empty() const final { return stack_.empty(); } void Clear() final { stack_.clear(); } private: std::vector stack_; }; // Shortest-first queue discipline, templated on the StateId and as well as a // comparison functor used to compare two StateIds. If a (single) state's order // changes, it can be reordered in the queue with a call to Update(). If update // is false, call to Update() does not reorder the queue. // // This is not a final class. template class ShortestFirstQueue : public QueueBase { public: using StateId = S; explicit ShortestFirstQueue(Compare comp) : QueueBase(SHORTEST_FIRST_QUEUE), heap_(comp) {} ~ShortestFirstQueue() override = default; StateId Head() const override { return heap_.Top(); } void Enqueue(StateId s) override { if (update) { for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId); key_[s] = heap_.Insert(s); } else { heap_.Insert(s); } } void Dequeue() override { if (update) { key_[heap_.Pop()] = kNoStateId; } else { heap_.Pop(); } } void Update(StateId s) override { if (!update) return; if (s >= key_.size() || key_[s] == kNoStateId) { Enqueue(s); } else { heap_.Update(key_[s], s); } } bool Empty() const override { return heap_.Empty(); } void Clear() override { heap_.Clear(); if (update) key_.clear(); } ssize_t Size() const { return heap_.Size(); } const Compare &GetCompare() const { return heap_.GetCompare(); } private: Heap heap_; std::vector key_; }; namespace internal { // Given a vector that maps from states to weights, and a comparison functor // for weights, this class defines a comparison function object between states. template class StateWeightCompare { public: using Weight = typename Less::Weight; StateWeightCompare(const std::vector &weights, const Less &less) : weights_(weights), less_(less) {} bool operator()(const StateId s1, const StateId s2) const { return less_(weights_[s1], weights_[s2]); } private: // Borrowed references. const std::vector &weights_; const Less &less_; }; // Comparison that can never be instantiated. Useful only to pass a pointer to // this to a function that needs a comparison when it is known that the pointer // will always be null. template struct ErrorLess { using Weight = W; ErrorLess() { FSTERROR() << "ErrorLess: instantiated for Weight " << Weight::Type(); } bool operator()(const Weight &, const Weight &) const { return false; } }; } // namespace internal // Shortest-first queue discipline, templated on the StateId and Weight, is // specialized to use the weight's natural order for the comparison function. // Requires Weight is idempotent (due to use of NaturalLess). template class NaturalShortestFirstQueue : public ShortestFirstQueue< S, internal::StateWeightCompare>> { public: using StateId = S; using Less = NaturalLess; using Compare = internal::StateWeightCompare; explicit NaturalShortestFirstQueue(const std::vector &distance) : ShortestFirstQueue(Compare(distance, Less())) {} ~NaturalShortestFirstQueue() override = default; }; // In a shortest path computation on a lattice-like FST, we may keep many old // nonviable paths as a part of the search. Since the search process always // expands the lowest cost path next, that lowest cost path may be a very old // nonviable path instead of one we expect to lead to a shortest path. // // For instance, suppose that the current best path in an alignment has // traversed 500 arcs with a cost of 10. We may also have a bad path in // the queue that has traversed only 40 arcs but also has a cost of 10. // This path is very unlikely to lead to a reasonable alignment, so this queue // can prune it from the search space. // // This queue relies on the caller using a shortest-first exploration order // like this: // while (true) { // StateId head = queue.Head(); // queue.Dequeue(); // for (const auto& arc : GetArcs(fst, head)) { // queue.Enqueue(arc.nextstate); // } // } // We use this assumption to guess that there is an arc between Head and the // Enqueued state; this is how the number of path steps is measured. template class PruneNaturalShortestFirstQueue : public NaturalShortestFirstQueue { public: using StateId = S; using Base = NaturalShortestFirstQueue; PruneNaturalShortestFirstQueue(const std::vector &distance, ssize_t arc_threshold, ssize_t state_limit = 0) : Base(distance), arc_threshold_(arc_threshold), state_limit_(state_limit), head_steps_(0), max_head_steps_(0) {} ~PruneNaturalShortestFirstQueue() override = default; StateId Head() const override { const auto head = Base::Head(); // Stores the number of steps from the start of the graph to this state // along the shortest-weight path. if (head < steps_.size()) { max_head_steps_ = std::max(steps_[head], max_head_steps_); head_steps_ = steps_[head]; } return head; } void Enqueue(StateId s) override { // We assume that there is an arc between the Head() state and this // Enqueued state. const ssize_t state_steps = head_steps_ + 1; if (s >= steps_.size()) { steps_.resize(s + 1, state_steps); } // This is the number of arcs in the minimum cost path from Start to s. steps_[s] = state_steps; // Adjust the threshold in cases where path step thresholding wasn't // enough to keep the queue small. ssize_t adjusted_threshold = arc_threshold_; if (Base::Size() > state_limit_ && state_limit_ > 0) { adjusted_threshold = std::max( 0, arc_threshold_ - (Base::Size() / state_limit_) - 1); } if (state_steps > (max_head_steps_ - adjusted_threshold) || arc_threshold_ < 0) { if (adjusted_threshold == 0 && state_limit_ > 0) { // If the queue is continuing to grow without bound, we follow any // path that makes progress and clear the rest. Base::Clear(); } Base::Enqueue(s); } } private: // A dense map from StateId to the number of arcs in the minimum weight // path from Start to this state. std::vector steps_; // We only keep paths that are within this number of arcs (not weight!) // of the longest path. const ssize_t arc_threshold_; // If the size of the queue climbs above this number, we increase the // threshold to reduce the amount of work we have to do. const ssize_t state_limit_; // The following are mutable because Head() is const. // The number of arcs traversed in the minimum cost path from the start // state to the current Head() state. mutable ssize_t head_steps_; // The maximum number of arcs traversed by any low-cost path so far. mutable ssize_t max_head_steps_; }; // Topological-order queue discipline, templated on the StateId. States are // ordered in the queue topologically. The FST must be acyclic. template class TopOrderQueue : public QueueBase { public: using StateId = S; // This constructor computes the topological order. It accepts an arc filter // to limit the transitions considered in that computation (e.g., only the // epsilon graph). template TopOrderQueue(const Fst &fst, ArcFilter filter) : QueueBase(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), order_(0), state_(0) { bool acyclic; TopOrderVisitor top_order_visitor(&order_, &acyclic); DfsVisit(fst, &top_order_visitor, filter); if (!acyclic) { FSTERROR() << "TopOrderQueue: FST is not acyclic"; QueueBase::SetError(true); } state_.resize(order_.size(), kNoStateId); } // This constructor is passed the pre-computed topological order. explicit TopOrderQueue(const std::vector &order) : QueueBase(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), order_(order), state_(order.size(), kNoStateId) {} ~TopOrderQueue() override = default; StateId Head() const final { return state_[front_]; } void Enqueue(StateId s) final { if (front_ > back_) { front_ = back_ = order_[s]; } else if (order_[s] > back_) { back_ = order_[s]; } else if (order_[s] < front_) { front_ = order_[s]; } state_[order_[s]] = s; } void Dequeue() final { state_[front_] = kNoStateId; while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; } void Update(StateId) final {} bool Empty() const final { return front_ > back_; } void Clear() final { for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId; back_ = kNoStateId; front_ = 0; } private: StateId front_; StateId back_; std::vector order_; std::vector state_; }; // State order queue discipline, templated on the StateId. States are ordered in // the queue by state ID. template class StateOrderQueue : public QueueBase { public: using StateId = S; StateOrderQueue() : QueueBase(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} ~StateOrderQueue() override = default; StateId Head() const final { return front_; } void Enqueue(StateId s) final { if (front_ > back_) { front_ = back_ = s; } else if (s > back_) { back_ = s; } else if (s < front_) { front_ = s; } while (enqueued_.size() <= s) enqueued_.push_back(false); enqueued_[s] = true; } void Dequeue() final { enqueued_[front_] = false; while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; } void Update(StateId) final {} bool Empty() const final { return front_ > back_; } void Clear() final { for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; front_ = 0; back_ = kNoStateId; } private: StateId front_; StateId back_; std::vector enqueued_; }; // SCC topological-order meta-queue discipline, templated on the StateId and a // queue used inside each SCC. It visits the SCCs of an FST in topological // order. Its constructor is passed the queues to to use within an SCC. template class SccQueue : public QueueBase { public: using StateId = S; // Constructor takes a vector specifying the SCC number per state and a // vector giving the queue to use per SCC number. SccQueue(const std::vector &scc, std::vector> *queue) : QueueBase(SCC_QUEUE), queue_(queue), scc_(scc), front_(0), back_(kNoStateId) {} ~SccQueue() override = default; StateId Head() const final { while ((front_ <= back_) && (((*queue_)[front_] && (*queue_)[front_]->Empty()) || (((*queue_)[front_] == nullptr) && ((front_ >= trivial_queue_.size()) || (trivial_queue_[front_] == kNoStateId))))) { ++front_; } if ((*queue_)[front_]) { return (*queue_)[front_]->Head(); } else { return trivial_queue_[front_]; } } void Enqueue(StateId s) final { if (front_ > back_) { front_ = back_ = scc_[s]; } else if (scc_[s] > back_) { back_ = scc_[s]; } else if (scc_[s] < front_) { front_ = scc_[s]; } if ((*queue_)[scc_[s]]) { (*queue_)[scc_[s]]->Enqueue(s); } else { while (trivial_queue_.size() <= scc_[s]) { trivial_queue_.push_back(kNoStateId); } trivial_queue_[scc_[s]] = s; } } void Dequeue() final { if ((*queue_)[front_]) { (*queue_)[front_]->Dequeue(); } else if (front_ < trivial_queue_.size()) { trivial_queue_[front_] = kNoStateId; } } void Update(StateId s) final { if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s); } bool Empty() const final { // Queues SCC number back_ is not empty unless back_ == front_. if (front_ < back_) { return false; } else if (front_ > back_) { return true; } else if ((*queue_)[front_]) { return (*queue_)[front_]->Empty(); } else { return (front_ >= trivial_queue_.size()) || (trivial_queue_[front_] == kNoStateId); } } void Clear() final { for (StateId i = front_; i <= back_; ++i) { if ((*queue_)[i]) { (*queue_)[i]->Clear(); } else if (i < trivial_queue_.size()) { trivial_queue_[i] = kNoStateId; } } front_ = 0; back_ = kNoStateId; } private: std::vector> *queue_; const std::vector &scc_; mutable StateId front_; StateId back_; std::vector trivial_queue_; }; // Automatic queue discipline. It selects a queue discipline for a given FST // based on its properties. template class AutoQueue : public QueueBase { public: using StateId = S; // This constructor takes a state distance vector that, if non-null and if // the Weight type has the path property, will entertain the shortest-first // queue using the natural order w.r.t to the distance. template AutoQueue(const Fst &fst, const std::vector *distance, ArcFilter filter) : QueueBase(AUTO_QUEUE) { using Weight = typename Arc::Weight; // We need to have variables of type Less and Compare, so we use // ErrorLess if the type NaturalLess cannot be instantiated due // to lack of path property. using Less = std::conditional_t::value, NaturalLess, internal::ErrorLess>; using Compare = internal::StateWeightCompare; // First checks if the FST is known to have these properties. const auto props = fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false); if ((props & kTopSorted) || fst.Start() == kNoStateId) { queue_ = std::make_unique>(); VLOG(2) << "AutoQueue: using state-order discipline"; } else if (props & kAcyclic) { queue_ = std::make_unique>(fst, filter); VLOG(2) << "AutoQueue: using top-order discipline"; } else if ((props & kUnweighted) && IsIdempotent::value) { queue_ = std::make_unique>(); VLOG(2) << "AutoQueue: using LIFO discipline"; } else { uint64_t properties; // Decomposes into strongly-connected components. SccVisitor scc_visitor(&scc_, nullptr, nullptr, &properties); DfsVisit(fst, &scc_visitor, filter); auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1; std::vector queue_types(nscc); std::unique_ptr less; std::unique_ptr comp; if constexpr (IsPath::value) { if (distance) { less = std::make_unique(); comp = std::make_unique(*distance, *less); } } // Finds the queue type to use per SCC. bool unweighted; bool all_trivial; SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial, &unweighted); // If unweighted and semiring is idempotent, uses LIFO queue. if (unweighted) { queue_ = std::make_unique>(); VLOG(2) << "AutoQueue: using LIFO discipline"; return; } // If all the SCC are trivial, the FST is acyclic and the scc number gives // the topological order. if (all_trivial) { queue_ = std::make_unique>(scc_); VLOG(2) << "AutoQueue: using top-order discipline"; return; } VLOG(2) << "AutoQueue: using SCC meta-discipline"; queues_.resize(nscc); for (StateId i = 0; i < nscc; ++i) { switch (queue_types[i]) { case TRIVIAL_QUEUE: queues_[i].reset(); VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline"; break; case SHORTEST_FIRST_QUEUE: // The IsPath test is not needed for correctness. It just saves // instantiating a ShortestFirstQueue that can never be called. if constexpr (IsPath::value) { queues_[i] = std::make_unique>( *comp); VLOG(3) << "AutoQueue: SCC #" << i << ": using shortest-first discipline"; } else { // SccQueueType should ensure this can never happen. FSTERROR() << "Got SHORTEST_FIRST_QUEUE for non-Path Weight " << Weight::Type(); queues_[i].reset(); } break; case LIFO_QUEUE: queues_[i] = std::make_unique>(); VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline"; break; case FIFO_QUEUE: default: queues_[i] = std::make_unique>(); VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine"; break; } } queue_ = std::make_unique>>( scc_, &queues_); } } ~AutoQueue() override = default; StateId Head() const final { return queue_->Head(); } void Enqueue(StateId s) final { queue_->Enqueue(s); } void Dequeue() final { queue_->Dequeue(); } void Update(StateId s) final { queue_->Update(s); } bool Empty() const final { return queue_->Empty(); } void Clear() final { queue_->Clear(); } private: template static void SccQueueType(const Fst &fst, const std::vector &scc, std::vector *queue_types, ArcFilter filter, Less *less, bool *all_trivial, bool *unweighted); std::unique_ptr> queue_; std::vector>> queues_; std::vector scc_; }; // Examines the states in an FST's strongly connected components and determines // which type of queue to use per SCC. Stores result as a vector of QueueTypes // which is assumed to have length equal to the number of SCCs. An arc filter // is used to limit the transitions considered (e.g., only the epsilon graph). // The argument all_trivial is set to true if every queue is the trivial queue. // The argument unweighted is set to true if the semiring is idempotent and all // the arc weights are equal to Zero() or One(). template template void AutoQueue::SccQueueType(const Fst &fst, const std::vector &scc, std::vector *queue_type, ArcFilter filter, Less *less, bool *all_trivial, bool *unweighted) { using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; *all_trivial = true; *unweighted = true; for (StateId i = 0; i < queue_type->size(); ++i) { (*queue_type)[i] = TRIVIAL_QUEUE; } for (StateIterator> sit(fst); !sit.Done(); sit.Next()) { const auto state = sit.Value(); for (ArcIterator> ait(fst, state); !ait.Done(); ait.Next()) { const auto &arc = ait.Value(); if (!filter(arc)) continue; if (scc[state] == scc[arc.nextstate]) { auto &type = (*queue_type)[scc[state]]; if constexpr (!IsPath::value) { type = FIFO_QUEUE; } else if (!less || (*less)(arc.weight, Weight::One())) { type = FIFO_QUEUE; } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { if (!IsIdempotent::value || (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { type = SHORTEST_FIRST_QUEUE; } else { type = LIFO_QUEUE; } } if (type != TRIVIAL_QUEUE) *all_trivial = false; } if (!IsIdempotent::value || (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { *unweighted = false; } } } } // An A* estimate is a function object that maps from a state ID to an // estimate of the shortest distance to the final states. // A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's // algorithm. template struct TrivialAStarEstimate { constexpr Weight operator()(StateId) const { return Weight::One(); } }; // A non-trivial A* estimate using a vector of the estimated future costs. template class NaturalAStarEstimate { public: NaturalAStarEstimate(const std::vector &beta) : beta_(beta) {} const Weight &operator()(StateId s) const { return (s < beta_.size()) ? beta_[s] : kZero; } private: static constexpr Weight kZero = Weight::Zero(); const std::vector &beta_; }; // Given a vector that maps from states to weights representing the shortest // distance from the initial state, a comparison function object between // weights, and an estimate of the shortest distance to the final states, this // class defines a comparison function object between states. template class AStarWeightCompare { public: using StateId = S; using Weight = typename Less::Weight; AStarWeightCompare(const std::vector &weights, const Less &less, const Estimate &estimate) : weights_(weights), less_(less), estimate_(estimate) {} bool operator()(StateId s1, StateId s2) const { const auto w1 = Times(weights_[s1], estimate_(s1)); const auto w2 = Times(weights_[s2], estimate_(s2)); return less_(w1, w2); } const Estimate &GetEstimate() const { return estimate_; } private: const std::vector &weights_; const Less &less_; const Estimate &estimate_; }; // A* queue discipline templated on StateId, Weight, and Estimate. template class NaturalAStarQueue : public ShortestFirstQueue< S, AStarWeightCompare, Estimate>> { public: using StateId = S; using Compare = AStarWeightCompare, Estimate>; NaturalAStarQueue(const std::vector &distance, const Estimate &estimate) : ShortestFirstQueue( Compare(distance, less_, estimate)) {} ~NaturalAStarQueue() override = default; private: // This is non-static because the constructor for non-idempotent weights will // result in an error. const NaturalLess less_{}; }; // A state equivalence class is a function object that maps from a state ID to // an equivalence class (state) ID. The trivial equivalence class maps a state // ID to itself. template struct TrivialStateEquivClass { StateId operator()(StateId s) const { return s; } }; // Distance-based pruning queue discipline: Enqueues a state only when its // shortest distance (so far), as specified by distance, is less than (as // specified by comp) the shortest distance Times() the threshold to any state // in the same equivalence class, as specified by the functor class_func. The // underlying queue discipline is specified by queue. // // This is not a final class. template class PruneQueue : public QueueBase { public: using StateId = typename Queue::StateId; using Weight = typename Less::Weight; PruneQueue(const std::vector &distance, std::unique_ptr queue, const Less &less, const ClassFnc &class_fnc, Weight threshold) : QueueBase(OTHER_QUEUE), distance_(distance), queue_(std::move(queue)), less_(less), class_fnc_(class_fnc), threshold_(std::move(threshold)) {} ~PruneQueue() override = default; StateId Head() const override { return queue_->Head(); } void Enqueue(StateId s) override { const auto c = class_fnc_(s); if (c >= class_distance_.size()) { class_distance_.resize(c + 1, Weight::Zero()); } if (less_(distance_[s], class_distance_[c])) { class_distance_[c] = distance_[s]; } // Enqueues only if below threshold limit. const auto limit = Times(class_distance_[c], threshold_); if (less_(distance_[s], limit)) queue_->Enqueue(s); } void Dequeue() override { queue_->Dequeue(); } void Update(StateId s) override { const auto c = class_fnc_(s); if (less_(distance_[s], class_distance_[c])) { class_distance_[c] = distance_[s]; } queue_->Update(s); } bool Empty() const override { return queue_->Empty(); } void Clear() override { queue_->Clear(); } private: const std::vector &distance_; // Shortest distance to state. std::unique_ptr queue_; const Less &less_; // Borrowed reference. const ClassFnc &class_fnc_; // Equivalence class functor. Weight threshold_; // Pruning weight threshold. std::vector class_distance_; // Shortest distance to class. }; // Pruning queue discipline (see above) using the weight's natural order for the // comparison function. The ownership of the queue argument is given to this // class. template class NaturalPruneQueue final : public PruneQueue, ClassFnc> { public: using StateId = typename Queue::StateId; NaturalPruneQueue(const std::vector &distance, std::unique_ptr queue, const ClassFnc &class_fnc, Weight threshold) : PruneQueue, ClassFnc>( distance, std::move(queue), NaturalLess(), class_fnc, threshold) {} ~NaturalPruneQueue() override = default; }; // Filter-based pruning queue discipline: enqueues a state only if allowed by // the filter, specified by the state filter functor argument. The underlying // queue discipline is specified by the queue argument. template class FilterQueue : public QueueBase { public: using StateId = typename Queue::StateId; FilterQueue(std::unique_ptr queue, const Filter &filter) : QueueBase(OTHER_QUEUE), queue_(std::move(queue)), filter_(filter) {} ~FilterQueue() override = default; StateId Head() const final { return queue_->Head(); } // Enqueues only if allowed by state filter. void Enqueue(StateId s) final { if (filter_(s)) queue_->Enqueue(s); } void Dequeue() final { queue_->Dequeue(); } void Update(StateId s) final {} bool Empty() const final { return queue_->Empty(); } void Clear() final { queue_->Clear(); } private: std::unique_ptr queue_; const Filter &filter_; }; } // namespace fst #endif // FST_QUEUE_H_