|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Classes to factor weights in an FST.
#ifndef FST_FACTOR_WEIGHT_H_
#define FST_FACTOR_WEIGHT_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/cache.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/properties.h>
#include <fst/string-weight.h>
#include <fst/union-weight.h>
#include <fst/weight.h>
#include <unordered_map>
namespace fst {
inline constexpr uint8_t kFactorFinalWeights = 0x01; inline constexpr uint8_t kFactorArcWeights = 0x02;
template <class Arc> struct FactorWeightOptions : CacheOptions { using Label = typename Arc::Label;
float delta; uint8_t mode; // Factor arc weights and/or final weights.
Label final_ilabel; // Input label of arc when factoring final weights.
Label final_olabel; // Output label of arc when factoring final weights.
bool increment_final_ilabel; // When factoring final w' results in > 1 arcs
bool increment_final_olabel; // at state, increment labels to make distinct?
explicit FactorWeightOptions(const CacheOptions &opts, float delta = kDelta, uint8_t mode = kFactorArcWeights | kFactorFinalWeights, Label final_ilabel = 0, Label final_olabel = 0, bool increment_final_ilabel = false, bool increment_final_olabel = false) : CacheOptions(opts), delta(delta), mode(mode), final_ilabel(final_ilabel), final_olabel(final_olabel), increment_final_ilabel(increment_final_ilabel), increment_final_olabel(increment_final_olabel) {}
explicit FactorWeightOptions(float delta = kDelta, uint8_t mode = kFactorArcWeights | kFactorFinalWeights, Label final_ilabel = 0, Label final_olabel = 0, bool increment_final_ilabel = false, bool increment_final_olabel = false) : delta(delta), mode(mode), final_ilabel(final_ilabel), final_olabel(final_olabel), increment_final_ilabel(increment_final_ilabel), increment_final_olabel(increment_final_olabel) {} };
// A factor iterator takes as argument a weight w and returns a sequence of
// pairs of weights (xi, yi) such that the sum of the products xi times yi is
// equal to w. If w is fully factored, the iterator should return nothing.
//
// template <class W>
// class FactorIterator {
// public:
// explicit FactorIterator(W w);
//
// bool Done() const;
//
// void Next();
//
// std::pair<W, W> Value() const;
//
// void Reset();
// }
// Factors trivially.
template <class W> class IdentityFactor { public: explicit IdentityFactor(const W &weight) {}
bool Done() const { return true; }
void Next() {}
std::pair<W, W> Value() const { return std::make_pair(W::One(), W::One()); }
void Reset() {} };
// Factor the Fst to unfold it as needed so that every two paths leading to the
// same state have the same weight. Requires applying only to arc weights
// (FactorWeightOptions::mode == kFactorArcWeights).
template <class W> class OneFactor { public: explicit OneFactor(const W &w) : weight_(w), done_(w == W::One()) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
std::pair<W, W> Value() const { return std::make_pair(W::One(), weight_); }
void Reset() { done_ = weight_ == W::One(); }
private: W weight_; bool done_; };
// Factors a StringWeight w as 'ab' where 'a' is a label.
template <typename Label, StringType S = STRING_LEFT> class StringFactor { public: explicit StringFactor(const StringWeight<Label, S> &weight) : weight_(weight), done_(weight.Size() <= 1) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
std::pair<StringWeight<Label, S>, StringWeight<Label, S>> Value() const { using Weight = StringWeight<Label, S>; typename Weight::Iterator siter(weight_); Weight w1(siter.Value()); Weight w2; for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value()); return std::make_pair(w1, w2); }
void Reset() { done_ = weight_.Size() <= 1; }
private: const StringWeight<Label, S> weight_; bool done_; };
// Factor a GallicWeight using StringFactor.
template <class Label, class W, GallicType G = GALLIC_LEFT> class GallicFactor { public: using GW = GallicWeight<Label, W, G>;
explicit GallicFactor(const GW &weight) : weight_(weight), done_(weight.Value1().Size() <= 1) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
std::pair<GW, GW> Value() const { StringFactor<Label, GallicStringType(G)> siter(weight_.Value1()); GW w1(siter.Value().first, weight_.Value2()); GW w2(siter.Value().second, W::One()); return std::make_pair(w1, w2); }
void Reset() { done_ = weight_.Value1().Size() <= 1; }
private: const GW weight_; bool done_; };
// Specialization for the (general) GALLIC type GallicWeight.
template <class Label, class W> class GallicFactor<Label, W, GALLIC> { public: using GW = GallicWeight<Label, W, GALLIC>; using GRW = GallicWeight<Label, W, GALLIC_RESTRICT>;
explicit GallicFactor(const GW &weight) : iter_(weight), done_(weight.Size() == 0 || (weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
bool Done() const { return done_ || iter_.Done(); }
void Next() { iter_.Next(); }
void Reset() { iter_.Reset(); }
std::pair<GW, GW> Value() const { const auto weight = iter_.Value(); StringFactor<Label, GallicStringType(GALLIC_RESTRICT)> siter( weight.Value1()); GRW w1(siter.Value().first, weight.Value2()); GRW w2(siter.Value().second, W::One()); return std::make_pair(GW(w1), GW(w2)); }
private: UnionWeightIterator<GRW, GallicUnionWeightOptions<Label, W>> iter_; bool done_; };
namespace internal {
// Implementation class for FactorWeight
template <class Arc, class FactorIterator> class FactorWeightFstImpl : public CacheImpl<Arc> { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using FstImpl<Arc>::SetType; using FstImpl<Arc>::SetProperties; using FstImpl<Arc>::SetInputSymbols; using FstImpl<Arc>::SetOutputSymbols;
using CacheBaseImpl<CacheState<Arc>>::EmplaceArc; using CacheBaseImpl<CacheState<Arc>>::HasArcs; using CacheBaseImpl<CacheState<Arc>>::HasFinal; using CacheBaseImpl<CacheState<Arc>>::HasStart; using CacheBaseImpl<CacheState<Arc>>::SetArcs; using CacheBaseImpl<CacheState<Arc>>::SetFinal; using CacheBaseImpl<CacheState<Arc>>::SetStart;
struct Element { Element() = default;
Element(StateId s, Weight weight_) : state(s), weight(std::move(weight_)) {}
StateId state; // Input state ID.
Weight weight; // Residual weight.
};
FactorWeightFstImpl(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts) : CacheImpl<Arc>(opts), fst_(fst.Copy()), delta_(opts.delta), mode_(opts.mode), final_ilabel_(opts.final_ilabel), final_olabel_(opts.final_olabel), increment_final_ilabel_(opts.increment_final_ilabel), increment_final_olabel_(opts.increment_final_olabel) { SetType("factor_weight"); const auto props = fst.Properties(kFstProperties, false); SetProperties(FactorWeightProperties(props), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); if (mode_ == 0) { LOG(WARNING) << "FactorWeightFst: Factor mode is set to 0; " << "factoring neither arc weights nor final weights"; } }
FactorWeightFstImpl(const FactorWeightFstImpl<Arc, FactorIterator> &impl) : CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)), delta_(impl.delta_), mode_(impl.mode_), final_ilabel_(impl.final_ilabel_), final_olabel_(impl.final_olabel_), increment_final_ilabel_(impl.increment_final_ilabel_), increment_final_olabel_(impl.increment_final_olabel_) { SetType("factor_weight"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); }
StateId Start() { if (!HasStart()) { const auto s = fst_->Start(); if (s == kNoStateId) return kNoStateId; SetStart(FindState(Element(fst_->Start(), Weight::One()))); } return CacheImpl<Arc>::Start(); }
Weight Final(StateId s) { if (!HasFinal(s)) { const auto &element = elements_[s]; const auto weight = element.state == kNoStateId ? element.weight : Times(element.weight, fst_->Final(element.state)); FactorIterator siter(weight); if (!(mode_ & kFactorFinalWeights) || siter.Done()) { SetFinal(s, weight); } else { SetFinal(s, Weight::Zero()); } } return CacheImpl<Arc>::Final(s); }
size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<Arc>::NumArcs(s); }
size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<Arc>::NumInputEpsilons(s); }
size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<Arc>::NumOutputEpsilons(s); }
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override { if ((mask & kError) && fst_->Properties(kError, false)) { SetProperties(kError, kError); } return FstImpl<Arc>::Properties(mask); }
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) { if (!HasArcs(s)) Expand(s); CacheImpl<Arc>::InitArcIterator(s, data); }
// Finds state corresponding to an element, creating new state if element not
// found.
StateId FindState(const Element &element) { if (!(mode_ & kFactorArcWeights) && element.weight == Weight::One() && element.state != kNoStateId) { while (unfactored_.size() <= element.state) unfactored_.push_back(kNoStateId); if (unfactored_[element.state] == kNoStateId) { unfactored_[element.state] = elements_.size(); elements_.push_back(element); } return unfactored_[element.state]; } else { const auto insert_result = element_map_.emplace(element, elements_.size()); if (insert_result.second) { elements_.push_back(element); } return insert_result.first->second; } }
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) { const auto element = elements_[s]; if (element.state != kNoStateId) { for (ArcIterator<Fst<Arc>> ait(*fst_, element.state); !ait.Done(); ait.Next()) { const auto &arc = ait.Value(); auto weight = Times(element.weight, arc.weight); FactorIterator fiter(weight); if (!(mode_ & kFactorArcWeights) || fiter.Done()) { const auto dest = FindState(Element(arc.nextstate, Weight::One())); EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest); } else { for (; !fiter.Done(); fiter.Next()) { auto pair = fiter.Value(); const auto dest = FindState(Element(arc.nextstate, pair.second.Quantize(delta_))); EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest); } } } } if ((mode_ & kFactorFinalWeights) && ((element.state == kNoStateId) || (fst_->Final(element.state) != Weight::Zero()))) { const auto weight = element.state == kNoStateId ? element.weight : Times(element.weight, fst_->Final(element.state)); auto ilabel = final_ilabel_; auto olabel = final_olabel_; for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) { auto pair = fiter.Value(); const auto dest = FindState(Element(kNoStateId, pair.second.Quantize(delta_))); EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest); if (increment_final_ilabel_) ++ilabel; if (increment_final_olabel_) ++olabel; } } SetArcs(s); }
private: // Equality function for Elements, assume weights have been quantized.
class ElementEqual { public: bool operator()(const Element &x, const Element &y) const { return x.state == y.state && x.weight == y.weight; } };
// Hash function for Elements to Fst states.
class ElementKey { public: size_t operator()(const Element &x) const { static constexpr auto prime = 7853; return static_cast<size_t>(x.state * prime + x.weight.Hash()); } };
using ElementMap = std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
std::unique_ptr<const Fst<Arc>> fst_; float delta_; uint8_t mode_; // Factoring arc and/or final weights.
Label final_ilabel_; // ilabel of arc created when factoring final weights.
Label final_olabel_; // olabel of arc created when factoring final weights.
bool increment_final_ilabel_; // When factoring final weights results in
bool increment_final_olabel_; // mutiple arcs, increment labels?
std::vector<Element> elements_; // Mapping from FST state to Element.
ElementMap element_map_; // Mapping from Element to FST state.
// Mapping between old/new StateId for states that do not need to be factored
// when mode_ is 0 or kFactorFinalWeights.
std::vector<StateId> unfactored_; };
} // namespace internal
// FactorWeightFst takes as template parameter a FactorIterator as defined
// above. The result of weight factoring is a transducer equivalent to the
// input whose path weights have been factored according to the FactorIterator.
// States and transitions will be added as necessary. The algorithm is a
// generalization to arbitrary weights of the second step of the input
// epsilon-normalization algorithm.
//
// This class attaches interface to implementation and handles reference
// counting, delegating most methods to ImplToFst.
template <class A, class FactorIterator> class FactorWeightFst : public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> { public: using Arc = A; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>; using State = typename Store::State; using Impl = internal::FactorWeightFstImpl<Arc, FactorIterator>;
friend class ArcIterator<FactorWeightFst<Arc, FactorIterator>>; friend class StateIterator<FactorWeightFst<Arc, FactorIterator>>;
explicit FactorWeightFst(const Fst<Arc> &fst) : ImplToFst<Impl>( std::make_shared<Impl>(fst, FactorWeightOptions<Arc>())) {}
FactorWeightFst(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts) : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
// See Fst<>::Copy() for doc.
FactorWeightFst(const FactorWeightFst &fst, bool copy) : ImplToFst<Impl>(fst, copy) {}
// Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
FactorWeightFst *Copy(bool copy = false) const override { return new FactorWeightFst(*this, copy); }
inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override { GetMutableImpl()->InitArcIterator(s, data); }
private: using ImplToFst<Impl>::GetImpl; using ImplToFst<Impl>::GetMutableImpl;
FactorWeightFst &operator=(const FactorWeightFst &) = delete; };
// Specialization for FactorWeightFst.
template <class Arc, class FactorIterator> class StateIterator<FactorWeightFst<Arc, FactorIterator>> : public CacheStateIterator<FactorWeightFst<Arc, FactorIterator>> { public: explicit StateIterator(const FactorWeightFst<Arc, FactorIterator> &fst) : CacheStateIterator<FactorWeightFst<Arc, FactorIterator>>( fst, fst.GetMutableImpl()) {} };
// Specialization for FactorWeightFst.
template <class Arc, class FactorIterator> class ArcIterator<FactorWeightFst<Arc, FactorIterator>> : public CacheArcIterator<FactorWeightFst<Arc, FactorIterator>> { public: using StateId = typename Arc::StateId;
ArcIterator(const FactorWeightFst<Arc, FactorIterator> &fst, StateId s) : CacheArcIterator<FactorWeightFst<Arc, FactorIterator>>( fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } };
template <class Arc, class FactorIterator> inline void FactorWeightFst<Arc, FactorIterator>::InitStateIterator( StateIteratorData<Arc> *data) const { data->base = std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>( *this); }
} // namespace fst
#endif // FST_FACTOR_WEIGHT_H_
|