// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Classes to factor weights in an FST.
|
|
|
|
#ifndef FST_FACTOR_WEIGHT_H_
|
|
#define FST_FACTOR_WEIGHT_H_
|
|
|
|
#include <algorithm>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <fst/log.h>
|
|
#include <fst/cache.h>
|
|
#include <fst/fst.h>
|
|
#include <fst/impl-to-fst.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/string-weight.h>
|
|
#include <fst/union-weight.h>
|
|
#include <fst/weight.h>
|
|
#include <unordered_map>
|
|
|
|
namespace fst {
|
|
|
|
inline constexpr uint8_t kFactorFinalWeights = 0x01;
|
|
inline constexpr uint8_t kFactorArcWeights = 0x02;
|
|
|
|
template <class Arc>
|
|
struct FactorWeightOptions : CacheOptions {
|
|
using Label = typename Arc::Label;
|
|
|
|
float delta;
|
|
uint8_t mode; // Factor arc weights and/or final weights.
|
|
Label final_ilabel; // Input label of arc when factoring final weights.
|
|
Label final_olabel; // Output label of arc when factoring final weights.
|
|
bool increment_final_ilabel; // When factoring final w' results in > 1 arcs
|
|
bool increment_final_olabel; // at state, increment labels to make distinct?
|
|
|
|
explicit FactorWeightOptions(const CacheOptions &opts, float delta = kDelta,
|
|
uint8_t mode = kFactorArcWeights |
|
|
kFactorFinalWeights,
|
|
Label final_ilabel = 0, Label final_olabel = 0,
|
|
bool increment_final_ilabel = false,
|
|
bool increment_final_olabel = false)
|
|
: CacheOptions(opts),
|
|
delta(delta),
|
|
mode(mode),
|
|
final_ilabel(final_ilabel),
|
|
final_olabel(final_olabel),
|
|
increment_final_ilabel(increment_final_ilabel),
|
|
increment_final_olabel(increment_final_olabel) {}
|
|
|
|
explicit FactorWeightOptions(float delta = kDelta,
|
|
uint8_t mode = kFactorArcWeights |
|
|
kFactorFinalWeights,
|
|
Label final_ilabel = 0, Label final_olabel = 0,
|
|
bool increment_final_ilabel = false,
|
|
bool increment_final_olabel = false)
|
|
: delta(delta),
|
|
mode(mode),
|
|
final_ilabel(final_ilabel),
|
|
final_olabel(final_olabel),
|
|
increment_final_ilabel(increment_final_ilabel),
|
|
increment_final_olabel(increment_final_olabel) {}
|
|
};
|
|
|
|
// A factor iterator takes as argument a weight w and returns a sequence of
|
|
// pairs of weights (xi, yi) such that the sum of the products xi times yi is
|
|
// equal to w. If w is fully factored, the iterator should return nothing.
|
|
//
|
|
// template <class W>
|
|
// class FactorIterator {
|
|
// public:
|
|
// explicit FactorIterator(W w);
|
|
//
|
|
// bool Done() const;
|
|
//
|
|
// void Next();
|
|
//
|
|
// std::pair<W, W> Value() const;
|
|
//
|
|
// void Reset();
|
|
// }
|
|
|
|
// Factors trivially.
|
|
template <class W>
|
|
class IdentityFactor {
|
|
public:
|
|
explicit IdentityFactor(const W &weight) {}
|
|
|
|
bool Done() const { return true; }
|
|
|
|
void Next() {}
|
|
|
|
std::pair<W, W> Value() const { return std::make_pair(W::One(), W::One()); }
|
|
|
|
void Reset() {}
|
|
};
|
|
|
|
// Factor the Fst to unfold it as needed so that every two paths leading to the
|
|
// same state have the same weight. Requires applying only to arc weights
|
|
// (FactorWeightOptions::mode == kFactorArcWeights).
|
|
template <class W>
|
|
class OneFactor {
|
|
public:
|
|
explicit OneFactor(const W &w) : weight_(w), done_(w == W::One()) {}
|
|
|
|
bool Done() const { return done_; }
|
|
|
|
void Next() { done_ = true; }
|
|
|
|
std::pair<W, W> Value() const { return std::make_pair(W::One(), weight_); }
|
|
|
|
void Reset() { done_ = weight_ == W::One(); }
|
|
|
|
private:
|
|
W weight_;
|
|
bool done_;
|
|
};
|
|
|
|
// Factors a StringWeight w as 'ab' where 'a' is a label.
|
|
template <typename Label, StringType S = STRING_LEFT>
|
|
class StringFactor {
|
|
public:
|
|
explicit StringFactor(const StringWeight<Label, S> &weight)
|
|
: weight_(weight), done_(weight.Size() <= 1) {}
|
|
|
|
bool Done() const { return done_; }
|
|
|
|
void Next() { done_ = true; }
|
|
|
|
std::pair<StringWeight<Label, S>, StringWeight<Label, S>> Value() const {
|
|
using Weight = StringWeight<Label, S>;
|
|
typename Weight::Iterator siter(weight_);
|
|
Weight w1(siter.Value());
|
|
Weight w2;
|
|
for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
|
|
return std::make_pair(w1, w2);
|
|
}
|
|
|
|
void Reset() { done_ = weight_.Size() <= 1; }
|
|
|
|
private:
|
|
const StringWeight<Label, S> weight_;
|
|
bool done_;
|
|
};
|
|
|
|
// Factor a GallicWeight using StringFactor.
|
|
template <class Label, class W, GallicType G = GALLIC_LEFT>
|
|
class GallicFactor {
|
|
public:
|
|
using GW = GallicWeight<Label, W, G>;
|
|
|
|
explicit GallicFactor(const GW &weight)
|
|
: weight_(weight), done_(weight.Value1().Size() <= 1) {}
|
|
|
|
bool Done() const { return done_; }
|
|
|
|
void Next() { done_ = true; }
|
|
|
|
std::pair<GW, GW> Value() const {
|
|
StringFactor<Label, GallicStringType(G)> siter(weight_.Value1());
|
|
GW w1(siter.Value().first, weight_.Value2());
|
|
GW w2(siter.Value().second, W::One());
|
|
return std::make_pair(w1, w2);
|
|
}
|
|
|
|
void Reset() { done_ = weight_.Value1().Size() <= 1; }
|
|
|
|
private:
|
|
const GW weight_;
|
|
bool done_;
|
|
};
|
|
|
|
// Specialization for the (general) GALLIC type GallicWeight.
|
|
template <class Label, class W>
|
|
class GallicFactor<Label, W, GALLIC> {
|
|
public:
|
|
using GW = GallicWeight<Label, W, GALLIC>;
|
|
using GRW = GallicWeight<Label, W, GALLIC_RESTRICT>;
|
|
|
|
explicit GallicFactor(const GW &weight)
|
|
: iter_(weight),
|
|
done_(weight.Size() == 0 ||
|
|
(weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
|
|
|
|
bool Done() const { return done_ || iter_.Done(); }
|
|
|
|
void Next() { iter_.Next(); }
|
|
|
|
void Reset() { iter_.Reset(); }
|
|
|
|
std::pair<GW, GW> Value() const {
|
|
const auto weight = iter_.Value();
|
|
StringFactor<Label, GallicStringType(GALLIC_RESTRICT)> siter(
|
|
weight.Value1());
|
|
GRW w1(siter.Value().first, weight.Value2());
|
|
GRW w2(siter.Value().second, W::One());
|
|
return std::make_pair(GW(w1), GW(w2));
|
|
}
|
|
|
|
private:
|
|
UnionWeightIterator<GRW, GallicUnionWeightOptions<Label, W>> iter_;
|
|
bool done_;
|
|
};
|
|
|
|
namespace internal {
|
|
|
|
// Implementation class for FactorWeight
|
|
template <class Arc, class FactorIterator>
|
|
class FactorWeightFstImpl : public CacheImpl<Arc> {
|
|
public:
|
|
using Label = typename Arc::Label;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
using FstImpl<Arc>::SetType;
|
|
using FstImpl<Arc>::SetProperties;
|
|
using FstImpl<Arc>::SetInputSymbols;
|
|
using FstImpl<Arc>::SetOutputSymbols;
|
|
|
|
using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
|
|
using CacheBaseImpl<CacheState<Arc>>::HasArcs;
|
|
using CacheBaseImpl<CacheState<Arc>>::HasFinal;
|
|
using CacheBaseImpl<CacheState<Arc>>::HasStart;
|
|
using CacheBaseImpl<CacheState<Arc>>::SetArcs;
|
|
using CacheBaseImpl<CacheState<Arc>>::SetFinal;
|
|
using CacheBaseImpl<CacheState<Arc>>::SetStart;
|
|
|
|
struct Element {
|
|
Element() = default;
|
|
|
|
Element(StateId s, Weight weight_) : state(s), weight(std::move(weight_)) {}
|
|
|
|
StateId state; // Input state ID.
|
|
Weight weight; // Residual weight.
|
|
};
|
|
|
|
FactorWeightFstImpl(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts)
|
|
: CacheImpl<Arc>(opts),
|
|
fst_(fst.Copy()),
|
|
delta_(opts.delta),
|
|
mode_(opts.mode),
|
|
final_ilabel_(opts.final_ilabel),
|
|
final_olabel_(opts.final_olabel),
|
|
increment_final_ilabel_(opts.increment_final_ilabel),
|
|
increment_final_olabel_(opts.increment_final_olabel) {
|
|
SetType("factor_weight");
|
|
const auto props = fst.Properties(kFstProperties, false);
|
|
SetProperties(FactorWeightProperties(props), kCopyProperties);
|
|
SetInputSymbols(fst.InputSymbols());
|
|
SetOutputSymbols(fst.OutputSymbols());
|
|
if (mode_ == 0) {
|
|
LOG(WARNING) << "FactorWeightFst: Factor mode is set to 0; "
|
|
<< "factoring neither arc weights nor final weights";
|
|
}
|
|
}
|
|
|
|
FactorWeightFstImpl(const FactorWeightFstImpl<Arc, FactorIterator> &impl)
|
|
: CacheImpl<Arc>(impl),
|
|
fst_(impl.fst_->Copy(true)),
|
|
delta_(impl.delta_),
|
|
mode_(impl.mode_),
|
|
final_ilabel_(impl.final_ilabel_),
|
|
final_olabel_(impl.final_olabel_),
|
|
increment_final_ilabel_(impl.increment_final_ilabel_),
|
|
increment_final_olabel_(impl.increment_final_olabel_) {
|
|
SetType("factor_weight");
|
|
SetProperties(impl.Properties(), kCopyProperties);
|
|
SetInputSymbols(impl.InputSymbols());
|
|
SetOutputSymbols(impl.OutputSymbols());
|
|
}
|
|
|
|
StateId Start() {
|
|
if (!HasStart()) {
|
|
const auto s = fst_->Start();
|
|
if (s == kNoStateId) return kNoStateId;
|
|
SetStart(FindState(Element(fst_->Start(), Weight::One())));
|
|
}
|
|
return CacheImpl<Arc>::Start();
|
|
}
|
|
|
|
Weight Final(StateId s) {
|
|
if (!HasFinal(s)) {
|
|
const auto &element = elements_[s];
|
|
const auto weight =
|
|
element.state == kNoStateId
|
|
? element.weight
|
|
: Times(element.weight, fst_->Final(element.state));
|
|
FactorIterator siter(weight);
|
|
if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
|
|
SetFinal(s, weight);
|
|
} else {
|
|
SetFinal(s, Weight::Zero());
|
|
}
|
|
}
|
|
return CacheImpl<Arc>::Final(s);
|
|
}
|
|
|
|
size_t NumArcs(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<Arc>::NumArcs(s);
|
|
}
|
|
|
|
size_t NumInputEpsilons(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<Arc>::NumInputEpsilons(s);
|
|
}
|
|
|
|
size_t NumOutputEpsilons(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<Arc>::NumOutputEpsilons(s);
|
|
}
|
|
|
|
uint64_t Properties() const override { return Properties(kFstProperties); }
|
|
|
|
// Sets error if found, and returns other FST impl properties.
|
|
uint64_t Properties(uint64_t mask) const override {
|
|
if ((mask & kError) && fst_->Properties(kError, false)) {
|
|
SetProperties(kError, kError);
|
|
}
|
|
return FstImpl<Arc>::Properties(mask);
|
|
}
|
|
|
|
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
CacheImpl<Arc>::InitArcIterator(s, data);
|
|
}
|
|
|
|
// Finds state corresponding to an element, creating new state if element not
|
|
// found.
|
|
StateId FindState(const Element &element) {
|
|
if (!(mode_ & kFactorArcWeights) && element.weight == Weight::One() &&
|
|
element.state != kNoStateId) {
|
|
while (unfactored_.size() <= element.state)
|
|
unfactored_.push_back(kNoStateId);
|
|
if (unfactored_[element.state] == kNoStateId) {
|
|
unfactored_[element.state] = elements_.size();
|
|
elements_.push_back(element);
|
|
}
|
|
return unfactored_[element.state];
|
|
} else {
|
|
const auto insert_result =
|
|
element_map_.emplace(element, elements_.size());
|
|
if (insert_result.second) {
|
|
elements_.push_back(element);
|
|
}
|
|
return insert_result.first->second;
|
|
}
|
|
}
|
|
|
|
// Computes the outgoing transitions from a state, creating new destination
|
|
// states as needed.
|
|
void Expand(StateId s) {
|
|
const auto element = elements_[s];
|
|
if (element.state != kNoStateId) {
|
|
for (ArcIterator<Fst<Arc>> ait(*fst_, element.state); !ait.Done();
|
|
ait.Next()) {
|
|
const auto &arc = ait.Value();
|
|
auto weight = Times(element.weight, arc.weight);
|
|
FactorIterator fiter(weight);
|
|
if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
|
|
const auto dest = FindState(Element(arc.nextstate, Weight::One()));
|
|
EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
|
|
} else {
|
|
for (; !fiter.Done(); fiter.Next()) {
|
|
auto pair = fiter.Value();
|
|
const auto dest =
|
|
FindState(Element(arc.nextstate, pair.second.Quantize(delta_)));
|
|
EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if ((mode_ & kFactorFinalWeights) &&
|
|
((element.state == kNoStateId) ||
|
|
(fst_->Final(element.state) != Weight::Zero()))) {
|
|
const auto weight =
|
|
element.state == kNoStateId
|
|
? element.weight
|
|
: Times(element.weight, fst_->Final(element.state));
|
|
auto ilabel = final_ilabel_;
|
|
auto olabel = final_olabel_;
|
|
for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
|
|
auto pair = fiter.Value();
|
|
const auto dest =
|
|
FindState(Element(kNoStateId, pair.second.Quantize(delta_)));
|
|
EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
|
|
if (increment_final_ilabel_) ++ilabel;
|
|
if (increment_final_olabel_) ++olabel;
|
|
}
|
|
}
|
|
SetArcs(s);
|
|
}
|
|
|
|
private:
|
|
// Equality function for Elements, assume weights have been quantized.
|
|
class ElementEqual {
|
|
public:
|
|
bool operator()(const Element &x, const Element &y) const {
|
|
return x.state == y.state && x.weight == y.weight;
|
|
}
|
|
};
|
|
|
|
// Hash function for Elements to Fst states.
|
|
class ElementKey {
|
|
public:
|
|
size_t operator()(const Element &x) const {
|
|
static constexpr auto prime = 7853;
|
|
return static_cast<size_t>(x.state * prime + x.weight.Hash());
|
|
}
|
|
};
|
|
|
|
using ElementMap =
|
|
std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
|
|
|
|
std::unique_ptr<const Fst<Arc>> fst_;
|
|
float delta_;
|
|
uint8_t mode_; // Factoring arc and/or final weights.
|
|
Label final_ilabel_; // ilabel of arc created when factoring final weights.
|
|
Label final_olabel_; // olabel of arc created when factoring final weights.
|
|
bool increment_final_ilabel_; // When factoring final weights results in
|
|
bool increment_final_olabel_; // mutiple arcs, increment labels?
|
|
std::vector<Element> elements_; // Mapping from FST state to Element.
|
|
ElementMap element_map_; // Mapping from Element to FST state.
|
|
// Mapping between old/new StateId for states that do not need to be factored
|
|
// when mode_ is 0 or kFactorFinalWeights.
|
|
std::vector<StateId> unfactored_;
|
|
};
|
|
|
|
} // namespace internal
|
|
|
|
// FactorWeightFst takes as template parameter a FactorIterator as defined
|
|
// above. The result of weight factoring is a transducer equivalent to the
|
|
// input whose path weights have been factored according to the FactorIterator.
|
|
// States and transitions will be added as necessary. The algorithm is a
|
|
// generalization to arbitrary weights of the second step of the input
|
|
// epsilon-normalization algorithm.
|
|
//
|
|
// This class attaches interface to implementation and handles reference
|
|
// counting, delegating most methods to ImplToFst.
|
|
template <class A, class FactorIterator>
|
|
class FactorWeightFst
|
|
: public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
|
|
public:
|
|
using Arc = A;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
using Store = DefaultCacheStore<Arc>;
|
|
using State = typename Store::State;
|
|
using Impl = internal::FactorWeightFstImpl<Arc, FactorIterator>;
|
|
|
|
friend class ArcIterator<FactorWeightFst<Arc, FactorIterator>>;
|
|
friend class StateIterator<FactorWeightFst<Arc, FactorIterator>>;
|
|
|
|
explicit FactorWeightFst(const Fst<Arc> &fst)
|
|
: ImplToFst<Impl>(
|
|
std::make_shared<Impl>(fst, FactorWeightOptions<Arc>())) {}
|
|
|
|
FactorWeightFst(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts)
|
|
: ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
|
|
|
|
// See Fst<>::Copy() for doc.
|
|
FactorWeightFst(const FactorWeightFst &fst, bool copy)
|
|
: ImplToFst<Impl>(fst, copy) {}
|
|
|
|
// Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
|
|
FactorWeightFst *Copy(bool copy = false) const override {
|
|
return new FactorWeightFst(*this, copy);
|
|
}
|
|
|
|
inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
|
|
|
|
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
|
|
GetMutableImpl()->InitArcIterator(s, data);
|
|
}
|
|
|
|
private:
|
|
using ImplToFst<Impl>::GetImpl;
|
|
using ImplToFst<Impl>::GetMutableImpl;
|
|
|
|
FactorWeightFst &operator=(const FactorWeightFst &) = delete;
|
|
};
|
|
|
|
// Specialization for FactorWeightFst.
|
|
template <class Arc, class FactorIterator>
|
|
class StateIterator<FactorWeightFst<Arc, FactorIterator>>
|
|
: public CacheStateIterator<FactorWeightFst<Arc, FactorIterator>> {
|
|
public:
|
|
explicit StateIterator(const FactorWeightFst<Arc, FactorIterator> &fst)
|
|
: CacheStateIterator<FactorWeightFst<Arc, FactorIterator>>(
|
|
fst, fst.GetMutableImpl()) {}
|
|
};
|
|
|
|
// Specialization for FactorWeightFst.
|
|
template <class Arc, class FactorIterator>
|
|
class ArcIterator<FactorWeightFst<Arc, FactorIterator>>
|
|
: public CacheArcIterator<FactorWeightFst<Arc, FactorIterator>> {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
|
|
ArcIterator(const FactorWeightFst<Arc, FactorIterator> &fst, StateId s)
|
|
: CacheArcIterator<FactorWeightFst<Arc, FactorIterator>>(
|
|
fst.GetMutableImpl(), s) {
|
|
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
|
|
}
|
|
};
|
|
|
|
template <class Arc, class FactorIterator>
|
|
inline void FactorWeightFst<Arc, FactorIterator>::InitStateIterator(
|
|
StateIteratorData<Arc> *data) const {
|
|
data->base =
|
|
std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>(
|
|
*this);
|
|
}
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_FACTOR_WEIGHT_H_
|