You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

580 lines
19 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Functions and classes that implemement epsilon-removal.
#ifndef FST_RMEPSILON_H_
#define FST_RMEPSILON_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arc.h>
#include <fst/arcfilter.h>
#include <fst/cache.h>
#include <fst/cc-visitors.h>
#include <fst/connect.h>
#include <fst/dfs-visit.h>
#include <fst/factor-weight.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/invert.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/prune.h>
#include <fst/queue.h>
#include <fst/shortest-distance.h>
#include <fst/topsort.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <unordered_map>
namespace fst {
template <class Arc, class Queue>
struct RmEpsilonOptions
: public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>> {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
bool connect; // Connect output
Weight weight_threshold; // Pruning weight threshold.
StateId state_threshold; // Pruning state threshold.
explicit RmEpsilonOptions(Queue *queue, float delta = kShortestDelta,
bool connect = true,
Weight weight_threshold = Weight::Zero(),
StateId state_threshold = kNoStateId)
: ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc>>(
queue, EpsilonArcFilter<Arc>(), kNoStateId, delta),
connect(connect),
weight_threshold(std::move(weight_threshold)),
state_threshold(state_threshold) {}
};
namespace internal {
// Computation state of the epsilon-removal algorithm.
template <class Arc, class Queue>
class RmEpsilonState {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
RmEpsilonState(const Fst<Arc> &fst, std::vector<Weight> *distance,
const RmEpsilonOptions<Arc, Queue> &opts)
: fst_(fst),
distance_(distance),
sd_state_(fst_, distance, opts, true),
expand_id_(0) {}
void Expand(StateId s);
std::vector<Arc> &Arcs() { return arcs_; }
const Weight &Final() const { return final_weight_; }
bool Error() const { return sd_state_.Error(); }
private:
struct Element {
Label ilabel;
Label olabel;
StateId nextstate;
Element() = default;
Element(Label ilabel, Label olabel, StateId nexstate)
: ilabel(ilabel), olabel(olabel), nextstate(nexstate) {}
};
struct ElementHash {
public:
size_t operator()(const Element &element) const {
static constexpr size_t prime0 = 7853;
static constexpr size_t prime1 = 7867;
return static_cast<size_t>(element.nextstate) +
static_cast<size_t>(element.ilabel) * prime0 +
static_cast<size_t>(element.olabel) * prime1;
}
};
class ElementEqual {
public:
bool operator()(const Element &e1, const Element &e2) const {
return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) &&
(e1.nextstate == e2.nextstate);
}
};
using ElementMap = std::unordered_map<Element, std::pair<StateId, size_t>,
ElementHash, ElementEqual>;
const Fst<Arc> &fst_;
// Distance from state being expanded in epsilon-closure.
std::vector<Weight> *distance_;
// Shortest distance algorithm computation state.
internal::ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc>> sd_state_;
// Maps an element to a pair corresponding to a position in the arcs vector
// of the state being expanded. The element corresopnds to the position in
// the arcs_ vector if p.first is equal to the state being expanded.
ElementMap element_map_;
EpsilonArcFilter<Arc> eps_filter_;
std::stack<StateId, std::vector<StateId>>
eps_queue_; // Queue used to visit the epsilon-closure.
std::vector<bool> visited_; // True if the state has been visited.
std::vector<StateId> visited_states_; // List of visited states.
std::vector<Arc> arcs_; // Arcs of state being expanded.
Weight final_weight_; // Final weight of state being expanded.
StateId expand_id_; // Unique ID for each call to Expand
RmEpsilonState(const RmEpsilonState &) = delete;
RmEpsilonState &operator=(const RmEpsilonState &) = delete;
};
template <class Arc, class Queue>
void RmEpsilonState<Arc, Queue>::Expand(typename Arc::StateId source) {
final_weight_ = Weight::Zero();
arcs_.clear();
sd_state_.ShortestDistance(source);
if (sd_state_.Error()) return;
eps_queue_.push(source);
while (!eps_queue_.empty()) {
const auto state = eps_queue_.top();
eps_queue_.pop();
if (static_cast<decltype(state)>(visited_.size()) <= state) {
visited_.resize(state + 1, false);
}
if (visited_[state]) continue;
visited_[state] = true;
visited_states_.push_back(state);
for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
aiter.Next()) {
auto arc = aiter.Value();
arc.weight = Times((*distance_)[state], arc.weight);
if (eps_filter_(arc)) {
if (static_cast<decltype(arc.nextstate)>(visited_.size()) <=
arc.nextstate) {
visited_.resize(arc.nextstate + 1, false);
}
if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate);
} else if (auto [insert_it, success] = element_map_.emplace(
Element(arc.ilabel, arc.olabel, arc.nextstate),
std::make_pair(expand_id_, arcs_.size()));
success) {
arcs_.push_back(std::move(arc));
} else if (auto &[xid, arc_idx] = insert_it->second; xid == expand_id_) {
auto &weight = arcs_[arc_idx].weight;
weight = Plus(weight, arc.weight);
} else {
xid = expand_id_;
arc_idx = arcs_.size();
arcs_.push_back(std::move(arc));
}
}
final_weight_ =
Plus(final_weight_, Times((*distance_)[state], fst_.Final(state)));
}
for (const auto state_id : visited_states_) visited_[state_id] = false;
visited_states_.clear();
++expand_id_;
}
} // namespace internal
// Removes epsilon-transitions (when both the input and output label are an
// epsilon) from a transducer. The result will be an equivalent FST that has no
// such epsilon transitions. This version modifies its input. It allows fine
// control via the options argument; see below for a simpler interface.
//
// The distance vector will be used to hold the shortest distances during the
// epsilon-closure computation. The state queue discipline and convergence delta
// are taken in the options argument.
template <class Arc, class Queue>
void RmEpsilon(MutableFst<Arc> *fst,
std::vector<typename Arc::Weight> *distance,
const RmEpsilonOptions<Arc, Queue> &opts) {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
if (fst->Start() == kNoStateId) return;
// noneps_in[s] will be set to true iff s admits a non-epsilon incoming
// transition or is the start state.
std::vector<bool> noneps_in(fst->NumStates(), false);
noneps_in[fst->Start()] = true;
for (size_t i = 0; i < fst->NumStates(); ++i) {
for (ArcIterator<Fst<Arc>> aiter(*fst, i); !aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
if (arc.ilabel != 0 || arc.olabel != 0) {
noneps_in[arc.nextstate] = true;
}
}
}
// States sorted in topological order when (acyclic) or generic topological
// order (cyclic).
std::vector<StateId> states;
states.reserve(fst->NumStates());
if (fst->Properties(kTopSorted, false) & kTopSorted) {
for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i);
} else if (fst->Properties(kAcyclic, false) & kAcyclic) {
std::vector<StateId> order;
bool acyclic;
TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
// Sanity check: should be acyclic if property bit is set.
if (!acyclic) {
FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit";
fst->SetProperties(kError, kError);
return;
}
states.resize(order.size());
for (StateId i = 0; i < order.size(); i++) states[order[i]] = i;
} else {
uint64_t props;
std::vector<StateId> scc;
SccVisitor<Arc> scc_visitor(&scc, nullptr, nullptr, &props);
DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
std::vector<StateId> first(scc.size(), kNoStateId);
std::vector<StateId> next(scc.size(), kNoStateId);
for (StateId i = 0; i < scc.size(); i++) {
if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]];
first[scc[i]] = i;
}
for (StateId i = 0; i < first.size(); i++) {
for (auto j = first[i]; j != kNoStateId; j = next[j]) {
states.push_back(j);
}
}
}
internal::RmEpsilonState<Arc, Queue> rmeps_state(*fst, distance, opts);
while (!states.empty()) {
const auto state = states.back();
states.pop_back();
if (!noneps_in[state] &&
(opts.connect || opts.weight_threshold != Weight::Zero() ||
opts.state_threshold != kNoStateId)) {
continue;
}
rmeps_state.Expand(state);
fst->SetFinal(state, rmeps_state.Final());
fst->DeleteArcs(state);
auto &arcs = rmeps_state.Arcs();
fst->ReserveArcs(state, arcs.size());
while (!arcs.empty()) {
fst->AddArc(state, arcs.back());
arcs.pop_back();
}
}
if (opts.connect || opts.weight_threshold != Weight::Zero() ||
opts.state_threshold != kNoStateId) {
for (size_t s = 0; s < fst->NumStates(); ++s) {
if (!noneps_in[s]) fst->DeleteArcs(s);
}
}
if (rmeps_state.Error()) fst->SetProperties(kError, kError);
fst->SetProperties(
RmEpsilonProperties(fst->Properties(kFstProperties, false)),
kFstProperties);
if (opts.weight_threshold != Weight::Zero() ||
opts.state_threshold != kNoStateId) {
if constexpr (IsPath<Weight>::value) {
Prune(fst, opts.weight_threshold, opts.state_threshold);
} else {
FSTERROR() << "RmEpsilon: Weight must have path property: "
<< Weight::Type();
fst->SetProperties(kError, kError);
return;
}
}
if (opts.connect && opts.weight_threshold == Weight::Zero() &&
opts.state_threshold == kNoStateId) {
Connect(fst);
}
}
// Removes epsilon-transitions (when both the input and output label
// are an epsilon) from a transducer. The result will be an equivalent
// FST that has no such epsilon transitions. This version modifies its
// input. It has a simplified interface; see above for a version that
// allows finer control.
//
// Complexity:
//
// - Time:
//
// Unweighted: O(v^2 + ve).
// Acyclic: O(v^2 + V e).
// Tropical semiring: O(v^2 log V + ve).
// General: exponential.
//
// - Space: O(vE)
//
// where v is the number of states visited and e is the number of arcs visited.
//
// For more information, see:
//
// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
// algorithms for weighted transducers. International Journal of Computer
// Science 13(1): 129-143.
template <class Arc>
void RmEpsilon(MutableFst<Arc> *fst, bool connect = true,
typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
typename Arc::StateId state_threshold = kNoStateId,
float delta = kShortestDelta) {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
std::vector<Weight> distance;
AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
RmEpsilonOptions<Arc, AutoQueue<StateId>> opts(
&state_queue, delta, connect, weight_threshold, state_threshold);
RmEpsilon(fst, &distance, opts);
}
struct RmEpsilonFstOptions : CacheOptions {
float delta;
explicit RmEpsilonFstOptions(const CacheOptions &opts,
float delta = kShortestDelta)
: CacheOptions(opts), delta(delta) {}
explicit RmEpsilonFstOptions(float delta = kShortestDelta) : delta(delta) {}
};
namespace internal {
// Implementation of delayed RmEpsilonFst.
template <class Arc>
class RmEpsilonFstImpl : public CacheImpl<Arc> {
public:
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using FstImpl<Arc>::Properties;
using FstImpl<Arc>::SetType;
using FstImpl<Arc>::SetProperties;
using FstImpl<Arc>::SetInputSymbols;
using FstImpl<Arc>::SetOutputSymbols;
using CacheBaseImpl<CacheState<Arc>>::HasArcs;
using CacheBaseImpl<CacheState<Arc>>::HasFinal;
using CacheBaseImpl<CacheState<Arc>>::HasStart;
using CacheBaseImpl<CacheState<Arc>>::PushArc;
using CacheBaseImpl<CacheState<Arc>>::SetArcs;
using CacheBaseImpl<CacheState<Arc>>::SetFinal;
using CacheBaseImpl<CacheState<Arc>>::SetStart;
RmEpsilonFstImpl(const Fst<Arc> &fst, const RmEpsilonFstOptions &opts)
: CacheImpl<Arc>(opts),
fst_(fst.Copy()),
delta_(opts.delta),
rmeps_state_(
*fst_, &distance_,
RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
SetType("rmepsilon");
SetProperties(
RmEpsilonProperties(fst.Properties(kFstProperties, false), true),
kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
: CacheImpl<Arc>(impl),
fst_(impl.fst_->Copy(true)),
delta_(impl.delta_),
rmeps_state_(
*fst_, &distance_,
RmEpsilonOptions<Arc, FifoQueue<StateId>>(&queue_, delta_, false)) {
SetType("rmepsilon");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
StateId Start() {
if (!HasStart()) SetStart(fst_->Start());
return CacheImpl<Arc>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) Expand(s);
return CacheImpl<Arc>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumOutputEpsilons(s);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) &&
(fst_->Properties(kError, false) || rmeps_state_.Error())) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
if (!HasArcs(s)) Expand(s);
CacheImpl<Arc>::InitArcIterator(s, data);
}
void Expand(StateId s) {
rmeps_state_.Expand(s);
SetFinal(s, rmeps_state_.Final());
auto &arcs = rmeps_state_.Arcs();
while (!arcs.empty()) {
PushArc(s, std::move(arcs.back()));
arcs.pop_back();
}
SetArcs(s);
}
private:
std::unique_ptr<const Fst<Arc>> fst_;
float delta_;
std::vector<Weight> distance_;
FifoQueue<StateId> queue_;
internal::RmEpsilonState<Arc, FifoQueue<StateId>> rmeps_state_;
};
} // namespace internal
// Removes epsilon-transitions (when both the input and output label are an
// epsilon) from a transducer. The result will be an equivalent FST that has no
// such epsilon transitions. This version is a
// delayed FST.
//
// Complexity:
//
// - Time:
// Unweighted: O(v^2 + ve).
// General: exponential.
//
// - Space: O(vE)
//
// where v is the number of states visited and e is the number of arcs visited.
// Constant time to visit an input state or arc is assumed and exclusive of
// caching.
//
// For more information, see:
//
// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization
// algorithms for weighted transducers. International Journal of Computer
// Science 13(1): 129-143.
//
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A>
class RmEpsilonFst : public ImplToFst<internal::RmEpsilonFstImpl<A>> {
public:
using Arc = A;
using StateId = typename Arc::StateId;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using Impl = internal::RmEpsilonFstImpl<Arc>;
friend class ArcIterator<RmEpsilonFst<Arc>>;
friend class StateIterator<RmEpsilonFst<Arc>>;
explicit RmEpsilonFst(const Fst<Arc> &fst)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, RmEpsilonFstOptions())) {}
RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
// See Fst<>::Copy() for doc.
RmEpsilonFst(const RmEpsilonFst &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
RmEpsilonFst *Copy(bool safe = false) const override {
return new RmEpsilonFst(*this, safe);
}
inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
GetMutableImpl()->InitArcIterator(s, data);
}
private:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
RmEpsilonFst &operator=(const RmEpsilonFst &) = delete;
};
// Specialization for RmEpsilonFst.
template <class Arc>
class StateIterator<RmEpsilonFst<Arc>>
: public CacheStateIterator<RmEpsilonFst<Arc>> {
public:
explicit StateIterator(const RmEpsilonFst<Arc> &fst)
: CacheStateIterator<RmEpsilonFst<Arc>>(fst, fst.GetMutableImpl()) {}
};
// Specialization for RmEpsilonFst.
template <class Arc>
class ArcIterator<RmEpsilonFst<Arc>>
: public CacheArcIterator<RmEpsilonFst<Arc>> {
public:
using StateId = typename Arc::StateId;
ArcIterator(const RmEpsilonFst<Arc> &fst, StateId s)
: CacheArcIterator<RmEpsilonFst<Arc>>(fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class Arc>
inline void RmEpsilonFst<Arc>::InitStateIterator(
StateIteratorData<Arc> *data) const {
data->base = std::make_unique<StateIterator<RmEpsilonFst<Arc>>>(*this);
}
// Useful alias when using StdArc.
using StdRmEpsilonFst = RmEpsilonFst<StdArc>;
} // namespace fst
#endif // FST_RMEPSILON_H_