You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

369 lines
14 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Functions and classes to find shortest distance in an FST.
#ifndef FST_SHORTEST_DISTANCE_H_
#define FST_SHORTEST_DISTANCE_H_
#include <cstddef>
#include <vector>
#include <fst/log.h>
#include <fst/arc.h>
#include <fst/arcfilter.h>
#include <fst/cache.h>
#include <fst/equal.h>
#include <fst/expanded-fst.h>
#include <fst/fst.h>
#include <fst/properties.h>
#include <fst/queue.h>
#include <fst/reverse.h>
#include <fst/util.h>
#include <fst/vector-fst.h>
#include <fst/weight.h>
namespace fst {
// A representable float for shortest distance and shortest path algorithms.
inline constexpr float kShortestDelta = 1e-6;
template <class Arc, class Queue, class ArcFilter>
struct ShortestDistanceOptions {
using StateId = typename Arc::StateId;
Queue *state_queue; // Queue discipline used; owned by caller.
ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
StateId source; // If kNoStateId, use the FST's initial state.
float delta; // Determines the degree of convergence required
bool first_path; // For a semiring with the path property (o.w.
// undefined), compute the shortest-distances along
// along the first path to a final state found
// by the algorithm. That path is the shortest-path
// only if the FST has a unique final state (or all
// the final states have the same final weight), the
// queue discipline is shortest-first and all the
// weights in the FST are between One() and Zero()
// according to NaturalLess.
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter,
StateId source = kNoStateId,
float delta = kShortestDelta, bool first_path = false)
: state_queue(state_queue),
arc_filter(arc_filter),
source(source),
delta(delta),
first_path(first_path) {}
};
namespace internal {
// Computation state of the shortest-distance algorithm. Reusable information
// is maintained across calls to member function ShortestDistance(source) when
// retain is true for improved efficiency when calling multiple times from
// different source states (e.g., in epsilon removal). Contrary to the usual
// conventions, fst may not be freed before this class. Vector distance
// should not be modified by the user between these calls. The Error() method
// returns true iff an error was encountered.
template <class Arc, class Queue, class ArcFilter,
class WeightEqual = WeightApproxEqual>
class ShortestDistanceState {
public:
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
ShortestDistanceState(
const Fst<Arc> &fst, std::vector<Weight> *distance,
const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain)
: fst_(fst),
distance_(distance),
state_queue_(opts.state_queue),
arc_filter_(opts.arc_filter),
weight_equal_(opts.delta),
first_path_(opts.first_path),
retain_(retain),
source_id_(0),
error_(false) {
distance_->clear();
if (std::optional<StateId> num_states = fst.NumStatesIfKnown()) {
distance_->reserve(*num_states);
adder_.reserve(*num_states);
radder_.reserve(*num_states);
enqueued_.reserve(*num_states);
}
}
void ShortestDistance(StateId source);
bool Error() const { return error_; }
private:
void EnsureDistanceIndexIsValid(std::size_t index) {
while (distance_->size() <= index) {
distance_->push_back(Weight::Zero());
adder_.push_back(Adder<Weight>());
radder_.push_back(Adder<Weight>());
enqueued_.push_back(false);
}
DCHECK_LT(index, distance_->size());
}
void EnsureSourcesIndexIsValid(std::size_t index) {
while (sources_.size() <= index) {
sources_.push_back(kNoStateId);
}
DCHECK_LT(index, sources_.size());
}
const Fst<Arc> &fst_;
std::vector<Weight> *distance_;
Queue *state_queue_;
ArcFilter arc_filter_;
WeightEqual weight_equal_; // Determines when relaxation stops.
const bool first_path_;
const bool retain_; // Retain and reuse information across calls.
std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
std::vector<Adder<Weight>> radder_; // Relaxation distance.
std::vector<bool> enqueued_; // Is state enqueued?
std::vector<StateId> sources_; // Source ID for ith state in distance_,
// (r)adder_, and enqueued_ if retained.
StateId source_id_; // Unique ID characterizing each call.
bool error_;
};
// Compute the shortest distance; if source is kNoStateId, uses the initial
// state of the FST.
template <class Arc, class Queue, class ArcFilter, class WeightEqual>
void ShortestDistanceState<Arc, Queue, ArcFilter,
WeightEqual>::ShortestDistance(StateId source) {
if (fst_.Start() == kNoStateId) {
if (fst_.Properties(kError, false)) error_ = true;
return;
}
if (!(Weight::Properties() & kRightSemiring)) {
FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
<< Weight::Type();
error_ = true;
return;
}
if (first_path_ && !(Weight::Properties() & kPath)) {
FSTERROR() << "ShortestDistance: The first_path option is disallowed when "
<< "Weight does not have the path property: " << Weight::Type();
error_ = true;
return;
}
state_queue_->Clear();
if (!retain_) {
distance_->clear();
adder_.clear();
radder_.clear();
enqueued_.clear();
}
if (source == kNoStateId) source = fst_.Start();
EnsureDistanceIndexIsValid(source);
if (retain_) {
EnsureSourcesIndexIsValid(source);
sources_[source] = source_id_;
}
(*distance_)[source] = Weight::One();
adder_[source].Reset(Weight::One());
radder_[source].Reset(Weight::One());
enqueued_[source] = true;
state_queue_->Enqueue(source);
while (!state_queue_->Empty()) {
const auto state = state_queue_->Head();
state_queue_->Dequeue();
EnsureDistanceIndexIsValid(state);
if (first_path_ && (fst_.Final(state) != Weight::Zero())) break;
enqueued_[state] = false;
const auto r = radder_[state].Sum();
radder_[state].Reset();
for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
aiter.Next()) {
const auto &arc = aiter.Value();
const auto nextstate = arc.nextstate;
if (!arc_filter_(arc)) continue;
EnsureDistanceIndexIsValid(nextstate);
if (retain_) {
EnsureSourcesIndexIsValid(nextstate);
if (sources_[nextstate] != source_id_) {
(*distance_)[nextstate] = Weight::Zero();
adder_[nextstate].Reset();
radder_[nextstate].Reset();
enqueued_[nextstate] = false;
sources_[nextstate] = source_id_;
}
}
auto &nd = (*distance_)[nextstate];
auto &na = adder_[nextstate];
auto &nr = radder_[nextstate];
auto weight = Times(r, arc.weight);
if (!weight_equal_(nd, Plus(nd, weight))) {
nd = na.Add(weight);
nr.Add(weight);
if (!nd.Member() || !nr.Sum().Member()) {
error_ = true;
return;
}
if (!enqueued_[nextstate]) {
state_queue_->Enqueue(nextstate);
enqueued_[nextstate] = true;
} else {
state_queue_->Update(nextstate);
}
}
}
}
++source_id_;
if (fst_.Properties(kError, false)) error_ = true;
}
} // namespace internal
// Shortest-distance algorithm: this version allows fine control
// via the options argument. See below for a simpler interface.
//
// This computes the shortest distance from the opts.source state to each
// visited state S and stores the value in the distance vector. An
// unvisited state S has distance Zero(), which will be stored in the
// distance vector if S is less than the maximum visited state. The state
// queue discipline, arc filter, and convergence delta are taken in the
// options argument. The distance vector will contain a unique element for
// which Member() is false if an error was encountered.
//
// The weights must must be right distributive and k-closed (i.e., 1 +
// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
//
// Complexity:
//
// Depends on properties of the semiring and the queue discipline.
//
// For more information, see:
//
// Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
// problems, Journal of Automata, Languages and
// Combinatorics 7(3): 321-350, 2002.
template <class Arc, class Queue, class ArcFilter>
void ShortestDistance(
const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance,
const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance,
opts, false);
sd_state.ShortestDistance(opts.source);
if (sd_state.Error()) {
distance->assign(1, Arc::Weight::NoWeight());
}
}
// Shortest-distance algorithm: simplified interface. See above for a version
// that permits finer control.
//
// If reverse is false, this computes the shortest distance from the initial
// state to each state S and stores the value in the distance vector. If
// reverse is true, this computes the shortest distance from each state to the
// final states. An unvisited state S has distance Zero(), which will be stored
// in the distance vector if S is less than the maximum visited state. The
// state queue discipline is automatically-selected. The distance vector will
// contain a unique element for which Member() is false if an error was
// encountered.
//
// The weights must must be right (left) distributive if reverse is false (true)
// and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
//
// Arc weights must satisfy the property that the sum of the weights of one or
// more paths from some state S to T is never Zero(). In particular, arc weights
// are never Zero().
//
// Complexity:
//
// Depends on properties of the semiring and the queue discipline.
//
// For more information, see:
//
// Mohri, M. 2002. Semiring framework and algorithms for
// shortest-distance problems, Journal of Automata, Languages and
// Combinatorics 7(3): 321-350, 2002.
template <class Arc>
void ShortestDistance(const Fst<Arc> &fst,
std::vector<typename Arc::Weight> *distance,
bool reverse = false, float delta = kShortestDelta) {
using StateId = typename Arc::StateId;
if (!reverse) {
AnyArcFilter<Arc> arc_filter;
AutoQueue<StateId> state_queue(fst, distance, arc_filter);
const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>>
opts(&state_queue, arc_filter, kNoStateId, delta);
ShortestDistance(fst, distance, opts);
} else {
using ReverseArc = ReverseArc<Arc>;
using ReverseWeight = typename ReverseArc::Weight;
AnyArcFilter<ReverseArc> rarc_filter;
VectorFst<ReverseArc> rfst;
Reverse(fst, &rfst);
std::vector<ReverseWeight> rdistance;
AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>,
AnyArcFilter<ReverseArc>>
ropts(&state_queue, rarc_filter, kNoStateId, delta);
ShortestDistance(rfst, &rdistance, ropts);
distance->clear();
if (rdistance.size() == 1 && !rdistance[0].Member()) {
distance->assign(1, Arc::Weight::NoWeight());
return;
}
DCHECK_GE(rdistance.size(), 1); // reversing added one state
distance->reserve(rdistance.size() - 1);
while (distance->size() < rdistance.size() - 1) {
distance->push_back(rdistance[distance->size() + 1].Reverse());
}
}
}
// Return the sum of the weight of all successful paths in an FST, i.e., the
// shortest-distance from the initial state to the final states. Returns a
// weight such that Member() is false if an error was encountered.
template <class Arc>
typename Arc::Weight ShortestDistance(const Fst<Arc> &fst,
float delta = kShortestDelta) {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
std::vector<Weight> distance;
if (Weight::Properties() & kRightSemiring) {
ShortestDistance(fst, &distance, false, delta);
if (distance.size() == 1 && !distance[0].Member()) {
return Arc::Weight::NoWeight();
}
Adder<Weight> adder; // maintains cumulative sum accurately
for (StateId state = 0; state < distance.size(); ++state) {
adder.Add(Times(distance[state], fst.Final(state)));
}
return adder.Sum();
} else {
ShortestDistance(fst, &distance, true, delta);
const auto state = fst.Start();
if (distance.size() == 1 && !distance[0].Member()) {
return Arc::Weight::NoWeight();
}
return state != kNoStateId && state < distance.size() ? distance[state]
: Weight::Zero();
}
}
} // namespace fst
#endif // FST_SHORTEST_DISTANCE_H_