|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Float weight set and associated semiring operation definitions.
#ifndef FST_FLOAT_WEIGHT_H_
#define FST_FLOAT_WEIGHT_H_
#include <algorithm>
#include <climits>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <ios>
#include <istream>
#include <limits>
#include <ostream>
#include <random>
#include <sstream>
#include <string>
#include <type_traits>
#include <fst/log.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <fst/compat.h>
#include <string_view>
namespace fst {
namespace internal { // TODO(wolfsonkin): Replace with `std::isnan` if and when that ends up
// constexpr. For context, see
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0533r6.pdf.
template <class T> inline constexpr bool IsNan(T value) { return value != value; } } // namespace internal
// Numeric limits class.
template <class T> class FloatLimits { public: static constexpr T PosInfinity() { return std::numeric_limits<T>::infinity(); }
static constexpr T NegInfinity() { return -PosInfinity(); }
static constexpr T NumberBad() { return std::numeric_limits<T>::quiet_NaN(); } };
// Weight class to be templated on floating-points types.
template <class T = float> class FloatWeightTpl { public: using ValueType = T;
FloatWeightTpl() noexcept = default;
constexpr FloatWeightTpl(T f) : value_(f) {} // NOLINT
std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); }
std::ostream &Write(std::ostream &strm) const { return WriteType(strm, value_); }
size_t Hash() const { size_t hash = 0; // Avoid using union, which would be undefined behavior.
// Use memcpy, similar to bit_cast, but sizes may be different.
// This should be optimized into a single move instruction by
// any reasonable compiler.
std::memcpy(&hash, &value_, std::min(sizeof(hash), sizeof(value_))); return hash; }
constexpr const T &Value() const { return value_; }
protected: void SetValue(const T &f) { value_ = f; }
static constexpr std::string_view GetPrecisionString() { return sizeof(T) == 4 ? "" : sizeof(T) == 1 ? "8" : sizeof(T) == 2 ? "16" : sizeof(T) == 8 ? "64" : "unknown"; }
private: T value_; };
// Single-precision float weight.
using FloatWeight = FloatWeightTpl<float>;
template <class T> constexpr bool operator==(const FloatWeightTpl<T> &w1, const FloatWeightTpl<T> &w2) { #if (defined(__i386__) || defined(__x86_64__)) && !defined(__SSE2_MATH__)
// With i387 instructions, excess precision on a weight in an 80-bit
// register may cause it to compare unequal to that same weight when
// stored to memory. This breaks =='s reflexivity, in turn breaking
// NaturalLess.
#error "Please compile with -msse -mfpmath=sse, or equivalent."
#endif
return w1.Value() == w2.Value(); }
// These seemingly unnecessary overloads are actually needed to make
// comparisons like FloatWeightTpl<float> == float compile. If only the
// templated version exists, the FloatWeightTpl<float>(float) conversion
// won't be found.
constexpr bool operator==(const FloatWeightTpl<float> &w1, const FloatWeightTpl<float> &w2) { return operator==<float>(w1, w2); }
constexpr bool operator==(const FloatWeightTpl<double> &w1, const FloatWeightTpl<double> &w2) { return operator==<double>(w1, w2); }
template <class T> constexpr bool operator!=(const FloatWeightTpl<T> &w1, const FloatWeightTpl<T> &w2) { return !(w1 == w2); }
constexpr bool operator!=(const FloatWeightTpl<float> &w1, const FloatWeightTpl<float> &w2) { return operator!=<float>(w1, w2); }
constexpr bool operator!=(const FloatWeightTpl<double> &w1, const FloatWeightTpl<double> &w2) { return operator!=<double>(w1, w2); }
template <class T> constexpr bool FloatApproxEqual(T w1, T w2, float delta = kDelta) { return w1 <= w2 + delta && w2 <= w1 + delta; }
template <class T> constexpr bool ApproxEqual(const FloatWeightTpl<T> &w1, const FloatWeightTpl<T> &w2, float delta = kDelta) { return FloatApproxEqual(w1.Value(), w2.Value(), delta); }
template <class T> inline std::ostream &operator<<(std::ostream &strm, const FloatWeightTpl<T> &w) { if (w.Value() == FloatLimits<T>::PosInfinity()) { return strm << "Infinity"; } else if (w.Value() == FloatLimits<T>::NegInfinity()) { return strm << "-Infinity"; } else if (internal::IsNan(w.Value())) { return strm << "BadNumber"; } else { return strm << w.Value(); } }
template <class T> inline std::istream &operator>>(std::istream &strm, FloatWeightTpl<T> &w) { std::string s; strm >> s; if (s == "Infinity") { w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity()); } else if (s == "-Infinity") { w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity()); } else { char *p; T f = strtod(s.c_str(), &p); if (p < s.c_str() + s.size()) { strm.clear(std::ios::badbit); } else { w = FloatWeightTpl<T>(f); } } return strm; }
// Tropical semiring: (min, +, inf, 0).
template <class T> class TropicalWeightTpl : public FloatWeightTpl<T> { public: using typename FloatWeightTpl<T>::ValueType; using FloatWeightTpl<T>::Value; using ReverseWeight = TropicalWeightTpl<T>; using Limits = FloatLimits<T>;
TropicalWeightTpl() noexcept : FloatWeightTpl<T>() {}
constexpr TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
static constexpr TropicalWeightTpl<T> Zero() { return Limits::PosInfinity(); }
static constexpr TropicalWeightTpl<T> One() { return 0; }
static constexpr TropicalWeightTpl<T> NoWeight() { return Limits::NumberBad(); }
static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("tropical", FloatWeightTpl<T>::GetPrecisionString())); return *type; }
constexpr bool Member() const { // All floating point values except for NaNs and negative infinity are valid
// tropical weights.
//
// Testing membership of a given value can be done by simply checking that
// it is strictly greater than negative infinity, which fails for negative
// infinity itself but also for NaNs. This can usually be accomplished in a
// single instruction (such as *UCOMI* on x86) without branching logic.
//
// An additional wrinkle involves constexpr correctness of floating point
// comparisons against NaN. GCC is uneven when it comes to which expressions
// it considers compile-time constants. In particular, current versions of
// GCC do not always consider (nan < inf) to be a constant expression, but
// do consider (inf < nan) to be a constant expression. (See
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88173 and
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88683 for details.) In order
// to allow Member() to be a constexpr function accepted by GCC, we write
// the comparison here as (-inf < v).
return Limits::NegInfinity() < Value(); }
TropicalWeightTpl<T> Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return TropicalWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta); } }
constexpr TropicalWeightTpl<T> Reverse() const { return *this; }
static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent; } };
// Single precision tropical weight.
using TropicalWeight = TropicalWeightTpl<float>;
template <class T> constexpr TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1, const TropicalWeightTpl<T> &w2) { return (!w1.Member() || !w2.Member()) ? TropicalWeightTpl<T>::NoWeight() : w1.Value() < w2.Value() ? w1 : w2; }
// See comment at operator==(FloatWeightTpl<float>, FloatWeightTpl<float>)
// for why these overloads are present.
constexpr TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1, const TropicalWeightTpl<float> &w2) { return Plus<float>(w1, w2); }
constexpr TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1, const TropicalWeightTpl<double> &w2) { return Plus<double>(w1, w2); }
template <class T> constexpr TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, const TropicalWeightTpl<T> &w2) { // The following is safe in the context of the Tropical (and Log) semiring
// for all IEEE floating point values, including infinities and NaNs,
// because:
//
// * If one or both of the floating point Values is NaN and hence not a
// Member, the result of addition below is NaN, so the result is not a
// Member. This supersedes all other cases, so we only consider non-NaN
// values next.
//
// * If both Values are finite, there is no issue.
//
// * If one of the Values is infinite, or if both are infinities with the
// same sign, the result of floating point addition is the same infinity,
// so there is no issue.
//
// * If both of the Values are infinities with opposite signs, the result of
// adding IEEE floating point -inf + inf is NaN and hence not a Member. But
// since -inf was not a Member to begin with, returning a non-Member result
// is fine as well.
return TropicalWeightTpl<T>(w1.Value() + w2.Value()); }
constexpr TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1, const TropicalWeightTpl<float> &w2) { return Times<float>(w1, w2); }
constexpr TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1, const TropicalWeightTpl<double> &w2) { return Times<double>(w1, w2); }
template <class T> constexpr TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, const TropicalWeightTpl<T> &w2, DivideType typ = DIVIDE_ANY) { // The following is safe in the context of the Tropical (and Log) semiring
// for all IEEE floating point values, including infinities and NaNs,
// because:
//
// * If one or both of the floating point Values is NaN and hence not a
// Member, the result of subtraction below is NaN, so the result is not a
// Member. This supersedes all other cases, so we only consider non-NaN
// values next.
//
// * If both Values are finite, there is no issue.
//
// * If w2.Value() is -inf (and hence w2 is not a Member), the result of ?:
// below is NoWeight, which is not a Member.
//
// Whereas in IEEE floating point semantics 0/inf == 0, this does not carry
// over to this semiring (since TropicalWeight(-inf) would be the analogue
// of floating point inf) and instead Divide(Zero(), TropicalWeight(-inf))
// is NoWeight().
//
// * If w2.Value() is inf (and hence w2 is Zero), the resulting floating
// point value is either NaN (if w1 is Zero or if w1.Value() is NaN) and
// hence not a Member, or it is -inf and hence not a Member; either way,
// division by Zero results in a non-Member result.
using Weight = TropicalWeightTpl<T>; return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight(); }
constexpr TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1, const TropicalWeightTpl<float> &w2, DivideType typ = DIVIDE_ANY) { return Divide<float>(w1, w2, typ); }
constexpr TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1, const TropicalWeightTpl<double> &w2, DivideType typ = DIVIDE_ANY) { return Divide<double>(w1, w2, typ); }
// Power(w, n) calculates the n-th power of w with respect to semiring Times.
//
// In the case of the Tropical (and Log) semiring, the exponent n is not
// restricted to be an integer. It can be a floating point value, for example.
//
// In weight.h, a narrower and hence more broadly applicable version of
// Power(w, n) is defined for arbitrary weight types and non-negative integer
// exponents n (of type size_t) and implemented in terms of repeated
// multiplication using Times.
//
// Without further provisions this means that, when an expression such as
//
// Power(TropicalWeightTpl<float>::One(), static_cast<size_t>(2))
//
// is specified, the overload of Power() is ambiguous. The template function
// below could be instantiated as
//
// Power<float, size_t>(const TropicalWeightTpl<float> &, size_t)
//
// and the template function defined in weight.h (further specialized below)
// could be instantiated as
//
// Power<TropicalWeightTpl<float>>(const TropicalWeightTpl<float> &, size_t)
//
// That would lead to two definitions with identical signatures, which results
// in a compilation error. To avoid that, we hide the definition of Power<T, V>
// when V is size_t, so only Power<W> is visible. Power<W> is further
// specialized to Power<TropicalWeightTpl<...>>, and the overloaded definition
// of Power<T, V> is made conditionally available only to that template
// specialization.
template <class T, class V, bool Enable = !std::is_same_v<V, size_t>, typename std::enable_if_t<Enable> * = nullptr> constexpr TropicalWeightTpl<T> Power(const TropicalWeightTpl<T> &w, V n) { using Weight = TropicalWeightTpl<T>; return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight() : (n == 0 || w == Weight::One()) ? Weight::One() : Weight(w.Value() * n); }
// Specializes the library-wide template to use the above implementation; rules
// of function template instantiation require this be a full instantiation.
template <> constexpr TropicalWeightTpl<float> Power<TropicalWeightTpl<float>>( const TropicalWeightTpl<float> &weight, size_t n) { return Power<float, size_t, true>(weight, n); }
template <> constexpr TropicalWeightTpl<double> Power<TropicalWeightTpl<double>>( const TropicalWeightTpl<double> &weight, size_t n) { return Power<double, size_t, true>(weight, n); }
// Log semiring: (log(e^-x + e^-y), +, inf, 0).
template <class T> class LogWeightTpl : public FloatWeightTpl<T> { public: using typename FloatWeightTpl<T>::ValueType; using FloatWeightTpl<T>::Value; using ReverseWeight = LogWeightTpl; using Limits = FloatLimits<T>;
LogWeightTpl() noexcept : FloatWeightTpl<T>() {}
constexpr LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
static constexpr LogWeightTpl Zero() { return Limits::PosInfinity(); }
static constexpr LogWeightTpl One() { return 0; }
static constexpr LogWeightTpl NoWeight() { return Limits::NumberBad(); }
static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("log", FloatWeightTpl<T>::GetPrecisionString())); return *type; }
constexpr bool Member() const { // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
return Limits::NegInfinity() < Value(); }
LogWeightTpl<T> Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return LogWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta); } }
constexpr LogWeightTpl<T> Reverse() const { return *this; }
static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative; } };
// Single-precision log weight.
using LogWeight = LogWeightTpl<float>;
// Double-precision log weight.
using Log64Weight = LogWeightTpl<double>;
namespace internal {
// -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming y >= x.
inline double LogPosExp(double x) { DCHECK(!(x < 0)); // NB: NaN values are allowed.
return log1p(exp(-x)); }
// -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming y >= x.
inline double LogNegExp(double x) { DCHECK(!(x < 0)); // NB: NaN values are allowed.
return log1p(-exp(-x)); }
// a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...).
// Kahan compensated summation provides an error bound that is
// independent of the number of addends. Assumes b >= a;
// c is the compensation.
inline double KahanLogSum(double a, double b, double *c) { DCHECK_GE(b, a); double y = -LogPosExp(b - a) - *c; double t = a + y; *c = (t - a) - y; return t; }
// a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...).
// Kahan compensated summation provides an error bound that is
// independent of the number of addends. Assumes b > a;
// c is the compensation.
inline double KahanLogDiff(double a, double b, double *c) { DCHECK_GT(b, a); double y = -LogNegExp(b - a) - *c; double t = a + y; *c = (t - a) - y; return t; }
} // namespace internal
template <class T> inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, const LogWeightTpl<T> &w2) { using Limits = FloatLimits<T>; const T f1 = w1.Value(); const T f2 = w2.Value(); if (f1 == Limits::PosInfinity()) { return w2; } else if (f2 == Limits::PosInfinity()) { return w1; } else if (f1 > f2) { return LogWeightTpl<T>(f2 - internal::LogPosExp(f1 - f2)); } else { return LogWeightTpl<T>(f1 - internal::LogPosExp(f2 - f1)); } }
inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1, const LogWeightTpl<float> &w2) { return Plus<float>(w1, w2); }
inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1, const LogWeightTpl<double> &w2) { return Plus<double>(w1, w2); }
// Returns NoWeight if w1 < w2 (w1.Value() > w2.Value()).
template <class T> inline LogWeightTpl<T> Minus(const LogWeightTpl<T> &w1, const LogWeightTpl<T> &w2) { using Limits = FloatLimits<T>; const T f1 = w1.Value(); const T f2 = w2.Value(); if (f1 > f2) return LogWeightTpl<T>::NoWeight(); if (f2 == Limits::PosInfinity()) return f1; const T d = f2 - f1; if (d == Limits::PosInfinity()) return f1; return f1 - internal::LogNegExp(d); }
inline LogWeightTpl<float> Minus(const LogWeightTpl<float> &w1, const LogWeightTpl<float> &w2) { return Minus<float>(w1, w2); }
inline LogWeightTpl<double> Minus(const LogWeightTpl<double> &w1, const LogWeightTpl<double> &w2) { return Minus<double>(w1, w2); }
template <class T> constexpr LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, const LogWeightTpl<T> &w2) { // The comments for Times(Tropical...) above apply here unchanged.
return LogWeightTpl<T>(w1.Value() + w2.Value()); }
constexpr LogWeightTpl<float> Times(const LogWeightTpl<float> &w1, const LogWeightTpl<float> &w2) { return Times<float>(w1, w2); }
constexpr LogWeightTpl<double> Times(const LogWeightTpl<double> &w1, const LogWeightTpl<double> &w2) { return Times<double>(w1, w2); }
template <class T> constexpr LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, const LogWeightTpl<T> &w2, DivideType typ = DIVIDE_ANY) { // The comments for Divide(Tropical...) above apply here unchanged.
using Weight = LogWeightTpl<T>; return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight(); }
constexpr LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1, const LogWeightTpl<float> &w2, DivideType typ = DIVIDE_ANY) { return Divide<float>(w1, w2, typ); }
constexpr LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1, const LogWeightTpl<double> &w2, DivideType typ = DIVIDE_ANY) { return Divide<double>(w1, w2, typ); }
// The comments for Power<>(Tropical...) above apply here unchanged.
template <class T, class V, bool Enable = !std::is_same_v<V, size_t>, typename std::enable_if_t<Enable> * = nullptr> constexpr LogWeightTpl<T> Power(const LogWeightTpl<T> &w, V n) { using Weight = LogWeightTpl<T>; return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight() : (n == 0 || w == Weight::One()) ? Weight::One() : Weight(w.Value() * n); }
// Specializes the library-wide template to use the above implementation; rules
// of function template instantiation require this be a full instantiation.
template <> constexpr LogWeightTpl<float> Power<LogWeightTpl<float>>( const LogWeightTpl<float> &weight, size_t n) { return Power<float, size_t, true>(weight, n); }
template <> constexpr LogWeightTpl<double> Power<LogWeightTpl<double>>( const LogWeightTpl<double> &weight, size_t n) { return Power<double, size_t, true>(weight, n); }
// Specialization using the Kahan compensated summation.
template <class T> class Adder<LogWeightTpl<T>> { public: using Weight = LogWeightTpl<T>;
explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
Weight Add(const Weight &w) { using Limits = FloatLimits<T>; const T f = w.Value(); if (f == Limits::PosInfinity()) { return Sum(); } else if (sum_ == Limits::PosInfinity()) { sum_ = f; c_ = 0.0; } else if (f > sum_) { sum_ = internal::KahanLogSum(sum_, f, &c_); } else { sum_ = internal::KahanLogSum(f, sum_, &c_); } return Sum(); }
Weight Sum() const { return Weight(sum_); }
void Reset(Weight w = Weight::Zero()) { sum_ = w.Value(); c_ = 0.0; }
private: double sum_; double c_; // Kahan compensation.
};
// Real semiring: (+, *, 0, 1).
template <class T> class RealWeightTpl : public FloatWeightTpl<T> { public: using typename FloatWeightTpl<T>::ValueType; using FloatWeightTpl<T>::Value; using ReverseWeight = RealWeightTpl; using Limits = FloatLimits<T>;
RealWeightTpl() noexcept : FloatWeightTpl<T>() {}
constexpr RealWeightTpl(T f) : FloatWeightTpl<T>(f) {}
static constexpr RealWeightTpl Zero() { return 0; }
static constexpr RealWeightTpl One() { return 1; }
static constexpr RealWeightTpl NoWeight() { return Limits::NumberBad(); }
static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("real", FloatWeightTpl<T>::GetPrecisionString())); return *type; }
constexpr bool Member() const { // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
return Limits::NegInfinity() < Value(); }
RealWeightTpl<T> Quantize(float delta = kDelta) const { if (!Member() || Value() == Limits::PosInfinity()) { return *this; } else { return RealWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta); } }
constexpr RealWeightTpl<T> Reverse() const { return *this; }
static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative; } };
// Single-precision log weight.
using RealWeight = RealWeightTpl<float>;
// Double-precision log weight.
using Real64Weight = RealWeightTpl<double>;
namespace internal {
// a + b = KahanRealSum(a, b, ...).
// Kahan compensated summation provides an error bound that is
// independent of the number of addends. c is the compensation.
inline double KahanRealSum(double a, double b, double *c) { double y = b - *c; double t = a + y; *c = (t - a) - y; return t; }
}; // namespace internal
// The comments for Times(Tropical...) above apply here unchanged.
template <class T> inline RealWeightTpl<T> Plus(const RealWeightTpl<T> &w1, const RealWeightTpl<T> &w2) { const T f1 = w1.Value(); const T f2 = w2.Value(); return RealWeightTpl<T>(f1 + f2); }
inline RealWeightTpl<float> Plus(const RealWeightTpl<float> &w1, const RealWeightTpl<float> &w2) { return Plus<float>(w1, w2); }
inline RealWeightTpl<double> Plus(const RealWeightTpl<double> &w1, const RealWeightTpl<double> &w2) { return Plus<double>(w1, w2); }
template <class T> inline RealWeightTpl<T> Minus(const RealWeightTpl<T> &w1, const RealWeightTpl<T> &w2) { // The comments for Divide(Tropical...) above apply here unchanged.
const T f1 = w1.Value(); const T f2 = w2.Value(); return RealWeightTpl<T>(f1 - f2); }
inline RealWeightTpl<float> Minus(const RealWeightTpl<float> &w1, const RealWeightTpl<float> &w2) { return Minus<float>(w1, w2); }
inline RealWeightTpl<double> Minus(const RealWeightTpl<double> &w1, const RealWeightTpl<double> &w2) { return Minus<double>(w1, w2); }
// The comments for Times(Tropical...) above apply here similarly.
template <class T> constexpr RealWeightTpl<T> Times(const RealWeightTpl<T> &w1, const RealWeightTpl<T> &w2) { return RealWeightTpl<T>(w1.Value() * w2.Value()); }
constexpr RealWeightTpl<float> Times(const RealWeightTpl<float> &w1, const RealWeightTpl<float> &w2) { return Times<float>(w1, w2); }
constexpr RealWeightTpl<double> Times(const RealWeightTpl<double> &w1, const RealWeightTpl<double> &w2) { return Times<double>(w1, w2); }
template <class T> constexpr RealWeightTpl<T> Divide(const RealWeightTpl<T> &w1, const RealWeightTpl<T> &w2, DivideType typ = DIVIDE_ANY) { using Weight = RealWeightTpl<T>; return w2.Member() ? Weight(w1.Value() / w2.Value()) : Weight::NoWeight(); }
constexpr RealWeightTpl<float> Divide(const RealWeightTpl<float> &w1, const RealWeightTpl<float> &w2, DivideType typ = DIVIDE_ANY) { return Divide<float>(w1, w2, typ); }
constexpr RealWeightTpl<double> Divide(const RealWeightTpl<double> &w1, const RealWeightTpl<double> &w2, DivideType typ = DIVIDE_ANY) { return Divide<double>(w1, w2, typ); }
// The comments for Power<>(Tropical...) above apply here unchanged.
template <class T, class V, bool Enable = !std::is_same_v<V, size_t>, typename std::enable_if_t<Enable> * = nullptr> constexpr RealWeightTpl<T> Power(const RealWeightTpl<T> &w, V n) { using Weight = RealWeightTpl<T>; return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight() : (n == 0 || w == Weight::One()) ? Weight::One() : Weight(pow(w.Value(), n)); }
// Specializes the library-wide template to use the above implementation; rules
// of function template instantiation require this be a full instantiation.
template <> constexpr RealWeightTpl<float> Power<RealWeightTpl<float>>( const RealWeightTpl<float> &weight, size_t n) { return Power<float, size_t, true>(weight, n); }
template <> constexpr RealWeightTpl<double> Power<RealWeightTpl<double>>( const RealWeightTpl<double> &weight, size_t n) { return Power<double, size_t, true>(weight, n); }
// Specialization using the Kahan compensated summation.
template <class T> class Adder<RealWeightTpl<T>> { public: using Weight = RealWeightTpl<T>;
explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
Weight Add(const Weight &w) { using Limits = FloatLimits<T>; const T f = w.Value(); if (f == Limits::PosInfinity()) { sum_ = f; } else if (sum_ == Limits::PosInfinity()) { return sum_; } else { sum_ = internal::KahanRealSum(sum_, f, &c_); } return Sum(); }
Weight Sum() const { return Weight(sum_); }
void Reset(Weight w = Weight::Zero()) { sum_ = w.Value(); c_ = 0.0; }
private: double sum_; double c_; // Kahan compensation.
};
// MinMax semiring: (min, max, inf, -inf).
template <class T> class MinMaxWeightTpl : public FloatWeightTpl<T> { public: using typename FloatWeightTpl<T>::ValueType; using FloatWeightTpl<T>::Value; using ReverseWeight = MinMaxWeightTpl<T>; using Limits = FloatLimits<T>;
MinMaxWeightTpl() noexcept : FloatWeightTpl<T>() {}
constexpr MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} // NOLINT
static constexpr MinMaxWeightTpl Zero() { return Limits::PosInfinity(); }
static constexpr MinMaxWeightTpl One() { return Limits::NegInfinity(); }
static constexpr MinMaxWeightTpl NoWeight() { return Limits::NumberBad(); }
static const std::string &Type() { static const std::string *const type = new std::string( fst::StrCat("minmax", FloatWeightTpl<T>::GetPrecisionString())); return *type; }
// Fails for IEEE NaN.
constexpr bool Member() const { return !internal::IsNan(Value()); }
MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { // If one of infinities, or a NaN.
if (!Member() || Value() == Limits::NegInfinity() || Value() == Limits::PosInfinity()) { return *this; } else { return MinMaxWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta); } }
constexpr MinMaxWeightTpl<T> Reverse() const { return *this; }
static constexpr uint64_t Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; } };
// Single-precision min-max weight.
using MinMaxWeight = MinMaxWeightTpl<float>;
// Min.
template <class T> constexpr MinMaxWeightTpl<T> Plus(const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight() : w1.Value() < w2.Value() ? w1 : w2; }
constexpr MinMaxWeightTpl<float> Plus(const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { return Plus<float>(w1, w2); }
constexpr MinMaxWeightTpl<double> Plus(const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { return Plus<double>(w1, w2); }
// Max.
template <class T> constexpr MinMaxWeightTpl<T> Times(const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight() : w1.Value() >= w2.Value() ? w1 : w2; }
constexpr MinMaxWeightTpl<float> Times(const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { return Times<float>(w1, w2); }
constexpr MinMaxWeightTpl<double> Times(const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { return Times<double>(w1, w2); }
// Defined only for special cases.
template <class T> constexpr MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2, DivideType typ = DIVIDE_ANY) { return w1.Value() >= w2.Value() ? w1 : MinMaxWeightTpl<T>::NoWeight(); }
constexpr MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2, DivideType typ = DIVIDE_ANY) { return Divide<float>(w1, w2, typ); }
constexpr MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2, DivideType typ = DIVIDE_ANY) { return Divide<double>(w1, w2, typ); }
// Converts to tropical.
template <> struct WeightConvert<LogWeight, TropicalWeight> { constexpr TropicalWeight operator()(const LogWeight &w) const { return w.Value(); } };
template <> struct WeightConvert<Log64Weight, TropicalWeight> { constexpr TropicalWeight operator()(const Log64Weight &w) const { return w.Value(); } };
// Converts to log.
template <> struct WeightConvert<TropicalWeight, LogWeight> { constexpr LogWeight operator()(const TropicalWeight &w) const { return w.Value(); } };
template <> struct WeightConvert<RealWeight, LogWeight> { LogWeight operator()(const RealWeight &w) const { return -log(w.Value()); } };
template <> struct WeightConvert<Real64Weight, LogWeight> { LogWeight operator()(const Real64Weight &w) const { return -log(w.Value()); } };
template <> struct WeightConvert<Log64Weight, LogWeight> { constexpr LogWeight operator()(const Log64Weight &w) const { return w.Value(); } };
// Converts to log64.
template <> struct WeightConvert<TropicalWeight, Log64Weight> { constexpr Log64Weight operator()(const TropicalWeight &w) const { return w.Value(); } };
template <> struct WeightConvert<RealWeight, Log64Weight> { Log64Weight operator()(const RealWeight &w) const { return -log(w.Value()); } };
template <> struct WeightConvert<Real64Weight, Log64Weight> { Log64Weight operator()(const Real64Weight &w) const { return -log(w.Value()); } };
template <> struct WeightConvert<LogWeight, Log64Weight> { constexpr Log64Weight operator()(const LogWeight &w) const { return w.Value(); } };
// Converts to real.
template <> struct WeightConvert<LogWeight, RealWeight> { RealWeight operator()(const LogWeight &w) const { return exp(-w.Value()); } };
template <> struct WeightConvert<Log64Weight, RealWeight> { RealWeight operator()(const Log64Weight &w) const { return exp(-w.Value()); } };
template <> struct WeightConvert<Real64Weight, RealWeight> { constexpr RealWeight operator()(const Real64Weight &w) const { return w.Value(); } };
// Converts to real64
template <> struct WeightConvert<LogWeight, Real64Weight> { Real64Weight operator()(const LogWeight &w) const { return exp(-w.Value()); } };
template <> struct WeightConvert<Log64Weight, Real64Weight> { Real64Weight operator()(const Log64Weight &w) const { return exp(-w.Value()); } };
template <> struct WeightConvert<RealWeight, Real64Weight> { constexpr Real64Weight operator()(const RealWeight &w) const { return w.Value(); } };
// This function object returns random integers chosen from [0,
// num_random_weights). The allow_zero argument determines whether Zero() and
// zero divisors should be returned in the random weight generation. This is
// intended primary for testing.
template <class Weight> class FloatWeightGenerate { public: explicit FloatWeightGenerate( uint64_t seed = std::random_device()(), bool allow_zero = true, const size_t num_random_weights = kNumRandomWeights) : rand_(seed), allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
Weight operator()() const { const int sample = std::uniform_int_distribution<>( 0, num_random_weights_ + allow_zero_ - 1)(rand_); if (allow_zero_ && sample == num_random_weights_) return Weight::Zero(); return Weight(sample); }
private: mutable std::mt19937_64 rand_; const bool allow_zero_; const size_t num_random_weights_; };
template <class T> class WeightGenerate<TropicalWeightTpl<T>> : public FloatWeightGenerate<TropicalWeightTpl<T>> { public: using Weight = TropicalWeightTpl<T>; using Generate = FloatWeightGenerate<Weight>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(seed, allow_zero, num_random_weights) {}
Weight operator()() const { return Weight(Generate::operator()()); } };
template <class T> class WeightGenerate<LogWeightTpl<T>> : public FloatWeightGenerate<LogWeightTpl<T>> { public: using Weight = LogWeightTpl<T>; using Generate = FloatWeightGenerate<Weight>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(seed, allow_zero, num_random_weights) {}
Weight operator()() const { return Weight(Generate::operator()()); } };
template <class T> class WeightGenerate<RealWeightTpl<T>> : public FloatWeightGenerate<RealWeightTpl<T>> { public: using Weight = RealWeightTpl<T>; using Generate = FloatWeightGenerate<Weight>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : Generate(seed, allow_zero, num_random_weights) {}
Weight operator()() const { return Weight(Generate::operator()()); } };
// This function object returns random integers chosen from [0,
// num_random_weights). The boolean 'allow_zero' determines whether Zero() and
// zero divisors should be returned in the random weight generation. This is
// intended primary for testing.
template <class T> class WeightGenerate<MinMaxWeightTpl<T>> { public: using Weight = MinMaxWeightTpl<T>;
explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true, size_t num_random_weights = kNumRandomWeights) : rand_(seed), allow_zero_(allow_zero), num_random_weights_(num_random_weights) {}
Weight operator()() const { const int sample = std::uniform_int_distribution<>( -num_random_weights_, num_random_weights_ + allow_zero_)(rand_); if (allow_zero_ && sample == 0) { return Weight::Zero(); } else if (sample == -num_random_weights_) { return Weight::One(); } else { return Weight(sample); } }
private: mutable std::mt19937_64 rand_; const bool allow_zero_; const size_t num_random_weights_; };
} // namespace fst
#endif // FST_FLOAT_WEIGHT_H_
|