|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Functions and classes to find shortest distance in an FST.
#ifndef FST_SHORTEST_DISTANCE_H_
#define FST_SHORTEST_DISTANCE_H_
#include <cstddef>
#include <vector>
#include <fst/log.h>
#include <fst/arc.h>
#include <fst/arcfilter.h>
#include <fst/cache.h>
#include <fst/equal.h>
#include <fst/expanded-fst.h>
#include <fst/fst.h>
#include <fst/properties.h>
#include <fst/queue.h>
#include <fst/reverse.h>
#include <fst/util.h>
#include <fst/vector-fst.h>
#include <fst/weight.h>
namespace fst {
// A representable float for shortest distance and shortest path algorithms.
inline constexpr float kShortestDelta = 1e-6;
template <class Arc, class Queue, class ArcFilter> struct ShortestDistanceOptions { using StateId = typename Arc::StateId;
Queue *state_queue; // Queue discipline used; owned by caller.
ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph).
StateId source; // If kNoStateId, use the FST's initial state.
float delta; // Determines the degree of convergence required
bool first_path; // For a semiring with the path property (o.w.
// undefined), compute the shortest-distances along
// along the first path to a final state found
// by the algorithm. That path is the shortest-path
// only if the FST has a unique final state (or all
// the final states have the same final weight), the
// queue discipline is shortest-first and all the
// weights in the FST are between One() and Zero()
// according to NaturalLess.
ShortestDistanceOptions(Queue *state_queue, ArcFilter arc_filter, StateId source = kNoStateId, float delta = kShortestDelta, bool first_path = false) : state_queue(state_queue), arc_filter(arc_filter), source(source), delta(delta), first_path(first_path) {} };
namespace internal {
// Computation state of the shortest-distance algorithm. Reusable information
// is maintained across calls to member function ShortestDistance(source) when
// retain is true for improved efficiency when calling multiple times from
// different source states (e.g., in epsilon removal). Contrary to the usual
// conventions, fst may not be freed before this class. Vector distance
// should not be modified by the user between these calls. The Error() method
// returns true iff an error was encountered.
template <class Arc, class Queue, class ArcFilter, class WeightEqual = WeightApproxEqual> class ShortestDistanceState { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
ShortestDistanceState( const Fst<Arc> &fst, std::vector<Weight> *distance, const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, bool retain) : fst_(fst), distance_(distance), state_queue_(opts.state_queue), arc_filter_(opts.arc_filter), weight_equal_(opts.delta), first_path_(opts.first_path), retain_(retain), source_id_(0), error_(false) { distance_->clear(); if (std::optional<StateId> num_states = fst.NumStatesIfKnown()) { distance_->reserve(*num_states); adder_.reserve(*num_states); radder_.reserve(*num_states); enqueued_.reserve(*num_states); } }
void ShortestDistance(StateId source);
bool Error() const { return error_; }
private: void EnsureDistanceIndexIsValid(std::size_t index) { while (distance_->size() <= index) { distance_->push_back(Weight::Zero()); adder_.push_back(Adder<Weight>()); radder_.push_back(Adder<Weight>()); enqueued_.push_back(false); } DCHECK_LT(index, distance_->size()); }
void EnsureSourcesIndexIsValid(std::size_t index) { while (sources_.size() <= index) { sources_.push_back(kNoStateId); } DCHECK_LT(index, sources_.size()); }
const Fst<Arc> &fst_; std::vector<Weight> *distance_; Queue *state_queue_; ArcFilter arc_filter_; WeightEqual weight_equal_; // Determines when relaxation stops.
const bool first_path_; const bool retain_; // Retain and reuse information across calls.
std::vector<Adder<Weight>> adder_; // Sums distance_ accurately.
std::vector<Adder<Weight>> radder_; // Relaxation distance.
std::vector<bool> enqueued_; // Is state enqueued?
std::vector<StateId> sources_; // Source ID for ith state in distance_,
// (r)adder_, and enqueued_ if retained.
StateId source_id_; // Unique ID characterizing each call.
bool error_; };
// Compute the shortest distance; if source is kNoStateId, uses the initial
// state of the FST.
template <class Arc, class Queue, class ArcFilter, class WeightEqual> void ShortestDistanceState<Arc, Queue, ArcFilter, WeightEqual>::ShortestDistance(StateId source) { if (fst_.Start() == kNoStateId) { if (fst_.Properties(kError, false)) error_ = true; return; } if (!(Weight::Properties() & kRightSemiring)) { FSTERROR() << "ShortestDistance: Weight needs to be right distributive: " << Weight::Type(); error_ = true; return; } if (first_path_ && !(Weight::Properties() & kPath)) { FSTERROR() << "ShortestDistance: The first_path option is disallowed when " << "Weight does not have the path property: " << Weight::Type(); error_ = true; return; } state_queue_->Clear(); if (!retain_) { distance_->clear(); adder_.clear(); radder_.clear(); enqueued_.clear(); } if (source == kNoStateId) source = fst_.Start(); EnsureDistanceIndexIsValid(source); if (retain_) { EnsureSourcesIndexIsValid(source); sources_[source] = source_id_; } (*distance_)[source] = Weight::One(); adder_[source].Reset(Weight::One()); radder_[source].Reset(Weight::One()); enqueued_[source] = true; state_queue_->Enqueue(source); while (!state_queue_->Empty()) { const auto state = state_queue_->Head(); state_queue_->Dequeue(); EnsureDistanceIndexIsValid(state); if (first_path_ && (fst_.Final(state) != Weight::Zero())) break; enqueued_[state] = false; const auto r = radder_[state].Sum(); radder_[state].Reset(); for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); const auto nextstate = arc.nextstate; if (!arc_filter_(arc)) continue; EnsureDistanceIndexIsValid(nextstate); if (retain_) { EnsureSourcesIndexIsValid(nextstate); if (sources_[nextstate] != source_id_) { (*distance_)[nextstate] = Weight::Zero(); adder_[nextstate].Reset(); radder_[nextstate].Reset(); enqueued_[nextstate] = false; sources_[nextstate] = source_id_; } } auto &nd = (*distance_)[nextstate]; auto &na = adder_[nextstate]; auto &nr = radder_[nextstate]; auto weight = Times(r, arc.weight); if (!weight_equal_(nd, Plus(nd, weight))) { nd = na.Add(weight); nr.Add(weight); if (!nd.Member() || !nr.Sum().Member()) { error_ = true; return; } if (!enqueued_[nextstate]) { state_queue_->Enqueue(nextstate); enqueued_[nextstate] = true; } else { state_queue_->Update(nextstate); } } } } ++source_id_; if (fst_.Properties(kError, false)) error_ = true; }
} // namespace internal
// Shortest-distance algorithm: this version allows fine control
// via the options argument. See below for a simpler interface.
//
// This computes the shortest distance from the opts.source state to each
// visited state S and stores the value in the distance vector. An
// unvisited state S has distance Zero(), which will be stored in the
// distance vector if S is less than the maximum visited state. The state
// queue discipline, arc filter, and convergence delta are taken in the
// options argument. The distance vector will contain a unique element for
// which Member() is false if an error was encountered.
//
// The weights must must be right distributive and k-closed (i.e., 1 +
// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
//
// Complexity:
//
// Depends on properties of the semiring and the queue discipline.
//
// For more information, see:
//
// Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
// problems, Journal of Automata, Languages and
// Combinatorics 7(3): 321-350, 2002.
template <class Arc, class Queue, class ArcFilter> void ShortestDistance( const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance, const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { internal::ShortestDistanceState<Arc, Queue, ArcFilter> sd_state(fst, distance, opts, false); sd_state.ShortestDistance(opts.source); if (sd_state.Error()) { distance->assign(1, Arc::Weight::NoWeight()); } }
// Shortest-distance algorithm: simplified interface. See above for a version
// that permits finer control.
//
// If reverse is false, this computes the shortest distance from the initial
// state to each state S and stores the value in the distance vector. If
// reverse is true, this computes the shortest distance from each state to the
// final states. An unvisited state S has distance Zero(), which will be stored
// in the distance vector if S is less than the maximum visited state. The
// state queue discipline is automatically-selected. The distance vector will
// contain a unique element for which Member() is false if an error was
// encountered.
//
// The weights must must be right (left) distributive if reverse is false (true)
// and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
//
// Arc weights must satisfy the property that the sum of the weights of one or
// more paths from some state S to T is never Zero(). In particular, arc weights
// are never Zero().
//
// Complexity:
//
// Depends on properties of the semiring and the queue discipline.
//
// For more information, see:
//
// Mohri, M. 2002. Semiring framework and algorithms for
// shortest-distance problems, Journal of Automata, Languages and
// Combinatorics 7(3): 321-350, 2002.
template <class Arc> void ShortestDistance(const Fst<Arc> &fst, std::vector<typename Arc::Weight> *distance, bool reverse = false, float delta = kShortestDelta) { using StateId = typename Arc::StateId; if (!reverse) { AnyArcFilter<Arc> arc_filter; AutoQueue<StateId> state_queue(fst, distance, arc_filter); const ShortestDistanceOptions<Arc, AutoQueue<StateId>, AnyArcFilter<Arc>> opts(&state_queue, arc_filter, kNoStateId, delta); ShortestDistance(fst, distance, opts); } else { using ReverseArc = ReverseArc<Arc>; using ReverseWeight = typename ReverseArc::Weight; AnyArcFilter<ReverseArc> rarc_filter; VectorFst<ReverseArc> rfst; Reverse(fst, &rfst); std::vector<ReverseWeight> rdistance; AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); const ShortestDistanceOptions<ReverseArc, AutoQueue<StateId>, AnyArcFilter<ReverseArc>> ropts(&state_queue, rarc_filter, kNoStateId, delta); ShortestDistance(rfst, &rdistance, ropts); distance->clear(); if (rdistance.size() == 1 && !rdistance[0].Member()) { distance->assign(1, Arc::Weight::NoWeight()); return; } DCHECK_GE(rdistance.size(), 1); // reversing added one state
distance->reserve(rdistance.size() - 1); while (distance->size() < rdistance.size() - 1) { distance->push_back(rdistance[distance->size() + 1].Reverse()); } } }
// Return the sum of the weight of all successful paths in an FST, i.e., the
// shortest-distance from the initial state to the final states. Returns a
// weight such that Member() is false if an error was encountered.
template <class Arc> typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kShortestDelta) { using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; std::vector<Weight> distance; if (Weight::Properties() & kRightSemiring) { ShortestDistance(fst, &distance, false, delta); if (distance.size() == 1 && !distance[0].Member()) { return Arc::Weight::NoWeight(); } Adder<Weight> adder; // maintains cumulative sum accurately
for (StateId state = 0; state < distance.size(); ++state) { adder.Add(Times(distance[state], fst.Final(state))); } return adder.Sum(); } else { ShortestDistance(fst, &distance, true, delta); const auto state = fst.Start(); if (distance.size() == 1 && !distance[0].Member()) { return Arc::Weight::NoWeight(); } return state != kNoStateId && state < distance.size() ? distance[state] : Weight::Zero(); } }
} // namespace fst
#endif // FST_SHORTEST_DISTANCE_H_
|