|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Weights consisting of sets (of integral Labels) and
// associated semiring operation definitions using intersect
// and union.
#ifndef FST_SET_WEIGHT_H_
#define FST_SET_WEIGHT_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <ios>
#include <istream>
#include <list>
#include <optional>
#include <ostream>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/union-weight.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <string_view>
namespace fst {
inline constexpr int kSetEmpty = 0; // Label for the empty set.
inline constexpr int kSetUniv = -1; // Label for the universal set.
inline constexpr int kSetBad = -2; // Label for a non-set.
inline constexpr char kSetSeparator = '_'; // Label separator in sets.
// Determines whether to use (intersect, union) or (union, intersect)
// as (+, *) for the semiring. SET_INTERSECT_UNION_RESTRICTED is a
// restricted version of (intersect, union) that requires summed
// arguments to be equal (or an error is signalled), useful for
// algorithms that require a unique labelled path weight. SET_BOOLEAN
// treats all non-Zero() elements as equivalent (with Zero() ==
// UnivSet()), useful for algorithms that don't really depend on the
// detailed sets.
enum SetType { SET_INTERSECT_UNION = 0, SET_UNION_INTERSECT = 1, SET_INTERSECT_UNION_RESTRICT = 2, SET_BOOLEAN = 3 };
template <class> class SetWeightIterator;
// Set semiring of integral labels.
template <typename L, SetType S = SET_INTERSECT_UNION> class SetWeight { public: using Label = L; using ReverseWeight = SetWeight<Label, S>; using Iterator = SetWeightIterator<SetWeight>; friend class SetWeightIterator<SetWeight>; // Allow type-converting copy and move constructors private access.
template <typename L2, SetType S2> friend class SetWeight;
SetWeight() = default;
// Input should be positive, sorted and unique.
template <typename Iterator> SetWeight(const Iterator begin, const Iterator end) { for (auto iter = begin; iter != end; ++iter) PushBack(*iter); }
// Input should be positive. (Non-positive value has
// special internal meaning w.r.t. integral constants above.)
explicit SetWeight(Label label) { PushBack(label); }
template <SetType S2> explicit SetWeight(const SetWeight<Label, S2> &w) : first_(w.first_), rest_(w.rest_) {}
template <SetType S2> explicit SetWeight(SetWeight<Label, S2> &&w) : first_(w.first_), rest_(std::move(w.rest_)) { w.Clear(); }
template <SetType S2> SetWeight &operator=(const SetWeight<Label, S2> &w) { first_ = w.first_; rest_ = w.rest_; return *this; }
template <SetType S2> SetWeight &operator=(SetWeight<Label, S2> &&w) { first_ = w.first_; rest_ = std::move(w.rest_); w.Clear(); return *this; }
static const SetWeight &Zero() { return S == SET_UNION_INTERSECT ? EmptySet() : UnivSet(); }
static const SetWeight &One() { return S == SET_UNION_INTERSECT ? UnivSet() : EmptySet(); }
static const SetWeight &NoWeight() { static const auto *const no_weight = new SetWeight(Label(kSetBad)); return *no_weight; }
static const std::string &Type() { static const std::string *const type = new std::string(S == SET_UNION_INTERSECT ? "union_intersect_set" : (S == SET_INTERSECT_UNION ? "intersect_union_set" : (S == SET_INTERSECT_UNION_RESTRICT ? "restricted_set_intersect_union" : "boolean_set"))); return *type; }
bool Member() const;
std::istream &Read(std::istream &strm);
std::ostream &Write(std::ostream &strm) const;
size_t Hash() const;
SetWeight Quantize(float delta = kDelta) const { return *this; }
ReverseWeight Reverse() const;
static constexpr uint64_t Properties() { return kIdempotent | kLeftSemiring | kRightSemiring | kCommutative; }
// These operations combined with the SetWeightIterator
// provide the access and mutation of the set internal elements.
// The empty set.
static const SetWeight &EmptySet() { static const auto *const empty = new SetWeight(Label(kSetEmpty)); return *empty; }
// The univeral set.
static const SetWeight &UnivSet() { static const auto *const univ = new SetWeight(Label(kSetUniv)); return *univ; }
// Clear existing SetWeight.
void Clear() { first_ = kSetEmpty; rest_.clear(); }
size_t Size() const { return first_ == kSetEmpty ? 0 : rest_.size() + 1; }
Label Back() { if (rest_.empty()) { return first_; } else { return rest_.back(); } }
// Caller must add in sort order and be unique (or error signalled).
// Input should also be positive. Non-positive value (for the first
// push) has special internal meaning w.r.t. integral constants above.
void PushBack(Label label) { if (first_ == kSetEmpty) { first_ = label; } else { if (label <= Back() || label <= 0) { FSTERROR() << "SetWeight: labels must be positive, added" << " in sort order and be unique."; rest_.push_back(Label(kSetBad)); } rest_.push_back(label); } }
private: Label first_ = kSetEmpty; // First label in set (kSetEmpty if empty).
std::list<Label> rest_; // Remaining labels in set.
};
// Traverses set in forward direction.
template <class SetWeight_> class SetWeightIterator { public: using Weight = SetWeight_; using Label = typename Weight::Label;
explicit SetWeightIterator(const Weight &w) : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
bool Done() const { if (init_) { return first_ == kSetEmpty; } else { return iter_ == rest_.end(); } }
const Label &Value() const { return init_ ? first_ : *iter_; }
void Next() { if (init_) { init_ = false; } else { ++iter_; } }
void Reset() { init_ = true; iter_ = rest_.begin(); }
private: const Label &first_; const decltype(Weight::rest_) &rest_; bool init_; // In the initialized state?
typename decltype(Weight::rest_)::const_iterator iter_; };
// SetWeight member functions follow that require SetWeightIterator
template <typename Label, SetType S> inline std::istream &SetWeight<Label, S>::Read(std::istream &strm) { Clear(); int32_t size; ReadType(strm, &size); for (int32_t i = 0; i < size; ++i) { Label label; ReadType(strm, &label); PushBack(label); } return strm; }
template <typename Label, SetType S> inline std::ostream &SetWeight<Label, S>::Write(std::ostream &strm) const { const int32_t size = Size(); WriteType(strm, size); for (Iterator iter(*this); !iter.Done(); iter.Next()) { WriteType(strm, iter.Value()); } return strm; }
template <typename Label, SetType S> inline bool SetWeight<Label, S>::Member() const { Iterator iter(*this); return iter.Value() != Label(kSetBad); }
template <typename Label, SetType S> inline typename SetWeight<Label, S>::ReverseWeight SetWeight<Label, S>::Reverse() const { return *this; }
template <typename Label, SetType S> inline size_t SetWeight<Label, S>::Hash() const { using Weight = SetWeight<Label, S>; if (S == SET_BOOLEAN) { return *this == Weight::Zero() ? 0 : 1; } else { size_t h = 0; for (Iterator iter(*this); !iter.Done(); iter.Next()) { h ^= h << 1 ^ iter.Value(); } return h; } }
// Default ==
template <typename Label, SetType S> inline bool operator==(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2) { if (w1.Size() != w2.Size()) return false; using Iterator = typename SetWeight<Label, S>::Iterator; Iterator iter1(w1); Iterator iter2(w2); for (; !iter1.Done(); iter1.Next(), iter2.Next()) { if (iter1.Value() != iter2.Value()) return false; } return true; }
// Boolean ==
template <typename Label> inline bool operator==(const SetWeight<Label, SET_BOOLEAN> &w1, const SetWeight<Label, SET_BOOLEAN> &w2) { // x == kSetEmpty if x \nin {kUnivSet, kSetBad}
if (!w1.Member() || !w2.Member()) return false; using Iterator = typename SetWeight<Label, SET_BOOLEAN>::Iterator; Iterator iter1(w1); Iterator iter2(w2); Label label1 = iter1.Done() ? kSetEmpty : iter1.Value(); Label label2 = iter2.Done() ? kSetEmpty : iter2.Value(); if (label1 == kSetUniv) return label2 == kSetUniv; if (label2 == kSetUniv) return label1 == kSetUniv; return true; }
template <typename Label, SetType S> inline bool operator!=(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2) { return !(w1 == w2); }
template <typename Label, SetType S> inline bool ApproxEqual(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2, float delta = kDelta) { return w1 == w2; }
template <typename Label, SetType S> inline std::ostream &operator<<(std::ostream &strm, const SetWeight<Label, S> &weight) { typename SetWeight<Label, S>::Iterator iter(weight); if (iter.Done()) { return strm << "EmptySet"; } else if (iter.Value() == Label(kSetUniv)) { return strm << "UnivSet"; } else if (iter.Value() == Label(kSetBad)) { return strm << "BadSet"; } else { for (size_t i = 0; !iter.Done(); ++i, iter.Next()) { if (i > 0) strm << kSetSeparator; strm << iter.Value(); } } return strm; }
template <typename Label, SetType S> inline std::istream &operator>>(std::istream &strm, SetWeight<Label, S> &weight) { std::string str; strm >> str; using Weight = SetWeight<Label, S>; if (str == "EmptySet") { weight = Weight(Label(kSetEmpty)); } else if (str == "UnivSet") { weight = Weight(Label(kSetUniv)); } else { weight.Clear(); for (std::string_view sv : StrSplit(str, kSetSeparator)) { auto maybe_label = ParseInt64(sv); if (!maybe_label.has_value()) { strm.clear(std::ios::badbit); break; } weight.PushBack(*maybe_label); } } return strm; }
template <typename Label, SetType S> inline SetWeight<Label, S> Union(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2) { using Weight = SetWeight<Label, S>; using Iterator = typename SetWeight<Label, S>::Iterator; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::EmptySet()) return w2; if (w2 == Weight::EmptySet()) return w1; if (w1 == Weight::UnivSet()) return w1; if (w2 == Weight::UnivSet()) return w2; Iterator it1(w1); Iterator it2(w2); Weight result; while (!it1.Done() && !it2.Done()) { const auto v1 = it1.Value(); const auto v2 = it2.Value(); if (v1 < v2) { result.PushBack(v1); it1.Next(); } else if (v1 > v2) { result.PushBack(v2); it2.Next(); } else { result.PushBack(v1); it1.Next(); it2.Next(); } } for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value()); for (; !it2.Done(); it2.Next()) result.PushBack(it2.Value()); return result; }
template <typename Label, SetType S> inline SetWeight<Label, S> Intersect(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2) { using Weight = SetWeight<Label, S>; using Iterator = typename SetWeight<Label, S>::Iterator; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::EmptySet()) return w1; if (w2 == Weight::EmptySet()) return w2; if (w1 == Weight::UnivSet()) return w2; if (w2 == Weight::UnivSet()) return w1; Iterator it1(w1); Iterator it2(w2); Weight result; while (!it1.Done() && !it2.Done()) { const auto v1 = it1.Value(); const auto v2 = it2.Value(); if (v1 < v2) { it1.Next(); } else if (v1 > v2) { it2.Next(); } else { result.PushBack(v1); it1.Next(); it2.Next(); } } return result; }
template <typename Label, SetType S> inline SetWeight<Label, S> Difference(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2) { using Weight = SetWeight<Label, S>; using Iterator = typename SetWeight<Label, S>::Iterator; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::EmptySet()) return w1; if (w2 == Weight::EmptySet()) return w1; if (w2 == Weight::UnivSet()) return Weight::EmptySet(); Iterator it1(w1); Iterator it2(w2); Weight result; while (!it1.Done() && !it2.Done()) { const auto v1 = it1.Value(); const auto v2 = it2.Value(); if (v1 < v2) { result.PushBack(v1); it1.Next(); } else if (v1 > v2) { it2.Next(); } else { it1.Next(); it2.Next(); } } for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value()); return result; }
// Default: Plus = Intersect.
template <typename Label, SetType S> inline SetWeight<Label, S> Plus(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2) { return Intersect(w1, w2); }
// Plus = Union.
template <typename Label> inline SetWeight<Label, SET_UNION_INTERSECT> Plus( const SetWeight<Label, SET_UNION_INTERSECT> &w1, const SetWeight<Label, SET_UNION_INTERSECT> &w2) { return Union(w1, w2); }
// Plus = Set equality is required (for non-Zero() input). The
// restriction is useful (e.g., in determinization) to ensure the input
// has a unique labelled path weight.
template <typename Label> inline SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> Plus( const SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> &w1, const SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> &w2) { using Weight = SetWeight<Label, SET_INTERSECT_UNION_RESTRICT>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::Zero()) return w2; if (w2 == Weight::Zero()) return w1; if (w1 != w2) { FSTERROR() << "SetWeight::Plus: Unequal arguments " << "(non-unique labelled path weights?)" << " w1 = " << w1 << " w2 = " << w2; return Weight::NoWeight(); } return w1; }
// Plus = Or.
template <typename Label> inline SetWeight<Label, SET_BOOLEAN> Plus( const SetWeight<Label, SET_BOOLEAN> &w1, const SetWeight<Label, SET_BOOLEAN> &w2) { using Weight = SetWeight<Label, SET_BOOLEAN>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::One()) return w1; if (w2 == Weight::One()) return w2; return Weight::Zero(); }
// Default: Times = Union.
template <typename Label, SetType S> inline SetWeight<Label, S> Times(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2) { return Union(w1, w2); }
// Times = Intersect.
template <typename Label> inline SetWeight<Label, SET_UNION_INTERSECT> Times( const SetWeight<Label, SET_UNION_INTERSECT> &w1, const SetWeight<Label, SET_UNION_INTERSECT> &w2) { return Intersect(w1, w2); }
// Times = And.
template <typename Label> inline SetWeight<Label, SET_BOOLEAN> Times( const SetWeight<Label, SET_BOOLEAN> &w1, const SetWeight<Label, SET_BOOLEAN> &w2) { using Weight = SetWeight<Label, SET_BOOLEAN>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::One()) return w2; return w1; }
// Divide = Difference.
template <typename Label, SetType S> inline SetWeight<Label, S> Divide(const SetWeight<Label, S> &w1, const SetWeight<Label, S> &w2, DivideType divide_type = DIVIDE_ANY) { return Difference(w1, w2); }
// Divide = dividend (or the universal set if the
// dividend == divisor).
template <typename Label> inline SetWeight<Label, SET_UNION_INTERSECT> Divide( const SetWeight<Label, SET_UNION_INTERSECT> &w1, const SetWeight<Label, SET_UNION_INTERSECT> &w2, DivideType divide_type = DIVIDE_ANY) { using Weight = SetWeight<Label, SET_UNION_INTERSECT>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == w2) return Weight::UnivSet(); return w1; }
// Divide = Or Not.
template <typename Label> inline SetWeight<Label, SET_BOOLEAN> Divide( const SetWeight<Label, SET_BOOLEAN> &w1, const SetWeight<Label, SET_BOOLEAN> &w2, DivideType divide_type = DIVIDE_ANY) { using Weight = SetWeight<Label, SET_BOOLEAN>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::One()) return w1; if (w2 == Weight::Zero()) return Weight::One(); return Weight::Zero(); }
// Converts between different set types.
template <typename Label, SetType S1, SetType S2> struct WeightConvert<SetWeight<Label, S1>, SetWeight<Label, S2>> { SetWeight<Label, S2> operator()(const SetWeight<Label, S1> &w1) const { using Iterator = SetWeightIterator<SetWeight<Label, S1>>; SetWeight<Label, S2> w2; for (Iterator iter(w1); !iter.Done(); iter.Next()) w2.PushBack(iter.Value()); return w2; } };
// This function object generates SetWeights that are random integer sets
// from {1, ... , alphabet_size}^{0, max_set_length} U { Zero }. This is
// intended primarily for testing.
template <class Label, SetType S> class WeightGenerate<SetWeight<Label, S>> { public: using Weight = SetWeight<Label, S>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t alphabet_size = kNumRandomWeights, size_t max_set_length = kNumRandomWeights) : allow_zero_(allow_zero), alphabet_size_(alphabet_size), max_set_length_(max_set_length) {}
Weight operator()() const { const int n = std::uniform_int_distribution<>( 0, max_set_length_ + allow_zero_ - 1)(rand_); if (allow_zero_ && n == max_set_length_) return Weight::Zero(); std::vector<Label> labels; labels.reserve(n); for (int i = 0; i < n; ++i) { labels.push_back( std::uniform_int_distribution<>(0, alphabet_size_)(rand_)); } std::sort(labels.begin(), labels.end()); const auto labels_end = std::unique(labels.begin(), labels.end()); labels.resize(labels_end - labels.begin()); return Weight(labels.begin(), labels.end()); }
private: mutable std::mt19937_64 rand_; const bool allow_zero_; const size_t alphabet_size_; const size_t max_set_length_; };
} // namespace fst
#endif // FST_SET_WEIGHT_H_
|