|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// String weight set and associated semiring operation definitions.
#ifndef FST_STRING_WEIGHT_H_
#define FST_STRING_WEIGHT_H_
#include <cstddef>
#include <cstdint>
#include <ios>
#include <istream>
#include <list>
#include <optional>
#include <ostream>
#include <random>
#include <string>
#include <vector>
#include <fst/log.h>
#include <fst/product-weight.h>
#include <fst/union-weight.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <string_view>
namespace fst {
inline constexpr int kStringInfinity = -1; // Label for the infinite string.
inline constexpr int kStringBad = -2; // Label for a non-string.
inline constexpr char kStringSeparator = '_'; // Label separator in strings.
// Determines whether to use left or right string semiring. Includes a
// 'restricted' version that signals an error if proper prefixes/suffixes
// would otherwise be returned by Plus, useful with various
// algorithms that require functional transducer input with the
// string semirings.
enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1, STRING_RESTRICT = 2 };
constexpr StringType ReverseStringType(StringType s) { return s == STRING_LEFT ? STRING_RIGHT : (s == STRING_RIGHT ? STRING_LEFT : STRING_RESTRICT); }
template <class> class StringWeightIterator; template <class> class StringWeightReverseIterator;
// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon)
template <typename L, StringType S = STRING_LEFT> class StringWeight { public: using Label = L; using ReverseWeight = StringWeight<Label, ReverseStringType(S)>; using Iterator = StringWeightIterator<StringWeight>; using ReverseIterator = StringWeightReverseIterator<StringWeight>;
friend class StringWeightIterator<StringWeight>; friend class StringWeightReverseIterator<StringWeight>;
StringWeight() = default;
template <typename Iterator> StringWeight(const Iterator begin, const Iterator end) { for (auto iter = begin; iter != end; ++iter) PushBack(*iter); }
explicit StringWeight(Label label) { PushBack(label); }
static const StringWeight &Zero() { static const auto *const zero = new StringWeight(Label(kStringInfinity)); return *zero; }
static const StringWeight &One() { static const auto *const one = new StringWeight(); return *one; }
static const StringWeight &NoWeight() { static const auto *const no_weight = new StringWeight(Label(kStringBad)); return *no_weight; }
static const std::string &Type() { static const std::string *const type = new std::string( S == STRING_LEFT ? "left_string" : (S == STRING_RIGHT ? "right_string" : "restricted_string")); return *type; }
bool Member() const;
std::istream &Read(std::istream &strm);
std::ostream &Write(std::ostream &strm) const;
size_t Hash() const;
StringWeight Quantize(float delta = kDelta) const { return *this; }
ReverseWeight Reverse() const;
static constexpr uint64_t Properties() { return kIdempotent | (S == STRING_LEFT ? kLeftSemiring : (S == STRING_RIGHT ? kRightSemiring : /* S == STRING_RESTRICT */ kLeftSemiring | kRightSemiring)); }
// These operations combined with the StringWeightIterator and
// StringWeightReverseIterator provide the access and mutation of the string
// internal elements.
// Clear existing StringWeight.
void Clear() { first_ = 0; rest_.clear(); }
size_t Size() const { return first_ ? rest_.size() + 1 : 0; }
void PushFront(Label label) { if (first_) rest_.push_front(first_); first_ = label; }
void PushBack(Label label) { if (!first_) { first_ = label; } else { rest_.push_back(label); } }
private: Label first_ = 0; // First label in string (0 if empty).
std::list<Label> rest_; // Remaining labels in string.
};
// Traverses string in forward direction.
template <class StringWeight_> class StringWeightIterator { public: using Weight = StringWeight_; using Label = typename Weight::Label;
explicit StringWeightIterator(const Weight &w) : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
bool Done() const { if (init_) { return first_ == 0; } else { return iter_ == rest_.end(); } }
const Label &Value() const { return init_ ? first_ : *iter_; }
void Next() { if (init_) { init_ = false; } else { ++iter_; } }
void Reset() { init_ = true; iter_ = rest_.begin(); }
private: const Label &first_; const decltype(Weight::rest_) &rest_; bool init_; // In the initialized state?
typename decltype(Weight::rest_)::const_iterator iter_; };
// Traverses string in backward direction.
template <class StringWeight_> class StringWeightReverseIterator { public: using Weight = StringWeight_; using Label = typename Weight::Label;
explicit StringWeightReverseIterator(const Weight &w) : first_(w.first_), rest_(w.rest_), fin_(first_ == Label()), iter_(rest_.rbegin()) {}
bool Done() const { return fin_; }
const Label &Value() const { return iter_ == rest_.rend() ? first_ : *iter_; }
void Next() { if (iter_ == rest_.rend()) { fin_ = true; } else { ++iter_; } }
void Reset() { fin_ = false; iter_ = rest_.rbegin(); }
private: const Label &first_; const decltype(Weight::rest_) &rest_; bool fin_; // In the final state?
typename decltype(Weight::rest_)::const_reverse_iterator iter_; };
// StringWeight member functions follow that require
// StringWeightIterator or StringWeightReverseIterator.
template <typename Label, StringType S> inline std::istream &StringWeight<Label, S>::Read(std::istream &strm) { Clear(); int32_t size; ReadType(strm, &size); for (int32_t i = 0; i < size; ++i) { Label label; ReadType(strm, &label); PushBack(label); } return strm; }
template <typename Label, StringType S> inline std::ostream &StringWeight<Label, S>::Write(std::ostream &strm) const { const int32_t size = Size(); WriteType(strm, size); for (Iterator iter(*this); !iter.Done(); iter.Next()) { WriteType(strm, iter.Value()); } return strm; }
template <typename Label, StringType S> inline bool StringWeight<Label, S>::Member() const { Iterator iter(*this); return iter.Value() != Label(kStringBad); }
template <typename Label, StringType S> inline typename StringWeight<Label, S>::ReverseWeight StringWeight<Label, S>::Reverse() const { ReverseWeight rweight; for (Iterator iter(*this); !iter.Done(); iter.Next()) { rweight.PushFront(iter.Value()); } return rweight; }
template <typename Label, StringType S> inline size_t StringWeight<Label, S>::Hash() const { size_t h = 0; for (Iterator iter(*this); !iter.Done(); iter.Next()) { h ^= h << 1 ^ iter.Value(); } return h; }
template <typename Label, StringType S> inline bool operator==(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2) { if (w1.Size() != w2.Size()) return false; using Iterator = typename StringWeight<Label, S>::Iterator; Iterator iter1(w1); Iterator iter2(w2); for (; !iter1.Done(); iter1.Next(), iter2.Next()) { if (iter1.Value() != iter2.Value()) return false; } return true; }
template <typename Label, StringType S> inline bool operator!=(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2) { return !(w1 == w2); }
template <typename Label, StringType S> inline bool ApproxEqual(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2, float delta = kDelta) { return w1 == w2; }
template <typename Label, StringType S> inline std::ostream &operator<<(std::ostream &strm, const StringWeight<Label, S> &weight) { typename StringWeight<Label, S>::Iterator iter(weight); if (iter.Done()) { return strm << "Epsilon"; } else if (iter.Value() == Label(kStringInfinity)) { return strm << "Infinity"; } else if (iter.Value() == Label(kStringBad)) { return strm << "BadString"; } else { for (size_t i = 0; !iter.Done(); ++i, iter.Next()) { if (i > 0) strm << kStringSeparator; strm << iter.Value(); } } return strm; }
template <typename Label, StringType S> inline std::istream &operator>>(std::istream &strm, StringWeight<Label, S> &weight) { std::string str; strm >> str; using Weight = StringWeight<Label, S>; if (str == "Infinity") { weight = Weight::Zero(); } else if (str == "Epsilon") { weight = Weight::One(); } else { weight.Clear(); for (std::string_view sv : StrSplit(str, kStringSeparator)) { auto maybe_label = ParseInt64(sv); if (!maybe_label.has_value()) { strm.clear(std::ios::badbit); break; } weight.PushBack(*maybe_label); } } return strm; }
// Default is for the restricted string semiring. String equality is required
// (for non-Zero() input). The restriction is used (e.g., in determinization)
// to ensure the input is functional.
template <typename Label, StringType S> inline StringWeight<Label, S> Plus(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2) { using Weight = StringWeight<Label, S>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::Zero()) return w2; if (w2 == Weight::Zero()) return w1; if (w1 != w2) { FSTERROR() << "StringWeight::Plus: Unequal arguments " << "(non-functional FST?)" << " w1 = " << w1 << " w2 = " << w2; return Weight::NoWeight(); } return w1; }
// Longest common prefix for left string semiring.
template <typename Label> inline StringWeight<Label, STRING_LEFT> Plus( const StringWeight<Label, STRING_LEFT> &w1, const StringWeight<Label, STRING_LEFT> &w2) { using Weight = StringWeight<Label, STRING_LEFT>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::Zero()) return w2; if (w2 == Weight::Zero()) return w1; Weight sum; typename Weight::Iterator iter1(w1); typename Weight::Iterator iter2(w2); for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value(); iter1.Next(), iter2.Next()) { sum.PushBack(iter1.Value()); } return sum; }
// Longest common suffix for right string semiring.
template <typename Label> inline StringWeight<Label, STRING_RIGHT> Plus( const StringWeight<Label, STRING_RIGHT> &w1, const StringWeight<Label, STRING_RIGHT> &w2) { using Weight = StringWeight<Label, STRING_RIGHT>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::Zero()) return w2; if (w2 == Weight::Zero()) return w1; Weight sum; typename Weight::ReverseIterator iter1(w1); typename Weight::ReverseIterator iter2(w2); for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value(); iter1.Next(), iter2.Next()) { sum.PushFront(iter1.Value()); } return sum; }
template <typename Label, StringType S> inline StringWeight<Label, S> Times(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2) { using Weight = StringWeight<Label, S>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w1 == Weight::Zero() || w2 == Weight::Zero()) return Weight::Zero(); Weight product(w1); for (typename Weight::Iterator iter(w2); !iter.Done(); iter.Next()) { product.PushBack(iter.Value()); } return product; }
// Left division in a left string semiring.
template <typename Label, StringType S> inline StringWeight<Label, S> DivideLeft(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2) { using Weight = StringWeight<Label, S>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w2 == Weight::Zero()) { return Weight(Label(kStringBad)); } else if (w1 == Weight::Zero()) { return Weight::Zero(); } Weight result; typename Weight::Iterator iter(w1); size_t i = 0; for (; !iter.Done() && i < w2.Size(); iter.Next(), ++i) { } for (; !iter.Done(); iter.Next()) result.PushBack(iter.Value()); return result; }
// Right division in a right string semiring.
template <typename Label, StringType S> inline StringWeight<Label, S> DivideRight(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2) { using Weight = StringWeight<Label, S>; if (!w1.Member() || !w2.Member()) return Weight::NoWeight(); if (w2 == Weight::Zero()) { return Weight(Label(kStringBad)); } else if (w1 == Weight::Zero()) { return Weight::Zero(); } Weight result; typename Weight::ReverseIterator iter(w1); size_t i = 0; for (; !iter.Done() && i < w2.Size(); iter.Next(), ++i) { } for (; !iter.Done(); iter.Next()) result.PushFront(iter.Value()); return result; }
// Default is the restricted string semiring.
template <typename Label, StringType S> inline StringWeight<Label, S> Divide(const StringWeight<Label, S> &w1, const StringWeight<Label, S> &w2, DivideType divide_type) { using Weight = StringWeight<Label, S>; if (divide_type == DIVIDE_LEFT) { return DivideLeft(w1, w2); } else if (divide_type == DIVIDE_RIGHT) { return DivideRight(w1, w2); } else { FSTERROR() << "StringWeight::Divide: " << "Only explicit left or right division is defined " << "for the " << Weight::Type() << " semiring"; return Weight::NoWeight(); } }
// Left division in the left string semiring.
template <typename Label> inline StringWeight<Label, STRING_LEFT> Divide( const StringWeight<Label, STRING_LEFT> &w1, const StringWeight<Label, STRING_LEFT> &w2, DivideType divide_type) { if (divide_type != DIVIDE_LEFT) { FSTERROR() << "StringWeight::Divide: Only left division is defined " << "for the left string semiring"; return StringWeight<Label, STRING_LEFT>::NoWeight(); } return DivideLeft(w1, w2); }
// Right division in the right string semiring.
template <typename Label> inline StringWeight<Label, STRING_RIGHT> Divide( const StringWeight<Label, STRING_RIGHT> &w1, const StringWeight<Label, STRING_RIGHT> &w2, DivideType divide_type) { if (divide_type != DIVIDE_RIGHT) { FSTERROR() << "StringWeight::Divide: Only right division is defined " << "for the right string semiring"; return StringWeight<Label, STRING_RIGHT>::NoWeight(); } return DivideRight(w1, w2); }
// This function object generates StringWeights that are random integer strings
// from {1, ... , alphabet_size)^{0, max_string_length} U { Zero }. This is
// intended primarily for testing.
template <class Label, StringType S> class WeightGenerate<StringWeight<Label, S>> { public: using Weight = StringWeight<Label, S>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t alphabet_size = kNumRandomWeights, size_t max_string_length = kNumRandomWeights) : rand_(seed), allow_zero_(allow_zero), alphabet_size_(alphabet_size), max_string_length_(max_string_length) {}
Weight operator()() const { const int n = std::uniform_int_distribution<>( 0, max_string_length_ + allow_zero_)(rand_); if (allow_zero_ && n == max_string_length_) return Weight::Zero(); std::vector<Label> labels; labels.reserve(n); for (int i = 0; i < n; ++i) { labels.push_back( std::uniform_int_distribution<>(1, alphabet_size_)(rand_)); } return Weight(labels.begin(), labels.end()); }
private: mutable std::mt19937_64 rand_; const bool allow_zero_; const size_t alphabet_size_; const size_t max_string_length_; };
// Determines whether to use left, right, or (general) gallic semiring. Includes
// a restricted version that signals an error if proper string prefixes or
// suffixes would otherwise be returned by string Plus. This is useful with
// algorithms that require functional transducer input. Also includes min
// version that changes the Plus to keep only the lowest W weight string.
enum GallicType { GALLIC_LEFT = 0, GALLIC_RIGHT = 1, GALLIC_RESTRICT = 2, GALLIC_MIN = 3, GALLIC = 4 };
constexpr StringType GallicStringType(GallicType g) { return g == GALLIC_LEFT ? STRING_LEFT : (g == GALLIC_RIGHT ? STRING_RIGHT : STRING_RESTRICT); }
constexpr GallicType ReverseGallicType(GallicType g) { return g == GALLIC_LEFT ? GALLIC_RIGHT : (g == GALLIC_RIGHT ? GALLIC_LEFT : (g == GALLIC_RESTRICT ? GALLIC_RESTRICT : (g == GALLIC_MIN ? GALLIC_MIN : GALLIC))); }
// Product of string weight and an arbitraryy weight.
template <class Label, class W, GallicType G = GALLIC_LEFT> struct GallicWeight : public ProductWeight<StringWeight<Label, GallicStringType(G)>, W> { using ReverseWeight = GallicWeight<Label, typename W::ReverseWeight, ReverseGallicType(G)>; using SW = StringWeight<Label, GallicStringType(G)>;
using ProductWeight<SW, W>::Properties;
GallicWeight() = default;
GallicWeight(SW w1, W w2) : ProductWeight<SW, W>(w1, w2) {}
explicit GallicWeight(std::string_view s, int *nread = nullptr) : ProductWeight<SW, W>(s, nread) {}
explicit GallicWeight(const ProductWeight<SW, W> &w) : ProductWeight<SW, W>(w) {}
static const GallicWeight &Zero() { static const GallicWeight zero(ProductWeight<SW, W>::Zero()); return zero; }
static const GallicWeight &One() { static const GallicWeight one(ProductWeight<SW, W>::One()); return one; }
static const GallicWeight &NoWeight() { static const GallicWeight no_weight(ProductWeight<SW, W>::NoWeight()); return no_weight; }
static const std::string &Type() { static const std::string *const type = new std::string( G == GALLIC_LEFT ? "left_gallic" : (G == GALLIC_RIGHT ? "right_gallic" : (G == GALLIC_RESTRICT ? "restricted_gallic" : (G == GALLIC_MIN ? "min_gallic" : "gallic")))); return *type; }
GallicWeight Quantize(float delta = kDelta) const { return GallicWeight(ProductWeight<SW, W>::Quantize(delta)); }
ReverseWeight Reverse() const { return ReverseWeight(ProductWeight<SW, W>::Reverse()); } };
// Default plus.
template <class Label, class W, GallicType G> inline GallicWeight<Label, W, G> Plus(const GallicWeight<Label, W, G> &w, const GallicWeight<Label, W, G> &v) { return GallicWeight<Label, W, G>(Plus(w.Value1(), v.Value1()), Plus(w.Value2(), v.Value2())); }
// Min gallic plus.
template <class Label, class W> inline GallicWeight<Label, W, GALLIC_MIN> Plus( const GallicWeight<Label, W, GALLIC_MIN> &w1, const GallicWeight<Label, W, GALLIC_MIN> &w2) { static const NaturalLess<W> less; return less(w1.Value2(), w2.Value2()) ? w1 : w2; }
template <class Label, class W, GallicType G> inline GallicWeight<Label, W, G> Times(const GallicWeight<Label, W, G> &w, const GallicWeight<Label, W, G> &v) { return GallicWeight<Label, W, G>(Times(w.Value1(), v.Value1()), Times(w.Value2(), v.Value2())); }
template <class Label, class W, GallicType G> inline GallicWeight<Label, W, G> Divide(const GallicWeight<Label, W, G> &w, const GallicWeight<Label, W, G> &v, DivideType divide_type = DIVIDE_ANY) { return GallicWeight<Label, W, G>(Divide(w.Value1(), v.Value1(), divide_type), Divide(w.Value2(), v.Value2(), divide_type)); }
// This function object generates gallic weights by calling an underlying
// product weight generator. This is intended primarily for testing.
template <class Label, class W, GallicType G> class WeightGenerate<GallicWeight<Label, W, G>> : public WeightGenerate< ProductWeight<StringWeight<Label, GallicStringType(G)>, W>> { public: using Weight = GallicWeight<Label, W, G>; using Generate = WeightGenerate< ProductWeight<StringWeight<Label, GallicStringType(G)>, W>>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true) : generate_(seed, allow_zero) {}
Weight operator()() const { return Weight(generate_()); }
private: const Generate generate_; };
// Union weight options for (general) GALLIC type.
template <class Label, class W> struct GallicUnionWeightOptions { using ReverseOptions = GallicUnionWeightOptions<Label, W>; using GW = GallicWeight<Label, W, GALLIC_RESTRICT>; using SW = StringWeight<Label, GallicStringType(GALLIC_RESTRICT)>; using SI = StringWeightIterator<SW>;
// Military order.
struct Compare { bool operator()(const GW &w1, const GW &w2) const { const SW &s1 = w1.Value1(); const SW &s2 = w2.Value1(); if (s1.Size() < s2.Size()) return true; if (s1.Size() > s2.Size()) return false; SI iter1(s1); SI iter2(s2); while (!iter1.Done()) { const auto l1 = iter1.Value(); const auto l2 = iter2.Value(); if (l1 < l2) return true; if (l1 > l2) return false; iter1.Next(); iter2.Next(); } return false; } };
// Adds W weights when string part equal.
struct Merge { GW operator()(const GW &w1, const GW &w2) const { return GW(w1.Value1(), Plus(w1.Value2(), w2.Value2())); } }; };
// Specialization for the (general) GALLIC type.
template <class Label, class W> struct GallicWeight<Label, W, GALLIC> : public UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>, GallicUnionWeightOptions<Label, W>> { using GW = GallicWeight<Label, W, GALLIC_RESTRICT>; using SW = StringWeight<Label, GallicStringType(GALLIC_RESTRICT)>; using SI = StringWeightIterator<SW>; using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>; using UI = UnionWeightIterator<GW, GallicUnionWeightOptions<Label, W>>; using ReverseWeight = GallicWeight<Label, W, GALLIC>;
using UW::Properties;
GallicWeight() = default;
// Copy constructor.
// NOLINTNEXTLINE(google-explicit-constructor)
GallicWeight(const UW &weight) : UW(weight) {}
// Singleton constructors: create a GALLIC weight containing a single
// GALLIC_RESTRICT weight. Takes as argument (1) a GALLIC_RESTRICT weight or
// (2) the two components of a GALLIC_RESTRICT weight.
explicit GallicWeight(const GW &weight) : UW(weight) {}
GallicWeight(SW w1, W w2) : UW(GW(w1, w2)) {}
explicit GallicWeight(std::string_view str, int *nread = nullptr) : UW(str, nread) {}
static const GallicWeight<Label, W, GALLIC> &Zero() { static const GallicWeight<Label, W, GALLIC> zero(UW::Zero()); return zero; }
static const GallicWeight<Label, W, GALLIC> &One() { static const GallicWeight<Label, W, GALLIC> one(UW::One()); return one; }
static const GallicWeight<Label, W, GALLIC> &NoWeight() { static const GallicWeight<Label, W, GALLIC> no_weight(UW::NoWeight()); return no_weight; }
static const std::string &Type() { static const std::string *const type = new std::string("gallic"); return *type; }
GallicWeight<Label, W, GALLIC> Quantize(float delta = kDelta) const { return UW::Quantize(delta); }
ReverseWeight Reverse() const { return UW::Reverse(); } };
// (General) gallic plus.
template <class Label, class W> inline GallicWeight<Label, W, GALLIC> Plus( const GallicWeight<Label, W, GALLIC> &w1, const GallicWeight<Label, W, GALLIC> &w2) { using GW = GallicWeight<Label, W, GALLIC_RESTRICT>; using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>; return Plus(static_cast<UW>(w1), static_cast<UW>(w2)); }
// (General) gallic times.
template <class Label, class W> inline GallicWeight<Label, W, GALLIC> Times( const GallicWeight<Label, W, GALLIC> &w1, const GallicWeight<Label, W, GALLIC> &w2) { using GW = GallicWeight<Label, W, GALLIC_RESTRICT>; using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>; return Times(static_cast<UW>(w1), static_cast<UW>(w2)); }
// (General) gallic divide.
template <class Label, class W> inline GallicWeight<Label, W, GALLIC> Divide( const GallicWeight<Label, W, GALLIC> &w1, const GallicWeight<Label, W, GALLIC> &w2, DivideType divide_type = DIVIDE_ANY) { using GW = GallicWeight<Label, W, GALLIC_RESTRICT>; using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>; return Divide(static_cast<UW>(w1), static_cast<UW>(w2), divide_type); }
// This function object generates gallic weights by calling an underlying
// union weight generator. This is intended primarily for testing.
template <class Label, class W> class WeightGenerate<GallicWeight<Label, W, GALLIC>> : public WeightGenerate<UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>, GallicUnionWeightOptions<Label, W>>> { public: using Weight = GallicWeight<Label, W, GALLIC>; using Generate = WeightGenerate<UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>, GallicUnionWeightOptions<Label, W>>>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true) : generate_(seed, allow_zero) {}
Weight operator()() const { return Weight(generate_()); }
private: const Generate generate_; };
} // namespace fst
#endif // FST_STRING_WEIGHT_H_
|