You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

538 lines
18 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Classes to factor weights in an FST.
#ifndef FST_FACTOR_WEIGHT_H_
#define FST_FACTOR_WEIGHT_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/cache.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/properties.h>
#include <fst/string-weight.h>
#include <fst/union-weight.h>
#include <fst/weight.h>
#include <unordered_map>
namespace fst {
inline constexpr uint8_t kFactorFinalWeights = 0x01;
inline constexpr uint8_t kFactorArcWeights = 0x02;
template <class Arc>
struct FactorWeightOptions : CacheOptions {
using Label = typename Arc::Label;
float delta;
uint8_t mode; // Factor arc weights and/or final weights.
Label final_ilabel; // Input label of arc when factoring final weights.
Label final_olabel; // Output label of arc when factoring final weights.
bool increment_final_ilabel; // When factoring final w' results in > 1 arcs
bool increment_final_olabel; // at state, increment labels to make distinct?
explicit FactorWeightOptions(const CacheOptions &opts, float delta = kDelta,
uint8_t mode = kFactorArcWeights |
kFactorFinalWeights,
Label final_ilabel = 0, Label final_olabel = 0,
bool increment_final_ilabel = false,
bool increment_final_olabel = false)
: CacheOptions(opts),
delta(delta),
mode(mode),
final_ilabel(final_ilabel),
final_olabel(final_olabel),
increment_final_ilabel(increment_final_ilabel),
increment_final_olabel(increment_final_olabel) {}
explicit FactorWeightOptions(float delta = kDelta,
uint8_t mode = kFactorArcWeights |
kFactorFinalWeights,
Label final_ilabel = 0, Label final_olabel = 0,
bool increment_final_ilabel = false,
bool increment_final_olabel = false)
: delta(delta),
mode(mode),
final_ilabel(final_ilabel),
final_olabel(final_olabel),
increment_final_ilabel(increment_final_ilabel),
increment_final_olabel(increment_final_olabel) {}
};
// A factor iterator takes as argument a weight w and returns a sequence of
// pairs of weights (xi, yi) such that the sum of the products xi times yi is
// equal to w. If w is fully factored, the iterator should return nothing.
//
// template <class W>
// class FactorIterator {
// public:
// explicit FactorIterator(W w);
//
// bool Done() const;
//
// void Next();
//
// std::pair<W, W> Value() const;
//
// void Reset();
// }
// Factors trivially.
template <class W>
class IdentityFactor {
public:
explicit IdentityFactor(const W &weight) {}
bool Done() const { return true; }
void Next() {}
std::pair<W, W> Value() const { return std::make_pair(W::One(), W::One()); }
void Reset() {}
};
// Factor the Fst to unfold it as needed so that every two paths leading to the
// same state have the same weight. Requires applying only to arc weights
// (FactorWeightOptions::mode == kFactorArcWeights).
template <class W>
class OneFactor {
public:
explicit OneFactor(const W &w) : weight_(w), done_(w == W::One()) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
std::pair<W, W> Value() const { return std::make_pair(W::One(), weight_); }
void Reset() { done_ = weight_ == W::One(); }
private:
W weight_;
bool done_;
};
// Factors a StringWeight w as 'ab' where 'a' is a label.
template <typename Label, StringType S = STRING_LEFT>
class StringFactor {
public:
explicit StringFactor(const StringWeight<Label, S> &weight)
: weight_(weight), done_(weight.Size() <= 1) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
std::pair<StringWeight<Label, S>, StringWeight<Label, S>> Value() const {
using Weight = StringWeight<Label, S>;
typename Weight::Iterator siter(weight_);
Weight w1(siter.Value());
Weight w2;
for (siter.Next(); !siter.Done(); siter.Next()) w2.PushBack(siter.Value());
return std::make_pair(w1, w2);
}
void Reset() { done_ = weight_.Size() <= 1; }
private:
const StringWeight<Label, S> weight_;
bool done_;
};
// Factor a GallicWeight using StringFactor.
template <class Label, class W, GallicType G = GALLIC_LEFT>
class GallicFactor {
public:
using GW = GallicWeight<Label, W, G>;
explicit GallicFactor(const GW &weight)
: weight_(weight), done_(weight.Value1().Size() <= 1) {}
bool Done() const { return done_; }
void Next() { done_ = true; }
std::pair<GW, GW> Value() const {
StringFactor<Label, GallicStringType(G)> siter(weight_.Value1());
GW w1(siter.Value().first, weight_.Value2());
GW w2(siter.Value().second, W::One());
return std::make_pair(w1, w2);
}
void Reset() { done_ = weight_.Value1().Size() <= 1; }
private:
const GW weight_;
bool done_;
};
// Specialization for the (general) GALLIC type GallicWeight.
template <class Label, class W>
class GallicFactor<Label, W, GALLIC> {
public:
using GW = GallicWeight<Label, W, GALLIC>;
using GRW = GallicWeight<Label, W, GALLIC_RESTRICT>;
explicit GallicFactor(const GW &weight)
: iter_(weight),
done_(weight.Size() == 0 ||
(weight.Size() == 1 && weight.Back().Value1().Size() <= 1)) {}
bool Done() const { return done_ || iter_.Done(); }
void Next() { iter_.Next(); }
void Reset() { iter_.Reset(); }
std::pair<GW, GW> Value() const {
const auto weight = iter_.Value();
StringFactor<Label, GallicStringType(GALLIC_RESTRICT)> siter(
weight.Value1());
GRW w1(siter.Value().first, weight.Value2());
GRW w2(siter.Value().second, W::One());
return std::make_pair(GW(w1), GW(w2));
}
private:
UnionWeightIterator<GRW, GallicUnionWeightOptions<Label, W>> iter_;
bool done_;
};
namespace internal {
// Implementation class for FactorWeight
template <class Arc, class FactorIterator>
class FactorWeightFstImpl : public CacheImpl<Arc> {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using FstImpl<Arc>::SetType;
using FstImpl<Arc>::SetProperties;
using FstImpl<Arc>::SetInputSymbols;
using FstImpl<Arc>::SetOutputSymbols;
using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
using CacheBaseImpl<CacheState<Arc>>::HasArcs;
using CacheBaseImpl<CacheState<Arc>>::HasFinal;
using CacheBaseImpl<CacheState<Arc>>::HasStart;
using CacheBaseImpl<CacheState<Arc>>::SetArcs;
using CacheBaseImpl<CacheState<Arc>>::SetFinal;
using CacheBaseImpl<CacheState<Arc>>::SetStart;
struct Element {
Element() = default;
Element(StateId s, Weight weight_) : state(s), weight(std::move(weight_)) {}
StateId state; // Input state ID.
Weight weight; // Residual weight.
};
FactorWeightFstImpl(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts)
: CacheImpl<Arc>(opts),
fst_(fst.Copy()),
delta_(opts.delta),
mode_(opts.mode),
final_ilabel_(opts.final_ilabel),
final_olabel_(opts.final_olabel),
increment_final_ilabel_(opts.increment_final_ilabel),
increment_final_olabel_(opts.increment_final_olabel) {
SetType("factor_weight");
const auto props = fst.Properties(kFstProperties, false);
SetProperties(FactorWeightProperties(props), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
if (mode_ == 0) {
LOG(WARNING) << "FactorWeightFst: Factor mode is set to 0; "
<< "factoring neither arc weights nor final weights";
}
}
FactorWeightFstImpl(const FactorWeightFstImpl<Arc, FactorIterator> &impl)
: CacheImpl<Arc>(impl),
fst_(impl.fst_->Copy(true)),
delta_(impl.delta_),
mode_(impl.mode_),
final_ilabel_(impl.final_ilabel_),
final_olabel_(impl.final_olabel_),
increment_final_ilabel_(impl.increment_final_ilabel_),
increment_final_olabel_(impl.increment_final_olabel_) {
SetType("factor_weight");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
StateId Start() {
if (!HasStart()) {
const auto s = fst_->Start();
if (s == kNoStateId) return kNoStateId;
SetStart(FindState(Element(fst_->Start(), Weight::One())));
}
return CacheImpl<Arc>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
const auto &element = elements_[s];
const auto weight =
element.state == kNoStateId
? element.weight
: Times(element.weight, fst_->Final(element.state));
FactorIterator siter(weight);
if (!(mode_ & kFactorFinalWeights) || siter.Done()) {
SetFinal(s, weight);
} else {
SetFinal(s, Weight::Zero());
}
}
return CacheImpl<Arc>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumOutputEpsilons(s);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) && fst_->Properties(kError, false)) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
if (!HasArcs(s)) Expand(s);
CacheImpl<Arc>::InitArcIterator(s, data);
}
// Finds state corresponding to an element, creating new state if element not
// found.
StateId FindState(const Element &element) {
if (!(mode_ & kFactorArcWeights) && element.weight == Weight::One() &&
element.state != kNoStateId) {
while (unfactored_.size() <= element.state)
unfactored_.push_back(kNoStateId);
if (unfactored_[element.state] == kNoStateId) {
unfactored_[element.state] = elements_.size();
elements_.push_back(element);
}
return unfactored_[element.state];
} else {
const auto insert_result =
element_map_.emplace(element, elements_.size());
if (insert_result.second) {
elements_.push_back(element);
}
return insert_result.first->second;
}
}
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) {
const auto element = elements_[s];
if (element.state != kNoStateId) {
for (ArcIterator<Fst<Arc>> ait(*fst_, element.state); !ait.Done();
ait.Next()) {
const auto &arc = ait.Value();
auto weight = Times(element.weight, arc.weight);
FactorIterator fiter(weight);
if (!(mode_ & kFactorArcWeights) || fiter.Done()) {
const auto dest = FindState(Element(arc.nextstate, Weight::One()));
EmplaceArc(s, arc.ilabel, arc.olabel, std::move(weight), dest);
} else {
for (; !fiter.Done(); fiter.Next()) {
auto pair = fiter.Value();
const auto dest =
FindState(Element(arc.nextstate, pair.second.Quantize(delta_)));
EmplaceArc(s, arc.ilabel, arc.olabel, std::move(pair.first), dest);
}
}
}
}
if ((mode_ & kFactorFinalWeights) &&
((element.state == kNoStateId) ||
(fst_->Final(element.state) != Weight::Zero()))) {
const auto weight =
element.state == kNoStateId
? element.weight
: Times(element.weight, fst_->Final(element.state));
auto ilabel = final_ilabel_;
auto olabel = final_olabel_;
for (FactorIterator fiter(weight); !fiter.Done(); fiter.Next()) {
auto pair = fiter.Value();
const auto dest =
FindState(Element(kNoStateId, pair.second.Quantize(delta_)));
EmplaceArc(s, ilabel, olabel, std::move(pair.first), dest);
if (increment_final_ilabel_) ++ilabel;
if (increment_final_olabel_) ++olabel;
}
}
SetArcs(s);
}
private:
// Equality function for Elements, assume weights have been quantized.
class ElementEqual {
public:
bool operator()(const Element &x, const Element &y) const {
return x.state == y.state && x.weight == y.weight;
}
};
// Hash function for Elements to Fst states.
class ElementKey {
public:
size_t operator()(const Element &x) const {
static constexpr auto prime = 7853;
return static_cast<size_t>(x.state * prime + x.weight.Hash());
}
};
using ElementMap =
std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
std::unique_ptr<const Fst<Arc>> fst_;
float delta_;
uint8_t mode_; // Factoring arc and/or final weights.
Label final_ilabel_; // ilabel of arc created when factoring final weights.
Label final_olabel_; // olabel of arc created when factoring final weights.
bool increment_final_ilabel_; // When factoring final weights results in
bool increment_final_olabel_; // mutiple arcs, increment labels?
std::vector<Element> elements_; // Mapping from FST state to Element.
ElementMap element_map_; // Mapping from Element to FST state.
// Mapping between old/new StateId for states that do not need to be factored
// when mode_ is 0 or kFactorFinalWeights.
std::vector<StateId> unfactored_;
};
} // namespace internal
// FactorWeightFst takes as template parameter a FactorIterator as defined
// above. The result of weight factoring is a transducer equivalent to the
// input whose path weights have been factored according to the FactorIterator.
// States and transitions will be added as necessary. The algorithm is a
// generalization to arbitrary weights of the second step of the input
// epsilon-normalization algorithm.
//
// This class attaches interface to implementation and handles reference
// counting, delegating most methods to ImplToFst.
template <class A, class FactorIterator>
class FactorWeightFst
: public ImplToFst<internal::FactorWeightFstImpl<A, FactorIterator>> {
public:
using Arc = A;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using Impl = internal::FactorWeightFstImpl<Arc, FactorIterator>;
friend class ArcIterator<FactorWeightFst<Arc, FactorIterator>>;
friend class StateIterator<FactorWeightFst<Arc, FactorIterator>>;
explicit FactorWeightFst(const Fst<Arc> &fst)
: ImplToFst<Impl>(
std::make_shared<Impl>(fst, FactorWeightOptions<Arc>())) {}
FactorWeightFst(const Fst<Arc> &fst, const FactorWeightOptions<Arc> &opts)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
// See Fst<>::Copy() for doc.
FactorWeightFst(const FactorWeightFst &fst, bool copy)
: ImplToFst<Impl>(fst, copy) {}
// Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
FactorWeightFst *Copy(bool copy = false) const override {
return new FactorWeightFst(*this, copy);
}
inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
GetMutableImpl()->InitArcIterator(s, data);
}
private:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
FactorWeightFst &operator=(const FactorWeightFst &) = delete;
};
// Specialization for FactorWeightFst.
template <class Arc, class FactorIterator>
class StateIterator<FactorWeightFst<Arc, FactorIterator>>
: public CacheStateIterator<FactorWeightFst<Arc, FactorIterator>> {
public:
explicit StateIterator(const FactorWeightFst<Arc, FactorIterator> &fst)
: CacheStateIterator<FactorWeightFst<Arc, FactorIterator>>(
fst, fst.GetMutableImpl()) {}
};
// Specialization for FactorWeightFst.
template <class Arc, class FactorIterator>
class ArcIterator<FactorWeightFst<Arc, FactorIterator>>
: public CacheArcIterator<FactorWeightFst<Arc, FactorIterator>> {
public:
using StateId = typename Arc::StateId;
ArcIterator(const FactorWeightFst<Arc, FactorIterator> &fst, StateId s)
: CacheArcIterator<FactorWeightFst<Arc, FactorIterator>>(
fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class Arc, class FactorIterator>
inline void FactorWeightFst<Arc, FactorIterator>::InitStateIterator(
StateIteratorData<Arc> *data) const {
data->base =
std::make_unique<StateIterator<FactorWeightFst<Arc, FactorIterator>>>(
*this);
}
} // namespace fst
#endif // FST_FACTOR_WEIGHT_H_