|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Union weight set and associated semiring operation definitions.
//
// TODO(riley): add in normalizer functor.
#ifndef FST_UNION_WEIGHT_H_
#define FST_UNION_WEIGHT_H_
#include <climits>
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <istream>
#include <list>
#include <ostream>
#include <random>
#include <sstream>
#include <string>
#include <utility>
#include <fst/util.h>
#include <fst/weight.h>
namespace fst {
// Example UnionWeightOptions for UnionWeight template below. The Merge
// operation is used to collapse elements of the set and the Compare function
// to efficiently implement the merge. In the simplest case, merge would just
// apply with equality of set elements so the result is a set (and not a
// multiset). More generally, this can be used to maintain the multiplicity or
// other such weight associated with the set elements (cf. Gallic weights).
// template <class W>
// struct UnionWeightOptions {
// // Comparison function C is a total order on W that is monotonic w.r.t. to
// // Times: for all a, b,c != Zero(): C(a, b) => C(ca, cb) and is
// // anti-monotonic w.r.rt to Divide: C(a, b) => C(c/b, c/a).
// //
// // For all a, b: only one of C(a, b), C(b, a) or a ~ b must true where
// // ~ is an equivalence relation on W. Also we require a ~ b iff
// // a.Reverse() ~ b.Reverse().
// using Compare = NaturalLess<W>;
//
// // How to combine two weights if a ~ b as above. For all a, b: a ~ b =>
// // merge(a, b) ~ a, Merge must define a semiring endomorphism from the
// // unmerged weight sets to the merged weight sets.
// struct Merge {
// W operator()(const W &w1, const W &w2) const { return w1; }
// };
//
// // For ReverseWeight.
// using ReverseOptions = UnionWeightOptions<ReverseWeight>;
// };
template <class W, class O> class UnionWeight; template <class W, class O> class UnionWeightIterator; template <class W, class O> class UnionWeightReverseIterator;
template <class W, class O> bool operator==(const UnionWeight<W, O> &, const UnionWeight<W, O> &);
// Semiring that uses Times() and One() from W and union and the empty set
// for Plus() and Zero(), respectively. Template argument O specifies the union
// weight options as above.
template <class W, class O> class UnionWeight { public: using Weight = W; using Compare = typename O::Compare; using Merge = typename O::Merge;
using ReverseWeight = UnionWeight<typename W::ReverseWeight, typename O::ReverseOptions>;
friend class UnionWeightIterator<W, O>; friend class UnionWeightReverseIterator<W, O>;
// Sets represented as first_ weight + rest_ weights. Uses first_ as
// NoWeight() to indicate the union weight Zero() as the empty set. Uses
// rest_ containing NoWeight() to indicate the union weight NoWeight().
UnionWeight() : first_(W::NoWeight()) {}
explicit UnionWeight(W weight) : first_(weight) { if (!weight.Member()) rest_.push_back(W::NoWeight()); }
static const UnionWeight &Zero() { static const auto *const zero = new UnionWeight; return *zero; }
static const UnionWeight &One() { static const auto *const one = new UnionWeight(W::One()); return *one; }
static const UnionWeight &NoWeight() { static const auto *const no_weight = new UnionWeight(W::Zero(), W::NoWeight()); return *no_weight; }
static const std::string &Type() { static const std::string *const type = new std::string(W::Type() + "_union"); return *type; }
static constexpr uint64_t Properties() { return W::Properties() & (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); }
bool Member() const;
std::istream &Read(std::istream &strm);
std::ostream &Write(std::ostream &strm) const;
size_t Hash() const;
UnionWeight Quantize(float delta = kDelta) const;
ReverseWeight Reverse() const;
// These operations combined with the UnionWeightIterator and
// UnionWeightReverseIterator provide the access and mutation of the union
// weight internal elements.
// Common initializer among constructors; clears existing UnionWeight.
void Clear() { first_ = W::NoWeight(); rest_.clear(); }
size_t Size() const { return first_.Member() ? rest_.size() + 1 : 0; }
const W &Back() const { return rest_.empty() ? first_ : rest_.back(); }
// When srt is true, assumes elements added sorted w.r.t Compare and merging
// of weights performed as needed. Otherwise, just ensures first_ is the
// least element wrt Compare.
void PushBack(W weight, bool srt);
// Sorts the elements of the set. Assumes that first_, if present, is the
// least element.
void Sort() { rest_.sort(comp_); }
private: W &Back() { if (rest_.empty()) { return first_; } else { return rest_.back(); } }
UnionWeight(W w1, W w2) : first_(std::move(w1)), rest_(1, std::move(w2)) {}
W first_; // First weight in set.
std::list<W> rest_; // Remaining weights in set.
Compare comp_; Merge merge_; };
template <class W, class O> void UnionWeight<W, O>::PushBack(W weight, bool srt) { if (!weight.Member()) { rest_.push_back(std::move(weight)); } else if (!first_.Member()) { first_ = std::move(weight); } else if (srt) { auto &back = Back(); if (comp_(back, weight)) { rest_.push_back(std::move(weight)); } else { back = merge_(back, std::move(weight)); } } else { if (comp_(first_, weight)) { rest_.push_back(std::move(weight)); } else { rest_.push_back(first_); first_ = std::move(weight); } } }
// Traverses union weight in the forward direction.
template <class W, class O> class UnionWeightIterator { public: explicit UnionWeightIterator(const UnionWeight<W, O> &weight) : first_(weight.first_), rest_(weight.rest_), init_(true), it_(rest_.begin()) {}
bool Done() const { return init_ ? !first_.Member() : it_ == rest_.end(); }
const W &Value() const { return init_ ? first_ : *it_; }
void Next() { if (init_) { init_ = false; } else { ++it_; } }
void Reset() { init_ = true; it_ = rest_.begin(); }
private: const W &first_; const std::list<W> &rest_; bool init_; // in the initialized state?
typename std::list<W>::const_iterator it_; };
// Traverses union weight in backward direction.
template <typename L, class O> class UnionWeightReverseIterator { public: explicit UnionWeightReverseIterator(const UnionWeight<L, O> &weight) : first_(weight.first_), rest_(weight.rest_), fin_(!first_.Member()), it_(rest_.rbegin()) {}
bool Done() const { return fin_; }
const L &Value() const { return it_ == rest_.rend() ? first_ : *it_; }
void Next() { if (it_ == rest_.rend()) { fin_ = true; } else { ++it_; } }
void Reset() { fin_ = !first_.Member(); it_ = rest_.rbegin(); }
private: const L &first_; const std::list<L> &rest_; bool fin_; // in the final state?
typename std::list<L>::const_reverse_iterator it_; };
// UnionWeight member functions follow that require UnionWeightIterator.
template <class W, class O> inline std::istream &UnionWeight<W, O>::Read(std::istream &istrm) { Clear(); int32_t size; ReadType(istrm, &size); for (int i = 0; i < size; ++i) { W weight; ReadType(istrm, &weight); PushBack(weight, true); } return istrm; }
template <class W, class O> inline std::ostream &UnionWeight<W, O>::Write(std::ostream &ostrm) const { const int32_t size = Size(); WriteType(ostrm, size); for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) { WriteType(ostrm, it.Value()); } return ostrm; }
template <class W, class O> inline bool UnionWeight<W, O>::Member() const { if (Size() <= 1) return true; for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) { if (!it.Value().Member()) return false; } return true; }
template <class W, class O> inline UnionWeight<W, O> UnionWeight<W, O>::Quantize(float delta) const { UnionWeight weight; for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) { weight.PushBack(it.Value().Quantize(delta), true); } return weight; }
template <class W, class O> inline typename UnionWeight<W, O>::ReverseWeight UnionWeight<W, O>::Reverse() const { ReverseWeight weight; for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) { weight.PushBack(it.Value().Reverse(), false); } weight.Sort(); return weight; }
template <class W, class O> inline size_t UnionWeight<W, O>::Hash() const { size_t h = 0; static constexpr int lshift = 5; static constexpr int rshift = CHAR_BIT * sizeof(size_t) - lshift; for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) { h = h << lshift ^ h >> rshift ^ it.Value().Hash(); } return h; }
// Requires union weight has been canonicalized.
template <class W, class O> inline bool operator==(const UnionWeight<W, O> &w1, const UnionWeight<W, O> &w2) { if (w1.Size() != w2.Size()) return false; UnionWeightIterator<W, O> it1(w1); UnionWeightIterator<W, O> it2(w2); for (; !it1.Done(); it1.Next(), it2.Next()) { if (it1.Value() != it2.Value()) return false; } return true; }
// Requires union weight has been canonicalized.
template <class W, class O> inline bool operator!=(const UnionWeight<W, O> &w1, const UnionWeight<W, O> &w2) { return !(w1 == w2); }
// Requires union weight has been canonicalized.
template <class W, class O> inline bool ApproxEqual(const UnionWeight<W, O> &w1, const UnionWeight<W, O> &w2, float delta = kDelta) { if (w1.Size() != w2.Size()) return false; UnionWeightIterator<W, O> it1(w1); UnionWeightIterator<W, O> it2(w2); for (; !it1.Done(); it1.Next(), it2.Next()) { if (!ApproxEqual(it1.Value(), it2.Value(), delta)) return false; } return true; }
template <class W, class O> inline std::ostream &operator<<(std::ostream &ostrm, const UnionWeight<W, O> &weight) { UnionWeightIterator<W, O> it(weight); if (it.Done()) { return ostrm << "EmptySet"; } else if (!weight.Member()) { return ostrm << "BadSet"; } else { CompositeWeightWriter writer(ostrm); writer.WriteBegin(); for (; !it.Done(); it.Next()) writer.WriteElement(it.Value()); writer.WriteEnd(); } return ostrm; }
template <class W, class O> inline std::istream &operator>>(std::istream &istrm, UnionWeight<W, O> &weight) { std::string s; istrm >> s; if (s == "EmptySet") { weight = UnionWeight<W, O>::Zero(); } else if (s == "BadSet") { weight = UnionWeight<W, O>::NoWeight(); } else { weight = UnionWeight<W, O>::Zero(); std::istringstream sstrm(s); CompositeWeightReader reader(sstrm); reader.ReadBegin(); bool more = true; while (more) { W v; more = reader.ReadElement(&v); weight.PushBack(v, true); } reader.ReadEnd(); } return istrm; }
template <class W, class O> inline UnionWeight<W, O> Plus(const UnionWeight<W, O> &w1, const UnionWeight<W, O> &w2) { if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight(); if (w1 == UnionWeight<W, O>::Zero()) return w2; if (w2 == UnionWeight<W, O>::Zero()) return w1; UnionWeightIterator<W, O> it1(w1); UnionWeightIterator<W, O> it2(w2); UnionWeight<W, O> sum; typename O::Compare comp; while (!it1.Done() && !it2.Done()) { const auto v1 = it1.Value(); const auto v2 = it2.Value(); if (comp(v1, v2)) { sum.PushBack(v1, true); it1.Next(); } else { sum.PushBack(v2, true); it2.Next(); } } for (; !it1.Done(); it1.Next()) sum.PushBack(it1.Value(), true); for (; !it2.Done(); it2.Next()) sum.PushBack(it2.Value(), true); return sum; }
template <class W, class O> inline UnionWeight<W, O> Times(const UnionWeight<W, O> &w1, const UnionWeight<W, O> &w2) { if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight(); if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) { return UnionWeight<W, O>::Zero(); } UnionWeightIterator<W, O> it1(w1); UnionWeightIterator<W, O> it2(w2); UnionWeight<W, O> prod1; for (; !it1.Done(); it1.Next()) { UnionWeight<W, O> prod2; for (; !it2.Done(); it2.Next()) { prod2.PushBack(Times(it1.Value(), it2.Value()), true); } prod1 = Plus(prod1, prod2); it2.Reset(); } return prod1; }
template <class W, class O> inline UnionWeight<W, O> Divide(const UnionWeight<W, O> &w1, const UnionWeight<W, O> &w2, DivideType typ) { if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight(); if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) { return UnionWeight<W, O>::Zero(); } UnionWeightIterator<W, O> it1(w1); UnionWeightReverseIterator<W, O> it2(w2); UnionWeight<W, O> quot; if (w1.Size() == 1) { for (; !it2.Done(); it2.Next()) { quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true); } } else if (w2.Size() == 1) { for (; !it1.Done(); it1.Next()) { quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true); } } else { quot = UnionWeight<W, O>::NoWeight(); } return quot; }
// This function object generates weights over the union of weights for the
// underlying generators for the template weight types. This is intended
// primarily for testing.
template <class W, class O> class WeightGenerate<UnionWeight<W, O>> { public: using Weight = UnionWeight<W, O>; using Generate = WeightGenerate<W>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : rand_(seed), allow_zero_(allow_zero), num_random_weights_(num_random_weights), generate_(seed, false) {}
Weight operator()() const { const int sample = std::uniform_int_distribution<>( 0, num_random_weights_ + allow_zero_ - 1)(rand_); if (allow_zero_ && sample == num_random_weights_) { return Weight::Zero(); } else if (std::bernoulli_distribution(.5)(rand_)) { return Weight(generate_()); } else { return Plus(Weight(generate_()), Weight(generate_())); } }
private: mutable std::mt19937_64 rand_; const bool allow_zero_; const size_t num_random_weights_; const Generate generate_; };
} // namespace fst
#endif // FST_UNION_WEIGHT_H_
|