You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

638 lines
19 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Weights consisting of sets (of integral Labels) and
// associated semiring operation definitions using intersect
// and union.
#ifndef FST_SET_WEIGHT_H_
#define FST_SET_WEIGHT_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <ios>
#include <istream>
#include <list>
#include <optional>
#include <ostream>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/union-weight.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <string_view>
namespace fst {
inline constexpr int kSetEmpty = 0; // Label for the empty set.
inline constexpr int kSetUniv = -1; // Label for the universal set.
inline constexpr int kSetBad = -2; // Label for a non-set.
inline constexpr char kSetSeparator = '_'; // Label separator in sets.
// Determines whether to use (intersect, union) or (union, intersect)
// as (+, *) for the semiring. SET_INTERSECT_UNION_RESTRICTED is a
// restricted version of (intersect, union) that requires summed
// arguments to be equal (or an error is signalled), useful for
// algorithms that require a unique labelled path weight. SET_BOOLEAN
// treats all non-Zero() elements as equivalent (with Zero() ==
// UnivSet()), useful for algorithms that don't really depend on the
// detailed sets.
enum SetType {
SET_INTERSECT_UNION = 0,
SET_UNION_INTERSECT = 1,
SET_INTERSECT_UNION_RESTRICT = 2,
SET_BOOLEAN = 3
};
template <class>
class SetWeightIterator;
// Set semiring of integral labels.
template <typename L, SetType S = SET_INTERSECT_UNION>
class SetWeight {
public:
using Label = L;
using ReverseWeight = SetWeight<Label, S>;
using Iterator = SetWeightIterator<SetWeight>;
friend class SetWeightIterator<SetWeight>;
// Allow type-converting copy and move constructors private access.
template <typename L2, SetType S2>
friend class SetWeight;
SetWeight() = default;
// Input should be positive, sorted and unique.
template <typename Iterator>
SetWeight(const Iterator begin, const Iterator end) {
for (auto iter = begin; iter != end; ++iter) PushBack(*iter);
}
// Input should be positive. (Non-positive value has
// special internal meaning w.r.t. integral constants above.)
explicit SetWeight(Label label) { PushBack(label); }
template <SetType S2>
explicit SetWeight(const SetWeight<Label, S2> &w)
: first_(w.first_), rest_(w.rest_) {}
template <SetType S2>
explicit SetWeight(SetWeight<Label, S2> &&w)
: first_(w.first_), rest_(std::move(w.rest_)) {
w.Clear();
}
template <SetType S2>
SetWeight &operator=(const SetWeight<Label, S2> &w) {
first_ = w.first_;
rest_ = w.rest_;
return *this;
}
template <SetType S2>
SetWeight &operator=(SetWeight<Label, S2> &&w) {
first_ = w.first_;
rest_ = std::move(w.rest_);
w.Clear();
return *this;
}
static const SetWeight &Zero() {
return S == SET_UNION_INTERSECT ? EmptySet() : UnivSet();
}
static const SetWeight &One() {
return S == SET_UNION_INTERSECT ? UnivSet() : EmptySet();
}
static const SetWeight &NoWeight() {
static const auto *const no_weight = new SetWeight(Label(kSetBad));
return *no_weight;
}
static const std::string &Type() {
static const std::string *const type =
new std::string(S == SET_UNION_INTERSECT
? "union_intersect_set"
: (S == SET_INTERSECT_UNION
? "intersect_union_set"
: (S == SET_INTERSECT_UNION_RESTRICT
? "restricted_set_intersect_union"
: "boolean_set")));
return *type;
}
bool Member() const;
std::istream &Read(std::istream &strm);
std::ostream &Write(std::ostream &strm) const;
size_t Hash() const;
SetWeight Quantize(float delta = kDelta) const { return *this; }
ReverseWeight Reverse() const;
static constexpr uint64_t Properties() {
return kIdempotent | kLeftSemiring | kRightSemiring | kCommutative;
}
// These operations combined with the SetWeightIterator
// provide the access and mutation of the set internal elements.
// The empty set.
static const SetWeight &EmptySet() {
static const auto *const empty = new SetWeight(Label(kSetEmpty));
return *empty;
}
// The univeral set.
static const SetWeight &UnivSet() {
static const auto *const univ = new SetWeight(Label(kSetUniv));
return *univ;
}
// Clear existing SetWeight.
void Clear() {
first_ = kSetEmpty;
rest_.clear();
}
size_t Size() const { return first_ == kSetEmpty ? 0 : rest_.size() + 1; }
Label Back() {
if (rest_.empty()) {
return first_;
} else {
return rest_.back();
}
}
// Caller must add in sort order and be unique (or error signalled).
// Input should also be positive. Non-positive value (for the first
// push) has special internal meaning w.r.t. integral constants above.
void PushBack(Label label) {
if (first_ == kSetEmpty) {
first_ = label;
} else {
if (label <= Back() || label <= 0) {
FSTERROR() << "SetWeight: labels must be positive, added"
<< " in sort order and be unique.";
rest_.push_back(Label(kSetBad));
}
rest_.push_back(label);
}
}
private:
Label first_ = kSetEmpty; // First label in set (kSetEmpty if empty).
std::list<Label> rest_; // Remaining labels in set.
};
// Traverses set in forward direction.
template <class SetWeight_>
class SetWeightIterator {
public:
using Weight = SetWeight_;
using Label = typename Weight::Label;
explicit SetWeightIterator(const Weight &w)
: first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
bool Done() const {
if (init_) {
return first_ == kSetEmpty;
} else {
return iter_ == rest_.end();
}
}
const Label &Value() const { return init_ ? first_ : *iter_; }
void Next() {
if (init_) {
init_ = false;
} else {
++iter_;
}
}
void Reset() {
init_ = true;
iter_ = rest_.begin();
}
private:
const Label &first_;
const decltype(Weight::rest_) &rest_;
bool init_; // In the initialized state?
typename decltype(Weight::rest_)::const_iterator iter_;
};
// SetWeight member functions follow that require SetWeightIterator
template <typename Label, SetType S>
inline std::istream &SetWeight<Label, S>::Read(std::istream &strm) {
Clear();
int32_t size;
ReadType(strm, &size);
for (int32_t i = 0; i < size; ++i) {
Label label;
ReadType(strm, &label);
PushBack(label);
}
return strm;
}
template <typename Label, SetType S>
inline std::ostream &SetWeight<Label, S>::Write(std::ostream &strm) const {
const int32_t size = Size();
WriteType(strm, size);
for (Iterator iter(*this); !iter.Done(); iter.Next()) {
WriteType(strm, iter.Value());
}
return strm;
}
template <typename Label, SetType S>
inline bool SetWeight<Label, S>::Member() const {
Iterator iter(*this);
return iter.Value() != Label(kSetBad);
}
template <typename Label, SetType S>
inline typename SetWeight<Label, S>::ReverseWeight
SetWeight<Label, S>::Reverse() const {
return *this;
}
template <typename Label, SetType S>
inline size_t SetWeight<Label, S>::Hash() const {
using Weight = SetWeight<Label, S>;
if (S == SET_BOOLEAN) {
return *this == Weight::Zero() ? 0 : 1;
} else {
size_t h = 0;
for (Iterator iter(*this); !iter.Done(); iter.Next()) {
h ^= h << 1 ^ iter.Value();
}
return h;
}
}
// Default ==
template <typename Label, SetType S>
inline bool operator==(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2) {
if (w1.Size() != w2.Size()) return false;
using Iterator = typename SetWeight<Label, S>::Iterator;
Iterator iter1(w1);
Iterator iter2(w2);
for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
if (iter1.Value() != iter2.Value()) return false;
}
return true;
}
// Boolean ==
template <typename Label>
inline bool operator==(const SetWeight<Label, SET_BOOLEAN> &w1,
const SetWeight<Label, SET_BOOLEAN> &w2) {
// x == kSetEmpty if x \nin {kUnivSet, kSetBad}
if (!w1.Member() || !w2.Member()) return false;
using Iterator = typename SetWeight<Label, SET_BOOLEAN>::Iterator;
Iterator iter1(w1);
Iterator iter2(w2);
Label label1 = iter1.Done() ? kSetEmpty : iter1.Value();
Label label2 = iter2.Done() ? kSetEmpty : iter2.Value();
if (label1 == kSetUniv) return label2 == kSetUniv;
if (label2 == kSetUniv) return label1 == kSetUniv;
return true;
}
template <typename Label, SetType S>
inline bool operator!=(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2) {
return !(w1 == w2);
}
template <typename Label, SetType S>
inline bool ApproxEqual(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2, float delta = kDelta) {
return w1 == w2;
}
template <typename Label, SetType S>
inline std::ostream &operator<<(std::ostream &strm,
const SetWeight<Label, S> &weight) {
typename SetWeight<Label, S>::Iterator iter(weight);
if (iter.Done()) {
return strm << "EmptySet";
} else if (iter.Value() == Label(kSetUniv)) {
return strm << "UnivSet";
} else if (iter.Value() == Label(kSetBad)) {
return strm << "BadSet";
} else {
for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
if (i > 0) strm << kSetSeparator;
strm << iter.Value();
}
}
return strm;
}
template <typename Label, SetType S>
inline std::istream &operator>>(std::istream &strm,
SetWeight<Label, S> &weight) {
std::string str;
strm >> str;
using Weight = SetWeight<Label, S>;
if (str == "EmptySet") {
weight = Weight(Label(kSetEmpty));
} else if (str == "UnivSet") {
weight = Weight(Label(kSetUniv));
} else {
weight.Clear();
for (std::string_view sv : StrSplit(str, kSetSeparator)) {
auto maybe_label = ParseInt64(sv);
if (!maybe_label.has_value()) {
strm.clear(std::ios::badbit);
break;
}
weight.PushBack(*maybe_label);
}
}
return strm;
}
template <typename Label, SetType S>
inline SetWeight<Label, S> Union(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2) {
using Weight = SetWeight<Label, S>;
using Iterator = typename SetWeight<Label, S>::Iterator;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::EmptySet()) return w2;
if (w2 == Weight::EmptySet()) return w1;
if (w1 == Weight::UnivSet()) return w1;
if (w2 == Weight::UnivSet()) return w2;
Iterator it1(w1);
Iterator it2(w2);
Weight result;
while (!it1.Done() && !it2.Done()) {
const auto v1 = it1.Value();
const auto v2 = it2.Value();
if (v1 < v2) {
result.PushBack(v1);
it1.Next();
} else if (v1 > v2) {
result.PushBack(v2);
it2.Next();
} else {
result.PushBack(v1);
it1.Next();
it2.Next();
}
}
for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
for (; !it2.Done(); it2.Next()) result.PushBack(it2.Value());
return result;
}
template <typename Label, SetType S>
inline SetWeight<Label, S> Intersect(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2) {
using Weight = SetWeight<Label, S>;
using Iterator = typename SetWeight<Label, S>::Iterator;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::EmptySet()) return w1;
if (w2 == Weight::EmptySet()) return w2;
if (w1 == Weight::UnivSet()) return w2;
if (w2 == Weight::UnivSet()) return w1;
Iterator it1(w1);
Iterator it2(w2);
Weight result;
while (!it1.Done() && !it2.Done()) {
const auto v1 = it1.Value();
const auto v2 = it2.Value();
if (v1 < v2) {
it1.Next();
} else if (v1 > v2) {
it2.Next();
} else {
result.PushBack(v1);
it1.Next();
it2.Next();
}
}
return result;
}
template <typename Label, SetType S>
inline SetWeight<Label, S> Difference(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2) {
using Weight = SetWeight<Label, S>;
using Iterator = typename SetWeight<Label, S>::Iterator;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::EmptySet()) return w1;
if (w2 == Weight::EmptySet()) return w1;
if (w2 == Weight::UnivSet()) return Weight::EmptySet();
Iterator it1(w1);
Iterator it2(w2);
Weight result;
while (!it1.Done() && !it2.Done()) {
const auto v1 = it1.Value();
const auto v2 = it2.Value();
if (v1 < v2) {
result.PushBack(v1);
it1.Next();
} else if (v1 > v2) {
it2.Next();
} else {
it1.Next();
it2.Next();
}
}
for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
return result;
}
// Default: Plus = Intersect.
template <typename Label, SetType S>
inline SetWeight<Label, S> Plus(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2) {
return Intersect(w1, w2);
}
// Plus = Union.
template <typename Label>
inline SetWeight<Label, SET_UNION_INTERSECT> Plus(
const SetWeight<Label, SET_UNION_INTERSECT> &w1,
const SetWeight<Label, SET_UNION_INTERSECT> &w2) {
return Union(w1, w2);
}
// Plus = Set equality is required (for non-Zero() input). The
// restriction is useful (e.g., in determinization) to ensure the input
// has a unique labelled path weight.
template <typename Label>
inline SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> Plus(
const SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> &w1,
const SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> &w2) {
using Weight = SetWeight<Label, SET_INTERSECT_UNION_RESTRICT>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::Zero()) return w2;
if (w2 == Weight::Zero()) return w1;
if (w1 != w2) {
FSTERROR() << "SetWeight::Plus: Unequal arguments "
<< "(non-unique labelled path weights?)"
<< " w1 = " << w1 << " w2 = " << w2;
return Weight::NoWeight();
}
return w1;
}
// Plus = Or.
template <typename Label>
inline SetWeight<Label, SET_BOOLEAN> Plus(
const SetWeight<Label, SET_BOOLEAN> &w1,
const SetWeight<Label, SET_BOOLEAN> &w2) {
using Weight = SetWeight<Label, SET_BOOLEAN>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::One()) return w1;
if (w2 == Weight::One()) return w2;
return Weight::Zero();
}
// Default: Times = Union.
template <typename Label, SetType S>
inline SetWeight<Label, S> Times(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2) {
return Union(w1, w2);
}
// Times = Intersect.
template <typename Label>
inline SetWeight<Label, SET_UNION_INTERSECT> Times(
const SetWeight<Label, SET_UNION_INTERSECT> &w1,
const SetWeight<Label, SET_UNION_INTERSECT> &w2) {
return Intersect(w1, w2);
}
// Times = And.
template <typename Label>
inline SetWeight<Label, SET_BOOLEAN> Times(
const SetWeight<Label, SET_BOOLEAN> &w1,
const SetWeight<Label, SET_BOOLEAN> &w2) {
using Weight = SetWeight<Label, SET_BOOLEAN>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::One()) return w2;
return w1;
}
// Divide = Difference.
template <typename Label, SetType S>
inline SetWeight<Label, S> Divide(const SetWeight<Label, S> &w1,
const SetWeight<Label, S> &w2,
DivideType divide_type = DIVIDE_ANY) {
return Difference(w1, w2);
}
// Divide = dividend (or the universal set if the
// dividend == divisor).
template <typename Label>
inline SetWeight<Label, SET_UNION_INTERSECT> Divide(
const SetWeight<Label, SET_UNION_INTERSECT> &w1,
const SetWeight<Label, SET_UNION_INTERSECT> &w2,
DivideType divide_type = DIVIDE_ANY) {
using Weight = SetWeight<Label, SET_UNION_INTERSECT>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == w2) return Weight::UnivSet();
return w1;
}
// Divide = Or Not.
template <typename Label>
inline SetWeight<Label, SET_BOOLEAN> Divide(
const SetWeight<Label, SET_BOOLEAN> &w1,
const SetWeight<Label, SET_BOOLEAN> &w2,
DivideType divide_type = DIVIDE_ANY) {
using Weight = SetWeight<Label, SET_BOOLEAN>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::One()) return w1;
if (w2 == Weight::Zero()) return Weight::One();
return Weight::Zero();
}
// Converts between different set types.
template <typename Label, SetType S1, SetType S2>
struct WeightConvert<SetWeight<Label, S1>, SetWeight<Label, S2>> {
SetWeight<Label, S2> operator()(const SetWeight<Label, S1> &w1) const {
using Iterator = SetWeightIterator<SetWeight<Label, S1>>;
SetWeight<Label, S2> w2;
for (Iterator iter(w1); !iter.Done(); iter.Next())
w2.PushBack(iter.Value());
return w2;
}
};
// This function object generates SetWeights that are random integer sets
// from {1, ... , alphabet_size}^{0, max_set_length} U { Zero }. This is
// intended primarily for testing.
template <class Label, SetType S>
class WeightGenerate<SetWeight<Label, S>> {
public:
using Weight = SetWeight<Label, S>;
explicit WeightGenerate(uint64_t seed = std::random_device()(),
bool allow_zero = true,
size_t alphabet_size = kNumRandomWeights,
size_t max_set_length = kNumRandomWeights)
: allow_zero_(allow_zero),
alphabet_size_(alphabet_size),
max_set_length_(max_set_length) {}
Weight operator()() const {
const int n = std::uniform_int_distribution<>(
0, max_set_length_ + allow_zero_ - 1)(rand_);
if (allow_zero_ && n == max_set_length_) return Weight::Zero();
std::vector<Label> labels;
labels.reserve(n);
for (int i = 0; i < n; ++i) {
labels.push_back(
std::uniform_int_distribution<>(0, alphabet_size_)(rand_));
}
std::sort(labels.begin(), labels.end());
const auto labels_end = std::unique(labels.begin(), labels.end());
labels.resize(labels_end - labels.begin());
return Weight(labels.begin(), labels.end());
}
private:
mutable std::mt19937_64 rand_;
const bool allow_zero_;
const size_t alphabet_size_;
const size_t max_set_length_;
};
} // namespace fst
#endif // FST_SET_WEIGHT_H_