// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Functions and classes for various FST state queues with a unified interface.
|
|
|
|
#ifndef FST_QUEUE_H_
|
|
#define FST_QUEUE_H_
|
|
|
|
#include <sys/types.h>
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <queue>
|
|
#include <type_traits>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <fst/log.h>
|
|
#include <fst/arcfilter.h>
|
|
#include <fst/cc-visitors.h>
|
|
#include <fst/dfs-visit.h>
|
|
#include <fst/fst.h>
|
|
#include <fst/heap.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/topsort.h>
|
|
#include <fst/util.h>
|
|
#include <fst/weight.h>
|
|
|
|
namespace fst {
|
|
|
|
// The Queue interface is:
|
|
//
|
|
// template <class S>
|
|
// class Queue {
|
|
// public:
|
|
// using StateId = S;
|
|
//
|
|
// // Constructor: may need args (e.g., FST, comparator) for some queues.
|
|
// Queue(...) override;
|
|
//
|
|
// // Returns the head of the queue.
|
|
// StateId Head() const override;
|
|
//
|
|
// // Inserts a state.
|
|
// void Enqueue(StateId s) override;
|
|
//
|
|
// // Removes the head of the queue.
|
|
// void Dequeue() override;
|
|
//
|
|
// // Updates ordering of state s when weight changes, if necessary.
|
|
// void Update(StateId s) override;
|
|
//
|
|
// // Is the queue empty?
|
|
// bool Empty() const override;
|
|
//
|
|
// // Removes all states from the queue.
|
|
// void Clear() override;
|
|
// };
|
|
|
|
// State queue types.
|
|
enum QueueType {
|
|
TRIVIAL_QUEUE = 0, // Single state queue.
|
|
FIFO_QUEUE = 1, // First-in, first-out queue.
|
|
LIFO_QUEUE = 2, // Last-in, first-out queue.
|
|
SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue.
|
|
TOP_ORDER_QUEUE = 4, // Topologically-ordered queue.
|
|
STATE_ORDER_QUEUE = 5, // State ID-ordered queue.
|
|
SCC_QUEUE = 6, // Component graph top-ordered meta-queue.
|
|
AUTO_QUEUE = 7, // Auto-selected queue.
|
|
OTHER_QUEUE = 8
|
|
};
|
|
|
|
// QueueBase, templated on the StateId, is a virtual base class shared by all
|
|
// queues considered by AutoQueue.
|
|
template <class S>
|
|
class QueueBase {
|
|
public:
|
|
using StateId = S;
|
|
|
|
virtual ~QueueBase() = default;
|
|
|
|
// Concrete implementation.
|
|
|
|
explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
|
|
|
|
void SetError(bool error) { error_ = error; }
|
|
|
|
bool Error() const { return error_; }
|
|
|
|
QueueType Type() const { return queue_type_; }
|
|
|
|
// Virtual interface.
|
|
|
|
virtual StateId Head() const = 0;
|
|
virtual void Enqueue(StateId) = 0;
|
|
virtual void Dequeue() = 0;
|
|
virtual void Update(StateId) = 0;
|
|
virtual bool Empty() const = 0;
|
|
virtual void Clear() = 0;
|
|
|
|
private:
|
|
QueueType queue_type_;
|
|
bool error_;
|
|
};
|
|
|
|
// Trivial queue discipline; one may enqueue at most one state at a time. It
|
|
// can be used for strongly connected components with only one state and no
|
|
// self-loops.
|
|
template <class S>
|
|
class TrivialQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
TrivialQueue() : QueueBase<StateId>(TRIVIAL_QUEUE), front_(kNoStateId) {}
|
|
|
|
~TrivialQueue() override = default;
|
|
|
|
StateId Head() const final { return front_; }
|
|
|
|
void Enqueue(StateId s) final { front_ = s; }
|
|
|
|
void Dequeue() final { front_ = kNoStateId; }
|
|
|
|
void Update(StateId) final {}
|
|
|
|
bool Empty() const final { return front_ == kNoStateId; }
|
|
|
|
void Clear() final { front_ = kNoStateId; }
|
|
|
|
private:
|
|
StateId front_;
|
|
};
|
|
|
|
// First-in, first-out queue discipline.
|
|
//
|
|
// This is not a final class.
|
|
template <class S>
|
|
class FifoQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
FifoQueue() : QueueBase<StateId>(FIFO_QUEUE) {}
|
|
|
|
~FifoQueue() override = default;
|
|
|
|
StateId Head() const override { return queue_.front(); }
|
|
|
|
void Enqueue(StateId s) override { queue_.push(s); }
|
|
|
|
void Dequeue() override { queue_.pop(); }
|
|
|
|
void Update(StateId) override {}
|
|
|
|
bool Empty() const override { return queue_.empty(); }
|
|
|
|
void Clear() override { queue_ = std::queue<StateId>(); }
|
|
|
|
private:
|
|
std::queue<StateId> queue_;
|
|
};
|
|
|
|
// Last-in, first-out queue discipline.
|
|
template <class S>
|
|
class LifoQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
LifoQueue() : QueueBase<StateId>(LIFO_QUEUE) {}
|
|
|
|
~LifoQueue() override = default;
|
|
|
|
StateId Head() const final { return stack_.back(); }
|
|
|
|
void Enqueue(StateId s) final { stack_.push_back(s); }
|
|
|
|
void Dequeue() final { stack_.pop_back(); }
|
|
|
|
void Update(StateId) final {}
|
|
|
|
bool Empty() const final { return stack_.empty(); }
|
|
|
|
void Clear() final { stack_.clear(); }
|
|
|
|
private:
|
|
std::vector<StateId> stack_;
|
|
};
|
|
|
|
// Shortest-first queue discipline, templated on the StateId and as well as a
|
|
// comparison functor used to compare two StateIds. If a (single) state's order
|
|
// changes, it can be reordered in the queue with a call to Update(). If update
|
|
// is false, call to Update() does not reorder the queue.
|
|
//
|
|
// This is not a final class.
|
|
template <typename S, typename Compare, bool update = true>
|
|
class ShortestFirstQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
explicit ShortestFirstQueue(Compare comp)
|
|
: QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
|
|
|
|
~ShortestFirstQueue() override = default;
|
|
|
|
StateId Head() const override { return heap_.Top(); }
|
|
|
|
void Enqueue(StateId s) override {
|
|
if (update) {
|
|
for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
|
|
key_[s] = heap_.Insert(s);
|
|
} else {
|
|
heap_.Insert(s);
|
|
}
|
|
}
|
|
|
|
void Dequeue() override {
|
|
if (update) {
|
|
key_[heap_.Pop()] = kNoStateId;
|
|
} else {
|
|
heap_.Pop();
|
|
}
|
|
}
|
|
|
|
void Update(StateId s) override {
|
|
if (!update) return;
|
|
if (s >= key_.size() || key_[s] == kNoStateId) {
|
|
Enqueue(s);
|
|
} else {
|
|
heap_.Update(key_[s], s);
|
|
}
|
|
}
|
|
|
|
bool Empty() const override { return heap_.Empty(); }
|
|
|
|
void Clear() override {
|
|
heap_.Clear();
|
|
if (update) key_.clear();
|
|
}
|
|
|
|
ssize_t Size() const { return heap_.Size(); }
|
|
|
|
const Compare &GetCompare() const { return heap_.GetCompare(); }
|
|
|
|
private:
|
|
Heap<StateId, Compare> heap_;
|
|
std::vector<ssize_t> key_;
|
|
};
|
|
|
|
namespace internal {
|
|
|
|
// Given a vector that maps from states to weights, and a comparison functor
|
|
// for weights, this class defines a comparison function object between states.
|
|
template <typename StateId, typename Less>
|
|
class StateWeightCompare {
|
|
public:
|
|
using Weight = typename Less::Weight;
|
|
|
|
StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
|
|
: weights_(weights), less_(less) {}
|
|
|
|
bool operator()(const StateId s1, const StateId s2) const {
|
|
return less_(weights_[s1], weights_[s2]);
|
|
}
|
|
|
|
private:
|
|
// Borrowed references.
|
|
const std::vector<Weight> &weights_;
|
|
const Less &less_;
|
|
};
|
|
|
|
// Comparison that can never be instantiated. Useful only to pass a pointer to
|
|
// this to a function that needs a comparison when it is known that the pointer
|
|
// will always be null.
|
|
template <class W>
|
|
struct ErrorLess {
|
|
using Weight = W;
|
|
ErrorLess() {
|
|
FSTERROR() << "ErrorLess: instantiated for Weight " << Weight::Type();
|
|
}
|
|
bool operator()(const Weight &, const Weight &) const { return false; }
|
|
};
|
|
|
|
} // namespace internal
|
|
|
|
// Shortest-first queue discipline, templated on the StateId and Weight, is
|
|
// specialized to use the weight's natural order for the comparison function.
|
|
// Requires Weight is idempotent (due to use of NaturalLess).
|
|
template <typename S, typename Weight>
|
|
class NaturalShortestFirstQueue
|
|
: public ShortestFirstQueue<
|
|
S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
|
|
public:
|
|
using StateId = S;
|
|
using Less = NaturalLess<Weight>;
|
|
using Compare = internal::StateWeightCompare<StateId, Less>;
|
|
|
|
explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
|
|
: ShortestFirstQueue<StateId, Compare>(Compare(distance, Less())) {}
|
|
|
|
~NaturalShortestFirstQueue() override = default;
|
|
};
|
|
|
|
// In a shortest path computation on a lattice-like FST, we may keep many old
|
|
// nonviable paths as a part of the search. Since the search process always
|
|
// expands the lowest cost path next, that lowest cost path may be a very old
|
|
// nonviable path instead of one we expect to lead to a shortest path.
|
|
//
|
|
// For instance, suppose that the current best path in an alignment has
|
|
// traversed 500 arcs with a cost of 10. We may also have a bad path in
|
|
// the queue that has traversed only 40 arcs but also has a cost of 10.
|
|
// This path is very unlikely to lead to a reasonable alignment, so this queue
|
|
// can prune it from the search space.
|
|
//
|
|
// This queue relies on the caller using a shortest-first exploration order
|
|
// like this:
|
|
// while (true) {
|
|
// StateId head = queue.Head();
|
|
// queue.Dequeue();
|
|
// for (const auto& arc : GetArcs(fst, head)) {
|
|
// queue.Enqueue(arc.nextstate);
|
|
// }
|
|
// }
|
|
// We use this assumption to guess that there is an arc between Head and the
|
|
// Enqueued state; this is how the number of path steps is measured.
|
|
template <typename S, typename Weight>
|
|
class PruneNaturalShortestFirstQueue
|
|
: public NaturalShortestFirstQueue<S, Weight> {
|
|
public:
|
|
using StateId = S;
|
|
using Base = NaturalShortestFirstQueue<StateId, Weight>;
|
|
|
|
PruneNaturalShortestFirstQueue(const std::vector<Weight> &distance,
|
|
ssize_t arc_threshold, ssize_t state_limit = 0)
|
|
: Base(distance),
|
|
arc_threshold_(arc_threshold),
|
|
state_limit_(state_limit),
|
|
head_steps_(0),
|
|
max_head_steps_(0) {}
|
|
|
|
~PruneNaturalShortestFirstQueue() override = default;
|
|
|
|
StateId Head() const override {
|
|
const auto head = Base::Head();
|
|
// Stores the number of steps from the start of the graph to this state
|
|
// along the shortest-weight path.
|
|
if (head < steps_.size()) {
|
|
max_head_steps_ = std::max(steps_[head], max_head_steps_);
|
|
head_steps_ = steps_[head];
|
|
}
|
|
return head;
|
|
}
|
|
|
|
void Enqueue(StateId s) override {
|
|
// We assume that there is an arc between the Head() state and this
|
|
// Enqueued state.
|
|
const ssize_t state_steps = head_steps_ + 1;
|
|
if (s >= steps_.size()) {
|
|
steps_.resize(s + 1, state_steps);
|
|
}
|
|
// This is the number of arcs in the minimum cost path from Start to s.
|
|
steps_[s] = state_steps;
|
|
|
|
// Adjust the threshold in cases where path step thresholding wasn't
|
|
// enough to keep the queue small.
|
|
ssize_t adjusted_threshold = arc_threshold_;
|
|
if (Base::Size() > state_limit_ && state_limit_ > 0) {
|
|
adjusted_threshold = std::max<ssize_t>(
|
|
0, arc_threshold_ - (Base::Size() / state_limit_) - 1);
|
|
}
|
|
|
|
if (state_steps > (max_head_steps_ - adjusted_threshold) ||
|
|
arc_threshold_ < 0) {
|
|
if (adjusted_threshold == 0 && state_limit_ > 0) {
|
|
// If the queue is continuing to grow without bound, we follow any
|
|
// path that makes progress and clear the rest.
|
|
Base::Clear();
|
|
}
|
|
Base::Enqueue(s);
|
|
}
|
|
}
|
|
|
|
private:
|
|
// A dense map from StateId to the number of arcs in the minimum weight
|
|
// path from Start to this state.
|
|
std::vector<ssize_t> steps_;
|
|
// We only keep paths that are within this number of arcs (not weight!)
|
|
// of the longest path.
|
|
const ssize_t arc_threshold_;
|
|
// If the size of the queue climbs above this number, we increase the
|
|
// threshold to reduce the amount of work we have to do.
|
|
const ssize_t state_limit_;
|
|
|
|
// The following are mutable because Head() is const.
|
|
// The number of arcs traversed in the minimum cost path from the start
|
|
// state to the current Head() state.
|
|
mutable ssize_t head_steps_;
|
|
// The maximum number of arcs traversed by any low-cost path so far.
|
|
mutable ssize_t max_head_steps_;
|
|
};
|
|
|
|
// Topological-order queue discipline, templated on the StateId. States are
|
|
// ordered in the queue topologically. The FST must be acyclic.
|
|
template <class S>
|
|
class TopOrderQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
// This constructor computes the topological order. It accepts an arc filter
|
|
// to limit the transitions considered in that computation (e.g., only the
|
|
// epsilon graph).
|
|
template <class Arc, class ArcFilter>
|
|
TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
|
|
: QueueBase<StateId>(TOP_ORDER_QUEUE),
|
|
front_(0),
|
|
back_(kNoStateId),
|
|
order_(0),
|
|
state_(0) {
|
|
bool acyclic;
|
|
TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
|
|
DfsVisit(fst, &top_order_visitor, filter);
|
|
if (!acyclic) {
|
|
FSTERROR() << "TopOrderQueue: FST is not acyclic";
|
|
QueueBase<S>::SetError(true);
|
|
}
|
|
state_.resize(order_.size(), kNoStateId);
|
|
}
|
|
|
|
// This constructor is passed the pre-computed topological order.
|
|
explicit TopOrderQueue(const std::vector<StateId> &order)
|
|
: QueueBase<StateId>(TOP_ORDER_QUEUE),
|
|
front_(0),
|
|
back_(kNoStateId),
|
|
order_(order),
|
|
state_(order.size(), kNoStateId) {}
|
|
|
|
~TopOrderQueue() override = default;
|
|
|
|
StateId Head() const final { return state_[front_]; }
|
|
|
|
void Enqueue(StateId s) final {
|
|
if (front_ > back_) {
|
|
front_ = back_ = order_[s];
|
|
} else if (order_[s] > back_) {
|
|
back_ = order_[s];
|
|
} else if (order_[s] < front_) {
|
|
front_ = order_[s];
|
|
}
|
|
state_[order_[s]] = s;
|
|
}
|
|
|
|
void Dequeue() final {
|
|
state_[front_] = kNoStateId;
|
|
while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
|
|
}
|
|
|
|
void Update(StateId) final {}
|
|
|
|
bool Empty() const final { return front_ > back_; }
|
|
|
|
void Clear() final {
|
|
for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
|
|
back_ = kNoStateId;
|
|
front_ = 0;
|
|
}
|
|
|
|
private:
|
|
StateId front_;
|
|
StateId back_;
|
|
std::vector<StateId> order_;
|
|
std::vector<StateId> state_;
|
|
};
|
|
|
|
// State order queue discipline, templated on the StateId. States are ordered in
|
|
// the queue by state ID.
|
|
template <class S>
|
|
class StateOrderQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
StateOrderQueue()
|
|
: QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
|
|
|
|
~StateOrderQueue() override = default;
|
|
|
|
StateId Head() const final { return front_; }
|
|
|
|
void Enqueue(StateId s) final {
|
|
if (front_ > back_) {
|
|
front_ = back_ = s;
|
|
} else if (s > back_) {
|
|
back_ = s;
|
|
} else if (s < front_) {
|
|
front_ = s;
|
|
}
|
|
while (enqueued_.size() <= s) enqueued_.push_back(false);
|
|
enqueued_[s] = true;
|
|
}
|
|
|
|
void Dequeue() final {
|
|
enqueued_[front_] = false;
|
|
while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
|
|
}
|
|
|
|
void Update(StateId) final {}
|
|
|
|
bool Empty() const final { return front_ > back_; }
|
|
|
|
void Clear() final {
|
|
for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
|
|
front_ = 0;
|
|
back_ = kNoStateId;
|
|
}
|
|
|
|
private:
|
|
StateId front_;
|
|
StateId back_;
|
|
std::vector<bool> enqueued_;
|
|
};
|
|
|
|
// SCC topological-order meta-queue discipline, templated on the StateId and a
|
|
// queue used inside each SCC. It visits the SCCs of an FST in topological
|
|
// order. Its constructor is passed the queues to to use within an SCC.
|
|
template <class S, class Queue>
|
|
class SccQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
// Constructor takes a vector specifying the SCC number per state and a
|
|
// vector giving the queue to use per SCC number.
|
|
SccQueue(const std::vector<StateId> &scc,
|
|
std::vector<std::unique_ptr<Queue>> *queue)
|
|
: QueueBase<StateId>(SCC_QUEUE),
|
|
queue_(queue),
|
|
scc_(scc),
|
|
front_(0),
|
|
back_(kNoStateId) {}
|
|
|
|
~SccQueue() override = default;
|
|
|
|
StateId Head() const final {
|
|
while ((front_ <= back_) &&
|
|
(((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
|
|
(((*queue_)[front_] == nullptr) &&
|
|
((front_ >= trivial_queue_.size()) ||
|
|
(trivial_queue_[front_] == kNoStateId))))) {
|
|
++front_;
|
|
}
|
|
if ((*queue_)[front_]) {
|
|
return (*queue_)[front_]->Head();
|
|
} else {
|
|
return trivial_queue_[front_];
|
|
}
|
|
}
|
|
|
|
void Enqueue(StateId s) final {
|
|
if (front_ > back_) {
|
|
front_ = back_ = scc_[s];
|
|
} else if (scc_[s] > back_) {
|
|
back_ = scc_[s];
|
|
} else if (scc_[s] < front_) {
|
|
front_ = scc_[s];
|
|
}
|
|
if ((*queue_)[scc_[s]]) {
|
|
(*queue_)[scc_[s]]->Enqueue(s);
|
|
} else {
|
|
while (trivial_queue_.size() <= scc_[s]) {
|
|
trivial_queue_.push_back(kNoStateId);
|
|
}
|
|
trivial_queue_[scc_[s]] = s;
|
|
}
|
|
}
|
|
|
|
void Dequeue() final {
|
|
if ((*queue_)[front_]) {
|
|
(*queue_)[front_]->Dequeue();
|
|
} else if (front_ < trivial_queue_.size()) {
|
|
trivial_queue_[front_] = kNoStateId;
|
|
}
|
|
}
|
|
|
|
void Update(StateId s) final {
|
|
if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
|
|
}
|
|
|
|
bool Empty() const final {
|
|
// Queues SCC number back_ is not empty unless back_ == front_.
|
|
if (front_ < back_) {
|
|
return false;
|
|
} else if (front_ > back_) {
|
|
return true;
|
|
} else if ((*queue_)[front_]) {
|
|
return (*queue_)[front_]->Empty();
|
|
} else {
|
|
return (front_ >= trivial_queue_.size()) ||
|
|
(trivial_queue_[front_] == kNoStateId);
|
|
}
|
|
}
|
|
|
|
void Clear() final {
|
|
for (StateId i = front_; i <= back_; ++i) {
|
|
if ((*queue_)[i]) {
|
|
(*queue_)[i]->Clear();
|
|
} else if (i < trivial_queue_.size()) {
|
|
trivial_queue_[i] = kNoStateId;
|
|
}
|
|
}
|
|
front_ = 0;
|
|
back_ = kNoStateId;
|
|
}
|
|
|
|
private:
|
|
std::vector<std::unique_ptr<Queue>> *queue_;
|
|
const std::vector<StateId> &scc_;
|
|
mutable StateId front_;
|
|
StateId back_;
|
|
std::vector<StateId> trivial_queue_;
|
|
};
|
|
|
|
// Automatic queue discipline. It selects a queue discipline for a given FST
|
|
// based on its properties.
|
|
template <class S>
|
|
class AutoQueue : public QueueBase<S> {
|
|
public:
|
|
using StateId = S;
|
|
|
|
// This constructor takes a state distance vector that, if non-null and if
|
|
// the Weight type has the path property, will entertain the shortest-first
|
|
// queue using the natural order w.r.t to the distance.
|
|
template <class Arc, class ArcFilter>
|
|
AutoQueue(const Fst<Arc> &fst,
|
|
const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
|
|
: QueueBase<StateId>(AUTO_QUEUE) {
|
|
using Weight = typename Arc::Weight;
|
|
// We need to have variables of type Less and Compare, so we use
|
|
// ErrorLess if the type NaturalLess<Weight> cannot be instantiated due
|
|
// to lack of path property.
|
|
using Less = std::conditional_t<IsPath<Weight>::value, NaturalLess<Weight>,
|
|
internal::ErrorLess<Weight>>;
|
|
using Compare = internal::StateWeightCompare<StateId, Less>;
|
|
// First checks if the FST is known to have these properties.
|
|
const auto props =
|
|
fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false);
|
|
if ((props & kTopSorted) || fst.Start() == kNoStateId) {
|
|
queue_ = std::make_unique<StateOrderQueue<StateId>>();
|
|
VLOG(2) << "AutoQueue: using state-order discipline";
|
|
} else if (props & kAcyclic) {
|
|
queue_ = std::make_unique<TopOrderQueue<StateId>>(fst, filter);
|
|
VLOG(2) << "AutoQueue: using top-order discipline";
|
|
} else if ((props & kUnweighted) && IsIdempotent<Weight>::value) {
|
|
queue_ = std::make_unique<LifoQueue<StateId>>();
|
|
VLOG(2) << "AutoQueue: using LIFO discipline";
|
|
} else {
|
|
uint64_t properties;
|
|
// Decomposes into strongly-connected components.
|
|
SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
|
|
DfsVisit(fst, &scc_visitor, filter);
|
|
auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
|
|
std::vector<QueueType> queue_types(nscc);
|
|
std::unique_ptr<Less> less;
|
|
std::unique_ptr<Compare> comp;
|
|
if constexpr (IsPath<Weight>::value) {
|
|
if (distance) {
|
|
less = std::make_unique<Less>();
|
|
comp = std::make_unique<Compare>(*distance, *less);
|
|
}
|
|
}
|
|
// Finds the queue type to use per SCC.
|
|
bool unweighted;
|
|
bool all_trivial;
|
|
SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
|
|
&unweighted);
|
|
// If unweighted and semiring is idempotent, uses LIFO queue.
|
|
if (unweighted) {
|
|
queue_ = std::make_unique<LifoQueue<StateId>>();
|
|
VLOG(2) << "AutoQueue: using LIFO discipline";
|
|
return;
|
|
}
|
|
// If all the SCC are trivial, the FST is acyclic and the scc number gives
|
|
// the topological order.
|
|
if (all_trivial) {
|
|
queue_ = std::make_unique<TopOrderQueue<StateId>>(scc_);
|
|
VLOG(2) << "AutoQueue: using top-order discipline";
|
|
return;
|
|
}
|
|
VLOG(2) << "AutoQueue: using SCC meta-discipline";
|
|
queues_.resize(nscc);
|
|
for (StateId i = 0; i < nscc; ++i) {
|
|
switch (queue_types[i]) {
|
|
case TRIVIAL_QUEUE:
|
|
queues_[i].reset();
|
|
VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
|
|
break;
|
|
case SHORTEST_FIRST_QUEUE:
|
|
// The IsPath test is not needed for correctness. It just saves
|
|
// instantiating a ShortestFirstQueue that can never be called.
|
|
if constexpr (IsPath<Weight>::value) {
|
|
queues_[i] =
|
|
std::make_unique<ShortestFirstQueue<StateId, Compare, false>>(
|
|
*comp);
|
|
VLOG(3) << "AutoQueue: SCC #" << i
|
|
<< ": using shortest-first discipline";
|
|
} else {
|
|
// SccQueueType should ensure this can never happen.
|
|
FSTERROR() << "Got SHORTEST_FIRST_QUEUE for non-Path Weight "
|
|
<< Weight::Type();
|
|
queues_[i].reset();
|
|
}
|
|
break;
|
|
case LIFO_QUEUE:
|
|
queues_[i] = std::make_unique<LifoQueue<StateId>>();
|
|
VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
|
|
break;
|
|
case FIFO_QUEUE:
|
|
default:
|
|
queues_[i] = std::make_unique<FifoQueue<StateId>>();
|
|
VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
|
|
break;
|
|
}
|
|
}
|
|
queue_ = std::make_unique<SccQueue<StateId, QueueBase<StateId>>>(
|
|
scc_, &queues_);
|
|
}
|
|
}
|
|
|
|
~AutoQueue() override = default;
|
|
|
|
StateId Head() const final { return queue_->Head(); }
|
|
|
|
void Enqueue(StateId s) final { queue_->Enqueue(s); }
|
|
|
|
void Dequeue() final { queue_->Dequeue(); }
|
|
|
|
void Update(StateId s) final { queue_->Update(s); }
|
|
|
|
bool Empty() const final { return queue_->Empty(); }
|
|
|
|
void Clear() final { queue_->Clear(); }
|
|
|
|
private:
|
|
template <class Arc, class ArcFilter, class Less>
|
|
static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
|
|
std::vector<QueueType> *queue_types,
|
|
ArcFilter filter, Less *less, bool *all_trivial,
|
|
bool *unweighted);
|
|
|
|
std::unique_ptr<QueueBase<StateId>> queue_;
|
|
std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
|
|
std::vector<StateId> scc_;
|
|
};
|
|
|
|
// Examines the states in an FST's strongly connected components and determines
|
|
// which type of queue to use per SCC. Stores result as a vector of QueueTypes
|
|
// which is assumed to have length equal to the number of SCCs. An arc filter
|
|
// is used to limit the transitions considered (e.g., only the epsilon graph).
|
|
// The argument all_trivial is set to true if every queue is the trivial queue.
|
|
// The argument unweighted is set to true if the semiring is idempotent and all
|
|
// the arc weights are equal to Zero() or One().
|
|
template <class StateId>
|
|
template <class Arc, class ArcFilter, class Less>
|
|
void AutoQueue<StateId>::SccQueueType(const Fst<Arc> &fst,
|
|
const std::vector<StateId> &scc,
|
|
std::vector<QueueType> *queue_type,
|
|
ArcFilter filter, Less *less,
|
|
bool *all_trivial, bool *unweighted) {
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
*all_trivial = true;
|
|
*unweighted = true;
|
|
for (StateId i = 0; i < queue_type->size(); ++i) {
|
|
(*queue_type)[i] = TRIVIAL_QUEUE;
|
|
}
|
|
for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
|
|
const auto state = sit.Value();
|
|
for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
|
|
const auto &arc = ait.Value();
|
|
if (!filter(arc)) continue;
|
|
if (scc[state] == scc[arc.nextstate]) {
|
|
auto &type = (*queue_type)[scc[state]];
|
|
if constexpr (!IsPath<Weight>::value) {
|
|
type = FIFO_QUEUE;
|
|
} else if (!less || (*less)(arc.weight, Weight::One())) {
|
|
type = FIFO_QUEUE;
|
|
} else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
|
|
if (!IsIdempotent<Weight>::value ||
|
|
(arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
|
|
type = SHORTEST_FIRST_QUEUE;
|
|
} else {
|
|
type = LIFO_QUEUE;
|
|
}
|
|
}
|
|
if (type != TRIVIAL_QUEUE) *all_trivial = false;
|
|
}
|
|
if (!IsIdempotent<Weight>::value ||
|
|
(arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
|
|
*unweighted = false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// An A* estimate is a function object that maps from a state ID to an
|
|
// estimate of the shortest distance to the final states.
|
|
|
|
// A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
|
|
// algorithm.
|
|
template <typename StateId, typename Weight>
|
|
struct TrivialAStarEstimate {
|
|
constexpr Weight operator()(StateId) const { return Weight::One(); }
|
|
};
|
|
|
|
// A non-trivial A* estimate using a vector of the estimated future costs.
|
|
template <typename StateId, typename Weight>
|
|
class NaturalAStarEstimate {
|
|
public:
|
|
NaturalAStarEstimate(const std::vector<Weight> &beta) : beta_(beta) {}
|
|
|
|
const Weight &operator()(StateId s) const {
|
|
return (s < beta_.size()) ? beta_[s] : kZero;
|
|
}
|
|
|
|
private:
|
|
static constexpr Weight kZero = Weight::Zero();
|
|
|
|
const std::vector<Weight> &beta_;
|
|
};
|
|
|
|
// Given a vector that maps from states to weights representing the shortest
|
|
// distance from the initial state, a comparison function object between
|
|
// weights, and an estimate of the shortest distance to the final states, this
|
|
// class defines a comparison function object between states.
|
|
template <typename S, typename Less, typename Estimate>
|
|
class AStarWeightCompare {
|
|
public:
|
|
using StateId = S;
|
|
using Weight = typename Less::Weight;
|
|
|
|
AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
|
|
const Estimate &estimate)
|
|
: weights_(weights), less_(less), estimate_(estimate) {}
|
|
|
|
bool operator()(StateId s1, StateId s2) const {
|
|
const auto w1 = Times(weights_[s1], estimate_(s1));
|
|
const auto w2 = Times(weights_[s2], estimate_(s2));
|
|
return less_(w1, w2);
|
|
}
|
|
|
|
const Estimate &GetEstimate() const { return estimate_; }
|
|
|
|
private:
|
|
const std::vector<Weight> &weights_;
|
|
const Less &less_;
|
|
const Estimate &estimate_;
|
|
};
|
|
|
|
// A* queue discipline templated on StateId, Weight, and Estimate.
|
|
template <typename S, typename Weight, typename Estimate>
|
|
class NaturalAStarQueue
|
|
: public ShortestFirstQueue<
|
|
S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
|
|
public:
|
|
using StateId = S;
|
|
using Compare = AStarWeightCompare<StateId, NaturalLess<Weight>, Estimate>;
|
|
|
|
NaturalAStarQueue(const std::vector<Weight> &distance,
|
|
const Estimate &estimate)
|
|
: ShortestFirstQueue<StateId, Compare>(
|
|
Compare(distance, less_, estimate)) {}
|
|
|
|
~NaturalAStarQueue() override = default;
|
|
|
|
private:
|
|
// This is non-static because the constructor for non-idempotent weights will
|
|
// result in an error.
|
|
const NaturalLess<Weight> less_{};
|
|
};
|
|
|
|
// A state equivalence class is a function object that maps from a state ID to
|
|
// an equivalence class (state) ID. The trivial equivalence class maps a state
|
|
// ID to itself.
|
|
template <typename StateId>
|
|
struct TrivialStateEquivClass {
|
|
StateId operator()(StateId s) const { return s; }
|
|
};
|
|
|
|
// Distance-based pruning queue discipline: Enqueues a state only when its
|
|
// shortest distance (so far), as specified by distance, is less than (as
|
|
// specified by comp) the shortest distance Times() the threshold to any state
|
|
// in the same equivalence class, as specified by the functor class_func. The
|
|
// underlying queue discipline is specified by queue.
|
|
//
|
|
// This is not a final class.
|
|
template <typename Queue, typename Less, typename ClassFnc>
|
|
class PruneQueue : public QueueBase<typename Queue::StateId> {
|
|
public:
|
|
using StateId = typename Queue::StateId;
|
|
using Weight = typename Less::Weight;
|
|
|
|
PruneQueue(const std::vector<Weight> &distance, std::unique_ptr<Queue> queue,
|
|
const Less &less, const ClassFnc &class_fnc, Weight threshold)
|
|
: QueueBase<StateId>(OTHER_QUEUE),
|
|
distance_(distance),
|
|
queue_(std::move(queue)),
|
|
less_(less),
|
|
class_fnc_(class_fnc),
|
|
threshold_(std::move(threshold)) {}
|
|
|
|
~PruneQueue() override = default;
|
|
|
|
StateId Head() const override { return queue_->Head(); }
|
|
|
|
void Enqueue(StateId s) override {
|
|
const auto c = class_fnc_(s);
|
|
if (c >= class_distance_.size()) {
|
|
class_distance_.resize(c + 1, Weight::Zero());
|
|
}
|
|
if (less_(distance_[s], class_distance_[c])) {
|
|
class_distance_[c] = distance_[s];
|
|
}
|
|
// Enqueues only if below threshold limit.
|
|
const auto limit = Times(class_distance_[c], threshold_);
|
|
if (less_(distance_[s], limit)) queue_->Enqueue(s);
|
|
}
|
|
|
|
void Dequeue() override { queue_->Dequeue(); }
|
|
|
|
void Update(StateId s) override {
|
|
const auto c = class_fnc_(s);
|
|
if (less_(distance_[s], class_distance_[c])) {
|
|
class_distance_[c] = distance_[s];
|
|
}
|
|
queue_->Update(s);
|
|
}
|
|
|
|
bool Empty() const override { return queue_->Empty(); }
|
|
|
|
void Clear() override { queue_->Clear(); }
|
|
|
|
private:
|
|
const std::vector<Weight> &distance_; // Shortest distance to state.
|
|
std::unique_ptr<Queue> queue_;
|
|
const Less &less_; // Borrowed reference.
|
|
const ClassFnc &class_fnc_; // Equivalence class functor.
|
|
Weight threshold_; // Pruning weight threshold.
|
|
std::vector<Weight> class_distance_; // Shortest distance to class.
|
|
};
|
|
|
|
// Pruning queue discipline (see above) using the weight's natural order for the
|
|
// comparison function. The ownership of the queue argument is given to this
|
|
// class.
|
|
template <typename Queue, typename Weight, typename ClassFnc>
|
|
class NaturalPruneQueue final
|
|
: public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
|
|
public:
|
|
using StateId = typename Queue::StateId;
|
|
|
|
NaturalPruneQueue(const std::vector<Weight> &distance,
|
|
std::unique_ptr<Queue> queue, const ClassFnc &class_fnc,
|
|
Weight threshold)
|
|
: PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
|
|
distance, std::move(queue), NaturalLess<Weight>(), class_fnc,
|
|
threshold) {}
|
|
|
|
~NaturalPruneQueue() override = default;
|
|
};
|
|
|
|
// Filter-based pruning queue discipline: enqueues a state only if allowed by
|
|
// the filter, specified by the state filter functor argument. The underlying
|
|
// queue discipline is specified by the queue argument.
|
|
template <typename Queue, typename Filter>
|
|
class FilterQueue : public QueueBase<typename Queue::StateId> {
|
|
public:
|
|
using StateId = typename Queue::StateId;
|
|
|
|
FilterQueue(std::unique_ptr<Queue> queue, const Filter &filter)
|
|
: QueueBase<StateId>(OTHER_QUEUE),
|
|
queue_(std::move(queue)),
|
|
filter_(filter) {}
|
|
|
|
~FilterQueue() override = default;
|
|
|
|
StateId Head() const final { return queue_->Head(); }
|
|
|
|
// Enqueues only if allowed by state filter.
|
|
void Enqueue(StateId s) final {
|
|
if (filter_(s)) queue_->Enqueue(s);
|
|
}
|
|
|
|
void Dequeue() final { queue_->Dequeue(); }
|
|
|
|
void Update(StateId s) final {}
|
|
|
|
bool Empty() const final { return queue_->Empty(); }
|
|
|
|
void Clear() final { queue_->Clear(); }
|
|
|
|
private:
|
|
std::unique_ptr<Queue> queue_;
|
|
const Filter &filter_;
|
|
};
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_QUEUE_H_
|