You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

338 lines
13 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Functions implementing pruning.
#ifndef FST_PRUNE_H_
#define FST_PRUNE_H_
#include <cstddef>
#include <cstdlib>
#include <type_traits>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arcfilter.h>
#include <fst/fst.h>
#include <fst/heap.h>
#include <fst/mutable-fst.h>
#include <fst/shortest-distance.h>
#include <fst/weight.h>
namespace fst {
namespace internal {
template <class StateId, class Weight>
class PruneCompare {
public:
PruneCompare(const std::vector<Weight> &idistance,
const std::vector<Weight> &fdistance)
: idistance_(idistance), fdistance_(fdistance) {}
bool operator()(const StateId x, const StateId y) const {
const auto wx = Times(IDistance(x), FDistance(x));
const auto wy = Times(IDistance(y), FDistance(y));
return less_(wx, wy);
}
private:
Weight IDistance(const StateId s) const {
return s < idistance_.size() ? idistance_[s] : Weight::Zero();
}
Weight FDistance(const StateId s) const {
return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
}
const std::vector<Weight> &idistance_;
const std::vector<Weight> &fdistance_;
NaturalLess<Weight> less_;
};
} // namespace internal
template <class Arc, class ArcFilter>
struct PruneOptions {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
explicit PruneOptions(const Weight &weight_threshold = Weight::Zero(),
StateId state_threshold = kNoStateId,
ArcFilter filter = ArcFilter(),
std::vector<Weight> *distance = nullptr,
float delta = kDelta, bool threshold_initial = false)
: weight_threshold(std::move(weight_threshold)),
state_threshold(state_threshold),
filter(std::move(filter)),
distance(distance),
delta(delta),
threshold_initial(threshold_initial) {}
// Pruning weight threshold.
Weight weight_threshold;
// Pruning state threshold.
StateId state_threshold;
// Arc filter.
ArcFilter filter;
// If non-zero, passes in pre-computed shortest distance to final states.
const std::vector<Weight> *distance;
// Determines the degree of convergence required when computing shortest
// distances.
float delta;
// Determines if the shortest path weight is left (true) or right
// (false) multiplied by the threshold to get the limit for
// keeping a state or arc (matters if the semiring is not
// commutative).
bool threshold_initial;
};
// Pruning algorithm: this version modifies its input and it takes an options
// class as an argument. After pruning the FST contains states and arcs that
// belong to a successful path in the FST whose weight is no more than the
// weight of the shortest path Times() the provided weight threshold. When the
// state threshold is not kNoStateId, the output FST is further restricted to
// have no more than the number of states in opts.state_threshold. Weights must
// have the path property. The weight of any cycle needs to be bounded; i.e.,
//
// Plus(weight, Weight::One()) == Weight::One()
template <class Arc, class ArcFilter>
void Prune(MutableFst<Arc> *fst, const PruneOptions<Arc, ArcFilter> &opts =
PruneOptions<Arc, ArcFilter>()) {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
static_assert(IsPath<Weight>::value, "Weight must have path property.");
using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
auto ns = fst->NumStates();
if (ns < 1) return;
std::vector<Weight> idistance(ns, Weight::Zero());
std::vector<Weight> tmp;
if (!opts.distance) {
tmp.reserve(ns);
ShortestDistance(*fst, &tmp, true, opts.delta);
}
const auto *fdistance = opts.distance ? opts.distance : &tmp;
if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) ||
((*fdistance)[fst->Start()] == Weight::Zero())) {
fst->DeleteStates();
return;
}
internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
StateHeap heap(compare);
std::vector<bool> visited(ns, false);
std::vector<size_t> enqueued(ns, StateHeap::kNoKey);
std::vector<StateId> dead;
dead.push_back(fst->AddState());
NaturalLess<Weight> less;
auto s = fst->Start();
const auto limit = opts.threshold_initial
? Times(opts.weight_threshold, (*fdistance)[s])
: Times((*fdistance)[s], opts.weight_threshold);
StateId num_visited = 0;
if (!less(limit, (*fdistance)[s])) {
idistance[s] = Weight::One();
enqueued[s] = heap.Insert(s);
++num_visited;
}
while (!heap.Empty()) {
s = heap.Top();
heap.Pop();
enqueued[s] = StateHeap::kNoKey;
visited[s] = true;
if (less(limit, Times(idistance[s], fst->Final(s)))) {
fst->SetFinal(s, Weight::Zero());
}
for (MutableArcIterator<MutableFst<Arc>> aiter(fst, s); !aiter.Done();
aiter.Next()) {
auto arc = aiter.Value(); // Copy intended.
if (!opts.filter(arc)) continue;
const auto weight =
Times(Times(idistance[s], arc.weight),
arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
: Weight::Zero());
if (less(limit, weight)) {
arc.nextstate = dead[0];
aiter.SetValue(arc);
continue;
}
if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
idistance[arc.nextstate] = Times(idistance[s], arc.weight);
}
if (visited[arc.nextstate]) continue;
if ((opts.state_threshold != kNoStateId) &&
(num_visited >= opts.state_threshold)) {
continue;
}
if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
++num_visited;
} else {
heap.Update(enqueued[arc.nextstate], arc.nextstate);
}
}
}
for (StateId i = 0; i < visited.size(); ++i) {
if (!visited[i]) dead.push_back(i);
}
fst->DeleteStates(dead);
}
// Pruning algorithm: this version modifies its input and takes the
// pruning threshold as an argument. It deletes states and arcs in the
// FST that do not belong to a successful path whose weight is more
// than the weight of the shortest path Times() the provided weight
// threshold. When the state threshold is not kNoStateId, the output
// FST is further restricted to have no more than the number of states
// in opts.state_threshold. Weights must have the path property. The
// weight of any cycle needs to be bounded; i.e.,
//
// Plus(weight, Weight::One()) == Weight::One()
template <class Arc>
void Prune(MutableFst<Arc> *fst, typename Arc::Weight weight_threshold,
typename Arc::StateId state_threshold = kNoStateId,
float delta = kDelta) {
const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
Prune(fst, opts);
}
// Pruning algorithm: this version writes the pruned input FST to an
// output MutableFst and it takes an options class as an argument. The
// output FST contains states and arcs that belong to a successful
// path in the input FST whose weight is more than the weight of the
// shortest path Times() the provided weight threshold. When the state
// threshold is not kNoStateId, the output FST is further restricted
// to have no more than the number of states in
// opts.state_threshold. Weights have the path property. The weight
// of any cycle needs to be bounded; i.e.,
//
// Plus(weight, Weight::One()) == Weight::One()
template <class Arc, class ArcFilter>
void Prune(
const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
const PruneOptions<Arc, ArcFilter> &opts = PruneOptions<Arc, ArcFilter>()) {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
static_assert(IsPath<Weight>::value, "Weight must have path property.");
using StateHeap = Heap<StateId, internal::PruneCompare<StateId, Weight>>;
ofst->DeleteStates();
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
if (ifst.Start() == kNoStateId) return;
NaturalLess<Weight> less;
if (less(opts.weight_threshold, Weight::One()) ||
(opts.state_threshold == 0)) {
return;
}
std::vector<Weight> idistance;
std::vector<Weight> tmp;
if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta);
const auto *fdistance = opts.distance ? opts.distance : &tmp;
if ((fdistance->size() <= ifst.Start()) ||
((*fdistance)[ifst.Start()] == Weight::Zero())) {
return;
}
internal::PruneCompare<StateId, Weight> compare(idistance, *fdistance);
StateHeap heap(compare);
std::vector<StateId> copy;
std::vector<size_t> enqueued;
std::vector<bool> visited;
auto s = ifst.Start();
const auto limit = opts.threshold_initial
? Times(opts.weight_threshold, (*fdistance)[s])
: Times((*fdistance)[s], opts.weight_threshold);
while (copy.size() <= s) copy.push_back(kNoStateId);
copy[s] = ofst->AddState();
ofst->SetStart(copy[s]);
while (idistance.size() <= s) idistance.push_back(Weight::Zero());
idistance[s] = Weight::One();
while (enqueued.size() <= s) {
enqueued.push_back(StateHeap::kNoKey);
visited.push_back(false);
}
enqueued[s] = heap.Insert(s);
while (!heap.Empty()) {
s = heap.Top();
heap.Pop();
enqueued[s] = StateHeap::kNoKey;
visited[s] = true;
if (!less(limit, Times(idistance[s], ifst.Final(s)))) {
ofst->SetFinal(copy[s], ifst.Final(s));
}
for (ArcIterator<Fst<Arc>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
if (!opts.filter(arc)) continue;
const auto weight =
Times(Times(idistance[s], arc.weight),
arc.nextstate < fdistance->size() ? (*fdistance)[arc.nextstate]
: Weight::Zero());
if (less(limit, weight)) continue;
if ((opts.state_threshold != kNoStateId) &&
(ofst->NumStates() >= opts.state_threshold)) {
continue;
}
while (idistance.size() <= arc.nextstate) {
idistance.push_back(Weight::Zero());
}
if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) {
idistance[arc.nextstate] = Times(idistance[s], arc.weight);
}
while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId);
if (copy[arc.nextstate] == kNoStateId) {
copy[arc.nextstate] = ofst->AddState();
}
ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
copy[arc.nextstate]));
while (enqueued.size() <= arc.nextstate) {
enqueued.push_back(StateHeap::kNoKey);
visited.push_back(false);
}
if (visited[arc.nextstate]) continue;
if (enqueued[arc.nextstate] == StateHeap::kNoKey) {
enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
} else {
heap.Update(enqueued[arc.nextstate], arc.nextstate);
}
}
}
}
// Pruning algorithm: this version writes the pruned input FST to an
// output MutableFst and simply takes the pruning threshold as an
// argument. The output FST contains states and arcs that belong to a
// successful path in the input FST whose weight is no more than the
// weight of the shortest path Times() the provided weight
// threshold. When the state threshold is not kNoStateId, the output
// FST is further restricted to have no more than the number of states
// in opts.state_threshold. Weights must have the path property. The
// weight of any cycle needs to be bounded; i.e.,
//
// Plus(weight, Weight::One()) = Weight::One();
template <class Arc>
void Prune(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
typename Arc::Weight weight_threshold,
typename Arc::StateId state_threshold = kNoStateId,
float delta = kDelta) {
const PruneOptions<Arc, AnyArcFilter<Arc>> opts(
weight_threshold, state_threshold, AnyArcFilter<Arc>(), nullptr, delta);
Prune(ifst, ofst, opts);
}
} // namespace fst
#endif // FST_PRUNE_H_