You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

834 lines
26 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// String weight set and associated semiring operation definitions.
#ifndef FST_STRING_WEIGHT_H_
#define FST_STRING_WEIGHT_H_
#include <cstddef>
#include <cstdint>
#include <ios>
#include <istream>
#include <list>
#include <optional>
#include <ostream>
#include <random>
#include <string>
#include <vector>
#include <fst/log.h>
#include <fst/product-weight.h>
#include <fst/union-weight.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <string_view>
namespace fst {
inline constexpr int kStringInfinity = -1; // Label for the infinite string.
inline constexpr int kStringBad = -2; // Label for a non-string.
inline constexpr char kStringSeparator = '_'; // Label separator in strings.
// Determines whether to use left or right string semiring. Includes a
// 'restricted' version that signals an error if proper prefixes/suffixes
// would otherwise be returned by Plus, useful with various
// algorithms that require functional transducer input with the
// string semirings.
enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1, STRING_RESTRICT = 2 };
constexpr StringType ReverseStringType(StringType s) {
return s == STRING_LEFT ? STRING_RIGHT
: (s == STRING_RIGHT ? STRING_LEFT : STRING_RESTRICT);
}
template <class>
class StringWeightIterator;
template <class>
class StringWeightReverseIterator;
// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon)
template <typename L, StringType S = STRING_LEFT>
class StringWeight {
public:
using Label = L;
using ReverseWeight = StringWeight<Label, ReverseStringType(S)>;
using Iterator = StringWeightIterator<StringWeight>;
using ReverseIterator = StringWeightReverseIterator<StringWeight>;
friend class StringWeightIterator<StringWeight>;
friend class StringWeightReverseIterator<StringWeight>;
StringWeight() = default;
template <typename Iterator>
StringWeight(const Iterator begin, const Iterator end) {
for (auto iter = begin; iter != end; ++iter) PushBack(*iter);
}
explicit StringWeight(Label label) { PushBack(label); }
static const StringWeight &Zero() {
static const auto *const zero = new StringWeight(Label(kStringInfinity));
return *zero;
}
static const StringWeight &One() {
static const auto *const one = new StringWeight();
return *one;
}
static const StringWeight &NoWeight() {
static const auto *const no_weight = new StringWeight(Label(kStringBad));
return *no_weight;
}
static const std::string &Type() {
static const std::string *const type = new std::string(
S == STRING_LEFT
? "left_string"
: (S == STRING_RIGHT ? "right_string" : "restricted_string"));
return *type;
}
bool Member() const;
std::istream &Read(std::istream &strm);
std::ostream &Write(std::ostream &strm) const;
size_t Hash() const;
StringWeight Quantize(float delta = kDelta) const { return *this; }
ReverseWeight Reverse() const;
static constexpr uint64_t Properties() {
return kIdempotent |
(S == STRING_LEFT ? kLeftSemiring
: (S == STRING_RIGHT
? kRightSemiring
: /* S == STRING_RESTRICT */ kLeftSemiring |
kRightSemiring));
}
// These operations combined with the StringWeightIterator and
// StringWeightReverseIterator provide the access and mutation of the string
// internal elements.
// Clear existing StringWeight.
void Clear() {
first_ = 0;
rest_.clear();
}
size_t Size() const { return first_ ? rest_.size() + 1 : 0; }
void PushFront(Label label) {
if (first_) rest_.push_front(first_);
first_ = label;
}
void PushBack(Label label) {
if (!first_) {
first_ = label;
} else {
rest_.push_back(label);
}
}
private:
Label first_ = 0; // First label in string (0 if empty).
std::list<Label> rest_; // Remaining labels in string.
};
// Traverses string in forward direction.
template <class StringWeight_>
class StringWeightIterator {
public:
using Weight = StringWeight_;
using Label = typename Weight::Label;
explicit StringWeightIterator(const Weight &w)
: first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
bool Done() const {
if (init_) {
return first_ == 0;
} else {
return iter_ == rest_.end();
}
}
const Label &Value() const { return init_ ? first_ : *iter_; }
void Next() {
if (init_) {
init_ = false;
} else {
++iter_;
}
}
void Reset() {
init_ = true;
iter_ = rest_.begin();
}
private:
const Label &first_;
const decltype(Weight::rest_) &rest_;
bool init_; // In the initialized state?
typename decltype(Weight::rest_)::const_iterator iter_;
};
// Traverses string in backward direction.
template <class StringWeight_>
class StringWeightReverseIterator {
public:
using Weight = StringWeight_;
using Label = typename Weight::Label;
explicit StringWeightReverseIterator(const Weight &w)
: first_(w.first_),
rest_(w.rest_),
fin_(first_ == Label()),
iter_(rest_.rbegin()) {}
bool Done() const { return fin_; }
const Label &Value() const { return iter_ == rest_.rend() ? first_ : *iter_; }
void Next() {
if (iter_ == rest_.rend()) {
fin_ = true;
} else {
++iter_;
}
}
void Reset() {
fin_ = false;
iter_ = rest_.rbegin();
}
private:
const Label &first_;
const decltype(Weight::rest_) &rest_;
bool fin_; // In the final state?
typename decltype(Weight::rest_)::const_reverse_iterator iter_;
};
// StringWeight member functions follow that require
// StringWeightIterator or StringWeightReverseIterator.
template <typename Label, StringType S>
inline std::istream &StringWeight<Label, S>::Read(std::istream &strm) {
Clear();
int32_t size;
ReadType(strm, &size);
for (int32_t i = 0; i < size; ++i) {
Label label;
ReadType(strm, &label);
PushBack(label);
}
return strm;
}
template <typename Label, StringType S>
inline std::ostream &StringWeight<Label, S>::Write(std::ostream &strm) const {
const int32_t size = Size();
WriteType(strm, size);
for (Iterator iter(*this); !iter.Done(); iter.Next()) {
WriteType(strm, iter.Value());
}
return strm;
}
template <typename Label, StringType S>
inline bool StringWeight<Label, S>::Member() const {
Iterator iter(*this);
return iter.Value() != Label(kStringBad);
}
template <typename Label, StringType S>
inline typename StringWeight<Label, S>::ReverseWeight
StringWeight<Label, S>::Reverse() const {
ReverseWeight rweight;
for (Iterator iter(*this); !iter.Done(); iter.Next()) {
rweight.PushFront(iter.Value());
}
return rweight;
}
template <typename Label, StringType S>
inline size_t StringWeight<Label, S>::Hash() const {
size_t h = 0;
for (Iterator iter(*this); !iter.Done(); iter.Next()) {
h ^= h << 1 ^ iter.Value();
}
return h;
}
template <typename Label, StringType S>
inline bool operator==(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2) {
if (w1.Size() != w2.Size()) return false;
using Iterator = typename StringWeight<Label, S>::Iterator;
Iterator iter1(w1);
Iterator iter2(w2);
for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
if (iter1.Value() != iter2.Value()) return false;
}
return true;
}
template <typename Label, StringType S>
inline bool operator!=(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2) {
return !(w1 == w2);
}
template <typename Label, StringType S>
inline bool ApproxEqual(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2,
float delta = kDelta) {
return w1 == w2;
}
template <typename Label, StringType S>
inline std::ostream &operator<<(std::ostream &strm,
const StringWeight<Label, S> &weight) {
typename StringWeight<Label, S>::Iterator iter(weight);
if (iter.Done()) {
return strm << "Epsilon";
} else if (iter.Value() == Label(kStringInfinity)) {
return strm << "Infinity";
} else if (iter.Value() == Label(kStringBad)) {
return strm << "BadString";
} else {
for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
if (i > 0) strm << kStringSeparator;
strm << iter.Value();
}
}
return strm;
}
template <typename Label, StringType S>
inline std::istream &operator>>(std::istream &strm,
StringWeight<Label, S> &weight) {
std::string str;
strm >> str;
using Weight = StringWeight<Label, S>;
if (str == "Infinity") {
weight = Weight::Zero();
} else if (str == "Epsilon") {
weight = Weight::One();
} else {
weight.Clear();
for (std::string_view sv : StrSplit(str, kStringSeparator)) {
auto maybe_label = ParseInt64(sv);
if (!maybe_label.has_value()) {
strm.clear(std::ios::badbit);
break;
}
weight.PushBack(*maybe_label);
}
}
return strm;
}
// Default is for the restricted string semiring. String equality is required
// (for non-Zero() input). The restriction is used (e.g., in determinization)
// to ensure the input is functional.
template <typename Label, StringType S>
inline StringWeight<Label, S> Plus(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2) {
using Weight = StringWeight<Label, S>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::Zero()) return w2;
if (w2 == Weight::Zero()) return w1;
if (w1 != w2) {
FSTERROR() << "StringWeight::Plus: Unequal arguments "
<< "(non-functional FST?)"
<< " w1 = " << w1 << " w2 = " << w2;
return Weight::NoWeight();
}
return w1;
}
// Longest common prefix for left string semiring.
template <typename Label>
inline StringWeight<Label, STRING_LEFT> Plus(
const StringWeight<Label, STRING_LEFT> &w1,
const StringWeight<Label, STRING_LEFT> &w2) {
using Weight = StringWeight<Label, STRING_LEFT>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::Zero()) return w2;
if (w2 == Weight::Zero()) return w1;
Weight sum;
typename Weight::Iterator iter1(w1);
typename Weight::Iterator iter2(w2);
for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
iter1.Next(), iter2.Next()) {
sum.PushBack(iter1.Value());
}
return sum;
}
// Longest common suffix for right string semiring.
template <typename Label>
inline StringWeight<Label, STRING_RIGHT> Plus(
const StringWeight<Label, STRING_RIGHT> &w1,
const StringWeight<Label, STRING_RIGHT> &w2) {
using Weight = StringWeight<Label, STRING_RIGHT>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::Zero()) return w2;
if (w2 == Weight::Zero()) return w1;
Weight sum;
typename Weight::ReverseIterator iter1(w1);
typename Weight::ReverseIterator iter2(w2);
for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
iter1.Next(), iter2.Next()) {
sum.PushFront(iter1.Value());
}
return sum;
}
template <typename Label, StringType S>
inline StringWeight<Label, S> Times(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2) {
using Weight = StringWeight<Label, S>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w1 == Weight::Zero() || w2 == Weight::Zero()) return Weight::Zero();
Weight product(w1);
for (typename Weight::Iterator iter(w2); !iter.Done(); iter.Next()) {
product.PushBack(iter.Value());
}
return product;
}
// Left division in a left string semiring.
template <typename Label, StringType S>
inline StringWeight<Label, S> DivideLeft(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2) {
using Weight = StringWeight<Label, S>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w2 == Weight::Zero()) {
return Weight(Label(kStringBad));
} else if (w1 == Weight::Zero()) {
return Weight::Zero();
}
Weight result;
typename Weight::Iterator iter(w1);
size_t i = 0;
for (; !iter.Done() && i < w2.Size(); iter.Next(), ++i) {
}
for (; !iter.Done(); iter.Next()) result.PushBack(iter.Value());
return result;
}
// Right division in a right string semiring.
template <typename Label, StringType S>
inline StringWeight<Label, S> DivideRight(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2) {
using Weight = StringWeight<Label, S>;
if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
if (w2 == Weight::Zero()) {
return Weight(Label(kStringBad));
} else if (w1 == Weight::Zero()) {
return Weight::Zero();
}
Weight result;
typename Weight::ReverseIterator iter(w1);
size_t i = 0;
for (; !iter.Done() && i < w2.Size(); iter.Next(), ++i) {
}
for (; !iter.Done(); iter.Next()) result.PushFront(iter.Value());
return result;
}
// Default is the restricted string semiring.
template <typename Label, StringType S>
inline StringWeight<Label, S> Divide(const StringWeight<Label, S> &w1,
const StringWeight<Label, S> &w2,
DivideType divide_type) {
using Weight = StringWeight<Label, S>;
if (divide_type == DIVIDE_LEFT) {
return DivideLeft(w1, w2);
} else if (divide_type == DIVIDE_RIGHT) {
return DivideRight(w1, w2);
} else {
FSTERROR() << "StringWeight::Divide: "
<< "Only explicit left or right division is defined "
<< "for the " << Weight::Type() << " semiring";
return Weight::NoWeight();
}
}
// Left division in the left string semiring.
template <typename Label>
inline StringWeight<Label, STRING_LEFT> Divide(
const StringWeight<Label, STRING_LEFT> &w1,
const StringWeight<Label, STRING_LEFT> &w2, DivideType divide_type) {
if (divide_type != DIVIDE_LEFT) {
FSTERROR() << "StringWeight::Divide: Only left division is defined "
<< "for the left string semiring";
return StringWeight<Label, STRING_LEFT>::NoWeight();
}
return DivideLeft(w1, w2);
}
// Right division in the right string semiring.
template <typename Label>
inline StringWeight<Label, STRING_RIGHT> Divide(
const StringWeight<Label, STRING_RIGHT> &w1,
const StringWeight<Label, STRING_RIGHT> &w2, DivideType divide_type) {
if (divide_type != DIVIDE_RIGHT) {
FSTERROR() << "StringWeight::Divide: Only right division is defined "
<< "for the right string semiring";
return StringWeight<Label, STRING_RIGHT>::NoWeight();
}
return DivideRight(w1, w2);
}
// This function object generates StringWeights that are random integer strings
// from {1, ... , alphabet_size)^{0, max_string_length} U { Zero }. This is
// intended primarily for testing.
template <class Label, StringType S>
class WeightGenerate<StringWeight<Label, S>> {
public:
using Weight = StringWeight<Label, S>;
explicit WeightGenerate(uint64_t seed = std::random_device()(),
bool allow_zero = true,
size_t alphabet_size = kNumRandomWeights,
size_t max_string_length = kNumRandomWeights)
: rand_(seed),
allow_zero_(allow_zero),
alphabet_size_(alphabet_size),
max_string_length_(max_string_length) {}
Weight operator()() const {
const int n = std::uniform_int_distribution<>(
0, max_string_length_ + allow_zero_)(rand_);
if (allow_zero_ && n == max_string_length_) return Weight::Zero();
std::vector<Label> labels;
labels.reserve(n);
for (int i = 0; i < n; ++i) {
labels.push_back(
std::uniform_int_distribution<>(1, alphabet_size_)(rand_));
}
return Weight(labels.begin(), labels.end());
}
private:
mutable std::mt19937_64 rand_;
const bool allow_zero_;
const size_t alphabet_size_;
const size_t max_string_length_;
};
// Determines whether to use left, right, or (general) gallic semiring. Includes
// a restricted version that signals an error if proper string prefixes or
// suffixes would otherwise be returned by string Plus. This is useful with
// algorithms that require functional transducer input. Also includes min
// version that changes the Plus to keep only the lowest W weight string.
enum GallicType {
GALLIC_LEFT = 0,
GALLIC_RIGHT = 1,
GALLIC_RESTRICT = 2,
GALLIC_MIN = 3,
GALLIC = 4
};
constexpr StringType GallicStringType(GallicType g) {
return g == GALLIC_LEFT
? STRING_LEFT
: (g == GALLIC_RIGHT ? STRING_RIGHT : STRING_RESTRICT);
}
constexpr GallicType ReverseGallicType(GallicType g) {
return g == GALLIC_LEFT
? GALLIC_RIGHT
: (g == GALLIC_RIGHT
? GALLIC_LEFT
: (g == GALLIC_RESTRICT
? GALLIC_RESTRICT
: (g == GALLIC_MIN ? GALLIC_MIN : GALLIC)));
}
// Product of string weight and an arbitraryy weight.
template <class Label, class W, GallicType G = GALLIC_LEFT>
struct GallicWeight
: public ProductWeight<StringWeight<Label, GallicStringType(G)>, W> {
using ReverseWeight =
GallicWeight<Label, typename W::ReverseWeight, ReverseGallicType(G)>;
using SW = StringWeight<Label, GallicStringType(G)>;
using ProductWeight<SW, W>::Properties;
GallicWeight() = default;
GallicWeight(SW w1, W w2) : ProductWeight<SW, W>(w1, w2) {}
explicit GallicWeight(std::string_view s, int *nread = nullptr)
: ProductWeight<SW, W>(s, nread) {}
explicit GallicWeight(const ProductWeight<SW, W> &w)
: ProductWeight<SW, W>(w) {}
static const GallicWeight &Zero() {
static const GallicWeight zero(ProductWeight<SW, W>::Zero());
return zero;
}
static const GallicWeight &One() {
static const GallicWeight one(ProductWeight<SW, W>::One());
return one;
}
static const GallicWeight &NoWeight() {
static const GallicWeight no_weight(ProductWeight<SW, W>::NoWeight());
return no_weight;
}
static const std::string &Type() {
static const std::string *const type = new std::string(
G == GALLIC_LEFT
? "left_gallic"
: (G == GALLIC_RIGHT
? "right_gallic"
: (G == GALLIC_RESTRICT
? "restricted_gallic"
: (G == GALLIC_MIN ? "min_gallic" : "gallic"))));
return *type;
}
GallicWeight Quantize(float delta = kDelta) const {
return GallicWeight(ProductWeight<SW, W>::Quantize(delta));
}
ReverseWeight Reverse() const {
return ReverseWeight(ProductWeight<SW, W>::Reverse());
}
};
// Default plus.
template <class Label, class W, GallicType G>
inline GallicWeight<Label, W, G> Plus(const GallicWeight<Label, W, G> &w,
const GallicWeight<Label, W, G> &v) {
return GallicWeight<Label, W, G>(Plus(w.Value1(), v.Value1()),
Plus(w.Value2(), v.Value2()));
}
// Min gallic plus.
template <class Label, class W>
inline GallicWeight<Label, W, GALLIC_MIN> Plus(
const GallicWeight<Label, W, GALLIC_MIN> &w1,
const GallicWeight<Label, W, GALLIC_MIN> &w2) {
static const NaturalLess<W> less;
return less(w1.Value2(), w2.Value2()) ? w1 : w2;
}
template <class Label, class W, GallicType G>
inline GallicWeight<Label, W, G> Times(const GallicWeight<Label, W, G> &w,
const GallicWeight<Label, W, G> &v) {
return GallicWeight<Label, W, G>(Times(w.Value1(), v.Value1()),
Times(w.Value2(), v.Value2()));
}
template <class Label, class W, GallicType G>
inline GallicWeight<Label, W, G> Divide(const GallicWeight<Label, W, G> &w,
const GallicWeight<Label, W, G> &v,
DivideType divide_type = DIVIDE_ANY) {
return GallicWeight<Label, W, G>(Divide(w.Value1(), v.Value1(), divide_type),
Divide(w.Value2(), v.Value2(), divide_type));
}
// This function object generates gallic weights by calling an underlying
// product weight generator. This is intended primarily for testing.
template <class Label, class W, GallicType G>
class WeightGenerate<GallicWeight<Label, W, G>>
: public WeightGenerate<
ProductWeight<StringWeight<Label, GallicStringType(G)>, W>> {
public:
using Weight = GallicWeight<Label, W, G>;
using Generate = WeightGenerate<
ProductWeight<StringWeight<Label, GallicStringType(G)>, W>>;
explicit WeightGenerate(uint64_t seed = std::random_device()(),
bool allow_zero = true)
: generate_(seed, allow_zero) {}
Weight operator()() const { return Weight(generate_()); }
private:
const Generate generate_;
};
// Union weight options for (general) GALLIC type.
template <class Label, class W>
struct GallicUnionWeightOptions {
using ReverseOptions = GallicUnionWeightOptions<Label, W>;
using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
using SW = StringWeight<Label, GallicStringType(GALLIC_RESTRICT)>;
using SI = StringWeightIterator<SW>;
// Military order.
struct Compare {
bool operator()(const GW &w1, const GW &w2) const {
const SW &s1 = w1.Value1();
const SW &s2 = w2.Value1();
if (s1.Size() < s2.Size()) return true;
if (s1.Size() > s2.Size()) return false;
SI iter1(s1);
SI iter2(s2);
while (!iter1.Done()) {
const auto l1 = iter1.Value();
const auto l2 = iter2.Value();
if (l1 < l2) return true;
if (l1 > l2) return false;
iter1.Next();
iter2.Next();
}
return false;
}
};
// Adds W weights when string part equal.
struct Merge {
GW operator()(const GW &w1, const GW &w2) const {
return GW(w1.Value1(), Plus(w1.Value2(), w2.Value2()));
}
};
};
// Specialization for the (general) GALLIC type.
template <class Label, class W>
struct GallicWeight<Label, W, GALLIC>
: public UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>,
GallicUnionWeightOptions<Label, W>> {
using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
using SW = StringWeight<Label, GallicStringType(GALLIC_RESTRICT)>;
using SI = StringWeightIterator<SW>;
using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
using UI = UnionWeightIterator<GW, GallicUnionWeightOptions<Label, W>>;
using ReverseWeight = GallicWeight<Label, W, GALLIC>;
using UW::Properties;
GallicWeight() = default;
// Copy constructor.
// NOLINTNEXTLINE(google-explicit-constructor)
GallicWeight(const UW &weight) : UW(weight) {}
// Singleton constructors: create a GALLIC weight containing a single
// GALLIC_RESTRICT weight. Takes as argument (1) a GALLIC_RESTRICT weight or
// (2) the two components of a GALLIC_RESTRICT weight.
explicit GallicWeight(const GW &weight) : UW(weight) {}
GallicWeight(SW w1, W w2) : UW(GW(w1, w2)) {}
explicit GallicWeight(std::string_view str, int *nread = nullptr)
: UW(str, nread) {}
static const GallicWeight<Label, W, GALLIC> &Zero() {
static const GallicWeight<Label, W, GALLIC> zero(UW::Zero());
return zero;
}
static const GallicWeight<Label, W, GALLIC> &One() {
static const GallicWeight<Label, W, GALLIC> one(UW::One());
return one;
}
static const GallicWeight<Label, W, GALLIC> &NoWeight() {
static const GallicWeight<Label, W, GALLIC> no_weight(UW::NoWeight());
return no_weight;
}
static const std::string &Type() {
static const std::string *const type = new std::string("gallic");
return *type;
}
GallicWeight<Label, W, GALLIC> Quantize(float delta = kDelta) const {
return UW::Quantize(delta);
}
ReverseWeight Reverse() const { return UW::Reverse(); }
};
// (General) gallic plus.
template <class Label, class W>
inline GallicWeight<Label, W, GALLIC> Plus(
const GallicWeight<Label, W, GALLIC> &w1,
const GallicWeight<Label, W, GALLIC> &w2) {
using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
return Plus(static_cast<UW>(w1), static_cast<UW>(w2));
}
// (General) gallic times.
template <class Label, class W>
inline GallicWeight<Label, W, GALLIC> Times(
const GallicWeight<Label, W, GALLIC> &w1,
const GallicWeight<Label, W, GALLIC> &w2) {
using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
return Times(static_cast<UW>(w1), static_cast<UW>(w2));
}
// (General) gallic divide.
template <class Label, class W>
inline GallicWeight<Label, W, GALLIC> Divide(
const GallicWeight<Label, W, GALLIC> &w1,
const GallicWeight<Label, W, GALLIC> &w2,
DivideType divide_type = DIVIDE_ANY) {
using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
return Divide(static_cast<UW>(w1), static_cast<UW>(w2), divide_type);
}
// This function object generates gallic weights by calling an underlying
// union weight generator. This is intended primarily for testing.
template <class Label, class W>
class WeightGenerate<GallicWeight<Label, W, GALLIC>>
: public WeightGenerate<UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>,
GallicUnionWeightOptions<Label, W>>> {
public:
using Weight = GallicWeight<Label, W, GALLIC>;
using Generate =
WeightGenerate<UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>,
GallicUnionWeightOptions<Label, W>>>;
explicit WeightGenerate(uint64_t seed = std::random_device()(),
bool allow_zero = true)
: generate_(seed, allow_zero) {}
Weight operator()() const { return Weight(generate_()); }
private:
const Generate generate_;
};
} // namespace fst
#endif // FST_STRING_WEIGHT_H_