// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Float weight set and associated semiring operation definitions. #ifndef FST_FLOAT_WEIGHT_H_ #define FST_FLOAT_WEIGHT_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { namespace internal { // TODO(wolfsonkin): Replace with `std::isnan` if and when that ends up // constexpr. For context, see // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0533r6.pdf. template inline constexpr bool IsNan(T value) { return value != value; } } // namespace internal // Numeric limits class. template class FloatLimits { public: static constexpr T PosInfinity() { return std::numeric_limits::infinity(); } static constexpr T NegInfinity() { return -PosInfinity(); } static constexpr T NumberBad() { return std::numeric_limits::quiet_NaN(); } }; // Weight class to be templated on floating-points types. template class FloatWeightTpl { public: using ValueType = T; FloatWeightTpl() noexcept = default; constexpr FloatWeightTpl(T f) : value_(f) {} // NOLINT std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); } std::ostream &Write(std::ostream &strm) const { return WriteType(strm, value_); } size_t Hash() const { size_t hash = 0; // Avoid using union, which would be undefined behavior. // Use memcpy, similar to bit_cast, but sizes may be different. // This should be optimized into a single move instruction by // any reasonable compiler. std::memcpy(&hash, &value_, std::min(sizeof(hash), sizeof(value_))); return hash; } constexpr const T &Value() const { return value_; } protected: void SetValue(const T &f) { value_ = f; } static constexpr std::string_view GetPrecisionString() { return sizeof(T) == 4 ? "" : sizeof(T) == 1 ? "8" : sizeof(T) == 2 ? "16" : sizeof(T) == 8 ? "64" : "unknown"; } private: T value_; }; // Single-precision float weight. using FloatWeight = FloatWeightTpl; template constexpr bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { #if (defined(__i386__) || defined(__x86_64__)) && !defined(__SSE2_MATH__) // With i387 instructions, excess precision on a weight in an 80-bit // register may cause it to compare unequal to that same weight when // stored to memory. This breaks =='s reflexivity, in turn breaking // NaturalLess. #error "Please compile with -msse -mfpmath=sse, or equivalent." #endif return w1.Value() == w2.Value(); } // These seemingly unnecessary overloads are actually needed to make // comparisons like FloatWeightTpl == float compile. If only the // templated version exists, the FloatWeightTpl(float) conversion // won't be found. constexpr bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator==(w1, w2); } constexpr bool operator==(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator==(w1, w2); } template constexpr bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return !(w1 == w2); } constexpr bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator!=(w1, w2); } constexpr bool operator!=(const FloatWeightTpl &w1, const FloatWeightTpl &w2) { return operator!=(w1, w2); } template constexpr bool FloatApproxEqual(T w1, T w2, float delta = kDelta) { return w1 <= w2 + delta && w2 <= w1 + delta; } template constexpr bool ApproxEqual(const FloatWeightTpl &w1, const FloatWeightTpl &w2, float delta = kDelta) { return FloatApproxEqual(w1.Value(), w2.Value(), delta); } template inline std::ostream &operator<<(std::ostream &strm, const FloatWeightTpl &w) { if (w.Value() == FloatLimits::PosInfinity()) { return strm << "Infinity"; } else if (w.Value() == FloatLimits::NegInfinity()) { return strm << "-Infinity"; } else if (internal::IsNan(w.Value())) { return strm << "BadNumber"; } else { return strm << w.Value(); } } template inline std::istream &operator>>(std::istream &strm, FloatWeightTpl &w) { std::string s; strm >> s; if (s == "Infinity") { w = FloatWeightTpl(FloatLimits::PosInfinity()); } else if (s == "-Infinity") { w = FloatWeightTpl(FloatLimits::NegInfinity()); } else { char *p; T f = strtod(s.c_str(), &p); if (p < s.c_str() + s.size()) { strm.clear(std::ios::badbit); } else { w = FloatWeightTpl(f); } } return strm; } // Tropical semiring: (min, +, inf, 0). template class TropicalWeightTpl : public FloatWeightTpl { public: using typename FloatWeightTpl::ValueType; using FloatWeightTpl::Value; using ReverseWeight = TropicalWeightTpl; using Limits = FloatLimits; TropicalWeightTpl() noexcept : FloatWeightTpl() {} constexpr TropicalWeightTpl(T f) : FloatWeightTpl(f) {} static constexpr TropicalWeightTpl Zero() { return Limits::PosInfinity(); } static constexpr TropicalWeightTpl One() { return 0; } static constexpr TropicalWeightTpl NoWeight() { return Limits::NumberBad(); } static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("tropical", FloatWeightTpl::GetPrecisionString())); return *type; } constexpr bool Member() const { // All floating point values except for NaNs and negative infinity are valid // tropical weights. // // Testing membership of a given value can be done by simply checking that // it is strictly greater than negative infinity, which fails for negative // infinity itself but also for NaNs. This can usually be accomplished in a // single instruction (such as *UCOMI* on x86) without branching logic. // // An additional wrinkle involves constexpr correctness of floating point // comparisons against NaN. GCC is uneven when it comes to which expressions // it considers compile-time constants. In particular, current versions of // GCC do not always consider (nan < inf) to be a constant expression, but // do consider (inf < nan) to be a constant expression. (See // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88173 and // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88683 for details.) In order // to allow Member() to be a constexpr function accepted by GCC, we write // the comparison here as (-inf < v). return Limits::NegInfinity() < Value(); } TropicalWeightTpl Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return TropicalWeightTpl(std::floor(Value() / delta + 0.5F) * delta); } } constexpr TropicalWeightTpl Reverse() const { return *this; } static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent; } }; // Single precision tropical weight. using TropicalWeight = TropicalWeightTpl; template constexpr TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return (!w1.Member() || !w2.Member()) ? TropicalWeightTpl::NoWeight() : w1.Value() < w2.Value() ? w1 : w2; } // See comment at operator==(FloatWeightTpl, FloatWeightTpl) // for why these overloads are present. constexpr TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Plus(w1, w2); } constexpr TropicalWeightTpl Plus(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Plus(w1, w2); } template constexpr TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { // The following is safe in the context of the Tropical (and Log) semiring // for all IEEE floating point values, including infinities and NaNs, // because: // // * If one or both of the floating point Values is NaN and hence not a // Member, the result of addition below is NaN, so the result is not a // Member. This supersedes all other cases, so we only consider non-NaN // values next. // // * If both Values are finite, there is no issue. // // * If one of the Values is infinite, or if both are infinities with the // same sign, the result of floating point addition is the same infinity, // so there is no issue. // // * If both of the Values are infinities with opposite signs, the result of // adding IEEE floating point -inf + inf is NaN and hence not a Member. But // since -inf was not a Member to begin with, returning a non-Member result // is fine as well. return TropicalWeightTpl(w1.Value() + w2.Value()); } constexpr TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Times(w1, w2); } constexpr TropicalWeightTpl Times(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2) { return Times(w1, w2); } template constexpr TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { // The following is safe in the context of the Tropical (and Log) semiring // for all IEEE floating point values, including infinities and NaNs, // because: // // * If one or both of the floating point Values is NaN and hence not a // Member, the result of subtraction below is NaN, so the result is not a // Member. This supersedes all other cases, so we only consider non-NaN // values next. // // * If both Values are finite, there is no issue. // // * If w2.Value() is -inf (and hence w2 is not a Member), the result of ?: // below is NoWeight, which is not a Member. // // Whereas in IEEE floating point semantics 0/inf == 0, this does not carry // over to this semiring (since TropicalWeight(-inf) would be the analogue // of floating point inf) and instead Divide(Zero(), TropicalWeight(-inf)) // is NoWeight(). // // * If w2.Value() is inf (and hence w2 is Zero), the resulting floating // point value is either NaN (if w1 is Zero or if w1.Value() is NaN) and // hence not a Member, or it is -inf and hence not a Member; either way, // division by Zero results in a non-Member result. using Weight = TropicalWeightTpl; return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight(); } constexpr TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } constexpr TropicalWeightTpl Divide(const TropicalWeightTpl &w1, const TropicalWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // Power(w, n) calculates the n-th power of w with respect to semiring Times. // // In the case of the Tropical (and Log) semiring, the exponent n is not // restricted to be an integer. It can be a floating point value, for example. // // In weight.h, a narrower and hence more broadly applicable version of // Power(w, n) is defined for arbitrary weight types and non-negative integer // exponents n (of type size_t) and implemented in terms of repeated // multiplication using Times. // // Without further provisions this means that, when an expression such as // // Power(TropicalWeightTpl::One(), static_cast(2)) // // is specified, the overload of Power() is ambiguous. The template function // below could be instantiated as // // Power(const TropicalWeightTpl &, size_t) // // and the template function defined in weight.h (further specialized below) // could be instantiated as // // Power>(const TropicalWeightTpl &, size_t) // // That would lead to two definitions with identical signatures, which results // in a compilation error. To avoid that, we hide the definition of Power // when V is size_t, so only Power is visible. Power is further // specialized to Power>, and the overloaded definition // of Power is made conditionally available only to that template // specialization. template , typename std::enable_if_t * = nullptr> constexpr TropicalWeightTpl Power(const TropicalWeightTpl &w, V n) { using Weight = TropicalWeightTpl; return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight() : (n == 0 || w == Weight::One()) ? Weight::One() : Weight(w.Value() * n); } // Specializes the library-wide template to use the above implementation; rules // of function template instantiation require this be a full instantiation. template <> constexpr TropicalWeightTpl Power>( const TropicalWeightTpl &weight, size_t n) { return Power(weight, n); } template <> constexpr TropicalWeightTpl Power>( const TropicalWeightTpl &weight, size_t n) { return Power(weight, n); } // Log semiring: (log(e^-x + e^-y), +, inf, 0). template class LogWeightTpl : public FloatWeightTpl { public: using typename FloatWeightTpl::ValueType; using FloatWeightTpl::Value; using ReverseWeight = LogWeightTpl; using Limits = FloatLimits; LogWeightTpl() noexcept : FloatWeightTpl() {} constexpr LogWeightTpl(T f) : FloatWeightTpl(f) {} static constexpr LogWeightTpl Zero() { return Limits::PosInfinity(); } static constexpr LogWeightTpl One() { return 0; } static constexpr LogWeightTpl NoWeight() { return Limits::NumberBad(); } static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("log", FloatWeightTpl::GetPrecisionString())); return *type; } constexpr bool Member() const { // The comments for TropicalWeightTpl<>::Member() apply here unchanged. return Limits::NegInfinity() < Value(); } LogWeightTpl Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return LogWeightTpl(std::floor(Value() / delta + 0.5F) * delta); } } constexpr LogWeightTpl Reverse() const { return *this; } static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative; } }; // Single-precision log weight. using LogWeight = LogWeightTpl; // Double-precision log weight. using Log64Weight = LogWeightTpl; namespace internal { // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming y >= x. inline double LogPosExp(double x) { DCHECK(!(x < 0)); // NB: NaN values are allowed. return log1p(exp(-x)); } // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming y >= x. inline double LogNegExp(double x) { DCHECK(!(x < 0)); // NB: NaN values are allowed. return log1p(-exp(-x)); } // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...). // Kahan compensated summation provides an error bound that is // independent of the number of addends. Assumes b >= a; // c is the compensation. inline double KahanLogSum(double a, double b, double *c) { DCHECK_GE(b, a); double y = -LogPosExp(b - a) - *c; double t = a + y; *c = (t - a) - y; return t; } // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...). // Kahan compensated summation provides an error bound that is // independent of the number of addends. Assumes b > a; // c is the compensation. inline double KahanLogDiff(double a, double b, double *c) { DCHECK_GT(b, a); double y = -LogNegExp(b - a) - *c; double t = a + y; *c = (t - a) - y; return t; } } // namespace internal template inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { using Limits = FloatLimits; const T f1 = w1.Value(); const T f2 = w2.Value(); if (f1 == Limits::PosInfinity()) { return w2; } else if (f2 == Limits::PosInfinity()) { return w1; } else if (f1 > f2) { return LogWeightTpl(f2 - internal::LogPosExp(f1 - f2)); } else { return LogWeightTpl(f1 - internal::LogPosExp(f2 - f1)); } } inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Plus(w1, w2); } inline LogWeightTpl Plus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Plus(w1, w2); } // Returns NoWeight if w1 < w2 (w1.Value() > w2.Value()). template inline LogWeightTpl Minus(const LogWeightTpl &w1, const LogWeightTpl &w2) { using Limits = FloatLimits; const T f1 = w1.Value(); const T f2 = w2.Value(); if (f1 > f2) return LogWeightTpl::NoWeight(); if (f2 == Limits::PosInfinity()) return f1; const T d = f2 - f1; if (d == Limits::PosInfinity()) return f1; return f1 - internal::LogNegExp(d); } inline LogWeightTpl Minus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Minus(w1, w2); } inline LogWeightTpl Minus(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Minus(w1, w2); } template constexpr LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { // The comments for Times(Tropical...) above apply here unchanged. return LogWeightTpl(w1.Value() + w2.Value()); } constexpr LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Times(w1, w2); } constexpr LogWeightTpl Times(const LogWeightTpl &w1, const LogWeightTpl &w2) { return Times(w1, w2); } template constexpr LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { // The comments for Divide(Tropical...) above apply here unchanged. using Weight = LogWeightTpl; return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight(); } constexpr LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } constexpr LogWeightTpl Divide(const LogWeightTpl &w1, const LogWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // The comments for Power<>(Tropical...) above apply here unchanged. template , typename std::enable_if_t * = nullptr> constexpr LogWeightTpl Power(const LogWeightTpl &w, V n) { using Weight = LogWeightTpl; return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight() : (n == 0 || w == Weight::One()) ? Weight::One() : Weight(w.Value() * n); } // Specializes the library-wide template to use the above implementation; rules // of function template instantiation require this be a full instantiation. template <> constexpr LogWeightTpl Power>( const LogWeightTpl &weight, size_t n) { return Power(weight, n); } template <> constexpr LogWeightTpl Power>( const LogWeightTpl &weight, size_t n) { return Power(weight, n); } // Specialization using the Kahan compensated summation. template class Adder> { public: using Weight = LogWeightTpl; explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {} Weight Add(const Weight &w) { using Limits = FloatLimits; const T f = w.Value(); if (f == Limits::PosInfinity()) { return Sum(); } else if (sum_ == Limits::PosInfinity()) { sum_ = f; c_ = 0.0; } else if (f > sum_) { sum_ = internal::KahanLogSum(sum_, f, &c_); } else { sum_ = internal::KahanLogSum(f, sum_, &c_); } return Sum(); } Weight Sum() const { return Weight(sum_); } void Reset(Weight w = Weight::Zero()) { sum_ = w.Value(); c_ = 0.0; } private: double sum_; double c_; // Kahan compensation. }; // Real semiring: (+, *, 0, 1). template class RealWeightTpl : public FloatWeightTpl { public: using typename FloatWeightTpl::ValueType; using FloatWeightTpl::Value; using ReverseWeight = RealWeightTpl; using Limits = FloatLimits; RealWeightTpl() noexcept : FloatWeightTpl() {} constexpr RealWeightTpl(T f) : FloatWeightTpl(f) {} static constexpr RealWeightTpl Zero() { return 0; } static constexpr RealWeightTpl One() { return 1; } static constexpr RealWeightTpl NoWeight() { return Limits::NumberBad(); } static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("real", FloatWeightTpl::GetPrecisionString())); return *type; } constexpr bool Member() const { // The comments for TropicalWeightTpl<>::Member() apply here unchanged. return Limits::NegInfinity() < Value(); } RealWeightTpl Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return RealWeightTpl(std::floor(Value() / delta + 0.5F) * delta); } } constexpr RealWeightTpl Reverse() const { return *this; } static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative; } }; // Single-precision log weight. using RealWeight = RealWeightTpl; // Double-precision log weight. using Real64Weight = RealWeightTpl; namespace internal { // a + b = KahanRealSum(a, b, ...). // Kahan compensated summation provides an error bound that is // independent of the number of addends. c is the compensation. inline double KahanRealSum(double a, double b, double *c) { double y = b - *c; double t = a + y; *c = (t - a) - y; return t; } }; // namespace internal // The comments for Times(Tropical...) above apply here unchanged. template inline RealWeightTpl Plus(const RealWeightTpl &w1, const RealWeightTpl &w2) { const T f1 = w1.Value(); const T f2 = w2.Value(); return RealWeightTpl(f1 + f2); } inline RealWeightTpl Plus(const RealWeightTpl &w1, const RealWeightTpl &w2) { return Plus(w1, w2); } inline RealWeightTpl Plus(const RealWeightTpl &w1, const RealWeightTpl &w2) { return Plus(w1, w2); } template inline RealWeightTpl Minus(const RealWeightTpl &w1, const RealWeightTpl &w2) { // The comments for Divide(Tropical...) above apply here unchanged. const T f1 = w1.Value(); const T f2 = w2.Value(); return RealWeightTpl(f1 - f2); } inline RealWeightTpl Minus(const RealWeightTpl &w1, const RealWeightTpl &w2) { return Minus(w1, w2); } inline RealWeightTpl Minus(const RealWeightTpl &w1, const RealWeightTpl &w2) { return Minus(w1, w2); } // The comments for Times(Tropical...) above apply here similarly. template constexpr RealWeightTpl Times(const RealWeightTpl &w1, const RealWeightTpl &w2) { return RealWeightTpl(w1.Value() * w2.Value()); } constexpr RealWeightTpl Times(const RealWeightTpl &w1, const RealWeightTpl &w2) { return Times(w1, w2); } constexpr RealWeightTpl Times(const RealWeightTpl &w1, const RealWeightTpl &w2) { return Times(w1, w2); } template constexpr RealWeightTpl Divide(const RealWeightTpl &w1, const RealWeightTpl &w2, DivideType typ = DIVIDE_ANY) { using Weight = RealWeightTpl; return w2.Member() ? Weight(w1.Value() / w2.Value()) : Weight::NoWeight(); } constexpr RealWeightTpl Divide(const RealWeightTpl &w1, const RealWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } constexpr RealWeightTpl Divide(const RealWeightTpl &w1, const RealWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // The comments for Power<>(Tropical...) above apply here unchanged. template , typename std::enable_if_t * = nullptr> constexpr RealWeightTpl Power(const RealWeightTpl &w, V n) { using Weight = RealWeightTpl; return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight() : (n == 0 || w == Weight::One()) ? Weight::One() : Weight(pow(w.Value(), n)); } // Specializes the library-wide template to use the above implementation; rules // of function template instantiation require this be a full instantiation. template <> constexpr RealWeightTpl Power>( const RealWeightTpl &weight, size_t n) { return Power(weight, n); } template <> constexpr RealWeightTpl Power>( const RealWeightTpl &weight, size_t n) { return Power(weight, n); } // Specialization using the Kahan compensated summation. template class Adder> { public: using Weight = RealWeightTpl; explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {} Weight Add(const Weight &w) { using Limits = FloatLimits; const T f = w.Value(); if (f == Limits::PosInfinity()) { sum_ = f; } else if (sum_ == Limits::PosInfinity()) { return sum_; } else { sum_ = internal::KahanRealSum(sum_, f, &c_); } return Sum(); } Weight Sum() const { return Weight(sum_); } void Reset(Weight w = Weight::Zero()) { sum_ = w.Value(); c_ = 0.0; } private: double sum_; double c_; // Kahan compensation. }; // MinMax semiring: (min, max, inf, -inf). template class MinMaxWeightTpl : public FloatWeightTpl { public: using typename FloatWeightTpl::ValueType; using FloatWeightTpl::Value; using ReverseWeight = MinMaxWeightTpl; using Limits = FloatLimits; MinMaxWeightTpl() noexcept : FloatWeightTpl() {} constexpr MinMaxWeightTpl(T f) : FloatWeightTpl(f) {} // NOLINT static constexpr MinMaxWeightTpl Zero() { return Limits::PosInfinity(); } static constexpr MinMaxWeightTpl One() { return Limits::NegInfinity(); } static constexpr MinMaxWeightTpl NoWeight() { return Limits::NumberBad(); } static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("minmax", FloatWeightTpl::GetPrecisionString())); return *type; } // Fails for IEEE NaN. constexpr bool Member() const { return !internal::IsNan(Value()); } MinMaxWeightTpl Quantize(float delta = kDelta) const { // If one of infinities, or a NaN. if (!Member() || Value() == Limits::NegInfinity() || Value() == Limits::PosInfinity()) { return *this; } else { return MinMaxWeightTpl(std::floor(Value() / delta + 0.5F) * delta); } } constexpr MinMaxWeightTpl Reverse() const { return *this; } static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; } }; // Single-precision min-max weight. using MinMaxWeight = MinMaxWeightTpl; // Min. template constexpr MinMaxWeightTpl Plus(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl::NoWeight() : w1.Value() < w2.Value() ? w1 : w2; } constexpr MinMaxWeightTpl Plus(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Plus(w1, w2); } constexpr MinMaxWeightTpl Plus(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Plus(w1, w2); } // Max. template constexpr MinMaxWeightTpl Times(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl::NoWeight() : w1.Value() >= w2.Value() ? w1 : w2; } constexpr MinMaxWeightTpl Times(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Times(w1, w2); } constexpr MinMaxWeightTpl Times(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2) { return Times(w1, w2); } // Defined only for special cases. template constexpr MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return w1.Value() >= w2.Value() ? w1 : MinMaxWeightTpl::NoWeight(); } constexpr MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } constexpr MinMaxWeightTpl Divide(const MinMaxWeightTpl &w1, const MinMaxWeightTpl &w2, DivideType typ = DIVIDE_ANY) { return Divide(w1, w2, typ); } // Converts to tropical. template <> struct WeightConvert { constexpr TropicalWeight operator()(const LogWeight &w) const { return w.Value(); } }; template <> struct WeightConvert { constexpr TropicalWeight operator()(const Log64Weight &w) const { return w.Value(); } }; // Converts to log. template <> struct WeightConvert { constexpr LogWeight operator()(const TropicalWeight &w) const { return w.Value(); } }; template <> struct WeightConvert { LogWeight operator()(const RealWeight &w) const { return -log(w.Value()); } }; template <> struct WeightConvert { LogWeight operator()(const Real64Weight &w) const { return -log(w.Value()); } }; template <> struct WeightConvert { constexpr LogWeight operator()(const Log64Weight &w) const { return w.Value(); } }; // Converts to log64. template <> struct WeightConvert { constexpr Log64Weight operator()(const TropicalWeight &w) const { return w.Value(); } }; template <> struct WeightConvert { Log64Weight operator()(const RealWeight &w) const { return -log(w.Value()); } }; template <> struct WeightConvert { Log64Weight operator()(const Real64Weight &w) const { return -log(w.Value()); } }; template <> struct WeightConvert { constexpr Log64Weight operator()(const LogWeight &w) const { return w.Value(); } }; // Converts to real. template <> struct WeightConvert { RealWeight operator()(const LogWeight &w) const { return exp(-w.Value()); } }; template <> struct WeightConvert { RealWeight operator()(const Log64Weight &w) const { return exp(-w.Value()); } }; template <> struct WeightConvert { constexpr RealWeight operator()(const Real64Weight &w) const { return w.Value(); } }; // Converts to real64 template <> struct WeightConvert { Real64Weight operator()(const LogWeight &w) const { return exp(-w.Value()); } }; template <> struct WeightConvert { Real64Weight operator()(const Log64Weight &w) const { return exp(-w.Value()); } }; template <> struct WeightConvert { constexpr Real64Weight operator()(const RealWeight &w) const { return w.Value(); } }; // This function object returns random integers chosen from [0, // num_random_weights). The allow_zero argument determines whether Zero() and // zero divisors should be returned in the random weight generation. This is // intended primary for testing. template class FloatWeightGenerate { public: explicit FloatWeightGenerate( uint64_t seed = std::random_device()(), bool allow_zero = true, const size_t num_random_weights = kNumRandomWeights) : rand_(seed), allow_zero_(allow_zero), num_random_weights_(num_random_weights) {} Weight operator()() const { const int sample = std::uniform_int_distribution<>( 0, num_random_weights_ + allow_zero_ - 1)(rand_); if (allow_zero_ && sample == num_random_weights_) return Weight::Zero(); return Weight(sample); } private: mutable std::mt19937_64 rand_; const bool allow_zero_; const size_t num_random_weights_; }; template class WeightGenerate> : public FloatWeightGenerate> { public: using Weight = TropicalWeightTpl; using Generate = FloatWeightGenerate; explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(seed, allow_zero, num_random_weights) {} Weight operator()() const { return Weight(Generate::operator()()); } }; template class WeightGenerate> : public FloatWeightGenerate> { public: using Weight = LogWeightTpl; using Generate = FloatWeightGenerate; explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(seed, allow_zero, num_random_weights) {} Weight operator()() const { return Weight(Generate::operator()()); } }; template class WeightGenerate> : public FloatWeightGenerate> { public: using Weight = RealWeightTpl; using Generate = FloatWeightGenerate; explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(seed, allow_zero, num_random_weights) {} Weight operator()() const { return Weight(Generate::operator()()); } }; // This function object returns random integers chosen from [0, // num_random_weights). The boolean 'allow_zero' determines whether Zero() and // zero divisors should be returned in the random weight generation. This is // intended primary for testing. template class WeightGenerate> { public: using Weight = MinMaxWeightTpl; explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : rand_(seed), allow_zero_(allow_zero), num_random_weights_(num_random_weights) {} Weight operator()() const { const int sample = std::uniform_int_distribution<>( -num_random_weights_, num_random_weights_ + allow_zero_)(rand_); if (allow_zero_ && sample == 0) { return Weight::Zero(); } else if (sample == -num_random_weights_) { return Weight::One(); } else { return Weight(sample); } } private: mutable std::mt19937_64 rand_; const bool allow_zero_; const size_t num_random_weights_; }; } // namespace fst #endif // FST_FLOAT_WEIGHT_H_