You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1322 lines
40 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Class to map over/transform arcs e.g., change semirings or
// implement project/invert. Consider using when operation does
// not change the number of arcs (except possibly superfinal arcs).
#ifndef FST_ARC_MAP_H_
#define FST_ARC_MAP_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <fst/log.h>
#include <fst/arc.h>
#include <fst/cache.h>
#include <fst/expanded-fst.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/string-weight.h>
#include <fst/symbol-table.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <unordered_map>
namespace fst {
// Determines how final weights are mapped.
enum MapFinalAction {
// A final weight is mapped into a final weight. An error is raised if this
// is not possible.
MAP_NO_SUPERFINAL,
// A final weight is mapped to an arc to the superfinal state when the result
// cannot be represented as a final weight. The superfinal state will be
// added only if it is needed.
MAP_ALLOW_SUPERFINAL,
// A final weight is mapped to an arc to the superfinal state unless the
// result can be represented as a final weight of weight Zero(). The
// superfinal state is always added (if the input is not the empty FST).
MAP_REQUIRE_SUPERFINAL
};
// Determines how symbol tables are mapped.
enum MapSymbolsAction {
// Symbols should be cleared in the result by the map.
MAP_CLEAR_SYMBOLS,
// Symbols should be copied from the input FST by the map.
MAP_COPY_SYMBOLS,
// Symbols should not be modified in the result by the map itself.
// (They may set by the mapper).
MAP_NOOP_SYMBOLS
};
// The ArcMapper interfaces defines how arcs and final weights are mapped.
// This is useful for implementing operations that apply to each arc separately
// and do not change the number of arcs (except possibly superfinal arcs).
//
// template <class A, class B>
// class ArcMapper {
// public:
// using FromArc = A;
// using ToArc = B;
//
// // Maps an arc type FromArc to arc type ToArc.
// ToArc operator()(const FromArc &arc);
//
// // Specifies final action the mapper requires (see above).
// // The mapper will be passed final weights as arcs of the form
// // Arc(0, 0, weight, kNoStateId).
// MapFinalAction FinalAction() const;
//
// // Specifies input symbol table action the mapper requires (see above).
// MapSymbolsAction InputSymbolsAction() const;
//
// // Specifies output symbol table action the mapper requires (see above).
// MapSymbolsAction OutputSymbolsAction() const;
//
// // This specifies the known properties of an FST mapped by this mapper. It
// takes as argument the input FSTs's known properties.
// uint64_t Properties(uint64_t props) const;
// };
//
// The ArcMap functions and classes below will use the FinalAction()
// method of the mapper to determine how to treat final weights, e.g., whether
// to add a superfinal state. They will use the Properties() method to set the
// result FST properties.
//
// We include a various map versions below. One dimension of variation is
// whether the mapping mutates its input, writes to a new result FST, or is an
// on-the-fly FST. Another dimension is how we pass the mapper. We allow passing
// the mapper by pointer for cases that we need to change the state of the
// user's mapper. This is the case with the EncodeMapper, which is reused
// during decoding. We also include map versions that pass the mapper by value
// or const reference when this suffices.
// Maps an arc type A using a mapper function object C, passed
// by pointer. This version modifies its Fst input.
template <class A, class C>
void ArcMap(MutableFst<A> *fst, C *mapper) {
using FromArc = A;
using ToArc = A;
using Weight = typename FromArc::Weight;
if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
fst->SetInputSymbols(nullptr);
}
if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
fst->SetOutputSymbols(nullptr);
}
if (fst->Start() == kNoStateId) return;
const auto props = fst->Properties(kFstProperties, false);
const auto final_action = mapper->FinalAction();
auto superfinal = kNoStateId;
if (final_action == MAP_REQUIRE_SUPERFINAL) {
superfinal = fst->AddState();
fst->SetFinal(superfinal);
}
for (StateIterator<MutableFst<FromArc>> siter(*fst); !siter.Done();
siter.Next()) {
const auto state = siter.Value();
for (MutableArcIterator<MutableFst<FromArc>> aiter(fst, state);
!aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
aiter.SetValue((*mapper)(arc));
}
switch (final_action) {
case MAP_NO_SUPERFINAL:
default: {
const FromArc arc(0, 0, fst->Final(state), kNoStateId);
const auto final_arc = (*mapper)(arc);
if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc";
fst->SetProperties(kError, kError);
}
fst->SetFinal(state, final_arc.weight);
break;
}
case MAP_ALLOW_SUPERFINAL: {
if (state != superfinal) {
const FromArc arc(0, 0, fst->Final(state), kNoStateId);
auto final_arc = (*mapper)(arc);
if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
// Add a superfinal state if not already done.
if (superfinal == kNoStateId) {
superfinal = fst->AddState();
fst->SetFinal(superfinal);
}
final_arc.nextstate = superfinal;
fst->AddArc(state, std::move(final_arc));
fst->SetFinal(state, Weight::Zero());
} else {
fst->SetFinal(state, final_arc.weight);
}
}
break;
}
case MAP_REQUIRE_SUPERFINAL: {
if (state != superfinal) {
const FromArc arc(0, 0, fst->Final(state), kNoStateId);
const auto final_arc = (*mapper)(arc);
if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
final_arc.weight != Weight::Zero()) {
fst->AddArc(state, ToArc(final_arc.ilabel, final_arc.olabel,
final_arc.weight, superfinal));
}
fst->SetFinal(state, Weight::Zero());
}
break;
}
}
}
fst->SetProperties(mapper->Properties(props), kFstProperties);
}
// Maps an arc type A using a mapper function object C, passed by value. This
// version modifies its FST input.
template <class A, class C>
void ArcMap(MutableFst<A> *fst, C mapper) {
ArcMap(fst, &mapper);
}
// Maps an arc type A to an arc type B using mapper function object C,
// passed by pointer. This version writes the mapped input FST to an
// output MutableFst.
template <class A, class B, class C>
void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C *mapper) {
using FromArc = A;
using StateId = typename FromArc::StateId;
ofst->DeleteStates();
if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
ofst->SetInputSymbols(ifst.InputSymbols());
} else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
ofst->SetInputSymbols(nullptr);
}
if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
ofst->SetOutputSymbols(ifst.OutputSymbols());
} else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
ofst->SetOutputSymbols(nullptr);
}
const auto iprops = ifst.Properties(kCopyProperties, false);
if (ifst.Start() == kNoStateId) {
if (iprops & kError) ofst->SetProperties(kError, kError);
return;
}
const auto final_action = mapper->FinalAction();
if (std::optional<StateId> num_states = ifst.NumStatesIfKnown()) {
ofst->ReserveStates(*num_states +
(final_action == MAP_NO_SUPERFINAL ? 0 : 1));
}
// Adds all states.
for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
ofst->AddState();
}
StateId superfinal = kNoStateId;
if (final_action == MAP_REQUIRE_SUPERFINAL) {
superfinal = ofst->AddState();
ofst->SetFinal(superfinal);
}
for (StateIterator<Fst<A>> siter(ifst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
if (s == ifst.Start()) ofst->SetStart(s);
ofst->ReserveArcs(
s, ifst.NumArcs(s) + (final_action != MAP_NO_SUPERFINAL ? 1 : 0));
for (ArcIterator<Fst<A>> aiter(ifst, s); !aiter.Done(); aiter.Next()) {
ofst->AddArc(s, (*mapper)(aiter.Value()));
}
switch (final_action) {
case MAP_NO_SUPERFINAL:
default: {
B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc";
ofst->SetProperties(kError, kError);
}
ofst->SetFinal(s, final_arc.weight);
break;
}
case MAP_ALLOW_SUPERFINAL: {
B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
// Add a superfinal state if not already done.
if (superfinal == kNoStateId) {
superfinal = ofst->AddState();
ofst->SetFinal(superfinal);
}
final_arc.nextstate = superfinal;
ofst->AddArc(s, std::move(final_arc));
ofst->SetFinal(s, B::Weight::Zero());
} else {
ofst->SetFinal(s, final_arc.weight);
}
break;
}
case MAP_REQUIRE_SUPERFINAL: {
B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId));
if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
final_arc.weight != B::Weight::Zero()) {
ofst->AddArc(s, B(final_arc.ilabel, final_arc.olabel,
final_arc.weight, superfinal));
}
ofst->SetFinal(s, B::Weight::Zero());
break;
}
}
}
const auto oprops = ofst->Properties(kFstProperties, false);
ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
}
// Maps an arc type A to an arc type B using mapper function
// object C, passed by value. This version writes the mapped input
// Fst to an output MutableFst.
template <class A, class B, class C>
void ArcMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
ArcMap(ifst, ofst, &mapper);
}
struct ArcMapFstOptions : public CacheOptions {
// ArcMapFst default caching behaviour is to do no caching. Most mappers are
// cheap and therefore we save memory by not doing caching.
ArcMapFstOptions() : CacheOptions(true, 0) {}
explicit ArcMapFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
};
template <class A, class B, class C>
class ArcMapFst;
namespace internal {
// Implementation of delayed ArcMapFst.
template <class A, class B, class C>
class ArcMapFstImpl : public CacheImpl<B> {
public:
using Arc = B;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using FstImpl<B>::SetType;
using FstImpl<B>::SetProperties;
using FstImpl<B>::SetInputSymbols;
using FstImpl<B>::SetOutputSymbols;
using CacheImpl<B>::EmplaceArc;
using CacheImpl<B>::HasArcs;
using CacheImpl<B>::HasFinal;
using CacheImpl<B>::HasStart;
using CacheImpl<B>::PushArc;
using CacheImpl<B>::SetArcs;
using CacheImpl<B>::SetFinal;
using CacheImpl<B>::SetStart;
friend class StateIterator<ArcMapFst<A, B, C>>;
ArcMapFstImpl(const Fst<A> &fst, const C &mapper,
const ArcMapFstOptions &opts)
: CacheImpl<B>(opts),
fst_(fst.Copy()),
mapper_(new C(mapper)),
own_mapper_(true),
superfinal_(kNoStateId),
nstates_(0) {
Init();
}
ArcMapFstImpl(const Fst<A> &fst, C *mapper, const ArcMapFstOptions &opts)
: CacheImpl<B>(opts),
fst_(fst.Copy()),
mapper_(mapper),
own_mapper_(false),
superfinal_(kNoStateId),
nstates_(0) {
Init();
}
ArcMapFstImpl(const ArcMapFstImpl<A, B, C> &impl)
: CacheImpl<B>(impl),
fst_(impl.fst_->Copy(true)),
mapper_(new C(*impl.mapper_)),
own_mapper_(true),
superfinal_(kNoStateId),
nstates_(0) {
Init();
}
~ArcMapFstImpl() override {
if (own_mapper_) delete mapper_;
}
StateId Start() {
if (!HasStart()) SetStart(FindOState(fst_->Start()));
return CacheImpl<B>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
switch (final_action_) {
case MAP_NO_SUPERFINAL:
default: {
const auto final_arc =
(*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
FSTERROR() << "ArcMapFst: Non-zero arc labels for superfinal arc";
SetProperties(kError, kError);
}
SetFinal(s, final_arc.weight);
break;
}
case MAP_ALLOW_SUPERFINAL: {
if (s == superfinal_) {
SetFinal(s);
} else {
const auto final_arc =
(*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
if (final_arc.ilabel == 0 && final_arc.olabel == 0) {
SetFinal(s, final_arc.weight);
} else {
SetFinal(s, Weight::Zero());
}
}
break;
}
case MAP_REQUIRE_SUPERFINAL: {
SetFinal(s, s == superfinal_ ? Weight::One() : Weight::Zero());
break;
}
}
}
return CacheImpl<B>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<B>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<B>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<B>::NumOutputEpsilons(s);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) && (fst_->Properties(kError, false) ||
(mapper_->Properties(0) & kError))) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
if (!HasArcs(s)) Expand(s);
CacheImpl<B>::InitArcIterator(s, data);
}
void Expand(StateId s) {
// Add exiting arcs.
if (s == superfinal_) {
SetArcs(s);
return;
}
for (ArcIterator<Fst<A>> aiter(*fst_, FindIState(s)); !aiter.Done();
aiter.Next()) {
auto aarc = aiter.Value();
aarc.nextstate = FindOState(aarc.nextstate);
PushArc(s, (*mapper_)(aarc));
}
// Check for superfinal arcs.
if (!HasFinal(s) || Final(s) == Weight::Zero()) {
switch (final_action_) {
case MAP_NO_SUPERFINAL:
default:
break;
case MAP_ALLOW_SUPERFINAL: {
auto final_arc =
(*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
if (final_arc.ilabel != 0 || final_arc.olabel != 0) {
if (superfinal_ == kNoStateId) superfinal_ = nstates_++;
final_arc.nextstate = superfinal_;
PushArc(s, std::move(final_arc));
}
break;
}
case MAP_REQUIRE_SUPERFINAL: {
const auto final_arc =
(*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId));
if (final_arc.ilabel != 0 || final_arc.olabel != 0 ||
final_arc.weight != B::Weight::Zero()) {
EmplaceArc(s, final_arc.ilabel, final_arc.olabel, final_arc.weight,
superfinal_);
}
break;
}
}
}
SetArcs(s);
}
private:
void Init() {
SetType("map");
if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) {
SetInputSymbols(fst_->InputSymbols());
} else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
SetInputSymbols(nullptr);
}
if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) {
SetOutputSymbols(fst_->OutputSymbols());
} else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) {
SetOutputSymbols(nullptr);
}
if (fst_->Start() == kNoStateId) {
final_action_ = MAP_NO_SUPERFINAL;
SetProperties(kNullProperties);
} else {
final_action_ = mapper_->FinalAction();
uint64_t props = fst_->Properties(kCopyProperties, false);
SetProperties(mapper_->Properties(props));
if (final_action_ == MAP_REQUIRE_SUPERFINAL) superfinal_ = 0;
}
}
// Maps from output state to input state.
StateId FindIState(StateId s) {
if (superfinal_ == kNoStateId || s < superfinal_) {
return s;
} else {
return s - 1;
}
}
// Maps from input state to output state.
StateId FindOState(StateId is) {
auto os = is;
if (!(superfinal_ == kNoStateId || is < superfinal_)) ++os;
if (os >= nstates_) nstates_ = os + 1;
return os;
}
std::unique_ptr<const Fst<A>> fst_;
C *mapper_;
const bool own_mapper_;
MapFinalAction final_action_;
StateId superfinal_;
StateId nstates_;
};
} // namespace internal
// Maps an arc type A to an arc type B using Mapper function object
// C. This version is a delayed FST.
template <class A, class B, class C>
class ArcMapFst : public ImplToFst<internal::ArcMapFstImpl<A, B, C>> {
public:
using Arc = B;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<B>;
using State = typename Store::State;
using Impl = internal::ArcMapFstImpl<A, B, C>;
friend class ArcIterator<ArcMapFst<A, B, C>>;
friend class StateIterator<ArcMapFst<A, B, C>>;
explicit ArcMapFst(const Fst<A> &fst, const C &mapper = C(),
const ArcMapFstOptions &opts = ArcMapFstOptions())
: ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
ArcMapFst(const Fst<A> &fst, C *mapper,
const ArcMapFstOptions &opts = ArcMapFstOptions())
: ImplToFst<Impl>(std::make_shared<Impl>(fst, mapper, opts)) {}
// See Fst<>::Copy() for doc.
ArcMapFst(const ArcMapFst &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this ArcMapFst. See Fst<>::Copy() for further doc.
ArcMapFst *Copy(bool safe = false) const override {
return new ArcMapFst(*this, safe);
}
inline void InitStateIterator(StateIteratorData<B> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<B> *data) const override {
GetMutableImpl()->InitArcIterator(s, data);
}
protected:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
private:
ArcMapFst &operator=(const ArcMapFst &) = delete;
};
// Specialization for ArcMapFst.
//
// This may be derived from.
template <class A, class B, class C>
class StateIterator<ArcMapFst<A, B, C>> : public StateIteratorBase<B> {
public:
using StateId = typename B::StateId;
explicit StateIterator(const ArcMapFst<A, B, C> &fst)
: impl_(fst.GetImpl()),
siter_(*impl_->fst_),
s_(0),
superfinal_(impl_->final_action_ == MAP_REQUIRE_SUPERFINAL) {
CheckSuperfinal();
}
bool Done() const final { return siter_.Done() && !superfinal_; }
StateId Value() const final { return s_; }
void Next() final {
++s_;
if (!siter_.Done()) {
siter_.Next();
CheckSuperfinal();
} else if (superfinal_) {
superfinal_ = false;
}
}
void Reset() final {
s_ = 0;
siter_.Reset();
superfinal_ = impl_->final_action_ == MAP_REQUIRE_SUPERFINAL;
CheckSuperfinal();
}
private:
void CheckSuperfinal() {
if (impl_->final_action_ != MAP_ALLOW_SUPERFINAL || superfinal_) return;
if (!siter_.Done()) {
const auto final_arc =
(*impl_->mapper_)(A(0, 0, impl_->fst_->Final(s_), kNoStateId));
if (final_arc.ilabel != 0 || final_arc.olabel != 0) superfinal_ = true;
}
}
const internal::ArcMapFstImpl<A, B, C> *impl_;
StateIterator<Fst<A>> siter_;
StateId s_;
bool superfinal_; // True if there is a superfinal state and not done.
};
// Specialization for ArcMapFst.
template <class A, class B, class C>
class ArcIterator<ArcMapFst<A, B, C>>
: public CacheArcIterator<ArcMapFst<A, B, C>> {
public:
using StateId = typename A::StateId;
ArcIterator(const ArcMapFst<A, B, C> &fst, StateId s)
: CacheArcIterator<ArcMapFst<A, B, C>>(fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class A, class B, class C>
inline void ArcMapFst<A, B, C>::InitStateIterator(
StateIteratorData<B> *data) const {
data->base = std::make_unique<StateIterator<ArcMapFst<A, B, C>>>(*this);
}
// CTAD deduction guides
// This allows constructing ArcMapFsts without specifying all the types.
template <class ArcMapper>
ArcMapFst(const Fst<typename ArcMapper::FromArc> &, const ArcMapper &)
-> ArcMapFst<typename ArcMapper::FromArc, typename ArcMapper::ToArc,
ArcMapper>;
// As above, but using the ArcMapFst(..., ArcMapper *) constructor.
template <class ArcMapper>
ArcMapFst(const Fst<typename ArcMapper::FromArc> &, ArcMapper *)
-> ArcMapFst<typename ArcMapper::FromArc, typename ArcMapper::ToArc,
ArcMapper>;
// Utility Mappers.
// Mapper that returns its input.
template <class A>
class IdentityArcMapper {
public:
using FromArc = A;
using ToArc = A;
constexpr ToArc operator()(const FromArc &arc) const { return arc; }
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const { return props; }
};
// Mapper that converts all input symbols to epsilon.
template <class A>
class InputEpsilonMapper {
public:
using FromArc = A;
using ToArc = A;
constexpr ToArc operator()(const FromArc &arc) const {
return ToArc(0, arc.olabel, arc.weight, arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_CLEAR_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return (props & kSetArcProperties) | kIEpsilons | kILabelSorted;
}
};
// Mapper that converts all output symbols to epsilon.
template <class A>
class OutputEpsilonMapper {
public:
using FromArc = A;
using ToArc = A;
constexpr ToArc operator()(const FromArc &arc) const {
return ToArc(arc.ilabel, 0, arc.weight, arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_CLEAR_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return (props & kSetArcProperties) | kOEpsilons | kOLabelSorted;
}
};
// Mapper that returns its input with final states redirected to a single
// super-final state.
template <class A>
class SuperFinalMapper {
public:
using FromArc = A;
using ToArc = A;
using Label = typename FromArc::Label;
using Weight = typename FromArc::Weight;
// Arg allows setting super-final label.
constexpr explicit SuperFinalMapper(Label final_label = 0)
: final_label_(final_label) {}
ToArc operator()(const FromArc &arc) const {
// Super-final arc.
if (arc.nextstate == kNoStateId && arc.weight != Weight::Zero()) {
return ToArc(final_label_, final_label_, arc.weight, kNoStateId);
} else {
return arc;
}
}
constexpr MapFinalAction FinalAction() const {
return MAP_REQUIRE_SUPERFINAL;
}
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
uint64_t Properties(uint64_t props) const {
if (final_label_ == 0) {
return props & kAddSuperFinalProperties;
} else {
return props & kAddSuperFinalProperties & kILabelInvariantProperties &
kOLabelInvariantProperties;
}
}
private:
const Label final_label_;
};
// Mapper that leaves labels and nextstate unchanged and constructs a new weight
// from the underlying value of the arc weight. If no weight converter is
// explictly specified, requires that there is a WeightConvert class
// specialization that converts the weights.
template <class A, class B,
class C = WeightConvert<typename A::Weight, typename B::Weight>>
class WeightConvertMapper {
public:
using FromArc = A;
using ToArc = B;
using Converter = C;
using FromWeight = typename FromArc::Weight;
using ToWeight = typename ToArc::Weight;
constexpr explicit WeightConvertMapper(const Converter &c = Converter())
: convert_weight_(c) {}
constexpr ToArc operator()(const FromArc &arc) const {
return ToArc(arc.ilabel, arc.olabel, convert_weight_(arc.weight),
arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const { return props; }
private:
const Converter convert_weight_;
};
// Non-precision-changing weight conversions; consider using more efficient
// Cast method instead.
using StdToLogMapper = WeightConvertMapper<StdArc, LogArc>;
using LogToStdMapper = WeightConvertMapper<LogArc, StdArc>;
// Precision-changing weight conversions.
using StdToLog64Mapper = WeightConvertMapper<StdArc, Log64Arc>;
using LogToLog64Mapper = WeightConvertMapper<LogArc, Log64Arc>;
using Log64ToStdMapper = WeightConvertMapper<Log64Arc, StdArc>;
using Log64ToLogMapper = WeightConvertMapper<Log64Arc, LogArc>;
// Mapper from A to GallicArc<A>.
template <class A, GallicType G = GALLIC_LEFT>
class ToGallicMapper {
public:
using FromArc = A;
using ToArc = GallicArc<A, G>;
using SW = StringWeight<typename A::Label, GallicStringType(G)>;
using AW = typename FromArc::Weight;
using GW = typename ToArc::Weight;
ToArc operator()(const FromArc &arc) const {
// Super-final arc.
if (arc.nextstate == kNoStateId && arc.weight != AW::Zero()) {
return ToArc(0, 0, GW(SW::One(), arc.weight), kNoStateId);
// Super-non-final arc.
} else if (arc.nextstate == kNoStateId) {
return ToArc(0, 0, GW::Zero(), kNoStateId);
// Epsilon label.
} else if (arc.olabel == 0) {
return ToArc(arc.ilabel, arc.ilabel, GW(SW::One(), arc.weight),
arc.nextstate);
// Regular label.
} else {
return ToArc(arc.ilabel, arc.ilabel, GW(SW(arc.olabel), arc.weight),
arc.nextstate);
}
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_CLEAR_SYMBOLS;
}
uint64_t Properties(uint64_t props) const {
return ProjectProperties(props, true) & kWeightInvariantProperties;
}
};
// Mapper from GallicArc<A> to A.
template <class A, GallicType G = GALLIC_LEFT>
class FromGallicMapper {
public:
using FromArc = GallicArc<A, G>;
using ToArc = A;
using Label = typename ToArc::Label;
using AW = typename ToArc::Weight;
using GW = typename FromArc::Weight;
explicit FromGallicMapper(Label superfinal_label = 0)
: superfinal_label_(superfinal_label), error_(false) {}
ToArc operator()(const FromArc &arc) const {
// 'Super-non-final' arc.
if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) {
return A(arc.ilabel, 0, AW::Zero(), kNoStateId);
}
Label l = kNoLabel;
AW weight = AW::Zero();
if (!Extract(arc.weight, &weight, &l) || arc.ilabel != arc.olabel) {
FSTERROR() << "FromGallicMapper: Unrepresentable weight: " << arc.weight
<< " for arc with ilabel = " << arc.ilabel
<< ", olabel = " << arc.olabel
<< ", nextstate = " << arc.nextstate;
error_ = true;
}
if (arc.ilabel == 0 && l != 0 && arc.nextstate == kNoStateId) {
return ToArc(superfinal_label_, l, weight, arc.nextstate);
} else {
return ToArc(arc.ilabel, l, weight, arc.nextstate);
}
}
constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_CLEAR_SYMBOLS;
}
uint64_t Properties(uint64_t inprops) const {
uint64_t outprops = inprops & kOLabelInvariantProperties &
kWeightInvariantProperties & kAddSuperFinalProperties;
if (error_) outprops |= kError;
return outprops;
}
private:
template <GallicType GT>
static bool Extract(const GallicWeight<Label, AW, GT> &gallic_weight,
typename A::Weight *weight, typename A::Label *label) {
using GW = StringWeight<Label, GallicStringType(GT)>;
const GW &w1 = gallic_weight.Value1();
const AW &w2 = gallic_weight.Value2();
typename GW::Iterator iter1(w1);
const Label l = w1.Size() == 1 ? iter1.Value() : 0;
if (l == kStringInfinity || l == kStringBad || w1.Size() > 1) return false;
*label = l;
*weight = w2;
return true;
}
static bool Extract(const GallicWeight<Label, AW, GALLIC> &gallic_weight,
typename A::Weight *weight, typename A::Label *label) {
if (gallic_weight.Size() > 1) return false;
if (gallic_weight.Size() == 0) {
*label = 0;
*weight = A::Weight::Zero();
return true;
}
return Extract<GALLIC_RESTRICT>(gallic_weight.Back(), weight, label);
}
const Label superfinal_label_;
mutable bool error_;
};
// Mapper from GallicArc<A> to A.
template <class A, GallicType G = GALLIC_LEFT>
class GallicToNewSymbolsMapper {
public:
using FromArc = GallicArc<A, G>;
using ToArc = A;
using Label = typename ToArc::Label;
using StateId = typename ToArc::StateId;
using AW = typename ToArc::Weight;
using GW = typename FromArc::Weight;
using SW = StringWeight<Label, GallicStringType(G)>;
explicit GallicToNewSymbolsMapper(MutableFst<ToArc> *fst)
: fst_(fst),
lmax_(0),
osymbols_(fst->OutputSymbols()),
isymbols_(nullptr),
error_(false) {
fst_->DeleteStates();
state_ = fst_->AddState();
fst_->SetStart(state_);
fst_->SetFinal(state_);
if (osymbols_) {
std::string name = osymbols_->Name() + "_from_gallic";
fst_->SetInputSymbols(new SymbolTable(name));
isymbols_ = fst_->MutableInputSymbols();
const int64_t zero = 0;
isymbols_->AddSymbol(osymbols_->Find(zero), 0);
} else {
fst_->SetInputSymbols(nullptr);
}
}
ToArc operator()(const FromArc &arc) {
// Super-non-final arc.
if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) {
return ToArc(arc.ilabel, 0, AW::Zero(), kNoStateId);
}
SW w1 = arc.weight.Value1();
AW w2 = arc.weight.Value2();
Label l;
if (w1.Size() == 0) {
l = 0;
} else if (auto [it, inserted] = map_.emplace(w1, kNoLabel); !inserted) {
l = it->second;
} else {
l = ++lmax_;
it->second = l;
StringWeightIterator<SW> iter1(w1);
StateId n;
std::string s;
for (size_t i = 0, p = state_; i < w1.Size(); ++i, iter1.Next(), p = n) {
n = i == w1.Size() - 1 ? state_ : fst_->AddState();
fst_->AddArc(p, ToArc(i ? 0 : l, iter1.Value(), n));
if (isymbols_) {
if (i) s = s + "_";
s = s + osymbols_->Find(iter1.Value());
}
}
if (isymbols_) isymbols_->AddSymbol(s, l);
}
if (l == kStringInfinity || l == kStringBad || arc.ilabel != arc.olabel) {
FSTERROR() << "GallicToNewSymbolMapper: Unrepresentable weight: " << l;
error_ = true;
}
return ToArc(arc.ilabel, l, w2, arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_CLEAR_SYMBOLS;
}
uint64_t Properties(uint64_t inprops) const {
uint64_t outprops = inprops & kOLabelInvariantProperties &
kWeightInvariantProperties & kAddSuperFinalProperties;
if (error_) outprops |= kError;
return outprops;
}
private:
class StringKey {
public:
size_t operator()(const SW &x) const { return x.Hash(); }
};
using Map = std::unordered_map<SW, Label, StringKey>;
MutableFst<ToArc> *fst_;
Map map_;
Label lmax_;
StateId state_;
const SymbolTable *osymbols_;
SymbolTable *isymbols_;
mutable bool error_;
};
// TODO(kbg): Add common base class for those mappers which do nothing except
// mutate their weights.
// Mapper to add a constant to all weights.
template <class A>
class PlusMapper {
public:
using FromArc = A;
using ToArc = A;
using Weight = typename FromArc::Weight;
constexpr explicit PlusMapper(Weight weight) : weight_(std::move(weight)) {}
ToArc operator()(const FromArc &arc) const {
if (arc.weight == Weight::Zero()) return arc;
return ToArc(arc.ilabel, arc.olabel, Plus(arc.weight, weight_),
arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return props & kWeightInvariantProperties;
}
private:
const Weight weight_;
};
// Mapper to (right) multiply a constant to all weights.
template <class A>
class TimesMapper {
public:
using FromArc = A;
using ToArc = A;
using Weight = typename FromArc::Weight;
constexpr explicit TimesMapper(Weight weight) : weight_(std::move(weight)) {}
ToArc operator()(const FromArc &arc) const {
if (arc.weight == Weight::Zero()) return arc;
return ToArc(arc.ilabel, arc.olabel, Times(arc.weight, weight_),
arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return props & kWeightInvariantProperties;
}
private:
const Weight weight_;
};
// Mapper to take all weights to a constant power. The power argument is stored
// as a double, so if there is a floating-point power implementation for this
// weight type, it will take precedence. Otherwise, the power argument's 53 bits
// of integer precision will be implicitly converted to a size_t and the default
// power implementation (iterated multiplication) will be used instead.
template <class A>
class PowerMapper {
public:
using FromArc = A;
using ToArc = A;
using Weight = typename FromArc::Weight;
explicit PowerMapper(double power) : power_(power) {}
ToArc operator()(const FromArc &arc) const {
return ToArc(arc.ilabel, arc.olabel, Power(arc.weight, power_),
arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return props & kWeightInvariantProperties;
}
private:
const double power_;
};
// Mapper to reciprocate all non-Zero() weights.
template <class A>
class InvertWeightMapper {
public:
using FromArc = A;
using ToArc = A;
using Weight = typename FromArc::Weight;
ToArc operator()(const FromArc &arc) const {
if (arc.weight == Weight::Zero()) return arc;
return ToArc(arc.ilabel, arc.olabel, Divide(Weight::One(), arc.weight),
arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return props & kWeightInvariantProperties;
}
};
// Mapper to map all non-Zero() weights to One().
template <class A, class B = A>
class RmWeightMapper {
public:
using FromArc = A;
using ToArc = B;
using FromWeight = typename FromArc::Weight;
using ToWeight = typename ToArc::Weight;
ToArc operator()(const FromArc &arc) const {
return ToArc(
arc.ilabel, arc.olabel,
arc.weight != FromWeight::Zero() ? ToWeight::One() : ToWeight::Zero(),
arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return (props & kWeightInvariantProperties) | kUnweighted;
}
};
// Mapper to quantize all weights.
template <class A, class B = A>
class QuantizeMapper {
public:
using FromArc = A;
using ToArc = B;
using FromWeight = typename FromArc::Weight;
using ToWeight = typename ToArc::Weight;
QuantizeMapper() : delta_(kDelta) {}
explicit QuantizeMapper(float d) : delta_(d) {}
ToArc operator()(const FromArc &arc) const {
return ToArc(arc.ilabel, arc.olabel, arc.weight.Quantize(delta_),
arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const {
return props & kWeightInvariantProperties;
}
private:
const float delta_;
};
// Mapper from A to B under the assumption:
//
// B::Weight = A::Weight::ReverseWeight
// B::Label == A::Label
// B::StateId == A::StateId
//
// The weight is reversed, while the label and nextstate are preserved.
template <class A, class B>
class ReverseWeightMapper {
public:
using FromArc = A;
using ToArc = B;
static_assert(std::is_same_v<typename ToArc::Weight,
typename FromArc::Weight::ReverseWeight>,
"ToArc::Weight must be FromArc::Weight::ReverseWeight");
static_assert(std::is_same_v<typename ToArc::Label, typename FromArc::Label>,
"ToArc::Label must be FromArc::Label");
static_assert(
std::is_same_v<typename ToArc::StateId, typename FromArc::StateId>,
"ToArc::StateId must be FromArc::StateId");
constexpr ToArc operator()(const FromArc &arc) const {
return ToArc(arc.ilabel, arc.olabel, arc.weight.Reverse(), arc.nextstate);
}
constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
constexpr MapSymbolsAction InputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr MapSymbolsAction OutputSymbolsAction() const {
return MAP_COPY_SYMBOLS;
}
constexpr uint64_t Properties(uint64_t props) const { return props; }
};
} // namespace fst
#endif // FST_ARC_MAP_H_