You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1133 lines
41 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Functions and classes to determinize an FST.
#ifndef FST_DETERMINIZE_H_
#define FST_DETERMINIZE_H_
#include <algorithm>
#include <climits>
#include <cstddef>
#include <cstdint>
#include <forward_list>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arc-map.h>
#include <fst/arc.h>
#include <fst/arcfilter.h>
#include <fst/bi-table.h>
#include <fst/cache.h>
#include <fst/const-fst.h>
#include <fst/factor-weight.h>
#include <fst/filter-state.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/lexicographic-weight.h>
#include <fst/mutable-fst.h>
#include <fst/pair-weight.h>
#include <fst/power-weight.h>
#include <fst/product-weight.h>
#include <fst/properties.h>
#include <fst/prune.h>
#include <fst/shortest-distance.h>
#include <fst/string-weight.h>
#include <fst/tuple-weight.h>
#include <fst/union-weight.h>
#include <fst/util.h>
#include <fst/weight.h>
namespace fst {
// Common divisors are used in determinization to compute transition weights.
// In the simplest case, it is the same as semiring Plus, but other choices
// permit more efficient determinization when the output contains strings.
// The default common divisor uses the semiring Plus.
namespace internal {
template <class Arc, class Relation>
class RelationDeterminizeFilter;
} // namespace internal
struct PairArc;
template <class W>
struct DefaultCommonDivisor {
public:
using Weight = W;
Weight operator()(const Weight &w1, const Weight &w2) const {
return Plus(w1, w2);
}
};
// The label common divisor for a (left) string semiring selects a single
// letter common prefix or the empty string. This is used in the
// determinization of output strings so that at most a single letter will
// appear in the output of a transtion.
template <typename Label, StringType S>
struct LabelCommonDivisor {
public:
using Weight = StringWeight<Label, S>;
Weight operator()(const Weight &w1, const Weight &w2) const {
typename Weight::Iterator iter1(w1);
typename Weight::Iterator iter2(w2);
if (!(StringWeight<Label, S>::Properties() & kLeftSemiring)) {
FSTERROR() << "LabelCommonDivisor: Weight needs to be left semiring";
return Weight::NoWeight();
} else if (w1.Size() == 0 || w2.Size() == 0) {
return Weight::One();
} else if (w1 == Weight::Zero()) {
return Weight(iter2.Value());
} else if (w2 == Weight::Zero()) {
return Weight(iter1.Value());
} else if (iter1.Value() == iter2.Value()) {
return Weight(iter1.Value());
} else {
return Weight::One();
}
}
};
// The gallic common divisor uses the label common divisor on the string
// component and the common divisor on the weight component, which defaults to
// the default common divisor.
template <class Label, class W, GallicType G,
class CommonDivisor = DefaultCommonDivisor<W>>
class GallicCommonDivisor {
public:
using Weight = GallicWeight<Label, W, G>;
Weight operator()(const Weight &w1, const Weight &w2) const {
return Weight(label_common_divisor_(w1.Value1(), w2.Value1()),
weight_common_divisor_(w1.Value2(), w2.Value2()));
}
private:
LabelCommonDivisor<Label, GallicStringType(G)> label_common_divisor_;
CommonDivisor weight_common_divisor_;
};
// Specialization for general GALLIC weight.
template <class Label, class W, class CommonDivisor>
class GallicCommonDivisor<Label, W, GALLIC, CommonDivisor> {
public:
using Weight = GallicWeight<Label, W, GALLIC>;
using GRWeight = GallicWeight<Label, W, GALLIC_RESTRICT>;
using Iterator =
UnionWeightIterator<GRWeight, GallicUnionWeightOptions<Label, W>>;
Weight operator()(const Weight &w1, const Weight &w2) const {
auto weight = GRWeight::Zero();
for (Iterator iter(w1); !iter.Done(); iter.Next()) {
weight = common_divisor_(weight, iter.Value());
}
for (Iterator iter(w2); !iter.Done(); iter.Next()) {
weight = common_divisor_(weight, iter.Value());
}
return weight == GRWeight::Zero() ? Weight::Zero() : Weight(weight);
}
private:
GallicCommonDivisor<Label, W, GALLIC_RESTRICT, CommonDivisor> common_divisor_;
};
namespace internal {
// Represents an element in a subset
template <class Arc>
struct DeterminizeElement {
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
DeterminizeElement(StateId s, Weight weight)
: state_id(s), weight(std::move(weight)) {}
inline bool operator==(const DeterminizeElement &element) const {
return state_id == element.state_id && weight == element.weight;
}
inline bool operator!=(const DeterminizeElement &element) const {
return !(*this == element);
}
inline bool operator<(const DeterminizeElement<Arc> &element) const {
return state_id < element.state_id;
}
StateId state_id; // Input state ID.
Weight weight; // Residual weight.
};
// Represents a weighted subset and determinization filter state
template <typename A, typename FilterState>
struct DeterminizeStateTuple {
using Arc = A;
using Element = DeterminizeElement<Arc>;
using Subset = std::forward_list<Element>;
DeterminizeStateTuple() : filter_state(FilterState::NoState()) {}
inline bool operator==(const DeterminizeStateTuple &tuple) const {
return (tuple.filter_state == filter_state) && (tuple.subset == subset);
}
inline bool operator!=(const DeterminizeStateTuple &tuple) const {
return (tuple.filter_state != filter_state) || (tuple.subset != subset);
}
Subset subset;
FilterState filter_state;
};
// Proto-transition for determinization.
template <class StateTuple>
struct DeterminizeArc {
using Arc = typename StateTuple::Arc;
using Label = typename Arc::Label;
using Weight = typename Arc::Weight;
DeterminizeArc() = default;
explicit DeterminizeArc(const Arc &arc)
: label(arc.ilabel),
dest_tuple(fst::make_unique_for_overwrite<StateTuple>()) {}
Label label = kNoLabel; // Arc label.
Weight weight = Weight::Zero(); // Arc weight.
std::unique_ptr<StateTuple>
dest_tuple; // Destination subset and filter state.
};
} // namespace internal
// Determinization filters are used to compute destination state tuples based
// on the source tuple, transition, and destination element or on similar
// super-final transition information. The filter operates on a map between a
// label and the corresponding destination state tuples. It must define the map
// type LabelMap. The default filter is used for weighted determinization.
// A determinize filter for implementing weighted determinization.
template <class Arc>
class DefaultDeterminizeFilter {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using FilterState = CharFilterState;
using Element = internal::DeterminizeElement<Arc>;
using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
using LabelMap = std::map<Label, internal::DeterminizeArc<StateTuple>>;
// This is needed e.g. to go into the gallic domain for transducers.
template <class A>
struct rebind {
using Other = DefaultDeterminizeFilter<A>;
};
explicit DefaultDeterminizeFilter(const Fst<Arc> &fst) : fst_(fst.Copy()) {}
// This is needed (e.g.) to go into the gallic domain for transducers.
template <class Filter>
DefaultDeterminizeFilter(const Fst<Arc> &fst, std::unique_ptr<Filter> filter)
: fst_(fst.Copy()) {}
// Copy constructor; the FST can be passed if it has been deep-copied.
DefaultDeterminizeFilter(const DefaultDeterminizeFilter &filter,
const Fst<Arc> *fst = nullptr)
: fst_(fst ? fst->Copy() : filter.fst_->Copy()) {}
FilterState Start() const { return FilterState(0); }
// Does no work.
void SetState(StateId s, const StateTuple &tuple) {}
// Filters transition, possibly modifying label map. Returns true if arc is
// added to the label map.
bool FilterArc(const Arc &arc, const Element &src_element,
Element &&dest_element, LabelMap *label_map) const {
// Adds element to unique state tuple for arc label.
auto &det_arc = (*label_map)[arc.ilabel];
if (det_arc.label == kNoLabel) {
det_arc = internal::DeterminizeArc<StateTuple>(arc);
det_arc.dest_tuple->filter_state = FilterState(0);
}
det_arc.dest_tuple->subset.push_front(std::move(dest_element));
return true;
}
// Filters super-final transition, returning new final weight.
Weight FilterFinal(Weight weight, const Element &element) { return weight; }
static uint64_t Properties(uint64_t props) { return props; }
private:
std::unique_ptr<Fst<Arc>> fst_;
};
// Determinization state table interface:
//
// template <class Arc, class FilterState>
// class DeterminizeStateTable {
// public:
// using StateId = typename Arc::StateId;
// using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
//
// // Required sub-class. This is needed (e.g.) to go into the gallic domain.
// template <class B, class G>
// struct rebind {
// using Other = DeterminizeStateTable<B, G>;
// }
//
// // Required constuctor.
// DeterminizeStateTable();
//
// // Required copy constructor that does not copy state.
// DeterminizeStateTable(const DeterminizeStateTable<Arc, FilterState>
// &table);
//
// // Looks up state ID by state tuple; if it doesn't exist, then adds it.
// // FindState takes ownership of the state tuple argument so that it
// // doesn't have to copy it if it creates a new state.
// StateId FindState(std::unique_ptr<StateTuple> tuple);
//
// // Looks up state tuple by ID.
// const StateTuple *Tuple(StateId id) const;
// };
// The default determinization state table based on the compact hash bi-table.
template <class Arc, class FilterState>
class DefaultDeterminizeStateTable {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
using Element = typename StateTuple::Element;
using Subset = typename StateTuple::Subset;
template <class B, class G>
struct rebind {
using Other = DefaultDeterminizeStateTable<B, G>;
};
explicit DefaultDeterminizeStateTable(size_t table_size = 0)
: table_size_(table_size), tuples_(table_size_) {}
DefaultDeterminizeStateTable(const DefaultDeterminizeStateTable &table)
: table_size_(table.table_size_), tuples_(table_size_) {}
~DefaultDeterminizeStateTable() {
for (StateId s = 0; s < tuples_.Size(); ++s) delete tuples_.FindEntry(s);
}
// Finds the state corresponding to a state tuple. Only creates a new state if
// the tuple is not found. FindState takes ownership of the tuple argument so
// that it doesn't have to copy it if it creates a new state.
StateId FindState(std::unique_ptr<StateTuple> tuple) {
StateTuple *raw_tuple = tuple.release();
const StateId ns = tuples_.Size();
// TODO(wolfsonkin): Make CompactHashBiTable support move semantics so we
// can store a `std::unique_ptr` in `tuples_`.
const auto s = tuples_.FindId(raw_tuple);
if (s != ns) delete raw_tuple; // Tuple found.
return s;
}
const StateTuple *Tuple(StateId s) { return tuples_.FindEntry(s); }
private:
// Comparison object for StateTuples.
class StateTupleEqual {
public:
bool operator()(const StateTuple *tuple1, const StateTuple *tuple2) const {
return *tuple1 == *tuple2;
}
};
// Hash function for StateTuples.
class StateTupleKey {
public:
size_t operator()(const StateTuple *tuple) const {
size_t h = tuple->filter_state.Hash();
for (auto &element : tuple->subset) {
const size_t h1 = element.state_id;
static constexpr auto lshift = 5;
static constexpr auto rshift = CHAR_BIT * sizeof(size_t) - 5;
h ^= h << 1 ^ h1 << lshift ^ h1 >> rshift ^ element.weight.Hash();
}
return h;
}
};
size_t table_size_;
CompactHashBiTable<StateId, StateTuple *, StateTupleKey, StateTupleEqual,
HS_STL>
tuples_;
DefaultDeterminizeStateTable &operator=(
const DefaultDeterminizeStateTable &) = delete;
};
// Determinization type.
enum DeterminizeType {
// Input transducer is known to be functional (or error).
DETERMINIZE_FUNCTIONAL, // Input transducer is functional (error if not).
// Input transducer is not known to be functional.
DETERMINIZE_NONFUNCTIONAL,
// Input transducer is not known to be functional but only keep the min of
// of ambiguous outputs.
DETERMINIZE_DISAMBIGUATE
};
// Options for finite-state transducer determinization templated on the arc
// type, common divisor, the determinization filter and the state table.
// DeterminizeFst takes ownership of the determinization filter and state table,
// if provided.
template <class Arc,
class CommonDivisor = DefaultCommonDivisor<typename Arc::Weight>,
class Filter = DefaultDeterminizeFilter<Arc>,
class StateTable =
DefaultDeterminizeStateTable<Arc, typename Filter::FilterState>>
struct DeterminizeFstOptions : public CacheOptions {
using Label = typename Arc::Label;
float delta; // Quantization delta for subset weights.
Label subsequential_label; // Label used for residual final output
// when producing subsequential transducers.
DeterminizeType type; // Determinization type.
bool increment_subsequential_label; // When creating several subsequential
// arcs at a given state, make their
// label distinct by incrementing.
Filter *filter; // Determinization filter;
// DeterminizeFst takes ownership.
StateTable *state_table; // Determinization state table;
// DeterminizeFst takes ownership.
explicit DeterminizeFstOptions(const CacheOptions &opts, float delta = kDelta,
Label subsequential_label = 0,
DeterminizeType type = DETERMINIZE_FUNCTIONAL,
bool increment_subsequential_label = false,
Filter *filter = nullptr,
StateTable *state_table = nullptr)
: CacheOptions(opts),
delta(delta),
subsequential_label(subsequential_label),
type(type),
increment_subsequential_label(increment_subsequential_label),
filter(filter),
state_table(state_table) {}
explicit DeterminizeFstOptions(float delta = kDelta,
Label subsequential_label = 0,
DeterminizeType type = DETERMINIZE_FUNCTIONAL,
bool increment_subsequential_label = false,
Filter *filter = nullptr,
StateTable *state_table = nullptr)
: delta(delta),
subsequential_label(subsequential_label),
type(type),
increment_subsequential_label(increment_subsequential_label),
filter(filter),
state_table(state_table) {}
};
namespace internal {
// Implementation of delayed DeterminizeFst. This base class is
// common to the variants that implement acceptor and transducer
// determinization.
template <class Arc>
class DeterminizeFstImplBase : public CacheImpl<Arc> {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using FstImpl<Arc>::SetType;
using FstImpl<Arc>::SetProperties;
using FstImpl<Arc>::Properties;
using FstImpl<Arc>::SetInputSymbols;
using FstImpl<Arc>::SetOutputSymbols;
using CacheBaseImpl<CacheState<Arc>>::HasStart;
using CacheBaseImpl<CacheState<Arc>>::HasFinal;
using CacheBaseImpl<CacheState<Arc>>::HasArcs;
using CacheBaseImpl<CacheState<Arc>>::SetFinal;
using CacheBaseImpl<CacheState<Arc>>::SetStart;
template <class CommonDivisor, class Filter, class StateTable>
DeterminizeFstImplBase(
const Fst<Arc> &fst,
const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
: CacheImpl<Arc>(opts), fst_(fst.Copy()) {
SetType("determinize");
const auto iprops = fst.Properties(kFstProperties, false);
const auto dprops =
DeterminizeProperties(iprops, opts.subsequential_label != 0,
opts.type == DETERMINIZE_NONFUNCTIONAL
? opts.increment_subsequential_label
: true);
SetProperties(Filter::Properties(dprops), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
DeterminizeFstImplBase(const DeterminizeFstImplBase &impl)
: CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)) {
SetType("determinize");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
virtual DeterminizeFstImplBase *Copy() const = 0;
StateId Start() {
if (!HasStart()) {
const auto start = ComputeStart();
if (start != kNoStateId) SetStart(start);
}
return CacheImpl<Arc>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
return CacheImpl<Arc>::Final(s);
}
virtual void Expand(StateId s) = 0;
size_t NumArcs(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumOutputEpsilons(s);
}
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
if (!HasArcs(s)) Expand(s);
CacheImpl<Arc>::InitArcIterator(s, data);
}
virtual StateId ComputeStart() = 0;
virtual Weight ComputeFinal(StateId s) = 0;
const Fst<Arc> &GetFst() const { return *fst_; }
private:
std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
};
// Implementation of delayed determinization for weighted acceptors.
template <class Arc, class CommonDivisor, class Filter, class StateTable>
class DeterminizeFsaImpl : public DeterminizeFstImplBase<Arc> {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using FilterState = typename Filter::FilterState;
using StateTuple = internal::DeterminizeStateTuple<Arc, FilterState>;
using Element = typename StateTuple::Element;
using Subset = typename StateTuple::Subset;
using LabelMap = typename Filter::LabelMap;
using FstImpl<Arc>::SetProperties;
using DeterminizeFstImplBase<Arc>::GetFst;
using DeterminizeFstImplBase<Arc>::SetArcs;
DeterminizeFsaImpl(
const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
std::vector<Weight> *out_dist,
const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
: DeterminizeFstImplBase<Arc>(fst, opts),
delta_(opts.delta),
in_dist_(in_dist),
out_dist_(out_dist),
filter_(opts.filter ? opts.filter : new Filter(fst)),
state_table_(opts.state_table ? opts.state_table : new StateTable()) {
if (!fst.Properties(kAcceptor, true)) {
FSTERROR() << "DeterminizeFst: Argument not an acceptor";
SetProperties(kError, kError);
}
if (!(Weight::Properties() & kLeftSemiring)) {
FSTERROR() << "DeterminizeFst: Weight must be left distributive: "
<< Weight::Type();
SetProperties(kError, kError);
}
if (out_dist_) out_dist_->clear();
}
DeterminizeFsaImpl(const DeterminizeFsaImpl &impl)
: DeterminizeFstImplBase<Arc>(impl),
delta_(impl.delta_),
in_dist_(nullptr),
out_dist_(nullptr),
filter_(new Filter(*impl.filter_, &GetFst())),
state_table_(new StateTable(*impl.state_table_)) {
if (impl.out_dist_) {
FSTERROR() << "DeterminizeFsaImpl: Cannot copy with out_dist vector";
SetProperties(kError, kError);
}
}
DeterminizeFsaImpl *Copy() const override {
return new DeterminizeFsaImpl(*this);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) && (GetFst().Properties(kError, false))) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
StateId ComputeStart() override {
const auto s = GetFst().Start();
if (s == kNoStateId) return kNoStateId;
auto tuple = fst::make_unique_for_overwrite<StateTuple>();
tuple->subset.emplace_front(s, Weight::One());
tuple->filter_state = filter_->Start();
return FindState(std::move(tuple));
}
Weight ComputeFinal(StateId s) override {
const auto *tuple = state_table_->Tuple(s);
filter_->SetState(s, *tuple);
auto final_weight = Weight::Zero();
for (const auto &element : tuple->subset) {
final_weight =
Plus(final_weight,
Times(element.weight, GetFst().Final(element.state_id)));
final_weight = filter_->FilterFinal(final_weight, element);
if (!final_weight.Member()) SetProperties(kError, kError);
}
return final_weight;
}
StateId FindState(std::unique_ptr<StateTuple> tuple) {
const auto &subset = tuple->subset;
const auto s = state_table_->FindState(std::move(tuple));
if (in_dist_ && out_dist_->size() <= s) {
out_dist_->push_back(ComputeDistance(subset));
}
return s;
}
// Computes distance from a state to the final states in the DFA given the
// distances in the NFA.
Weight ComputeDistance(const Subset &subset) {
auto outd = Weight::Zero();
for (const auto &element : subset) {
const auto ind =
(element.state_id < in_dist_->size() ? (*in_dist_)[element.state_id]
: Weight::Zero());
outd = Plus(outd, Times(element.weight, ind));
}
return outd;
}
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) override {
LabelMap label_map;
GetLabelMap(s, &label_map);
for (auto &[unused_label, arc] : label_map) {
AddArc(s, std::move(arc));
}
SetArcs(s);
}
private:
using DetArc = internal::DeterminizeArc<StateTuple>;
// Constructs proto-determinization transition, including destination subset,
// per label.
void GetLabelMap(StateId s, LabelMap *label_map) {
const auto *src_tuple = state_table_->Tuple(s);
filter_->SetState(s, *src_tuple);
for (const auto &src_element : src_tuple->subset) {
for (ArcIterator<Fst<Arc>> aiter(GetFst(), src_element.state_id);
!aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
Element dest_element(arc.nextstate,
Times(src_element.weight, arc.weight));
filter_->FilterArc(arc, src_element, std::move(dest_element),
label_map);
}
}
for (auto &[unused_label, arc] : *label_map) {
NormArc(&arc);
}
}
// Sorts subsets and removes duplicate elements, normalizing transition and
// subset weights.
void NormArc(DetArc *det_arc) {
auto &dest_subset = det_arc->dest_tuple->subset;
dest_subset.sort();
auto piter = dest_subset.begin();
for (auto diter = dest_subset.begin(); diter != dest_subset.end(); ) {
auto &dest_element = *diter;
auto &prev_element = *piter;
// Computes arc weight.
det_arc->weight = common_divisor_(det_arc->weight, dest_element.weight);
if (piter != diter && dest_element.state_id == prev_element.state_id) {
// Found duplicate state: sums state weight and deletes duplicate.
prev_element.weight = Plus(prev_element.weight, dest_element.weight);
if (!prev_element.weight.Member()) SetProperties(kError, kError);
++diter;
dest_subset.erase_after(piter);
} else {
piter = diter;
++diter;
}
}
// Divides out label weight from destination subset elements, quantizing to
// ensure comparisons are effective.
for (auto &dest_element : dest_subset) {
dest_element.weight =
Divide(dest_element.weight, det_arc->weight, DIVIDE_LEFT);
dest_element.weight = dest_element.weight.Quantize(delta_);
}
}
// Adds an arc from state S to the destination state associated with state
// tuple in det_arc as created by GetLabelMap.
void AddArc(StateId s, DetArc &&det_arc) {
CacheImpl<Arc>::EmplaceArc(s, det_arc.label, det_arc.label,
std::move(det_arc.weight),
FindState(std::move(det_arc.dest_tuple)));
}
float delta_; // Quantization delta for weights.
const std::vector<Weight> *in_dist_; // Distance to final NFA states.
std::vector<Weight> *out_dist_; // Distance to final DFA states.
static const CommonDivisor common_divisor_;
std::unique_ptr<Filter> filter_;
std::unique_ptr<StateTable> state_table_;
};
template <class Arc, class CommonDivisor, class Filter, class StateTable>
const CommonDivisor DeterminizeFsaImpl<Arc, CommonDivisor, Filter,
StateTable>::common_divisor_{};
// Implementation of delayed determinization for transducers. Transducer
// determinization is implemented by mapping the input to the Gallic semiring as
// an acceptor whose weights contain the output strings and using acceptor
// determinization above to determinize that acceptor.
template <class Arc, GallicType G, class CommonDivisor, class Filter,
class StateTable>
class DeterminizeFstImpl : public DeterminizeFstImplBase<Arc> {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using ToMapper = ToGallicMapper<Arc, G>;
using ToArc = typename ToMapper::ToArc;
using ToFst = ArcMapFst<Arc, ToArc, ToMapper>;
using FromMapper = FromGallicMapper<Arc, G>;
using FromFst = ArcMapFst<ToArc, Arc, FromMapper>;
using ToCommonDivisor = GallicCommonDivisor<Label, Weight, G, CommonDivisor>;
using ToFilter = typename Filter::template rebind<ToArc>::Other;
using ToFilterState = typename ToFilter::FilterState;
using ToStateTable =
typename StateTable::template rebind<ToArc, ToFilterState>::Other;
using FactorIterator = GallicFactor<Label, Weight, G>;
using FstImpl<Arc>::SetProperties;
using DeterminizeFstImplBase<Arc>::GetFst;
using CacheBaseImpl<CacheState<Arc>>::GetCacheGc;
using CacheBaseImpl<CacheState<Arc>>::GetCacheLimit;
DeterminizeFstImpl(
const Fst<Arc> &fst,
const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable> &opts)
: DeterminizeFstImplBase<Arc>(fst, opts),
delta_(opts.delta),
subsequential_label_(opts.subsequential_label),
increment_subsequential_label_(opts.increment_subsequential_label) {
if (opts.state_table) {
FSTERROR() << "DeterminizeFst: "
<< "A state table can not be passed with transducer input";
SetProperties(kError, kError);
return;
}
// Takes ownership of filter.
Init(GetFst(), fst::WrapUnique(opts.filter));
}
DeterminizeFstImpl(const DeterminizeFstImpl &impl)
: DeterminizeFstImplBase<Arc>(impl),
delta_(impl.delta_),
subsequential_label_(impl.subsequential_label_),
increment_subsequential_label_(impl.increment_subsequential_label_) {
Init(GetFst(), nullptr);
}
DeterminizeFstImpl *Copy() const override {
return new DeterminizeFstImpl(*this);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) && (GetFst().Properties(kError, false) ||
from_fst_->Properties(kError, false))) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
StateId ComputeStart() override { return from_fst_->Start(); }
Weight ComputeFinal(StateId s) override { return from_fst_->Final(s); }
void Expand(StateId s) override {
for (ArcIterator<FromFst> aiter(*from_fst_, s); !aiter.Done();
aiter.Next()) {
CacheImpl<Arc>::PushArc(s, aiter.Value());
}
CacheImpl<Arc>::SetArcs(s);
}
private:
// Initialization of transducer determinization implementation, which is
// defined after DeterminizeFst since it calls it.
void Init(const Fst<Arc> &fst, std::unique_ptr<Filter> filter);
float delta_;
Label subsequential_label_;
bool increment_subsequential_label_;
std::unique_ptr<FromFst> from_fst_;
};
} // namespace internal
// Determinizes a weighted transducer. This version is a delayed
// FST. The result will be an equivalent FST that has the property
// that no state has two transitions with the same input label.
// For this algorithm, epsilon transitions are treated as regular
// symbols (cf. RmEpsilon).
//
// The transducer must be functional. The weights must be (weakly) left
// divisible (valid for TropicalWeight and LogWeight for instance) and be
// zero-sum-free if for all a, b: (Plus(a, b) == 0) => a = b = 0.
//
// Complexity:
//
// Determinizable: exponential (polynomial in the size of the output).
// Non-determinizable: does not terminate.
//
// The determinizable automata include all unweighted and all acyclic input.
//
// For more information, see:
//
// Mohri, M. 1997. Finite-state transducers in language and speech processing.
// Computational Linguistics 23(2): 269-311.
//
// This class attaches interface to implementation and handles reference
// counting, delegating most methods to ImplToFst.
template <class A>
class DeterminizeFst : public ImplToFst<internal::DeterminizeFstImplBase<A>> {
public:
using Arc = A;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using Impl = internal::DeterminizeFstImplBase<Arc>;
friend class ArcIterator<DeterminizeFst<Arc>>;
friend class StateIterator<DeterminizeFst<Arc>>;
template <class B, GallicType G, class CommonDivisor, class Filter,
class StateTable>
friend class DeterminizeFstImpl;
explicit DeterminizeFst(const Fst<A> &fst)
: ImplToFst<Impl>(CreateImpl(fst)) {}
template <class CommonDivisor, class Filter, class StateTable>
explicit DeterminizeFst(
const Fst<Arc> &fst,
const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
&opts =
DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
: ImplToFst<Impl>(CreateImpl(fst, opts)) {}
// This acceptor-only version additionally computes the distance to final
// states in the output if provided with those distances for the input; this
// is useful for e.g., computing the k-shortest unique paths.
template <class CommonDivisor, class Filter, class StateTable>
DeterminizeFst(
const Fst<Arc> &fst, const std::vector<Weight> *in_dist,
std::vector<Weight> *out_dist,
const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
&opts =
DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>())
: ImplToFst<Impl>(
std::make_shared<internal::DeterminizeFsaImpl<Arc, CommonDivisor,
Filter, StateTable>>(
fst, in_dist, out_dist, opts)) {
if (!fst.Properties(kAcceptor, true)) {
FSTERROR() << "DeterminizeFst: "
<< "Distance to final states computed for acceptors only";
GetMutableImpl()->SetProperties(kError, kError);
}
}
// See Fst<>::Copy() for doc.
DeterminizeFst(const DeterminizeFst &fst, bool safe = false)
: ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
: fst.GetSharedImpl()) {}
// Get a copy of this DeterminizeFst. See Fst<>::Copy() for further doc.
DeterminizeFst *Copy(bool safe = false) const override {
return new DeterminizeFst(*this, safe);
}
inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
GetMutableImpl()->InitArcIterator(s, data);
}
private:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
static std::shared_ptr<Impl> CreateImpl(const Fst<Arc> &fst) {
using D = DefaultCommonDivisor<Weight>;
using F = DefaultDeterminizeFilter<Arc>;
using T = DefaultDeterminizeStateTable<Arc, typename F::FilterState>;
const DeterminizeFstOptions<Arc, D, F, T> opts;
return CreateImpl(fst, opts);
}
template <class CommonDivisor, class Filter, class StateTable>
static std::shared_ptr<Impl> CreateImpl(
const Fst<Arc> &fst,
const DeterminizeFstOptions<Arc, CommonDivisor, Filter, StateTable>
&opts) {
if (fst.Properties(kAcceptor, true)) {
// Calls implementation for acceptors.
return std::make_shared<
internal::DeterminizeFsaImpl<Arc, CommonDivisor, Filter, StateTable>>(
fst, nullptr, nullptr, opts);
} else if (opts.type == DETERMINIZE_DISAMBIGUATE) {
if constexpr (IsPath<Weight>::value) {
// Calls disambiguating implementation for non-functional transducers.
return std::make_shared<internal::DeterminizeFstImpl<
Arc, GALLIC_MIN, CommonDivisor, Filter, StateTable>>(fst, opts);
} else {
FSTERROR() << "DeterminizeFst: Weight needs to have the path "
<< "property to disambiguate output: " << Weight::Type();
// Return an error Impl.
const ConstFst<Arc> empty_fst;
auto rv = std::make_shared<internal::DeterminizeFstImpl<
Arc, GALLIC, CommonDivisor, Filter, StateTable>>(empty_fst, opts);
rv->SetProperties(kError, kError);
return rv;
}
} else if (opts.type == DETERMINIZE_FUNCTIONAL) {
// Calls implementation for functional transducers.
return std::make_shared<internal::DeterminizeFstImpl<
Arc, GALLIC_RESTRICT, CommonDivisor, Filter, StateTable>>(fst, opts);
} else { // opts.type == DETERMINIZE_NONFUNCTIONAL
// Calls implementation for non functional transducers;
return std::make_shared<internal::DeterminizeFstImpl<
Arc, GALLIC, CommonDivisor, Filter, StateTable>>(fst, opts);
}
}
DeterminizeFst &operator=(const DeterminizeFst &) = delete;
};
namespace internal {
// Initialization of transducer determinization implementation, which is defined
// after DeterminizeFst since it calls it.
template <class A, GallicType G, class D, class F, class T>
void DeterminizeFstImpl<A, G, D, F, T>::Init(const Fst<A> &fst,
std::unique_ptr<F> filter) {
// Mapper to an acceptor.
const ToFst to_fst(fst);
auto *to_filter = filter ? new ToFilter(to_fst, std::move(filter)) : nullptr;
// This recursive call terminates since it is to a (non-recursive)
// different constructor.
const CacheOptions copts(GetCacheGc(), GetCacheLimit());
const DeterminizeFstOptions<ToArc, ToCommonDivisor, ToFilter, ToStateTable>
dopts(copts, delta_, 0, DETERMINIZE_FUNCTIONAL, false, to_filter);
// Uses acceptor-only constructor to avoid template recursion.
const DeterminizeFst<ToArc> det_fsa(to_fst, nullptr, nullptr, dopts);
// Mapper back to transducer.
const FactorWeightOptions<ToArc> fopts(
CacheOptions(true, 0), delta_, kFactorFinalWeights, subsequential_label_,
subsequential_label_, increment_subsequential_label_,
increment_subsequential_label_);
const FactorWeightFst<ToArc, FactorIterator> factored_fst(det_fsa, fopts);
from_fst_ =
std::make_unique<FromFst>(factored_fst, FromMapper(subsequential_label_));
}
} // namespace internal
// Specialization for DeterminizeFst.
template <class Arc>
class StateIterator<DeterminizeFst<Arc>>
: public CacheStateIterator<DeterminizeFst<Arc>> {
public:
explicit StateIterator(const DeterminizeFst<Arc> &fst)
: CacheStateIterator<DeterminizeFst<Arc>>(fst, fst.GetMutableImpl()) {}
};
// Specialization for DeterminizeFst.
template <class Arc>
class ArcIterator<DeterminizeFst<Arc>>
: public CacheArcIterator<DeterminizeFst<Arc>> {
public:
using StateId = typename Arc::StateId;
ArcIterator(const DeterminizeFst<Arc> &fst, StateId s)
: CacheArcIterator<DeterminizeFst<Arc>>(fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class Arc>
inline void DeterminizeFst<Arc>::InitStateIterator(
StateIteratorData<Arc> *data) const {
data->base = std::make_unique<StateIterator<DeterminizeFst<Arc>>>(*this);
}
// Useful aliases when using StdArc.
using StdDeterminizeFst = DeterminizeFst<StdArc>;
template <class Arc>
struct DeterminizeOptions {
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
float delta; // Quantization delta for subset weights.
Weight weight_threshold; // Pruning weight threshold.
StateId state_threshold; // Pruning state threshold.
Label subsequential_label; // Label used for residual final output.
DeterminizeType type;
bool increment_subsequential_label; // When creating several subsequential
// arcs at a given state, make their
// label distinct by incrementation?
explicit DeterminizeOptions(float delta = kDelta,
Weight weight_threshold = Weight::Zero(),
StateId state_threshold = kNoStateId,
Label subsequential_label = 0,
DeterminizeType type = DETERMINIZE_FUNCTIONAL,
bool increment_subsequential_label = false)
: delta(delta),
weight_threshold(std::move(weight_threshold)),
state_threshold(state_threshold),
subsequential_label(subsequential_label),
type(type),
increment_subsequential_label(increment_subsequential_label) {}
};
// Determinizes a weighted transducer. This version writes the
// determinized Fst to an output MutableFst. The result will be an
// equivalent FST that has the property that no state has two
// transitions with the same input label. For this algorithm, epsilon
// transitions are treated as regular symbols (cf. RmEpsilon).
//
// The transducer must be functional. The weights must be (weakly)
// left divisible (valid for TropicalWeight and LogWeight).
//
// Complexity:
//
// Determinizable: exponential (polynomial in the size of the output)
// Non-determinizable: does not terminate
//
// The determinizable automata include all unweighted and all acyclic input.
template <class Arc>
void Determinize(
const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
const DeterminizeOptions<Arc> &opts = DeterminizeOptions<Arc>()) {
using Weight = typename Arc::Weight;
DeterminizeFstOptions<Arc> nopts;
nopts.delta = opts.delta;
nopts.subsequential_label = opts.subsequential_label;
nopts.type = opts.type;
nopts.increment_subsequential_label = opts.increment_subsequential_label;
nopts.gc_limit = 0; // Caches only the last state for fastest copy.
if (opts.weight_threshold != Weight::Zero() ||
opts.state_threshold != kNoStateId) {
if constexpr (IsPath<Weight>::value) {
if (ifst.Properties(kAcceptor, false)) {
std::vector<Weight> idistance;
std::vector<Weight> odistance;
ShortestDistance(ifst, &idistance, true);
DeterminizeFst<Arc> dfst(ifst, &idistance, &odistance, nopts);
PruneOptions<Arc, AnyArcFilter<Arc>> popts(
opts.weight_threshold, opts.state_threshold, AnyArcFilter<Arc>(),
&odistance);
Prune(dfst, ofst, popts);
} else {
*ofst = DeterminizeFst<Arc>(ifst, nopts);
Prune(ofst, opts.weight_threshold, opts.state_threshold);
}
} else {
FSTERROR() << "Determinize: Weight needs to have the path "
<< "property to use pruning options: " << Weight::Type();
ofst->SetProperties(kError, kError);
}
} else {
*ofst = DeterminizeFst<Arc>(ifst, nopts);
}
}
} // namespace fst
#endif // FST_DETERMINIZE_H_