You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

640 lines
19 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Class to map over/transform states e.g., sort transitions.
//
// Consider using when operation does not change the number of states.
#ifndef FST_STATE_MAP_H_
#define FST_STATE_MAP_H_
#include <sys/types.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arc-map.h>
#include <fst/arc.h>
#include <fst/cache.h>
#include <fst/expanded-fst.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
namespace fst {
// StateMapper Interface. The class determines how states are mapped; useful for
// implementing operations that do not change the number of states.
//
// class StateMapper {
// public:
// using FromArc = ...;
// using ToArc = ...;
//
// // Typical constructor.
// StateMapper(const Fst<FromArc> &fst);
//
// // Required copy constructor that allows updating FST argument;
// // pass only if relevant and changed.
// StateMapper(const StateMapper &mapper, const Fst<FromArc> *fst = 0);
//
// // Specifies initial state of result.
// ToArc::StateId Start() const;
// // Specifies state's final weight in result.
// ToArc::Weight Final(ToArc::StateId state) const;
//
// // These methods iterate through a state's arcs in result.
//
// // Specifies state to iterate over.
// void SetState(ToArc::StateId state);
//
// // End of arcs?
// bool Done() const;
//
// // Current arc.
// const ToArc &Value() const;
//
// // Advances to next arc (when !Done)
// void Next();
//
// // Specifies input symbol table action the mapper requires (see above).
// MapSymbolsAction InputSymbolsAction() const;
//
// // Specifies output symbol table action the mapper requires (see above).
// MapSymbolsAction OutputSymbolsAction() const;
//
// // This specifies the known properties of an FST mapped by this
// // mapper. It takes as argument the input FST's known properties.
// uint64_t Properties(uint64_t props) const;
// };
//
// We include a various state map versions below. One dimension of variation is
// whether the mapping mutates its input, writes to a new result FST, or is an
// on-the-fly Fst. Another dimension is how we pass the mapper. We allow passing
// the mapper by pointer for cases that we need to change the state of the
// user's mapper. We also include map versions that pass the mapper by value or
// const reference when this suffices.
// Maps an arc type A using a mapper function object C, passed by pointer. This
// version modifies the input FST.
template <class A, class C>
void StateMap(MutableFst<A> *fst, C *mapper) {
if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
fst->SetInputSymbols(nullptr);
}
if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
fst->SetOutputSymbols(nullptr);
}
if (fst->Start() == kNoStateId) return;
const auto props = fst->Properties(kFstProperties, false);
fst->SetStart(mapper->Start());
for (StateIterator<Fst<A>> siter(*fst); !siter.Done(); siter.Next()) {
const auto state = siter.Value();
mapper->SetState(state);
fst->DeleteArcs(state);
for (; !mapper->Done(); mapper->Next()) {
fst->AddArc(state, mapper->Value());
}
fst->SetFinal(state, mapper->Final(state));
}
fst->SetProperties(mapper->Properties(props), kFstProperties);
}
// Maps an arc type A using a mapper function object C, passed by value.
// This version modifies the input FST.
template <class A, class C>
void StateMap(MutableFst<A> *fst, C mapper) {
StateMap(fst, &mapper);
}
// Maps an arc type A to an arc type B using mapper functor C, passed by
// pointer. This version writes to an output FST.
template <class A, class B, class C>
void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C *mapper) {
ofst->DeleteStates();
if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
ofst->SetInputSymbols(ifst.InputSymbols());
} else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
ofst->SetInputSymbols(nullptr);
}
if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
ofst->SetOutputSymbols(ifst.OutputSymbols());
} else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
ofst->SetOutputSymbols(nullptr);
}
const auto iprops = ifst.Properties(kCopyProperties, false);
if (ifst.Start() == kNoStateId) {
if (iprops & kError) ofst->SetProperties(kError, kError);
return;
}
// Adds all states.
if (std::optional<typename A::StateId> num_states = ifst.NumStatesIfKnown()) {
ofst->ReserveStates(*num_states);
}
for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
ofst->AddState();
}
ofst->SetStart(mapper->Start());
for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
const auto state = siter.Value();
mapper->SetState(state);
for (; !mapper->Done(); mapper->Next()) {
ofst->AddArc(state, mapper->Value());
}
ofst->SetFinal(state, mapper->Final(state));
}
const auto oprops = ofst->Properties(kFstProperties, false);
ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
}
// Maps an arc type A to an arc type B using mapper functor object C, passed by
// value. This version writes to an output FST.
template <class A, class B, class C>
void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
StateMap(ifst, ofst, &mapper);
}
using StateMapFstOptions = CacheOptions;
template <class A, class B, class C>
class StateMapFst;
// Facade around StateIteratorBase<A> inheriting from StateIteratorBase<B>.
template <class A, class B>
class StateMapStateIteratorBase : public StateIteratorBase<B> {
public:
using Arc = B;
using StateId = typename Arc::StateId;
explicit StateMapStateIteratorBase(std::unique_ptr<StateIteratorBase<A>> base)
: base_(std::move(base)) {}
bool Done() const final { return base_->Done(); }
StateId Value() const final { return base_->Value(); }
void Next() final { base_->Next(); }
void Reset() final { base_->Reset(); }
private:
std::unique_ptr<StateIteratorBase<A>> base_;
StateMapStateIteratorBase() = delete;
};
namespace internal {
// Implementation of delayed StateMapFst.
template <class A, class B, class C>
class StateMapFstImpl : public CacheImpl<B> {
public:
using Arc = B;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using FstImpl<B>::SetType;
using FstImpl<B>::SetProperties;
using FstImpl<B>::SetInputSymbols;
using FstImpl<B>::SetOutputSymbols;
using CacheImpl<B>::PushArc;
using CacheImpl<B>::HasArcs;
using CacheImpl<B>::HasFinal;
using CacheImpl<B>::HasStart;
using CacheImpl<B>::SetArcs;
using CacheImpl<B>::SetFinal;
using CacheImpl<B>::SetStart;
friend class StateIterator<StateMapFst<A, B, C>>;
StateMapFstImpl(const Fst<A> &fst, const C &mapper,
const StateMapFstOptions &opts)
: CacheImpl<B>(opts),
fst_(fst.Copy()),
mapper_(new C(mapper, fst_.get())),
own_mapper_(true) {
Init();
}
StateMapFstImpl(const Fst<A> &fst, C *mapper, const StateMapFstOptions &opts)
: CacheImpl<B>(opts),
fst_(fst.Copy()),
mapper_(mapper),
own_mapper_(false) {
Init();
}
StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl)
: CacheImpl<B>(impl),
fst_(impl.fst_->Copy(true)),
mapper_(new C(*impl.mapper_, fst_.get())),
own_mapper_(true) {
Init();
}
~StateMapFstImpl() override {
if (own_mapper_) delete mapper_;
}
StateId Start() {
if (!HasStart()) SetStart(mapper_->Start());
return CacheImpl<B>::Start();
}
Weight Final(StateId state) {
if (!HasFinal(state)) SetFinal(state, mapper_->Final(state));
return CacheImpl<B>::Final(state);
}
size_t NumArcs(StateId state) {
if (!HasArcs(state)) Expand(state);
return CacheImpl<B>::NumArcs(state);
}
size_t NumInputEpsilons(StateId state) {
if (!HasArcs(state)) Expand(state);
return CacheImpl<B>::NumInputEpsilons(state);
}
size_t NumOutputEpsilons(StateId state) {
if (!HasArcs(state)) Expand(state);
return CacheImpl<B>::NumOutputEpsilons(state);
}
void InitStateIterator(StateIteratorData<B> *datb) const {
StateIteratorData<A> data;
fst_->InitStateIterator(&data);
datb->base = data.base ? std::make_unique<StateMapStateIteratorBase<A, B>>(
std::move(data.base))
: nullptr;
datb->nstates = data.nstates;
}
void InitArcIterator(StateId state, ArcIteratorData<B> *data) {
if (!HasArcs(state)) Expand(state);
CacheImpl<B>::InitArcIterator(state, data);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) && (fst_->Properties(kError, false) ||
(mapper_->Properties(0) & kError))) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
void Expand(StateId state) {
// Adds exiting arcs.
for (mapper_->SetState(state); !mapper_->Done(); mapper_->Next()) {
PushArc(state, mapper_->Value());
}
SetArcs(state);
}
const Fst<A> *GetFst() const { return fst_.get(); }
private:
void Init() {
SetType("statemap");
if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
SetInputSymbols(fst_->InputSymbols());
} else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
SetInputSymbols(nullptr);
}
if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
SetOutputSymbols(fst_->OutputSymbols());
} else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
SetOutputSymbols(nullptr);
}
const auto props = fst_->Properties(kCopyProperties, false);
SetProperties(mapper_->Properties(props));
}
std::unique_ptr<const Fst<A>> fst_;
C *mapper_;
bool own_mapper_;
};
} // namespace internal
// Maps an arc type A to an arc type B using Mapper function object
// C. This version is a delayed FST.
template <class A, class B, class C>
class StateMapFst : public ImplToFst<internal::StateMapFstImpl<A, B, C>> {
public:
friend class ArcIterator<StateMapFst<A, B, C>>;
using Arc = B;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using Impl = internal::StateMapFstImpl<A, B, C>;
StateMapFst(const Fst<A> &fst, const C &mapper,
const StateMapFstOptions &opts)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
StateMapFst(const Fst<A> &fst, C *mapper, const StateMapFstOptions &opts)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
StateMapFst(const Fst<A> &fst, const C &mapper)
: ImplToFst<Impl>(
std::make_shared<Impl>(fst, mapper, StateMapFstOptions())) {}
StateMapFst(const Fst<A> &fst, C *mapper)
: ImplToFst<Impl>(
std::make_shared<Impl>(fst, mapper, StateMapFstOptions())) {}
// See Fst<>::Copy() for doc.
StateMapFst(const StateMapFst &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this StateMapFst. See Fst<>::Copy() for further doc.
StateMapFst *Copy(bool safe = false) const override {
return new StateMapFst(*this, safe);
}
void InitStateIterator(StateIteratorData<B> *data) const override {
GetImpl()->InitStateIterator(data);
}
void InitArcIterator(StateId state, ArcIteratorData<B> *data) const override {
GetMutableImpl()->InitArcIterator(state, data);
}
protected:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
private:
StateMapFst &operator=(const StateMapFst &) = delete;
};
// Specialization for StateMapFst.
template <class A, class B, class C>
class ArcIterator<StateMapFst<A, B, C>>
: public CacheArcIterator<StateMapFst<A, B, C>> {
public:
using StateId = typename A::StateId;
ArcIterator(const StateMapFst<A, B, C> &fst, StateId state)
: CacheArcIterator<StateMapFst<A, B, C>>(fst.GetMutableImpl(), state) {
if (!fst.GetImpl()->HasArcs(state)) fst.GetMutableImpl()->Expand(state);
}
};
// Utility mappers.
// Mapper that returns its input.
template <class Arc>
class IdentityStateMapper {
public:
using FromArc = Arc;
using ToArc = Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
explicit IdentityStateMapper(const Fst<Arc> &fst) : fst_(fst) {}
// Allows updating FST argument; pass only if changed.
IdentityStateMapper(const IdentityStateMapper<Arc> &mapper,
const Fst<Arc> *fst = nullptr)
: fst_(fst ? *fst : mapper.fst_) {}
StateId Start() const { return fst_.Start(); }
Weight Final(StateId state) const { return fst_.Final(state); }
void SetState(StateId state) {
aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst_, state);
}
bool Done() const { return aiter_->Done(); }
const Arc &Value() const { return aiter_->Value(); }
void Next() { aiter_->Next(); }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
uint64_t Properties(uint64_t props) const { return props; }
private:
const Fst<Arc> &fst_;
std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_;
};
template <class Arc>
class ArcSumMapper {
public:
using FromArc = Arc;
using ToArc = Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
explicit ArcSumMapper(const Fst<Arc> &fst) : fst_(fst), i_(0) {}
// Allows updating FST argument; pass only if changed.
ArcSumMapper(const ArcSumMapper<Arc> &mapper, const Fst<Arc> *fst = nullptr)
: fst_(fst ? *fst : mapper.fst_), i_(0) {}
StateId Start() const { return fst_.Start(); }
Weight Final(StateId state) const { return fst_.Final(state); }
void SetState(StateId state) {
i_ = 0;
arcs_.clear();
arcs_.reserve(fst_.NumArcs(state));
for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
aiter.Next()) {
arcs_.push_back(aiter.Value());
}
// First sorts the exiting arcs by input label, output label and destination
// state and then sums weights of arcs with the same input label, output
// label, and destination state.
std::sort(arcs_.begin(), arcs_.end(), comp_);
size_t narcs = 0;
for (const auto &arc : arcs_) {
if (narcs > 0 && equal_(arc, arcs_[narcs - 1])) {
arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, arc.weight);
} else {
arcs_[narcs] = arc;
++narcs;
}
}
arcs_.resize(narcs);
}
bool Done() const { return i_ >= arcs_.size(); }
const Arc &Value() const { return arcs_[i_]; }
void Next() { ++i_; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
uint64_t Properties(uint64_t props) const {
return props & kArcSortProperties & kDeleteArcsProperties &
kWeightInvariantProperties;
}
private:
struct Compare {
bool operator()(const Arc &x, const Arc &y) const {
if (x.ilabel < y.ilabel) return true;
if (x.ilabel > y.ilabel) return false;
if (x.olabel < y.olabel) return true;
if (x.olabel > y.olabel) return false;
if (x.nextstate < y.nextstate) return true;
if (x.nextstate > y.nextstate) return false;
return false;
}
};
struct Equal {
bool operator()(const Arc &x, const Arc &y) const {
return (x.ilabel == y.ilabel && x.olabel == y.olabel &&
x.nextstate == y.nextstate);
}
};
const Fst<Arc> &fst_;
Compare comp_;
Equal equal_;
std::vector<Arc> arcs_;
ssize_t i_; // Current arc position.
ArcSumMapper &operator=(const ArcSumMapper &) = delete;
};
template <class Arc>
class ArcUniqueMapper {
public:
using FromArc = Arc;
using ToArc = Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
explicit ArcUniqueMapper(const Fst<Arc> &fst) : fst_(fst), i_(0) {}
// Allows updating FST argument; pass only if changed.
ArcUniqueMapper(const ArcUniqueMapper<Arc> &mapper,
const Fst<Arc> *fst = nullptr)
: fst_(fst ? *fst : mapper.fst_), i_(0) {}
StateId Start() const { return fst_.Start(); }
Weight Final(StateId state) const { return fst_.Final(state); }
void SetState(StateId state) {
i_ = 0;
arcs_.clear();
arcs_.reserve(fst_.NumArcs(state));
for (ArcIterator<Fst<Arc>> aiter(fst_, state); !aiter.Done();
aiter.Next()) {
arcs_.push_back(aiter.Value());
}
// First sorts the exiting arcs by input label, output label and destination
// state and then uniques identical arcs.
std::sort(arcs_.begin(), arcs_.end(), comp_);
arcs_.erase(std::unique(arcs_.begin(), arcs_.end(), equal_), arcs_.end());
}
bool Done() const { return i_ >= arcs_.size(); }
const Arc &Value() const { return arcs_[i_]; }
void Next() { ++i_; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
uint64_t Properties(uint64_t props) const {
return props & kArcSortProperties & kDeleteArcsProperties;
}
private:
struct Compare {
bool operator()(const Arc &x, const Arc &y) const {
if (x.ilabel < y.ilabel) return true;
if (x.ilabel > y.ilabel) return false;
if (x.olabel < y.olabel) return true;
if (x.olabel > y.olabel) return false;
if (x.nextstate < y.nextstate) return true;
if (x.nextstate > y.nextstate) return false;
return false;
}
};
struct Equal {
bool operator()(const Arc &x, const Arc &y) const {
return (x.ilabel == y.ilabel && x.olabel == y.olabel &&
x.nextstate == y.nextstate && x.weight == y.weight);
}
};
const Fst<Arc> &fst_;
Compare comp_;
Equal equal_;
std::vector<Arc> arcs_;
size_t i_; // Current arc position.
ArcUniqueMapper &operator=(const ArcUniqueMapper &) = delete;
};
// Useful aliases when using StdArc.
using StdArcSumMapper = ArcSumMapper<StdArc>;
using StdArcUniqueMapper = ArcUniqueMapper<StdArc>;
} // namespace fst
#endif // FST_STATE_MAP_H_