You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1017 lines
32 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Functions and classes for various FST state queues with a unified interface.
#ifndef FST_QUEUE_H_
#define FST_QUEUE_H_
#include <sys/types.h>
#include <algorithm>
#include <cstdint>
#include <memory>
#include <queue>
#include <type_traits>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arcfilter.h>
#include <fst/cc-visitors.h>
#include <fst/dfs-visit.h>
#include <fst/fst.h>
#include <fst/heap.h>
#include <fst/properties.h>
#include <fst/topsort.h>
#include <fst/util.h>
#include <fst/weight.h>
namespace fst {
// The Queue interface is:
//
// template <class S>
// class Queue {
// public:
// using StateId = S;
//
// // Constructor: may need args (e.g., FST, comparator) for some queues.
// Queue(...) override;
//
// // Returns the head of the queue.
// StateId Head() const override;
//
// // Inserts a state.
// void Enqueue(StateId s) override;
//
// // Removes the head of the queue.
// void Dequeue() override;
//
// // Updates ordering of state s when weight changes, if necessary.
// void Update(StateId s) override;
//
// // Is the queue empty?
// bool Empty() const override;
//
// // Removes all states from the queue.
// void Clear() override;
// };
// State queue types.
enum QueueType {
TRIVIAL_QUEUE = 0, // Single state queue.
FIFO_QUEUE = 1, // First-in, first-out queue.
LIFO_QUEUE = 2, // Last-in, first-out queue.
SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue.
TOP_ORDER_QUEUE = 4, // Topologically-ordered queue.
STATE_ORDER_QUEUE = 5, // State ID-ordered queue.
SCC_QUEUE = 6, // Component graph top-ordered meta-queue.
AUTO_QUEUE = 7, // Auto-selected queue.
OTHER_QUEUE = 8
};
// QueueBase, templated on the StateId, is a virtual base class shared by all
// queues considered by AutoQueue.
template <class S>
class QueueBase {
public:
using StateId = S;
virtual ~QueueBase() = default;
// Concrete implementation.
explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {}
void SetError(bool error) { error_ = error; }
bool Error() const { return error_; }
QueueType Type() const { return queue_type_; }
// Virtual interface.
virtual StateId Head() const = 0;
virtual void Enqueue(StateId) = 0;
virtual void Dequeue() = 0;
virtual void Update(StateId) = 0;
virtual bool Empty() const = 0;
virtual void Clear() = 0;
private:
QueueType queue_type_;
bool error_;
};
// Trivial queue discipline; one may enqueue at most one state at a time. It
// can be used for strongly connected components with only one state and no
// self-loops.
template <class S>
class TrivialQueue : public QueueBase<S> {
public:
using StateId = S;
TrivialQueue() : QueueBase<StateId>(TRIVIAL_QUEUE), front_(kNoStateId) {}
~TrivialQueue() override = default;
StateId Head() const final { return front_; }
void Enqueue(StateId s) final { front_ = s; }
void Dequeue() final { front_ = kNoStateId; }
void Update(StateId) final {}
bool Empty() const final { return front_ == kNoStateId; }
void Clear() final { front_ = kNoStateId; }
private:
StateId front_;
};
// First-in, first-out queue discipline.
//
// This is not a final class.
template <class S>
class FifoQueue : public QueueBase<S> {
public:
using StateId = S;
FifoQueue() : QueueBase<StateId>(FIFO_QUEUE) {}
~FifoQueue() override = default;
StateId Head() const override { return queue_.front(); }
void Enqueue(StateId s) override { queue_.push(s); }
void Dequeue() override { queue_.pop(); }
void Update(StateId) override {}
bool Empty() const override { return queue_.empty(); }
void Clear() override { queue_ = std::queue<StateId>(); }
private:
std::queue<StateId> queue_;
};
// Last-in, first-out queue discipline.
template <class S>
class LifoQueue : public QueueBase<S> {
public:
using StateId = S;
LifoQueue() : QueueBase<StateId>(LIFO_QUEUE) {}
~LifoQueue() override = default;
StateId Head() const final { return stack_.back(); }
void Enqueue(StateId s) final { stack_.push_back(s); }
void Dequeue() final { stack_.pop_back(); }
void Update(StateId) final {}
bool Empty() const final { return stack_.empty(); }
void Clear() final { stack_.clear(); }
private:
std::vector<StateId> stack_;
};
// Shortest-first queue discipline, templated on the StateId and as well as a
// comparison functor used to compare two StateIds. If a (single) state's order
// changes, it can be reordered in the queue with a call to Update(). If update
// is false, call to Update() does not reorder the queue.
//
// This is not a final class.
template <typename S, typename Compare, bool update = true>
class ShortestFirstQueue : public QueueBase<S> {
public:
using StateId = S;
explicit ShortestFirstQueue(Compare comp)
: QueueBase<StateId>(SHORTEST_FIRST_QUEUE), heap_(comp) {}
~ShortestFirstQueue() override = default;
StateId Head() const override { return heap_.Top(); }
void Enqueue(StateId s) override {
if (update) {
for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId);
key_[s] = heap_.Insert(s);
} else {
heap_.Insert(s);
}
}
void Dequeue() override {
if (update) {
key_[heap_.Pop()] = kNoStateId;
} else {
heap_.Pop();
}
}
void Update(StateId s) override {
if (!update) return;
if (s >= key_.size() || key_[s] == kNoStateId) {
Enqueue(s);
} else {
heap_.Update(key_[s], s);
}
}
bool Empty() const override { return heap_.Empty(); }
void Clear() override {
heap_.Clear();
if (update) key_.clear();
}
ssize_t Size() const { return heap_.Size(); }
const Compare &GetCompare() const { return heap_.GetCompare(); }
private:
Heap<StateId, Compare> heap_;
std::vector<ssize_t> key_;
};
namespace internal {
// Given a vector that maps from states to weights, and a comparison functor
// for weights, this class defines a comparison function object between states.
template <typename StateId, typename Less>
class StateWeightCompare {
public:
using Weight = typename Less::Weight;
StateWeightCompare(const std::vector<Weight> &weights, const Less &less)
: weights_(weights), less_(less) {}
bool operator()(const StateId s1, const StateId s2) const {
return less_(weights_[s1], weights_[s2]);
}
private:
// Borrowed references.
const std::vector<Weight> &weights_;
const Less &less_;
};
// Comparison that can never be instantiated. Useful only to pass a pointer to
// this to a function that needs a comparison when it is known that the pointer
// will always be null.
template <class W>
struct ErrorLess {
using Weight = W;
ErrorLess() {
FSTERROR() << "ErrorLess: instantiated for Weight " << Weight::Type();
}
bool operator()(const Weight &, const Weight &) const { return false; }
};
} // namespace internal
// Shortest-first queue discipline, templated on the StateId and Weight, is
// specialized to use the weight's natural order for the comparison function.
// Requires Weight is idempotent (due to use of NaturalLess).
template <typename S, typename Weight>
class NaturalShortestFirstQueue
: public ShortestFirstQueue<
S, internal::StateWeightCompare<S, NaturalLess<Weight>>> {
public:
using StateId = S;
using Less = NaturalLess<Weight>;
using Compare = internal::StateWeightCompare<StateId, Less>;
explicit NaturalShortestFirstQueue(const std::vector<Weight> &distance)
: ShortestFirstQueue<StateId, Compare>(Compare(distance, Less())) {}
~NaturalShortestFirstQueue() override = default;
};
// In a shortest path computation on a lattice-like FST, we may keep many old
// nonviable paths as a part of the search. Since the search process always
// expands the lowest cost path next, that lowest cost path may be a very old
// nonviable path instead of one we expect to lead to a shortest path.
//
// For instance, suppose that the current best path in an alignment has
// traversed 500 arcs with a cost of 10. We may also have a bad path in
// the queue that has traversed only 40 arcs but also has a cost of 10.
// This path is very unlikely to lead to a reasonable alignment, so this queue
// can prune it from the search space.
//
// This queue relies on the caller using a shortest-first exploration order
// like this:
// while (true) {
// StateId head = queue.Head();
// queue.Dequeue();
// for (const auto& arc : GetArcs(fst, head)) {
// queue.Enqueue(arc.nextstate);
// }
// }
// We use this assumption to guess that there is an arc between Head and the
// Enqueued state; this is how the number of path steps is measured.
template <typename S, typename Weight>
class PruneNaturalShortestFirstQueue
: public NaturalShortestFirstQueue<S, Weight> {
public:
using StateId = S;
using Base = NaturalShortestFirstQueue<StateId, Weight>;
PruneNaturalShortestFirstQueue(const std::vector<Weight> &distance,
ssize_t arc_threshold, ssize_t state_limit = 0)
: Base(distance),
arc_threshold_(arc_threshold),
state_limit_(state_limit),
head_steps_(0),
max_head_steps_(0) {}
~PruneNaturalShortestFirstQueue() override = default;
StateId Head() const override {
const auto head = Base::Head();
// Stores the number of steps from the start of the graph to this state
// along the shortest-weight path.
if (head < steps_.size()) {
max_head_steps_ = std::max(steps_[head], max_head_steps_);
head_steps_ = steps_[head];
}
return head;
}
void Enqueue(StateId s) override {
// We assume that there is an arc between the Head() state and this
// Enqueued state.
const ssize_t state_steps = head_steps_ + 1;
if (s >= steps_.size()) {
steps_.resize(s + 1, state_steps);
}
// This is the number of arcs in the minimum cost path from Start to s.
steps_[s] = state_steps;
// Adjust the threshold in cases where path step thresholding wasn't
// enough to keep the queue small.
ssize_t adjusted_threshold = arc_threshold_;
if (Base::Size() > state_limit_ && state_limit_ > 0) {
adjusted_threshold = std::max<ssize_t>(
0, arc_threshold_ - (Base::Size() / state_limit_) - 1);
}
if (state_steps > (max_head_steps_ - adjusted_threshold) ||
arc_threshold_ < 0) {
if (adjusted_threshold == 0 && state_limit_ > 0) {
// If the queue is continuing to grow without bound, we follow any
// path that makes progress and clear the rest.
Base::Clear();
}
Base::Enqueue(s);
}
}
private:
// A dense map from StateId to the number of arcs in the minimum weight
// path from Start to this state.
std::vector<ssize_t> steps_;
// We only keep paths that are within this number of arcs (not weight!)
// of the longest path.
const ssize_t arc_threshold_;
// If the size of the queue climbs above this number, we increase the
// threshold to reduce the amount of work we have to do.
const ssize_t state_limit_;
// The following are mutable because Head() is const.
// The number of arcs traversed in the minimum cost path from the start
// state to the current Head() state.
mutable ssize_t head_steps_;
// The maximum number of arcs traversed by any low-cost path so far.
mutable ssize_t max_head_steps_;
};
// Topological-order queue discipline, templated on the StateId. States are
// ordered in the queue topologically. The FST must be acyclic.
template <class S>
class TopOrderQueue : public QueueBase<S> {
public:
using StateId = S;
// This constructor computes the topological order. It accepts an arc filter
// to limit the transitions considered in that computation (e.g., only the
// epsilon graph).
template <class Arc, class ArcFilter>
TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
: QueueBase<StateId>(TOP_ORDER_QUEUE),
front_(0),
back_(kNoStateId),
order_(0),
state_(0) {
bool acyclic;
TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
DfsVisit(fst, &top_order_visitor, filter);
if (!acyclic) {
FSTERROR() << "TopOrderQueue: FST is not acyclic";
QueueBase<S>::SetError(true);
}
state_.resize(order_.size(), kNoStateId);
}
// This constructor is passed the pre-computed topological order.
explicit TopOrderQueue(const std::vector<StateId> &order)
: QueueBase<StateId>(TOP_ORDER_QUEUE),
front_(0),
back_(kNoStateId),
order_(order),
state_(order.size(), kNoStateId) {}
~TopOrderQueue() override = default;
StateId Head() const final { return state_[front_]; }
void Enqueue(StateId s) final {
if (front_ > back_) {
front_ = back_ = order_[s];
} else if (order_[s] > back_) {
back_ = order_[s];
} else if (order_[s] < front_) {
front_ = order_[s];
}
state_[order_[s]] = s;
}
void Dequeue() final {
state_[front_] = kNoStateId;
while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
}
void Update(StateId) final {}
bool Empty() const final { return front_ > back_; }
void Clear() final {
for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId;
back_ = kNoStateId;
front_ = 0;
}
private:
StateId front_;
StateId back_;
std::vector<StateId> order_;
std::vector<StateId> state_;
};
// State order queue discipline, templated on the StateId. States are ordered in
// the queue by state ID.
template <class S>
class StateOrderQueue : public QueueBase<S> {
public:
using StateId = S;
StateOrderQueue()
: QueueBase<StateId>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}
~StateOrderQueue() override = default;
StateId Head() const final { return front_; }
void Enqueue(StateId s) final {
if (front_ > back_) {
front_ = back_ = s;
} else if (s > back_) {
back_ = s;
} else if (s < front_) {
front_ = s;
}
while (enqueued_.size() <= s) enqueued_.push_back(false);
enqueued_[s] = true;
}
void Dequeue() final {
enqueued_[front_] = false;
while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
}
void Update(StateId) final {}
bool Empty() const final { return front_ > back_; }
void Clear() final {
for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
front_ = 0;
back_ = kNoStateId;
}
private:
StateId front_;
StateId back_;
std::vector<bool> enqueued_;
};
// SCC topological-order meta-queue discipline, templated on the StateId and a
// queue used inside each SCC. It visits the SCCs of an FST in topological
// order. Its constructor is passed the queues to to use within an SCC.
template <class S, class Queue>
class SccQueue : public QueueBase<S> {
public:
using StateId = S;
// Constructor takes a vector specifying the SCC number per state and a
// vector giving the queue to use per SCC number.
SccQueue(const std::vector<StateId> &scc,
std::vector<std::unique_ptr<Queue>> *queue)
: QueueBase<StateId>(SCC_QUEUE),
queue_(queue),
scc_(scc),
front_(0),
back_(kNoStateId) {}
~SccQueue() override = default;
StateId Head() const final {
while ((front_ <= back_) &&
(((*queue_)[front_] && (*queue_)[front_]->Empty()) ||
(((*queue_)[front_] == nullptr) &&
((front_ >= trivial_queue_.size()) ||
(trivial_queue_[front_] == kNoStateId))))) {
++front_;
}
if ((*queue_)[front_]) {
return (*queue_)[front_]->Head();
} else {
return trivial_queue_[front_];
}
}
void Enqueue(StateId s) final {
if (front_ > back_) {
front_ = back_ = scc_[s];
} else if (scc_[s] > back_) {
back_ = scc_[s];
} else if (scc_[s] < front_) {
front_ = scc_[s];
}
if ((*queue_)[scc_[s]]) {
(*queue_)[scc_[s]]->Enqueue(s);
} else {
while (trivial_queue_.size() <= scc_[s]) {
trivial_queue_.push_back(kNoStateId);
}
trivial_queue_[scc_[s]] = s;
}
}
void Dequeue() final {
if ((*queue_)[front_]) {
(*queue_)[front_]->Dequeue();
} else if (front_ < trivial_queue_.size()) {
trivial_queue_[front_] = kNoStateId;
}
}
void Update(StateId s) final {
if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s);
}
bool Empty() const final {
// Queues SCC number back_ is not empty unless back_ == front_.
if (front_ < back_) {
return false;
} else if (front_ > back_) {
return true;
} else if ((*queue_)[front_]) {
return (*queue_)[front_]->Empty();
} else {
return (front_ >= trivial_queue_.size()) ||
(trivial_queue_[front_] == kNoStateId);
}
}
void Clear() final {
for (StateId i = front_; i <= back_; ++i) {
if ((*queue_)[i]) {
(*queue_)[i]->Clear();
} else if (i < trivial_queue_.size()) {
trivial_queue_[i] = kNoStateId;
}
}
front_ = 0;
back_ = kNoStateId;
}
private:
std::vector<std::unique_ptr<Queue>> *queue_;
const std::vector<StateId> &scc_;
mutable StateId front_;
StateId back_;
std::vector<StateId> trivial_queue_;
};
// Automatic queue discipline. It selects a queue discipline for a given FST
// based on its properties.
template <class S>
class AutoQueue : public QueueBase<S> {
public:
using StateId = S;
// This constructor takes a state distance vector that, if non-null and if
// the Weight type has the path property, will entertain the shortest-first
// queue using the natural order w.r.t to the distance.
template <class Arc, class ArcFilter>
AutoQueue(const Fst<Arc> &fst,
const std::vector<typename Arc::Weight> *distance, ArcFilter filter)
: QueueBase<StateId>(AUTO_QUEUE) {
using Weight = typename Arc::Weight;
// We need to have variables of type Less and Compare, so we use
// ErrorLess if the type NaturalLess<Weight> cannot be instantiated due
// to lack of path property.
using Less = std::conditional_t<IsPath<Weight>::value, NaturalLess<Weight>,
internal::ErrorLess<Weight>>;
using Compare = internal::StateWeightCompare<StateId, Less>;
// First checks if the FST is known to have these properties.
const auto props =
fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false);
if ((props & kTopSorted) || fst.Start() == kNoStateId) {
queue_ = std::make_unique<StateOrderQueue<StateId>>();
VLOG(2) << "AutoQueue: using state-order discipline";
} else if (props & kAcyclic) {
queue_ = std::make_unique<TopOrderQueue<StateId>>(fst, filter);
VLOG(2) << "AutoQueue: using top-order discipline";
} else if ((props & kUnweighted) && IsIdempotent<Weight>::value) {
queue_ = std::make_unique<LifoQueue<StateId>>();
VLOG(2) << "AutoQueue: using LIFO discipline";
} else {
uint64_t properties;
// Decomposes into strongly-connected components.
SccVisitor<Arc> scc_visitor(&scc_, nullptr, nullptr, &properties);
DfsVisit(fst, &scc_visitor, filter);
auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1;
std::vector<QueueType> queue_types(nscc);
std::unique_ptr<Less> less;
std::unique_ptr<Compare> comp;
if constexpr (IsPath<Weight>::value) {
if (distance) {
less = std::make_unique<Less>();
comp = std::make_unique<Compare>(*distance, *less);
}
}
// Finds the queue type to use per SCC.
bool unweighted;
bool all_trivial;
SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial,
&unweighted);
// If unweighted and semiring is idempotent, uses LIFO queue.
if (unweighted) {
queue_ = std::make_unique<LifoQueue<StateId>>();
VLOG(2) << "AutoQueue: using LIFO discipline";
return;
}
// If all the SCC are trivial, the FST is acyclic and the scc number gives
// the topological order.
if (all_trivial) {
queue_ = std::make_unique<TopOrderQueue<StateId>>(scc_);
VLOG(2) << "AutoQueue: using top-order discipline";
return;
}
VLOG(2) << "AutoQueue: using SCC meta-discipline";
queues_.resize(nscc);
for (StateId i = 0; i < nscc; ++i) {
switch (queue_types[i]) {
case TRIVIAL_QUEUE:
queues_[i].reset();
VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline";
break;
case SHORTEST_FIRST_QUEUE:
// The IsPath test is not needed for correctness. It just saves
// instantiating a ShortestFirstQueue that can never be called.
if constexpr (IsPath<Weight>::value) {
queues_[i] =
std::make_unique<ShortestFirstQueue<StateId, Compare, false>>(
*comp);
VLOG(3) << "AutoQueue: SCC #" << i
<< ": using shortest-first discipline";
} else {
// SccQueueType should ensure this can never happen.
FSTERROR() << "Got SHORTEST_FIRST_QUEUE for non-Path Weight "
<< Weight::Type();
queues_[i].reset();
}
break;
case LIFO_QUEUE:
queues_[i] = std::make_unique<LifoQueue<StateId>>();
VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline";
break;
case FIFO_QUEUE:
default:
queues_[i] = std::make_unique<FifoQueue<StateId>>();
VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine";
break;
}
}
queue_ = std::make_unique<SccQueue<StateId, QueueBase<StateId>>>(
scc_, &queues_);
}
}
~AutoQueue() override = default;
StateId Head() const final { return queue_->Head(); }
void Enqueue(StateId s) final { queue_->Enqueue(s); }
void Dequeue() final { queue_->Dequeue(); }
void Update(StateId s) final { queue_->Update(s); }
bool Empty() const final { return queue_->Empty(); }
void Clear() final { queue_->Clear(); }
private:
template <class Arc, class ArcFilter, class Less>
static void SccQueueType(const Fst<Arc> &fst, const std::vector<StateId> &scc,
std::vector<QueueType> *queue_types,
ArcFilter filter, Less *less, bool *all_trivial,
bool *unweighted);
std::unique_ptr<QueueBase<StateId>> queue_;
std::vector<std::unique_ptr<QueueBase<StateId>>> queues_;
std::vector<StateId> scc_;
};
// Examines the states in an FST's strongly connected components and determines
// which type of queue to use per SCC. Stores result as a vector of QueueTypes
// which is assumed to have length equal to the number of SCCs. An arc filter
// is used to limit the transitions considered (e.g., only the epsilon graph).
// The argument all_trivial is set to true if every queue is the trivial queue.
// The argument unweighted is set to true if the semiring is idempotent and all
// the arc weights are equal to Zero() or One().
template <class StateId>
template <class Arc, class ArcFilter, class Less>
void AutoQueue<StateId>::SccQueueType(const Fst<Arc> &fst,
const std::vector<StateId> &scc,
std::vector<QueueType> *queue_type,
ArcFilter filter, Less *less,
bool *all_trivial, bool *unweighted) {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
*all_trivial = true;
*unweighted = true;
for (StateId i = 0; i < queue_type->size(); ++i) {
(*queue_type)[i] = TRIVIAL_QUEUE;
}
for (StateIterator<Fst<Arc>> sit(fst); !sit.Done(); sit.Next()) {
const auto state = sit.Value();
for (ArcIterator<Fst<Arc>> ait(fst, state); !ait.Done(); ait.Next()) {
const auto &arc = ait.Value();
if (!filter(arc)) continue;
if (scc[state] == scc[arc.nextstate]) {
auto &type = (*queue_type)[scc[state]];
if constexpr (!IsPath<Weight>::value) {
type = FIFO_QUEUE;
} else if (!less || (*less)(arc.weight, Weight::One())) {
type = FIFO_QUEUE;
} else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) {
if (!IsIdempotent<Weight>::value ||
(arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
type = SHORTEST_FIRST_QUEUE;
} else {
type = LIFO_QUEUE;
}
}
if (type != TRIVIAL_QUEUE) *all_trivial = false;
}
if (!IsIdempotent<Weight>::value ||
(arc.weight != Weight::Zero() && arc.weight != Weight::One())) {
*unweighted = false;
}
}
}
}
// An A* estimate is a function object that maps from a state ID to an
// estimate of the shortest distance to the final states.
// A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's
// algorithm.
template <typename StateId, typename Weight>
struct TrivialAStarEstimate {
constexpr Weight operator()(StateId) const { return Weight::One(); }
};
// A non-trivial A* estimate using a vector of the estimated future costs.
template <typename StateId, typename Weight>
class NaturalAStarEstimate {
public:
NaturalAStarEstimate(const std::vector<Weight> &beta) : beta_(beta) {}
const Weight &operator()(StateId s) const {
return (s < beta_.size()) ? beta_[s] : kZero;
}
private:
static constexpr Weight kZero = Weight::Zero();
const std::vector<Weight> &beta_;
};
// Given a vector that maps from states to weights representing the shortest
// distance from the initial state, a comparison function object between
// weights, and an estimate of the shortest distance to the final states, this
// class defines a comparison function object between states.
template <typename S, typename Less, typename Estimate>
class AStarWeightCompare {
public:
using StateId = S;
using Weight = typename Less::Weight;
AStarWeightCompare(const std::vector<Weight> &weights, const Less &less,
const Estimate &estimate)
: weights_(weights), less_(less), estimate_(estimate) {}
bool operator()(StateId s1, StateId s2) const {
const auto w1 = Times(weights_[s1], estimate_(s1));
const auto w2 = Times(weights_[s2], estimate_(s2));
return less_(w1, w2);
}
const Estimate &GetEstimate() const { return estimate_; }
private:
const std::vector<Weight> &weights_;
const Less &less_;
const Estimate &estimate_;
};
// A* queue discipline templated on StateId, Weight, and Estimate.
template <typename S, typename Weight, typename Estimate>
class NaturalAStarQueue
: public ShortestFirstQueue<
S, AStarWeightCompare<S, NaturalLess<Weight>, Estimate>> {
public:
using StateId = S;
using Compare = AStarWeightCompare<StateId, NaturalLess<Weight>, Estimate>;
NaturalAStarQueue(const std::vector<Weight> &distance,
const Estimate &estimate)
: ShortestFirstQueue<StateId, Compare>(
Compare(distance, less_, estimate)) {}
~NaturalAStarQueue() override = default;
private:
// This is non-static because the constructor for non-idempotent weights will
// result in an error.
const NaturalLess<Weight> less_{};
};
// A state equivalence class is a function object that maps from a state ID to
// an equivalence class (state) ID. The trivial equivalence class maps a state
// ID to itself.
template <typename StateId>
struct TrivialStateEquivClass {
StateId operator()(StateId s) const { return s; }
};
// Distance-based pruning queue discipline: Enqueues a state only when its
// shortest distance (so far), as specified by distance, is less than (as
// specified by comp) the shortest distance Times() the threshold to any state
// in the same equivalence class, as specified by the functor class_func. The
// underlying queue discipline is specified by queue.
//
// This is not a final class.
template <typename Queue, typename Less, typename ClassFnc>
class PruneQueue : public QueueBase<typename Queue::StateId> {
public:
using StateId = typename Queue::StateId;
using Weight = typename Less::Weight;
PruneQueue(const std::vector<Weight> &distance, std::unique_ptr<Queue> queue,
const Less &less, const ClassFnc &class_fnc, Weight threshold)
: QueueBase<StateId>(OTHER_QUEUE),
distance_(distance),
queue_(std::move(queue)),
less_(less),
class_fnc_(class_fnc),
threshold_(std::move(threshold)) {}
~PruneQueue() override = default;
StateId Head() const override { return queue_->Head(); }
void Enqueue(StateId s) override {
const auto c = class_fnc_(s);
if (c >= class_distance_.size()) {
class_distance_.resize(c + 1, Weight::Zero());
}
if (less_(distance_[s], class_distance_[c])) {
class_distance_[c] = distance_[s];
}
// Enqueues only if below threshold limit.
const auto limit = Times(class_distance_[c], threshold_);
if (less_(distance_[s], limit)) queue_->Enqueue(s);
}
void Dequeue() override { queue_->Dequeue(); }
void Update(StateId s) override {
const auto c = class_fnc_(s);
if (less_(distance_[s], class_distance_[c])) {
class_distance_[c] = distance_[s];
}
queue_->Update(s);
}
bool Empty() const override { return queue_->Empty(); }
void Clear() override { queue_->Clear(); }
private:
const std::vector<Weight> &distance_; // Shortest distance to state.
std::unique_ptr<Queue> queue_;
const Less &less_; // Borrowed reference.
const ClassFnc &class_fnc_; // Equivalence class functor.
Weight threshold_; // Pruning weight threshold.
std::vector<Weight> class_distance_; // Shortest distance to class.
};
// Pruning queue discipline (see above) using the weight's natural order for the
// comparison function. The ownership of the queue argument is given to this
// class.
template <typename Queue, typename Weight, typename ClassFnc>
class NaturalPruneQueue final
: public PruneQueue<Queue, NaturalLess<Weight>, ClassFnc> {
public:
using StateId = typename Queue::StateId;
NaturalPruneQueue(const std::vector<Weight> &distance,
std::unique_ptr<Queue> queue, const ClassFnc &class_fnc,
Weight threshold)
: PruneQueue<Queue, NaturalLess<Weight>, ClassFnc>(
distance, std::move(queue), NaturalLess<Weight>(), class_fnc,
threshold) {}
~NaturalPruneQueue() override = default;
};
// Filter-based pruning queue discipline: enqueues a state only if allowed by
// the filter, specified by the state filter functor argument. The underlying
// queue discipline is specified by the queue argument.
template <typename Queue, typename Filter>
class FilterQueue : public QueueBase<typename Queue::StateId> {
public:
using StateId = typename Queue::StateId;
FilterQueue(std::unique_ptr<Queue> queue, const Filter &filter)
: QueueBase<StateId>(OTHER_QUEUE),
queue_(std::move(queue)),
filter_(filter) {}
~FilterQueue() override = default;
StateId Head() const final { return queue_->Head(); }
// Enqueues only if allowed by state filter.
void Enqueue(StateId s) final {
if (filter_(s)) queue_->Enqueue(s);
}
void Dequeue() final { queue_->Dequeue(); }
void Update(StateId s) final {}
bool Empty() const final { return queue_->Empty(); }
void Clear() final { queue_->Clear(); }
private:
std::unique_ptr<Queue> queue_;
const Filter &filter_;
};
} // namespace fst
#endif // FST_QUEUE_H_