|
// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Functions and classes to find shortest distance in an FST.
|
|
|
|
#ifndef FST_SHORTEST_DISTANCE_H_
|
|
#define FST_SHORTEST_DISTANCE_H_
|
|
|
|
#include <cstddef>
|
|
#include <vector>
|
|
|
|
#include <fst/log.h>
|
|
#include <fst/arc.h>
|
|
#include <fst/arcfilter.h>
|
|
#include <fst/cache.h>
|
|
#include <fst/equal.h>
|
|
#include <fst/expanded-fst.h>
|
|
#include <fst/fst.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/queue.h>
|
|
#include <fst/reverse.h>
|
|
#include <fst/util.h>
|
|
#include <fst/vector-fst.h>
|
|
#include <fst/weight.h>
|
|
|
|
namespace fst {
|
|
|
|
// A representable float for shortest distance and shortest path algorithms.
|
|
inline constexpr float kShortestDelta = 1e-6;
|
|
|
|
template <class Arc, class Queue, class ArcFilter>
|
|
struct ShortestDistanceOptions {
|
|
using StateId = typename Arc::StateId;
|
|
|
|
Queue *state_queue; // Queue discipline used; owned by caller.
|
|
ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
|
|
StateId source; // If kNoStateId, use the FST's initial state.
|
|
float delta; // Determines the degree of convergence required
|
|
bool first_path; // For a semiring with the path property (o.w.
|
|
// undefined), compute the shortest-distances along
|
|
// along the first path to a final state found
|
|
// by the algorithm. That path is the shortest-path
|
|
// only if the FST has a unique final state (or all
|
|
// the final states have the same final weight), the
|
|
// queue discipline is shortest-first and all the
|
|
// weights in the FST are between One() and Zero()
|
|
// according to NaturalLess.
|
|
|
|
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
|
|
StateId source = kNoStateId,
|
|
float delta = kShortestDelta, bool first_path = false)
|
|
: state_queue(state_queue),
|
|
arc_filter(arc_filter),
|
|
source(source),
|
|
delta(delta),
|
|
first_path(first_path) {}
|
|
};
|
|
|
|
namespace internal {
|
|
|
|
// Computation state of the shortest-distance algorithm. Reusable information
|
|
// is maintained across calls to member function ShortestDistance(source) when
|
|
// retain is true for improved efficiency when calling multiple times from
|
|
// different source states (e.g., in epsilon removal). Contrary to the usual
|
|
// conventions, fst may not be freed before this class. Vector distance
|
|
// should not be modified by the user between these calls. The Error() method
|
|
// returns true iff an error was encountered.
|
|
template <class Arc, class Queue, class ArcFilter,
|
|
class WeightEqual = WeightApproxEqual>
|
|
class ShortestDistanceState {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
ShortestDistanceState(
|
|
const Fst<Arc> &fst, std::vector<Weight> *distance,
|
|
const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
|
|
: fst_(fst),
|
|
distance_(distance),
|
|
state_queue_(opts.state_queue),
|
|
arc_filter_(opts.arc_filter),
|
|
weight_equal_(opts.delta),
|
|
first_path_(opts.first_path),
|
|
retain_(retain),
|
|
source_id_(0),
|
|
error_(false) {
|
|
distance_->clear();
|
|
if (std::optional<StateId> num_states = fst.NumStatesIfKnown()) {
|
|
distance_->reserve(*num_states);
|
|
adder_.reserve(*num_states);
|
|
radder_.reserve(*num_states);
|
|
enqueued_.reserve(*num_states);
|
|
}
|
|
}
|
|
|
|
void ShortestDistance(StateId source);
|
|
|
|
bool Error() const { return error_; }
|
|
|
|
private:
|
|
void EnsureDistanceIndexIsValid(std::size_t index) {
|
|
while (distance_->size() <= index) {
|
|
distance_->push_back(Weight::Zero());
|
|
adder_.push_back(Adder<Weight>());
|
|
radder_.push_back(Adder<Weight>());
|
|
enqueued_.push_back(false);
|
|
}
|
|
DCHECK_LT(index, distance_->size());
|
|
}
|
|
|
|
void EnsureSourcesIndexIsValid(std::size_t index) {
|
|
while (sources_.size() <= index) {
|
|
sources_.push_back(kNoStateId);
|
|
}
|
|
DCHECK_LT(index, sources_.size());
|
|
}
|
|
|
|
const Fst<Arc> &fst_;
|
|
std::vector<Weight> *distance_;
|
|
Queue *state_queue_;
|
|
ArcFilter arc_filter_;
|
|
WeightEqual weight_equal_; // Determines when relaxation stops.
|
|
const bool first_path_;
|
|
const bool retain_; // Retain and reuse information across calls.
|
|
|
|
std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
|
|
std::vector<Adder<Weight>> radder_; // Relaxation distance.
|
|
std::vector<bool> enqueued_; // Is state enqueued?
|
|
std::vector<StateId> sources_; // Source ID for ith state in distance_,
|
|
// (r)adder_, and enqueued_ if retained.
|
|
StateId source_id_; // Unique ID characterizing each call.
|
|
bool error_;
|
|
};
|
|
|
|
// Compute the shortest distance; if source is kNoStateId, uses the initial
|
|
// state of the FST.
|
|
template <class Arc, class Queue, class ArcFilter, class WeightEqual>
|
|
void ShortestDistanceState<Arc, Queue, ArcFilter,
|
|
WeightEqual>::ShortestDistance(StateId source) {
|
|
if (fst_.Start() == kNoStateId) {
|
|
if (fst_.Properties(kError, false)) error_ = true;
|
|
return;
|
|
}
|
|
if (!(Weight::Properties() & kRightSemiring)) {
|
|
FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
|
|
<< Weight::Type();
|
|
error_ = true;
|
|
return;
|
|
}
|
|
if (first_path_ && !(Weight::Properties() & kPath)) {
|
|
FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
|
|
<< "Weight does not have the path property: " << Weight::Type();
|
|
error_ = true;
|
|
return;
|
|
}
|
|
state_queue_->Clear();
|
|
if (!retain_) {
|
|
distance_->clear();
|
|
adder_.clear();
|
|
radder_.clear();
|
|
enqueued_.clear();
|
|
}
|
|
if (source == kNoStateId) source = fst_.Start();
|
|
EnsureDistanceIndexIsValid(source);
|
|
if (retain_) {
|
|
EnsureSourcesIndexIsValid(source);
|
|
sources_[source] = source_id_;
|
|
}
|
|
(*distance_)[source] = Weight::One();
|
|
adder_[source].Reset(Weight::One());
|
|
radder_[source].Reset(Weight::One());
|
|
enqueued_[source] = true;
|
|
state_queue_->Enqueue(source);
|
|
while (!state_queue_->Empty()) {
|
|
const auto state = state_queue_->Head();
|
|
state_queue_->Dequeue();
|
|
EnsureDistanceIndexIsValid(state);
|
|
if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
|
|
enqueued_[state] = false;
|
|
const auto r = radder_[state].Sum();
|
|
radder_[state].Reset();
|
|
for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
|
|
aiter.Next()) {
|
|
const auto &arc = aiter.Value();
|
|
const auto nextstate = arc.nextstate;
|
|
if (!arc_filter_(arc)) continue;
|
|
EnsureDistanceIndexIsValid(nextstate);
|
|
if (retain_) {
|
|
EnsureSourcesIndexIsValid(nextstate);
|
|
if (sources_[nextstate] != source_id_) {
|
|
(*distance_)[nextstate] = Weight::Zero();
|
|
adder_[nextstate].Reset();
|
|
radder_[nextstate].Reset();
|
|
enqueued_[nextstate] = false;
|
|
sources_[nextstate] = source_id_;
|
|
}
|
|
}
|
|
auto &nd = (*distance_)[nextstate];
|
|
auto &na = adder_[nextstate];
|
|
auto &nr = radder_[nextstate];
|
|
auto weight = Times(r, arc.weight);
|
|
if (!weight_equal_(nd, Plus(nd, weight))) {
|
|
nd = na.Add(weight);
|
|
nr.Add(weight);
|
|
if (!nd.Member() || !nr.Sum().Member()) {
|
|
error_ = true;
|
|
return;
|
|
}
|
|
if (!enqueued_[nextstate]) {
|
|
state_queue_->Enqueue(nextstate);
|
|
enqueued_[nextstate] = true;
|
|
} else {
|
|
state_queue_->Update(nextstate);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
++source_id_;
|
|
if (fst_.Properties(kError, false)) error_ = true;
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
// Shortest-distance algorithm: this version allows fine control
|
|
// via the options argument. See below for a simpler interface.
|
|
//
|
|
// This computes the shortest distance from the opts.source state to each
|
|
// visited state S and stores the value in the distance vector. An
|
|
// unvisited state S has distance Zero(), which will be stored in the
|
|
// distance vector if S is less than the maximum visited state. The state
|
|
// queue discipline, arc filter, and convergence delta are taken in the
|
|
// options argument. The distance vector will contain a unique element for
|
|
// which Member() is false if an error was encountered.
|
|
//
|
|
// The weights must must be right distributive and k-closed (i.e., 1 +
|
|
// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
|
|
//
|
|
// Complexity:
|
|
//
|
|
// Depends on properties of the semiring and the queue discipline.
|
|
//
|
|
// For more information, see:
|
|
//
|
|
// Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
|
|
// problems, Journal of Automata, Languages and
|
|
// Combinatorics 7(3): 321-350, 2002.
|
|
template <class Arc, class Queue, class ArcFilter>
|
|
void ShortestDistance(
|
|
const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
|
|
const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
|
|
internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance,
|
|
opts, false);
|
|
sd_state.ShortestDistance(opts.source);
|
|
if (sd_state.Error()) {
|
|
distance->assign(1, Arc::Weight::NoWeight());
|
|
}
|
|
}
|
|
|
|
// Shortest-distance algorithm: simplified interface. See above for a version
|
|
// that permits finer control.
|
|
//
|
|
// If reverse is false, this computes the shortest distance from the initial
|
|
// state to each state S and stores the value in the distance vector. If
|
|
// reverse is true, this computes the shortest distance from each state to the
|
|
// final states. An unvisited state S has distance Zero(), which will be stored
|
|
// in the distance vector if S is less than the maximum visited state. The
|
|
// state queue discipline is automatically-selected. The distance vector will
|
|
// contain a unique element for which Member() is false if an error was
|
|
// encountered.
|
|
//
|
|
// The weights must must be right (left) distributive if reverse is false (true)
|
|
// and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
|
|
//
|
|
// Arc weights must satisfy the property that the sum of the weights of one or
|
|
// more paths from some state S to T is never Zero(). In particular, arc weights
|
|
// are never Zero().
|
|
//
|
|
// Complexity:
|
|
//
|
|
// Depends on properties of the semiring and the queue discipline.
|
|
//
|
|
// For more information, see:
|
|
//
|
|
// Mohri, M. 2002. Semiring framework and algorithms for
|
|
// shortest-distance problems, Journal of Automata, Languages and
|
|
// Combinatorics 7(3): 321-350, 2002.
|
|
template <class Arc>
|
|
void ShortestDistance(const Fst<Arc> &fst,
|
|
std::vector<typename Arc::Weight> *distance,
|
|
bool reverse = false, float delta = kShortestDelta) {
|
|
using StateId = typename Arc::StateId;
|
|
if (!reverse) {
|
|
AnyArcFilter<Arc> arc_filter;
|
|
AutoQueue<StateId> state_queue(fst, distance, arc_filter);
|
|
const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>>
|
|
opts(&state_queue, arc_filter, kNoStateId, delta);
|
|
ShortestDistance(fst, distance, opts);
|
|
} else {
|
|
using ReverseArc = ReverseArc<Arc>;
|
|
using ReverseWeight = typename ReverseArc::Weight;
|
|
AnyArcFilter<ReverseArc> rarc_filter;
|
|
VectorFst<ReverseArc> rfst;
|
|
Reverse(fst, &rfst);
|
|
std::vector<ReverseWeight> rdistance;
|
|
AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
|
|
const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>,
|
|
AnyArcFilter<ReverseArc>>
|
|
ropts(&state_queue, rarc_filter, kNoStateId, delta);
|
|
ShortestDistance(rfst, &rdistance, ropts);
|
|
distance->clear();
|
|
if (rdistance.size() == 1 && !rdistance[0].Member()) {
|
|
distance->assign(1, Arc::Weight::NoWeight());
|
|
return;
|
|
}
|
|
DCHECK_GE(rdistance.size(), 1); // reversing added one state
|
|
distance->reserve(rdistance.size() - 1);
|
|
while (distance->size() < rdistance.size() - 1) {
|
|
distance->push_back(rdistance[distance->size() + 1].Reverse());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Return the sum of the weight of all successful paths in an FST, i.e., the
|
|
// shortest-distance from the initial state to the final states. Returns a
|
|
// weight such that Member() is false if an error was encountered.
|
|
template <class Arc>
|
|
typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
|
|
float delta = kShortestDelta) {
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
std::vector<Weight> distance;
|
|
if (Weight::Properties() & kRightSemiring) {
|
|
ShortestDistance(fst, &distance, false, delta);
|
|
if (distance.size() == 1 && !distance[0].Member()) {
|
|
return Arc::Weight::NoWeight();
|
|
}
|
|
Adder<Weight> adder; // maintains cumulative sum accurately
|
|
for (StateId state = 0; state < distance.size(); ++state) {
|
|
adder.Add(Times(distance[state], fst.Final(state)));
|
|
}
|
|
return adder.Sum();
|
|
} else {
|
|
ShortestDistance(fst, &distance, true, delta);
|
|
const auto state = fst.Start();
|
|
if (distance.size() == 1 && !distance[0].Member()) {
|
|
return Arc::Weight::NoWeight();
|
|
}
|
|
return state != kNoStateId && state < distance.size() ? distance[state]
|
|
: Weight::Zero();
|
|
}
|
|
}
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_SHORTEST_DISTANCE_H_
|