|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Class to map over/transform arcs e.g., change semirings or
// implement project/invert. Consider using when operation does
// not change the number of arcs (except possibly superfinal arcs).
#ifndef FST_ARC_MAP_H_
#define FST_ARC_MAP_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <fst/log.h>
#include <fst/arc.h>
#include <fst/cache.h>
#include <fst/expanded-fst.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/string-weight.h>
#include <fst/symbol-table.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <unordered_map>
namespace fst {
// Determines how final weights are mapped.
enum MapFinalAction { // A final weight is mapped into a final weight. An error is raised if this
// is not possible.
MAP_NO_SUPERFINAL, // A final weight is mapped to an arc to the superfinal state when the result
// cannot be represented as a final weight. The superfinal state will be
// added only if it is needed.
MAP_ALLOW_SUPERFINAL, // A final weight is mapped to an arc to the superfinal state unless the
// result can be represented as a final weight of weight Zero(). The
// superfinal state is always added (if the input is not the empty FST).
MAP_REQUIRE_SUPERFINAL };
// Determines how symbol tables are mapped.
enum MapSymbolsAction { // Symbols should be cleared in the result by the map.
MAP_CLEAR_SYMBOLS, // Symbols should be copied from the input FST by the map.
MAP_COPY_SYMBOLS, // Symbols should not be modified in the result by the map itself.
// (They may set by the mapper).
MAP_NOOP_SYMBOLS };
// The ArcMapper interfaces defines how arcs and final weights are mapped.
// This is useful for implementing operations that apply to each arc separately
// and do not change the number of arcs (except possibly superfinal arcs).
//
// template <class A, class B>
// class ArcMapper {
// public:
// using FromArc = A;
// using ToArc = B;
//
// // Maps an arc type FromArc to arc type ToArc.
// ToArc operator()(const FromArc &arc);
//
// // Specifies final action the mapper requires (see above).
// // The mapper will be passed final weights as arcs of the form
// // Arc(0, 0, weight, kNoStateId).
// MapFinalAction FinalAction() const;
//
// // Specifies input symbol table action the mapper requires (see above).
// MapSymbolsAction InputSymbolsAction() const;
//
// // Specifies output symbol table action the mapper requires (see above).
// MapSymbolsAction OutputSymbolsAction() const;
//
// // This specifies the known properties of an FST mapped by this mapper. It
// takes as argument the input FSTs's known properties.
// uint64_t Properties(uint64_t props) const;
// };
//
// The ArcMap functions and classes below will use the FinalAction()
// method of the mapper to determine how to treat final weights, e.g., whether
// to add a superfinal state. They will use the Properties() method to set the
// result FST properties.
//
// We include a various map versions below. One dimension of variation is
// whether the mapping mutates its input, writes to a new result FST, or is an
// on-the-fly FST. Another dimension is how we pass the mapper. We allow passing
// the mapper by pointer for cases that we need to change the state of the
// user's mapper. This is the case with the EncodeMapper, which is reused
// during decoding. We also include map versions that pass the mapper by value
// or const reference when this suffices.
// Maps an arc type A using a mapper function object C, passed
// by pointer. This version modifies its Fst input.
template <class A, class C> void ArcMap(MutableFst<A> *fst, C *mapper) { using FromArc = A; using ToArc = A; using Weight = typename FromArc::Weight; if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { fst->SetInputSymbols(nullptr); } if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { fst->SetOutputSymbols(nullptr); } if (fst->Start() == kNoStateId) return; const auto props = fst->Properties(kFstProperties, false); const auto final_action = mapper->FinalAction(); auto superfinal = kNoStateId; if (final_action == MAP_REQUIRE_SUPERFINAL) { superfinal = fst->AddState(); fst->SetFinal(superfinal); } for (StateIterator<MutableFst<FromArc>> siter(*fst); !siter.Done(); siter.Next()) { const auto state = siter.Value(); for (MutableArcIterator<MutableFst<FromArc>> aiter(fst, state); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); aiter.SetValue((*mapper)(arc)); } switch (final_action) { case MAP_NO_SUPERFINAL: default: { const FromArc arc(0, 0, fst->Final(state), kNoStateId); const auto final_arc = (*mapper)(arc); if (final_arc.ilabel != 0 || final_arc.olabel != 0) { FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc"; fst->SetProperties(kError, kError); } fst->SetFinal(state, final_arc.weight); break; } case MAP_ALLOW_SUPERFINAL: { if (state != superfinal) { const FromArc arc(0, 0, fst->Final(state), kNoStateId); auto final_arc = (*mapper)(arc); if (final_arc.ilabel != 0 || final_arc.olabel != 0) { // Add a superfinal state if not already done.
if (superfinal == kNoStateId) { superfinal = fst->AddState(); fst->SetFinal(superfinal); } final_arc.nextstate = superfinal; fst->AddArc(state, std::move(final_arc)); fst->SetFinal(state, Weight::Zero()); } else { fst->SetFinal(state, final_arc.weight); } } break; } case MAP_REQUIRE_SUPERFINAL: { if (state != superfinal) { const FromArc arc(0, 0, fst->Final(state), kNoStateId); const auto final_arc = (*mapper)(arc); if (final_arc.ilabel != 0 || final_arc.olabel != 0 || final_arc.weight != Weight::Zero()) { fst->AddArc(state, ToArc(final_arc.ilabel, final_arc.olabel, final_arc.weight, superfinal)); } fst->SetFinal(state, Weight::Zero()); } break; } } } fst->SetProperties(mapper->Properties(props), kFstProperties); }
// Maps an arc type A using a mapper function object C, passed by value. This
// version modifies its FST input.
template <class A, class C> void ArcMap(MutableFst<A> *fst, C mapper) { ArcMap(fst, &mapper); }
// Maps an arc type A to an arc type B using mapper function object C,
// passed by pointer. This version writes the mapped input FST to an
// output MutableFst.
template <class A, class B, class C> void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C *mapper) { using FromArc = A; using StateId = typename FromArc::StateId; ofst->DeleteStates(); if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) { ofst->SetInputSymbols(ifst.InputSymbols()); } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { ofst->SetInputSymbols(nullptr); } if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { ofst->SetOutputSymbols(ifst.OutputSymbols()); } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { ofst->SetOutputSymbols(nullptr); } const auto iprops = ifst.Properties(kCopyProperties, false); if (ifst.Start() == kNoStateId) { if (iprops & kError) ofst->SetProperties(kError, kError); return; } const auto final_action = mapper->FinalAction(); if (std::optional<StateId> num_states = ifst.NumStatesIfKnown()) { ofst->ReserveStates(*num_states + (final_action == MAP_NO_SUPERFINAL ? 0 : 1)); } // Adds all states.
for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) { ofst->AddState(); } StateId superfinal = kNoStateId; if (final_action == MAP_REQUIRE_SUPERFINAL) { superfinal = ofst->AddState(); ofst->SetFinal(superfinal); } for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); if (s == ifst.Start()) ofst->SetStart(s); ofst->ReserveArcs( s, ifst.NumArcs(s) + (final_action != MAP_NO_SUPERFINAL ? 1 : 0)); for (ArcIterator<Fst<A>> aiter(ifst, s); !aiter.Done(); aiter.Next()) { ofst->AddArc(s, (*mapper)(aiter.Value())); } switch (final_action) { case MAP_NO_SUPERFINAL: default: { B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); if (final_arc.ilabel != 0 || final_arc.olabel != 0) { FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc"; ofst->SetProperties(kError, kError); } ofst->SetFinal(s, final_arc.weight); break; } case MAP_ALLOW_SUPERFINAL: { B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); if (final_arc.ilabel != 0 || final_arc.olabel != 0) { // Add a superfinal state if not already done.
if (superfinal == kNoStateId) { superfinal = ofst->AddState(); ofst->SetFinal(superfinal); } final_arc.nextstate = superfinal; ofst->AddArc(s, std::move(final_arc)); ofst->SetFinal(s, B::Weight::Zero()); } else { ofst->SetFinal(s, final_arc.weight); } break; } case MAP_REQUIRE_SUPERFINAL: { B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); if (final_arc.ilabel != 0 || final_arc.olabel != 0 || final_arc.weight != B::Weight::Zero()) { ofst->AddArc(s, B(final_arc.ilabel, final_arc.olabel, final_arc.weight, superfinal)); } ofst->SetFinal(s, B::Weight::Zero()); break; } } } const auto oprops = ofst->Properties(kFstProperties, false); ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); }
// Maps an arc type A to an arc type B using mapper function
// object C, passed by value. This version writes the mapped input
// Fst to an output MutableFst.
template <class A, class B, class C> void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) { ArcMap(ifst, ofst, &mapper); }
struct ArcMapFstOptions : public CacheOptions { // ArcMapFst default caching behaviour is to do no caching. Most mappers are
// cheap and therefore we save memory by not doing caching.
ArcMapFstOptions() : CacheOptions(true, 0) {}
explicit ArcMapFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} };
template <class A, class B, class C> class ArcMapFst;
namespace internal {
// Implementation of delayed ArcMapFst.
template <class A, class B, class C> class ArcMapFstImpl : public CacheImpl<B> { public: using Arc = B; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using FstImpl<B>::SetType; using FstImpl<B>::SetProperties; using FstImpl<B>::SetInputSymbols; using FstImpl<B>::SetOutputSymbols;
using CacheImpl<B>::EmplaceArc; using CacheImpl<B>::HasArcs; using CacheImpl<B>::HasFinal; using CacheImpl<B>::HasStart; using CacheImpl<B>::PushArc; using CacheImpl<B>::SetArcs; using CacheImpl<B>::SetFinal; using CacheImpl<B>::SetStart;
friend class StateIterator<ArcMapFst<A, B, C>>;
ArcMapFstImpl(const Fst<A> &fst, const C &mapper, const ArcMapFstOptions &opts) : CacheImpl<B>(opts), fst_(fst.Copy()), mapper_(new C(mapper)), own_mapper_(true), superfinal_(kNoStateId), nstates_(0) { Init(); }
ArcMapFstImpl(const Fst<A> &fst, C *mapper, const ArcMapFstOptions &opts) : CacheImpl<B>(opts), fst_(fst.Copy()), mapper_(mapper), own_mapper_(false), superfinal_(kNoStateId), nstates_(0) { Init(); }
ArcMapFstImpl(const ArcMapFstImpl<A, B, C> &impl) : CacheImpl<B>(impl), fst_(impl.fst_->Copy(true)), mapper_(new C(*impl.mapper_)), own_mapper_(true), superfinal_(kNoStateId), nstates_(0) { Init(); }
~ArcMapFstImpl() override { if (own_mapper_) delete mapper_; }
StateId Start() { if (!HasStart()) SetStart(FindOState(fst_->Start())); return CacheImpl<B>::Start(); }
Weight Final(StateId s) { if (!HasFinal(s)) { switch (final_action_) { case MAP_NO_SUPERFINAL: default: { const auto final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); if (final_arc.ilabel != 0 || final_arc.olabel != 0) { FSTERROR() << "ArcMapFst: Non-zero arc labels for superfinal arc"; SetProperties(kError, kError); } SetFinal(s, final_arc.weight); break; } case MAP_ALLOW_SUPERFINAL: { if (s == superfinal_) { SetFinal(s); } else { const auto final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); if (final_arc.ilabel == 0 && final_arc.olabel == 0) { SetFinal(s, final_arc.weight); } else { SetFinal(s, Weight::Zero()); } } break; } case MAP_REQUIRE_SUPERFINAL: { SetFinal(s, s == superfinal_ ? Weight::One() : Weight::Zero()); break; } } } return CacheImpl<B>::Final(s); }
size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<B>::NumArcs(s); }
size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<B>::NumInputEpsilons(s); }
size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<B>::NumOutputEpsilons(s); }
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override { if ((mask & kError) && (fst_->Properties(kError, false) || (mapper_->Properties(0) & kError))) { SetProperties(kError, kError); } return FstImpl<Arc>::Properties(mask); }
void InitArcIterator(StateId s, ArcIteratorData<B> *data) { if (!HasArcs(s)) Expand(s); CacheImpl<B>::InitArcIterator(s, data); }
void Expand(StateId s) { // Add exiting arcs.
if (s == superfinal_) { SetArcs(s); return; } for (ArcIterator<Fst<A>> aiter(*fst_, FindIState(s)); !aiter.Done(); aiter.Next()) { auto aarc = aiter.Value(); aarc.nextstate = FindOState(aarc.nextstate); PushArc(s, (*mapper_)(aarc)); }
// Check for superfinal arcs.
if (!HasFinal(s) || Final(s) == Weight::Zero()) { switch (final_action_) { case MAP_NO_SUPERFINAL: default: break; case MAP_ALLOW_SUPERFINAL: { auto final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); if (final_arc.ilabel != 0 || final_arc.olabel != 0) { if (superfinal_ == kNoStateId) superfinal_ = nstates_++; final_arc.nextstate = superfinal_; PushArc(s, std::move(final_arc)); } break; } case MAP_REQUIRE_SUPERFINAL: { const auto final_arc = (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); if (final_arc.ilabel != 0 || final_arc.olabel != 0 || final_arc.weight != B::Weight::Zero()) { EmplaceArc(s, final_arc.ilabel, final_arc.olabel, final_arc.weight, superfinal_); } break; } } } SetArcs(s); }
private: void Init() { SetType("map"); if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) { SetInputSymbols(fst_->InputSymbols()); } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { SetInputSymbols(nullptr); } if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { SetOutputSymbols(fst_->OutputSymbols()); } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { SetOutputSymbols(nullptr); } if (fst_->Start() == kNoStateId) { final_action_ = MAP_NO_SUPERFINAL; SetProperties(kNullProperties); } else { final_action_ = mapper_->FinalAction(); uint64_t props = fst_->Properties(kCopyProperties, false); SetProperties(mapper_->Properties(props)); if (final_action_ == MAP_REQUIRE_SUPERFINAL) superfinal_ = 0; } }
// Maps from output state to input state.
StateId FindIState(StateId s) { if (superfinal_ == kNoStateId || s < superfinal_) { return s; } else { return s - 1; } }
// Maps from input state to output state.
StateId FindOState(StateId is) { auto os = is; if (!(superfinal_ == kNoStateId || is < superfinal_)) ++os; if (os >= nstates_) nstates_ = os + 1; return os; }
std::unique_ptr<const Fst<A>> fst_; C *mapper_; const bool own_mapper_; MapFinalAction final_action_; StateId superfinal_; StateId nstates_; };
} // namespace internal
// Maps an arc type A to an arc type B using Mapper function object
// C. This version is a delayed FST.
template <class A, class B, class C> class ArcMapFst : public ImplToFst<internal::ArcMapFstImpl<A, B, C>> { public: using Arc = B; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<B>; using State = typename Store::State; using Impl = internal::ArcMapFstImpl<A, B, C>;
friend class ArcIterator<ArcMapFst<A, B, C>>; friend class StateIterator<ArcMapFst<A, B, C>>;
explicit ArcMapFst(const Fst<A> &fst, const C &mapper = C(), const ArcMapFstOptions &opts = ArcMapFstOptions()) : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
ArcMapFst(const Fst<A> &fst, C *mapper, const ArcMapFstOptions &opts = ArcMapFstOptions()) : ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
// See Fst<>::Copy() for doc.
ArcMapFst(const ArcMapFst &fst, bool safe = false) : ImplToFst<Impl>(fst, safe) {}
// Get a copy of this ArcMapFst. See Fst<>::Copy() for further doc.
ArcMapFst *Copy(bool safe = false) const override { return new ArcMapFst(*this, safe); }
inline void InitStateIterator(StateIteratorData<B> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<B> *data) const override { GetMutableImpl()->InitArcIterator(s, data); }
protected: using ImplToFst<Impl>::GetImpl; using ImplToFst<Impl>::GetMutableImpl;
private: ArcMapFst &operator=(const ArcMapFst &) = delete; };
// Specialization for ArcMapFst.
//
// This may be derived from.
template <class A, class B, class C> class StateIterator<ArcMapFst<A, B, C>> : public StateIteratorBase<B> { public: using StateId = typename B::StateId;
explicit StateIterator(const ArcMapFst<A, B, C> &fst) : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0), superfinal_(impl_->final_action_ == MAP_REQUIRE_SUPERFINAL) { CheckSuperfinal(); }
bool Done() const final { return siter_.Done() && !superfinal_; }
StateId Value() const final { return s_; }
void Next() final { ++s_; if (!siter_.Done()) { siter_.Next(); CheckSuperfinal(); } else if (superfinal_) { superfinal_ = false; } }
void Reset() final { s_ = 0; siter_.Reset(); superfinal_ = impl_->final_action_ == MAP_REQUIRE_SUPERFINAL; CheckSuperfinal(); }
private: void CheckSuperfinal() { if (impl_->final_action_ != MAP_ALLOW_SUPERFINAL || superfinal_) return; if (!siter_.Done()) { const auto final_arc = (*impl_->mapper_)(A(0, 0, impl_->fst_->Final(s_), kNoStateId)); if (final_arc.ilabel != 0 || final_arc.olabel != 0) superfinal_ = true; } }
const internal::ArcMapFstImpl<A, B, C> *impl_; StateIterator<Fst<A>> siter_; StateId s_; bool superfinal_; // True if there is a superfinal state and not done.
};
// Specialization for ArcMapFst.
template <class A, class B, class C> class ArcIterator<ArcMapFst<A, B, C>> : public CacheArcIterator<ArcMapFst<A, B, C>> { public: using StateId = typename A::StateId;
ArcIterator(const ArcMapFst<A, B, C> &fst, StateId s) : CacheArcIterator<ArcMapFst<A, B, C>>(fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } };
template <class A, class B, class C> inline void ArcMapFst<A, B, C>::InitStateIterator( StateIteratorData<B> *data) const { data->base = std::make_unique<StateIterator<ArcMapFst<A, B, C>>>(*this); }
// CTAD deduction guides
// This allows constructing ArcMapFsts without specifying all the types.
template <class ArcMapper> ArcMapFst(const Fst<typename ArcMapper::FromArc> &, const ArcMapper &) -> ArcMapFst<typename ArcMapper::FromArc, typename ArcMapper::ToArc, ArcMapper>;
// As above, but using the ArcMapFst(..., ArcMapper *) constructor.
template <class ArcMapper> ArcMapFst(const Fst<typename ArcMapper::FromArc> &, ArcMapper *) -> ArcMapFst<typename ArcMapper::FromArc, typename ArcMapper::ToArc, ArcMapper>;
// Utility Mappers.
// Mapper that returns its input.
template <class A> class IdentityArcMapper { public: using FromArc = A; using ToArc = A;
constexpr ToArc operator()(const FromArc &arc) const { return arc; }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props; } };
// Mapper that converts all input symbols to epsilon.
template <class A> class InputEpsilonMapper { public: using FromArc = A; using ToArc = A;
constexpr ToArc operator()(const FromArc &arc) const { return ToArc(0, arc.olabel, arc.weight, arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return (props & kSetArcProperties) | kIEpsilons | kILabelSorted; } };
// Mapper that converts all output symbols to epsilon.
template <class A> class OutputEpsilonMapper { public: using FromArc = A; using ToArc = A;
constexpr ToArc operator()(const FromArc &arc) const { return ToArc(arc.ilabel, 0, arc.weight, arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return (props & kSetArcProperties) | kOEpsilons | kOLabelSorted; } };
// Mapper that returns its input with final states redirected to a single
// super-final state.
template <class A> class SuperFinalMapper { public: using FromArc = A; using ToArc = A; using Label = typename FromArc::Label; using Weight = typename FromArc::Weight;
// Arg allows setting super-final label.
constexpr explicit SuperFinalMapper(Label final_label = 0) : final_label_(final_label) {}
ToArc operator()(const FromArc &arc) const { // Super-final arc.
if (arc.nextstate == kNoStateId && arc.weight != Weight::Zero()) { return ToArc(final_label_, final_label_, arc.weight, kNoStateId); } else { return arc; } }
constexpr MapFinalAction FinalAction() const { return MAP_REQUIRE_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
uint64_t Properties(uint64_t props) const { if (final_label_ == 0) { return props & kAddSuperFinalProperties; } else { return props & kAddSuperFinalProperties & kILabelInvariantProperties & kOLabelInvariantProperties; } }
private: const Label final_label_; };
// Mapper that leaves labels and nextstate unchanged and constructs a new weight
// from the underlying value of the arc weight. If no weight converter is
// explictly specified, requires that there is a WeightConvert class
// specialization that converts the weights.
template <class A, class B, class C = WeightConvert<typename A::Weight, typename B::Weight>> class WeightConvertMapper { public: using FromArc = A; using ToArc = B; using Converter = C; using FromWeight = typename FromArc::Weight; using ToWeight = typename ToArc::Weight;
constexpr explicit WeightConvertMapper(const Converter &c = Converter()) : convert_weight_(c) {}
constexpr ToArc operator()(const FromArc &arc) const { return ToArc(arc.ilabel, arc.olabel, convert_weight_(arc.weight), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props; }
private: const Converter convert_weight_; };
// Non-precision-changing weight conversions; consider using more efficient
// Cast method instead.
using StdToLogMapper = WeightConvertMapper<StdArc, LogArc>;
using LogToStdMapper = WeightConvertMapper<LogArc, StdArc>;
// Precision-changing weight conversions.
using StdToLog64Mapper = WeightConvertMapper<StdArc, Log64Arc>;
using LogToLog64Mapper = WeightConvertMapper<LogArc, Log64Arc>;
using Log64ToStdMapper = WeightConvertMapper<Log64Arc, StdArc>;
using Log64ToLogMapper = WeightConvertMapper<Log64Arc, LogArc>;
// Mapper from A to GallicArc<A>.
template <class A, GallicType G = GALLIC_LEFT> class ToGallicMapper { public: using FromArc = A; using ToArc = GallicArc<A, G>;
using SW = StringWeight<typename A::Label, GallicStringType(G)>; using AW = typename FromArc::Weight; using GW = typename ToArc::Weight;
ToArc operator()(const FromArc &arc) const { // Super-final arc.
if (arc.nextstate == kNoStateId && arc.weight != AW::Zero()) { return ToArc(0, 0, GW(SW::One(), arc.weight), kNoStateId); // Super-non-final arc.
} else if (arc.nextstate == kNoStateId) { return ToArc(0, 0, GW::Zero(), kNoStateId); // Epsilon label.
} else if (arc.olabel == 0) { return ToArc(arc.ilabel, arc.ilabel, GW(SW::One(), arc.weight), arc.nextstate); // Regular label.
} else { return ToArc(arc.ilabel, arc.ilabel, GW(SW(arc.olabel), arc.weight), arc.nextstate); } }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
uint64_t Properties(uint64_t props) const { return ProjectProperties(props, true) & kWeightInvariantProperties; } };
// Mapper from GallicArc<A> to A.
template <class A, GallicType G = GALLIC_LEFT> class FromGallicMapper { public: using FromArc = GallicArc<A, G>; using ToArc = A;
using Label = typename ToArc::Label; using AW = typename ToArc::Weight; using GW = typename FromArc::Weight;
explicit FromGallicMapper(Label superfinal_label = 0) : superfinal_label_(superfinal_label), error_(false) {}
ToArc operator()(const FromArc &arc) const { // 'Super-non-final' arc.
if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) { return A(arc.ilabel, 0, AW::Zero(), kNoStateId); } Label l = kNoLabel; AW weight = AW::Zero(); if (!Extract(arc.weight, &weight, &l) || arc.ilabel != arc.olabel) { FSTERROR() << "FromGallicMapper: Unrepresentable weight: " << arc.weight << " for arc with ilabel = " << arc.ilabel << ", olabel = " << arc.olabel << ", nextstate = " << arc.nextstate; error_ = true; } if (arc.ilabel == 0 && l != 0 && arc.nextstate == kNoStateId) { return ToArc(superfinal_label_, l, weight, arc.nextstate); } else { return ToArc(arc.ilabel, l, weight, arc.nextstate); } }
constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
uint64_t Properties(uint64_t inprops) const { uint64_t outprops = inprops & kOLabelInvariantProperties & kWeightInvariantProperties & kAddSuperFinalProperties; if (error_) outprops |= kError; return outprops; }
private: template <GallicType GT> static bool Extract(const GallicWeight<Label, AW, GT> &gallic_weight, typename A::Weight *weight, typename A::Label *label) { using GW = StringWeight<Label, GallicStringType(GT)>; const GW &w1 = gallic_weight.Value1(); const AW &w2 = gallic_weight.Value2(); typename GW::Iterator iter1(w1); const Label l = w1.Size() == 1 ? iter1.Value() : 0; if (l == kStringInfinity || l == kStringBad || w1.Size() > 1) return false; *label = l; *weight = w2; return true; }
static bool Extract(const GallicWeight<Label, AW, GALLIC> &gallic_weight, typename A::Weight *weight, typename A::Label *label) { if (gallic_weight.Size() > 1) return false; if (gallic_weight.Size() == 0) { *label = 0; *weight = A::Weight::Zero(); return true; } return Extract<GALLIC_RESTRICT>(gallic_weight.Back(), weight, label); }
const Label superfinal_label_; mutable bool error_; };
// Mapper from GallicArc<A> to A.
template <class A, GallicType G = GALLIC_LEFT> class GallicToNewSymbolsMapper { public: using FromArc = GallicArc<A, G>; using ToArc = A;
using Label = typename ToArc::Label; using StateId = typename ToArc::StateId; using AW = typename ToArc::Weight; using GW = typename FromArc::Weight; using SW = StringWeight<Label, GallicStringType(G)>;
explicit GallicToNewSymbolsMapper(MutableFst<ToArc> *fst) : fst_(fst), lmax_(0), osymbols_(fst->OutputSymbols()), isymbols_(nullptr), error_(false) { fst_->DeleteStates(); state_ = fst_->AddState(); fst_->SetStart(state_); fst_->SetFinal(state_); if (osymbols_) { std::string name = osymbols_->Name() + "_from_gallic"; fst_->SetInputSymbols(new SymbolTable(name)); isymbols_ = fst_->MutableInputSymbols(); const int64_t zero = 0; isymbols_->AddSymbol(osymbols_->Find(zero), 0); } else { fst_->SetInputSymbols(nullptr); } }
ToArc operator()(const FromArc &arc) { // Super-non-final arc.
if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) { return ToArc(arc.ilabel, 0, AW::Zero(), kNoStateId); } SW w1 = arc.weight.Value1(); AW w2 = arc.weight.Value2(); Label l; if (w1.Size() == 0) { l = 0; } else if (auto [it, inserted] = map_.emplace(w1, kNoLabel); !inserted) { l = it->second; } else { l = ++lmax_; it->second = l; StringWeightIterator<SW> iter1(w1); StateId n; std::string s; for (size_t i = 0, p = state_; i < w1.Size(); ++i, iter1.Next(), p = n) { n = i == w1.Size() - 1 ? state_ : fst_->AddState(); fst_->AddArc(p, ToArc(i ? 0 : l, iter1.Value(), n)); if (isymbols_) { if (i) s = s + "_"; s = s + osymbols_->Find(iter1.Value()); } } if (isymbols_) isymbols_->AddSymbol(s, l); } if (l == kStringInfinity || l == kStringBad || arc.ilabel != arc.olabel) { FSTERROR() << "GallicToNewSymbolMapper: Unrepresentable weight: " << l; error_ = true; } return ToArc(arc.ilabel, l, w2, arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
uint64_t Properties(uint64_t inprops) const { uint64_t outprops = inprops & kOLabelInvariantProperties & kWeightInvariantProperties & kAddSuperFinalProperties; if (error_) outprops |= kError; return outprops; }
private: class StringKey { public: size_t operator()(const SW &x) const { return x.Hash(); } };
using Map = std::unordered_map<SW, Label, StringKey>;
MutableFst<ToArc> *fst_; Map map_; Label lmax_; StateId state_; const SymbolTable *osymbols_; SymbolTable *isymbols_; mutable bool error_; };
// TODO(kbg): Add common base class for those mappers which do nothing except
// mutate their weights.
// Mapper to add a constant to all weights.
template <class A> class PlusMapper { public: using FromArc = A; using ToArc = A; using Weight = typename FromArc::Weight;
constexpr explicit PlusMapper(Weight weight) : weight_(std::move(weight)) {}
ToArc operator()(const FromArc &arc) const { if (arc.weight == Weight::Zero()) return arc; return ToArc(arc.ilabel, arc.olabel, Plus(arc.weight, weight_), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props & kWeightInvariantProperties; }
private: const Weight weight_; };
// Mapper to (right) multiply a constant to all weights.
template <class A> class TimesMapper { public: using FromArc = A; using ToArc = A; using Weight = typename FromArc::Weight;
constexpr explicit TimesMapper(Weight weight) : weight_(std::move(weight)) {}
ToArc operator()(const FromArc &arc) const { if (arc.weight == Weight::Zero()) return arc; return ToArc(arc.ilabel, arc.olabel, Times(arc.weight, weight_), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props & kWeightInvariantProperties; }
private: const Weight weight_; };
// Mapper to take all weights to a constant power. The power argument is stored
// as a double, so if there is a floating-point power implementation for this
// weight type, it will take precedence. Otherwise, the power argument's 53 bits
// of integer precision will be implicitly converted to a size_t and the default
// power implementation (iterated multiplication) will be used instead.
template <class A> class PowerMapper { public: using FromArc = A; using ToArc = A; using Weight = typename FromArc::Weight;
explicit PowerMapper(double power) : power_(power) {}
ToArc operator()(const FromArc &arc) const { return ToArc(arc.ilabel, arc.olabel, Power(arc.weight, power_), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props & kWeightInvariantProperties; }
private: const double power_; };
// Mapper to reciprocate all non-Zero() weights.
template <class A> class InvertWeightMapper { public: using FromArc = A; using ToArc = A; using Weight = typename FromArc::Weight;
ToArc operator()(const FromArc &arc) const { if (arc.weight == Weight::Zero()) return arc; return ToArc(arc.ilabel, arc.olabel, Divide(Weight::One(), arc.weight), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props & kWeightInvariantProperties; } };
// Mapper to map all non-Zero() weights to One().
template <class A, class B = A> class RmWeightMapper { public: using FromArc = A; using ToArc = B; using FromWeight = typename FromArc::Weight; using ToWeight = typename ToArc::Weight;
ToArc operator()(const FromArc &arc) const { return ToArc( arc.ilabel, arc.olabel, arc.weight != FromWeight::Zero() ? ToWeight::One() : ToWeight::Zero(), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return (props & kWeightInvariantProperties) | kUnweighted; } };
// Mapper to quantize all weights.
template <class A, class B = A> class QuantizeMapper { public: using FromArc = A; using ToArc = B; using FromWeight = typename FromArc::Weight; using ToWeight = typename ToArc::Weight;
QuantizeMapper() : delta_(kDelta) {}
explicit QuantizeMapper(float d) : delta_(d) {}
ToArc operator()(const FromArc &arc) const { return ToArc(arc.ilabel, arc.olabel, arc.weight.Quantize(delta_), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props & kWeightInvariantProperties; }
private: const float delta_; };
// Mapper from A to B under the assumption:
//
// B::Weight = A::Weight::ReverseWeight
// B::Label == A::Label
// B::StateId == A::StateId
//
// The weight is reversed, while the label and nextstate are preserved.
template <class A, class B> class ReverseWeightMapper { public: using FromArc = A; using ToArc = B; static_assert(std::is_same_v<typename ToArc::Weight, typename FromArc::Weight::ReverseWeight>, "ToArc::Weight must be FromArc::Weight::ReverseWeight"); static_assert(std::is_same_v<typename ToArc::Label, typename FromArc::Label>, "ToArc::Label must be FromArc::Label"); static_assert( std::is_same_v<typename ToArc::StateId, typename FromArc::StateId>, "ToArc::StateId must be FromArc::StateId");
constexpr ToArc operator()(const FromArc &arc) const { return ToArc(arc.ilabel, arc.olabel, arc.weight.Reverse(), arc.nextstate); }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
constexpr uint64_t Properties(uint64_t props) const { return props; } };
} // namespace fst
#endif // FST_ARC_MAP_H_
|