You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1183 lines
38 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Float weight set and associated semiring operation definitions.
  19. #ifndef FST_FLOAT_WEIGHT_H_
  20. #define FST_FLOAT_WEIGHT_H_
  21. #include <algorithm>
  22. #include <climits>
  23. #include <cmath>
  24. #include <cstddef>
  25. #include <cstdint>
  26. #include <cstdlib>
  27. #include <cstring>
  28. #include <ios>
  29. #include <istream>
  30. #include <limits>
  31. #include <ostream>
  32. #include <random>
  33. #include <sstream>
  34. #include <string>
  35. #include <type_traits>
  36. #include <fst/log.h>
  37. #include <fst/util.h>
  38. #include <fst/weight.h>
  39. #include <fst/compat.h>
  40. #include <string_view>
  41. namespace fst {
  42. namespace internal {
  43. // TODO(wolfsonkin): Replace with `std::isnan` if and when that ends up
  44. // constexpr. For context, see
  45. // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0533r6.pdf.
  46. template <class T>
  47. inline constexpr bool IsNan(T value) {
  48. return value != value;
  49. }
  50. } // namespace internal
  51. // Numeric limits class.
  52. template <class T>
  53. class FloatLimits {
  54. public:
  55. static constexpr T PosInfinity() {
  56. return std::numeric_limits<T>::infinity();
  57. }
  58. static constexpr T NegInfinity() { return -PosInfinity(); }
  59. static constexpr T NumberBad() { return std::numeric_limits<T>::quiet_NaN(); }
  60. };
  61. // Weight class to be templated on floating-points types.
  62. template <class T = float>
  63. class FloatWeightTpl {
  64. public:
  65. using ValueType = T;
  66. FloatWeightTpl() noexcept = default;
  67. constexpr FloatWeightTpl(T f) : value_(f) {} // NOLINT
  68. std::istream &Read(std::istream &strm) { return ReadType(strm, &value_); }
  69. std::ostream &Write(std::ostream &strm) const {
  70. return WriteType(strm, value_);
  71. }
  72. size_t Hash() const {
  73. size_t hash = 0;
  74. // Avoid using union, which would be undefined behavior.
  75. // Use memcpy, similar to bit_cast, but sizes may be different.
  76. // This should be optimized into a single move instruction by
  77. // any reasonable compiler.
  78. std::memcpy(&hash, &value_, std::min(sizeof(hash), sizeof(value_)));
  79. return hash;
  80. }
  81. constexpr const T &Value() const { return value_; }
  82. protected:
  83. void SetValue(const T &f) { value_ = f; }
  84. static constexpr std::string_view GetPrecisionString() {
  85. return sizeof(T) == 4 ? ""
  86. : sizeof(T) == 1 ? "8"
  87. : sizeof(T) == 2 ? "16"
  88. : sizeof(T) == 8 ? "64"
  89. : "unknown";
  90. }
  91. private:
  92. T value_;
  93. };
  94. // Single-precision float weight.
  95. using FloatWeight = FloatWeightTpl<float>;
  96. template <class T>
  97. constexpr bool operator==(const FloatWeightTpl<T> &w1,
  98. const FloatWeightTpl<T> &w2) {
  99. #if (defined(__i386__) || defined(__x86_64__)) && !defined(__SSE2_MATH__)
  100. // With i387 instructions, excess precision on a weight in an 80-bit
  101. // register may cause it to compare unequal to that same weight when
  102. // stored to memory. This breaks =='s reflexivity, in turn breaking
  103. // NaturalLess.
  104. #error "Please compile with -msse -mfpmath=sse, or equivalent."
  105. #endif
  106. return w1.Value() == w2.Value();
  107. }
  108. // These seemingly unnecessary overloads are actually needed to make
  109. // comparisons like FloatWeightTpl<float> == float compile. If only the
  110. // templated version exists, the FloatWeightTpl<float>(float) conversion
  111. // won't be found.
  112. constexpr bool operator==(const FloatWeightTpl<float> &w1,
  113. const FloatWeightTpl<float> &w2) {
  114. return operator==<float>(w1, w2);
  115. }
  116. constexpr bool operator==(const FloatWeightTpl<double> &w1,
  117. const FloatWeightTpl<double> &w2) {
  118. return operator==<double>(w1, w2);
  119. }
  120. template <class T>
  121. constexpr bool operator!=(const FloatWeightTpl<T> &w1,
  122. const FloatWeightTpl<T> &w2) {
  123. return !(w1 == w2);
  124. }
  125. constexpr bool operator!=(const FloatWeightTpl<float> &w1,
  126. const FloatWeightTpl<float> &w2) {
  127. return operator!=<float>(w1, w2);
  128. }
  129. constexpr bool operator!=(const FloatWeightTpl<double> &w1,
  130. const FloatWeightTpl<double> &w2) {
  131. return operator!=<double>(w1, w2);
  132. }
  133. template <class T>
  134. constexpr bool FloatApproxEqual(T w1, T w2, float delta = kDelta) {
  135. return w1 <= w2 + delta && w2 <= w1 + delta;
  136. }
  137. template <class T>
  138. constexpr bool ApproxEqual(const FloatWeightTpl<T> &w1,
  139. const FloatWeightTpl<T> &w2, float delta = kDelta) {
  140. return FloatApproxEqual(w1.Value(), w2.Value(), delta);
  141. }
  142. template <class T>
  143. inline std::ostream &operator<<(std::ostream &strm,
  144. const FloatWeightTpl<T> &w) {
  145. if (w.Value() == FloatLimits<T>::PosInfinity()) {
  146. return strm << "Infinity";
  147. } else if (w.Value() == FloatLimits<T>::NegInfinity()) {
  148. return strm << "-Infinity";
  149. } else if (internal::IsNan(w.Value())) {
  150. return strm << "BadNumber";
  151. } else {
  152. return strm << w.Value();
  153. }
  154. }
  155. template <class T>
  156. inline std::istream &operator>>(std::istream &strm, FloatWeightTpl<T> &w) {
  157. std::string s;
  158. strm >> s;
  159. if (s == "Infinity") {
  160. w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
  161. } else if (s == "-Infinity") {
  162. w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
  163. } else {
  164. char *p;
  165. T f = strtod(s.c_str(), &p);
  166. if (p < s.c_str() + s.size()) {
  167. strm.clear(std::ios::badbit);
  168. } else {
  169. w = FloatWeightTpl<T>(f);
  170. }
  171. }
  172. return strm;
  173. }
  174. // Tropical semiring: (min, +, inf, 0).
  175. template <class T>
  176. class TropicalWeightTpl : public FloatWeightTpl<T> {
  177. public:
  178. using typename FloatWeightTpl<T>::ValueType;
  179. using FloatWeightTpl<T>::Value;
  180. using ReverseWeight = TropicalWeightTpl<T>;
  181. using Limits = FloatLimits<T>;
  182. TropicalWeightTpl() noexcept : FloatWeightTpl<T>() {}
  183. constexpr TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
  184. static constexpr TropicalWeightTpl<T> Zero() { return Limits::PosInfinity(); }
  185. static constexpr TropicalWeightTpl<T> One() { return 0; }
  186. static constexpr TropicalWeightTpl<T> NoWeight() {
  187. return Limits::NumberBad();
  188. }
  189. static const std::string &Type() {
  190. static const std::string *const type = new std::string(
  191. fst::StrCat("tropical", FloatWeightTpl<T>::GetPrecisionString()));
  192. return *type;
  193. }
  194. constexpr bool Member() const {
  195. // All floating point values except for NaNs and negative infinity are valid
  196. // tropical weights.
  197. //
  198. // Testing membership of a given value can be done by simply checking that
  199. // it is strictly greater than negative infinity, which fails for negative
  200. // infinity itself but also for NaNs. This can usually be accomplished in a
  201. // single instruction (such as *UCOMI* on x86) without branching logic.
  202. //
  203. // An additional wrinkle involves constexpr correctness of floating point
  204. // comparisons against NaN. GCC is uneven when it comes to which expressions
  205. // it considers compile-time constants. In particular, current versions of
  206. // GCC do not always consider (nan < inf) to be a constant expression, but
  207. // do consider (inf < nan) to be a constant expression. (See
  208. // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88173 and
  209. // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88683 for details.) In order
  210. // to allow Member() to be a constexpr function accepted by GCC, we write
  211. // the comparison here as (-inf < v).
  212. return Limits::NegInfinity() < Value();
  213. }
  214. TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
  215. if (!Member() || Value() == Limits::PosInfinity()) {
  216. return *this;
  217. } else {
  218. return TropicalWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
  219. }
  220. }
  221. constexpr TropicalWeightTpl<T> Reverse() const { return *this; }
  222. static constexpr uint64_t Properties() {
  223. return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
  224. }
  225. };
  226. // Single precision tropical weight.
  227. using TropicalWeight = TropicalWeightTpl<float>;
  228. template <class T>
  229. constexpr TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
  230. const TropicalWeightTpl<T> &w2) {
  231. return (!w1.Member() || !w2.Member()) ? TropicalWeightTpl<T>::NoWeight()
  232. : w1.Value() < w2.Value() ? w1
  233. : w2;
  234. }
  235. // See comment at operator==(FloatWeightTpl<float>, FloatWeightTpl<float>)
  236. // for why these overloads are present.
  237. constexpr TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
  238. const TropicalWeightTpl<float> &w2) {
  239. return Plus<float>(w1, w2);
  240. }
  241. constexpr TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
  242. const TropicalWeightTpl<double> &w2) {
  243. return Plus<double>(w1, w2);
  244. }
  245. template <class T>
  246. constexpr TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
  247. const TropicalWeightTpl<T> &w2) {
  248. // The following is safe in the context of the Tropical (and Log) semiring
  249. // for all IEEE floating point values, including infinities and NaNs,
  250. // because:
  251. //
  252. // * If one or both of the floating point Values is NaN and hence not a
  253. // Member, the result of addition below is NaN, so the result is not a
  254. // Member. This supersedes all other cases, so we only consider non-NaN
  255. // values next.
  256. //
  257. // * If both Values are finite, there is no issue.
  258. //
  259. // * If one of the Values is infinite, or if both are infinities with the
  260. // same sign, the result of floating point addition is the same infinity,
  261. // so there is no issue.
  262. //
  263. // * If both of the Values are infinities with opposite signs, the result of
  264. // adding IEEE floating point -inf + inf is NaN and hence not a Member. But
  265. // since -inf was not a Member to begin with, returning a non-Member result
  266. // is fine as well.
  267. return TropicalWeightTpl<T>(w1.Value() + w2.Value());
  268. }
  269. constexpr TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
  270. const TropicalWeightTpl<float> &w2) {
  271. return Times<float>(w1, w2);
  272. }
  273. constexpr TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
  274. const TropicalWeightTpl<double> &w2) {
  275. return Times<double>(w1, w2);
  276. }
  277. template <class T>
  278. constexpr TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
  279. const TropicalWeightTpl<T> &w2,
  280. DivideType typ = DIVIDE_ANY) {
  281. // The following is safe in the context of the Tropical (and Log) semiring
  282. // for all IEEE floating point values, including infinities and NaNs,
  283. // because:
  284. //
  285. // * If one or both of the floating point Values is NaN and hence not a
  286. // Member, the result of subtraction below is NaN, so the result is not a
  287. // Member. This supersedes all other cases, so we only consider non-NaN
  288. // values next.
  289. //
  290. // * If both Values are finite, there is no issue.
  291. //
  292. // * If w2.Value() is -inf (and hence w2 is not a Member), the result of ?:
  293. // below is NoWeight, which is not a Member.
  294. //
  295. // Whereas in IEEE floating point semantics 0/inf == 0, this does not carry
  296. // over to this semiring (since TropicalWeight(-inf) would be the analogue
  297. // of floating point inf) and instead Divide(Zero(), TropicalWeight(-inf))
  298. // is NoWeight().
  299. //
  300. // * If w2.Value() is inf (and hence w2 is Zero), the resulting floating
  301. // point value is either NaN (if w1 is Zero or if w1.Value() is NaN) and
  302. // hence not a Member, or it is -inf and hence not a Member; either way,
  303. // division by Zero results in a non-Member result.
  304. using Weight = TropicalWeightTpl<T>;
  305. return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
  306. }
  307. constexpr TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
  308. const TropicalWeightTpl<float> &w2,
  309. DivideType typ = DIVIDE_ANY) {
  310. return Divide<float>(w1, w2, typ);
  311. }
  312. constexpr TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
  313. const TropicalWeightTpl<double> &w2,
  314. DivideType typ = DIVIDE_ANY) {
  315. return Divide<double>(w1, w2, typ);
  316. }
  317. // Power(w, n) calculates the n-th power of w with respect to semiring Times.
  318. //
  319. // In the case of the Tropical (and Log) semiring, the exponent n is not
  320. // restricted to be an integer. It can be a floating point value, for example.
  321. //
  322. // In weight.h, a narrower and hence more broadly applicable version of
  323. // Power(w, n) is defined for arbitrary weight types and non-negative integer
  324. // exponents n (of type size_t) and implemented in terms of repeated
  325. // multiplication using Times.
  326. //
  327. // Without further provisions this means that, when an expression such as
  328. //
  329. // Power(TropicalWeightTpl<float>::One(), static_cast<size_t>(2))
  330. //
  331. // is specified, the overload of Power() is ambiguous. The template function
  332. // below could be instantiated as
  333. //
  334. // Power<float, size_t>(const TropicalWeightTpl<float> &, size_t)
  335. //
  336. // and the template function defined in weight.h (further specialized below)
  337. // could be instantiated as
  338. //
  339. // Power<TropicalWeightTpl<float>>(const TropicalWeightTpl<float> &, size_t)
  340. //
  341. // That would lead to two definitions with identical signatures, which results
  342. // in a compilation error. To avoid that, we hide the definition of Power<T, V>
  343. // when V is size_t, so only Power<W> is visible. Power<W> is further
  344. // specialized to Power<TropicalWeightTpl<...>>, and the overloaded definition
  345. // of Power<T, V> is made conditionally available only to that template
  346. // specialization.
  347. template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
  348. typename std::enable_if_t<Enable> * = nullptr>
  349. constexpr TropicalWeightTpl<T> Power(const TropicalWeightTpl<T> &w, V n) {
  350. using Weight = TropicalWeightTpl<T>;
  351. return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
  352. : (n == 0 || w == Weight::One()) ? Weight::One()
  353. : Weight(w.Value() * n);
  354. }
  355. // Specializes the library-wide template to use the above implementation; rules
  356. // of function template instantiation require this be a full instantiation.
  357. template <>
  358. constexpr TropicalWeightTpl<float> Power<TropicalWeightTpl<float>>(
  359. const TropicalWeightTpl<float> &weight, size_t n) {
  360. return Power<float, size_t, true>(weight, n);
  361. }
  362. template <>
  363. constexpr TropicalWeightTpl<double> Power<TropicalWeightTpl<double>>(
  364. const TropicalWeightTpl<double> &weight, size_t n) {
  365. return Power<double, size_t, true>(weight, n);
  366. }
  367. // Log semiring: (log(e^-x + e^-y), +, inf, 0).
  368. template <class T>
  369. class LogWeightTpl : public FloatWeightTpl<T> {
  370. public:
  371. using typename FloatWeightTpl<T>::ValueType;
  372. using FloatWeightTpl<T>::Value;
  373. using ReverseWeight = LogWeightTpl;
  374. using Limits = FloatLimits<T>;
  375. LogWeightTpl() noexcept : FloatWeightTpl<T>() {}
  376. constexpr LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
  377. static constexpr LogWeightTpl Zero() { return Limits::PosInfinity(); }
  378. static constexpr LogWeightTpl One() { return 0; }
  379. static constexpr LogWeightTpl NoWeight() { return Limits::NumberBad(); }
  380. static const std::string &Type() {
  381. static const std::string *const type = new std::string(
  382. fst::StrCat("log", FloatWeightTpl<T>::GetPrecisionString()));
  383. return *type;
  384. }
  385. constexpr bool Member() const {
  386. // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
  387. return Limits::NegInfinity() < Value();
  388. }
  389. LogWeightTpl<T> Quantize(float delta = kDelta) const {
  390. if (!Member() || Value() == Limits::PosInfinity()) {
  391. return *this;
  392. } else {
  393. return LogWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
  394. }
  395. }
  396. constexpr LogWeightTpl<T> Reverse() const { return *this; }
  397. static constexpr uint64_t Properties() {
  398. return kLeftSemiring | kRightSemiring | kCommutative;
  399. }
  400. };
  401. // Single-precision log weight.
  402. using LogWeight = LogWeightTpl<float>;
  403. // Double-precision log weight.
  404. using Log64Weight = LogWeightTpl<double>;
  405. namespace internal {
  406. // -log(e^-x + e^-y) = x - LogPosExp(y - x), assuming y >= x.
  407. inline double LogPosExp(double x) {
  408. DCHECK(!(x < 0)); // NB: NaN values are allowed.
  409. return log1p(exp(-x));
  410. }
  411. // -log(e^-x - e^-y) = x - LogNegExp(y - x), assuming y >= x.
  412. inline double LogNegExp(double x) {
  413. DCHECK(!(x < 0)); // NB: NaN values are allowed.
  414. return log1p(-exp(-x));
  415. }
  416. // a +_log b = -log(e^-a + e^-b) = KahanLogSum(a, b, ...).
  417. // Kahan compensated summation provides an error bound that is
  418. // independent of the number of addends. Assumes b >= a;
  419. // c is the compensation.
  420. inline double KahanLogSum(double a, double b, double *c) {
  421. DCHECK_GE(b, a);
  422. double y = -LogPosExp(b - a) - *c;
  423. double t = a + y;
  424. *c = (t - a) - y;
  425. return t;
  426. }
  427. // a -_log b = -log(e^-a - e^-b) = KahanLogDiff(a, b, ...).
  428. // Kahan compensated summation provides an error bound that is
  429. // independent of the number of addends. Assumes b > a;
  430. // c is the compensation.
  431. inline double KahanLogDiff(double a, double b, double *c) {
  432. DCHECK_GT(b, a);
  433. double y = -LogNegExp(b - a) - *c;
  434. double t = a + y;
  435. *c = (t - a) - y;
  436. return t;
  437. }
  438. } // namespace internal
  439. template <class T>
  440. inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
  441. const LogWeightTpl<T> &w2) {
  442. using Limits = FloatLimits<T>;
  443. const T f1 = w1.Value();
  444. const T f2 = w2.Value();
  445. if (f1 == Limits::PosInfinity()) {
  446. return w2;
  447. } else if (f2 == Limits::PosInfinity()) {
  448. return w1;
  449. } else if (f1 > f2) {
  450. return LogWeightTpl<T>(f2 - internal::LogPosExp(f1 - f2));
  451. } else {
  452. return LogWeightTpl<T>(f1 - internal::LogPosExp(f2 - f1));
  453. }
  454. }
  455. inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
  456. const LogWeightTpl<float> &w2) {
  457. return Plus<float>(w1, w2);
  458. }
  459. inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
  460. const LogWeightTpl<double> &w2) {
  461. return Plus<double>(w1, w2);
  462. }
  463. // Returns NoWeight if w1 < w2 (w1.Value() > w2.Value()).
  464. template <class T>
  465. inline LogWeightTpl<T> Minus(const LogWeightTpl<T> &w1,
  466. const LogWeightTpl<T> &w2) {
  467. using Limits = FloatLimits<T>;
  468. const T f1 = w1.Value();
  469. const T f2 = w2.Value();
  470. if (f1 > f2) return LogWeightTpl<T>::NoWeight();
  471. if (f2 == Limits::PosInfinity()) return f1;
  472. const T d = f2 - f1;
  473. if (d == Limits::PosInfinity()) return f1;
  474. return f1 - internal::LogNegExp(d);
  475. }
  476. inline LogWeightTpl<float> Minus(const LogWeightTpl<float> &w1,
  477. const LogWeightTpl<float> &w2) {
  478. return Minus<float>(w1, w2);
  479. }
  480. inline LogWeightTpl<double> Minus(const LogWeightTpl<double> &w1,
  481. const LogWeightTpl<double> &w2) {
  482. return Minus<double>(w1, w2);
  483. }
  484. template <class T>
  485. constexpr LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
  486. const LogWeightTpl<T> &w2) {
  487. // The comments for Times(Tropical...) above apply here unchanged.
  488. return LogWeightTpl<T>(w1.Value() + w2.Value());
  489. }
  490. constexpr LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
  491. const LogWeightTpl<float> &w2) {
  492. return Times<float>(w1, w2);
  493. }
  494. constexpr LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
  495. const LogWeightTpl<double> &w2) {
  496. return Times<double>(w1, w2);
  497. }
  498. template <class T>
  499. constexpr LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
  500. const LogWeightTpl<T> &w2,
  501. DivideType typ = DIVIDE_ANY) {
  502. // The comments for Divide(Tropical...) above apply here unchanged.
  503. using Weight = LogWeightTpl<T>;
  504. return w2.Member() ? Weight(w1.Value() - w2.Value()) : Weight::NoWeight();
  505. }
  506. constexpr LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
  507. const LogWeightTpl<float> &w2,
  508. DivideType typ = DIVIDE_ANY) {
  509. return Divide<float>(w1, w2, typ);
  510. }
  511. constexpr LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
  512. const LogWeightTpl<double> &w2,
  513. DivideType typ = DIVIDE_ANY) {
  514. return Divide<double>(w1, w2, typ);
  515. }
  516. // The comments for Power<>(Tropical...) above apply here unchanged.
  517. template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
  518. typename std::enable_if_t<Enable> * = nullptr>
  519. constexpr LogWeightTpl<T> Power(const LogWeightTpl<T> &w, V n) {
  520. using Weight = LogWeightTpl<T>;
  521. return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
  522. : (n == 0 || w == Weight::One()) ? Weight::One()
  523. : Weight(w.Value() * n);
  524. }
  525. // Specializes the library-wide template to use the above implementation; rules
  526. // of function template instantiation require this be a full instantiation.
  527. template <>
  528. constexpr LogWeightTpl<float> Power<LogWeightTpl<float>>(
  529. const LogWeightTpl<float> &weight, size_t n) {
  530. return Power<float, size_t, true>(weight, n);
  531. }
  532. template <>
  533. constexpr LogWeightTpl<double> Power<LogWeightTpl<double>>(
  534. const LogWeightTpl<double> &weight, size_t n) {
  535. return Power<double, size_t, true>(weight, n);
  536. }
  537. // Specialization using the Kahan compensated summation.
  538. template <class T>
  539. class Adder<LogWeightTpl<T>> {
  540. public:
  541. using Weight = LogWeightTpl<T>;
  542. explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
  543. Weight Add(const Weight &w) {
  544. using Limits = FloatLimits<T>;
  545. const T f = w.Value();
  546. if (f == Limits::PosInfinity()) {
  547. return Sum();
  548. } else if (sum_ == Limits::PosInfinity()) {
  549. sum_ = f;
  550. c_ = 0.0;
  551. } else if (f > sum_) {
  552. sum_ = internal::KahanLogSum(sum_, f, &c_);
  553. } else {
  554. sum_ = internal::KahanLogSum(f, sum_, &c_);
  555. }
  556. return Sum();
  557. }
  558. Weight Sum() const { return Weight(sum_); }
  559. void Reset(Weight w = Weight::Zero()) {
  560. sum_ = w.Value();
  561. c_ = 0.0;
  562. }
  563. private:
  564. double sum_;
  565. double c_; // Kahan compensation.
  566. };
  567. // Real semiring: (+, *, 0, 1).
  568. template <class T>
  569. class RealWeightTpl : public FloatWeightTpl<T> {
  570. public:
  571. using typename FloatWeightTpl<T>::ValueType;
  572. using FloatWeightTpl<T>::Value;
  573. using ReverseWeight = RealWeightTpl;
  574. using Limits = FloatLimits<T>;
  575. RealWeightTpl() noexcept : FloatWeightTpl<T>() {}
  576. constexpr RealWeightTpl(T f) : FloatWeightTpl<T>(f) {}
  577. static constexpr RealWeightTpl Zero() { return 0; }
  578. static constexpr RealWeightTpl One() { return 1; }
  579. static constexpr RealWeightTpl NoWeight() { return Limits::NumberBad(); }
  580. static const std::string &Type() {
  581. static const std::string *const type = new std::string(
  582. fst::StrCat("real", FloatWeightTpl<T>::GetPrecisionString()));
  583. return *type;
  584. }
  585. constexpr bool Member() const {
  586. // The comments for TropicalWeightTpl<>::Member() apply here unchanged.
  587. return Limits::NegInfinity() < Value();
  588. }
  589. RealWeightTpl<T> Quantize(float delta = kDelta) const {
  590. if (!Member() || Value() == Limits::PosInfinity()) {
  591. return *this;
  592. } else {
  593. return RealWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
  594. }
  595. }
  596. constexpr RealWeightTpl<T> Reverse() const { return *this; }
  597. static constexpr uint64_t Properties() {
  598. return kLeftSemiring | kRightSemiring | kCommutative;
  599. }
  600. };
  601. // Single-precision log weight.
  602. using RealWeight = RealWeightTpl<float>;
  603. // Double-precision log weight.
  604. using Real64Weight = RealWeightTpl<double>;
  605. namespace internal {
  606. // a + b = KahanRealSum(a, b, ...).
  607. // Kahan compensated summation provides an error bound that is
  608. // independent of the number of addends. c is the compensation.
  609. inline double KahanRealSum(double a, double b, double *c) {
  610. double y = b - *c;
  611. double t = a + y;
  612. *c = (t - a) - y;
  613. return t;
  614. }
  615. }; // namespace internal
  616. // The comments for Times(Tropical...) above apply here unchanged.
  617. template <class T>
  618. inline RealWeightTpl<T> Plus(const RealWeightTpl<T> &w1,
  619. const RealWeightTpl<T> &w2) {
  620. const T f1 = w1.Value();
  621. const T f2 = w2.Value();
  622. return RealWeightTpl<T>(f1 + f2);
  623. }
  624. inline RealWeightTpl<float> Plus(const RealWeightTpl<float> &w1,
  625. const RealWeightTpl<float> &w2) {
  626. return Plus<float>(w1, w2);
  627. }
  628. inline RealWeightTpl<double> Plus(const RealWeightTpl<double> &w1,
  629. const RealWeightTpl<double> &w2) {
  630. return Plus<double>(w1, w2);
  631. }
  632. template <class T>
  633. inline RealWeightTpl<T> Minus(const RealWeightTpl<T> &w1,
  634. const RealWeightTpl<T> &w2) {
  635. // The comments for Divide(Tropical...) above apply here unchanged.
  636. const T f1 = w1.Value();
  637. const T f2 = w2.Value();
  638. return RealWeightTpl<T>(f1 - f2);
  639. }
  640. inline RealWeightTpl<float> Minus(const RealWeightTpl<float> &w1,
  641. const RealWeightTpl<float> &w2) {
  642. return Minus<float>(w1, w2);
  643. }
  644. inline RealWeightTpl<double> Minus(const RealWeightTpl<double> &w1,
  645. const RealWeightTpl<double> &w2) {
  646. return Minus<double>(w1, w2);
  647. }
  648. // The comments for Times(Tropical...) above apply here similarly.
  649. template <class T>
  650. constexpr RealWeightTpl<T> Times(const RealWeightTpl<T> &w1,
  651. const RealWeightTpl<T> &w2) {
  652. return RealWeightTpl<T>(w1.Value() * w2.Value());
  653. }
  654. constexpr RealWeightTpl<float> Times(const RealWeightTpl<float> &w1,
  655. const RealWeightTpl<float> &w2) {
  656. return Times<float>(w1, w2);
  657. }
  658. constexpr RealWeightTpl<double> Times(const RealWeightTpl<double> &w1,
  659. const RealWeightTpl<double> &w2) {
  660. return Times<double>(w1, w2);
  661. }
  662. template <class T>
  663. constexpr RealWeightTpl<T> Divide(const RealWeightTpl<T> &w1,
  664. const RealWeightTpl<T> &w2,
  665. DivideType typ = DIVIDE_ANY) {
  666. using Weight = RealWeightTpl<T>;
  667. return w2.Member() ? Weight(w1.Value() / w2.Value()) : Weight::NoWeight();
  668. }
  669. constexpr RealWeightTpl<float> Divide(const RealWeightTpl<float> &w1,
  670. const RealWeightTpl<float> &w2,
  671. DivideType typ = DIVIDE_ANY) {
  672. return Divide<float>(w1, w2, typ);
  673. }
  674. constexpr RealWeightTpl<double> Divide(const RealWeightTpl<double> &w1,
  675. const RealWeightTpl<double> &w2,
  676. DivideType typ = DIVIDE_ANY) {
  677. return Divide<double>(w1, w2, typ);
  678. }
  679. // The comments for Power<>(Tropical...) above apply here unchanged.
  680. template <class T, class V, bool Enable = !std::is_same_v<V, size_t>,
  681. typename std::enable_if_t<Enable> * = nullptr>
  682. constexpr RealWeightTpl<T> Power(const RealWeightTpl<T> &w, V n) {
  683. using Weight = RealWeightTpl<T>;
  684. return (!w.Member() || internal::IsNan(n)) ? Weight::NoWeight()
  685. : (n == 0 || w == Weight::One()) ? Weight::One()
  686. : Weight(pow(w.Value(), n));
  687. }
  688. // Specializes the library-wide template to use the above implementation; rules
  689. // of function template instantiation require this be a full instantiation.
  690. template <>
  691. constexpr RealWeightTpl<float> Power<RealWeightTpl<float>>(
  692. const RealWeightTpl<float> &weight, size_t n) {
  693. return Power<float, size_t, true>(weight, n);
  694. }
  695. template <>
  696. constexpr RealWeightTpl<double> Power<RealWeightTpl<double>>(
  697. const RealWeightTpl<double> &weight, size_t n) {
  698. return Power<double, size_t, true>(weight, n);
  699. }
  700. // Specialization using the Kahan compensated summation.
  701. template <class T>
  702. class Adder<RealWeightTpl<T>> {
  703. public:
  704. using Weight = RealWeightTpl<T>;
  705. explicit Adder(Weight w = Weight::Zero()) : sum_(w.Value()), c_(0.0) {}
  706. Weight Add(const Weight &w) {
  707. using Limits = FloatLimits<T>;
  708. const T f = w.Value();
  709. if (f == Limits::PosInfinity()) {
  710. sum_ = f;
  711. } else if (sum_ == Limits::PosInfinity()) {
  712. return sum_;
  713. } else {
  714. sum_ = internal::KahanRealSum(sum_, f, &c_);
  715. }
  716. return Sum();
  717. }
  718. Weight Sum() const { return Weight(sum_); }
  719. void Reset(Weight w = Weight::Zero()) {
  720. sum_ = w.Value();
  721. c_ = 0.0;
  722. }
  723. private:
  724. double sum_;
  725. double c_; // Kahan compensation.
  726. };
  727. // MinMax semiring: (min, max, inf, -inf).
  728. template <class T>
  729. class MinMaxWeightTpl : public FloatWeightTpl<T> {
  730. public:
  731. using typename FloatWeightTpl<T>::ValueType;
  732. using FloatWeightTpl<T>::Value;
  733. using ReverseWeight = MinMaxWeightTpl<T>;
  734. using Limits = FloatLimits<T>;
  735. MinMaxWeightTpl() noexcept : FloatWeightTpl<T>() {}
  736. constexpr MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} // NOLINT
  737. static constexpr MinMaxWeightTpl Zero() { return Limits::PosInfinity(); }
  738. static constexpr MinMaxWeightTpl One() { return Limits::NegInfinity(); }
  739. static constexpr MinMaxWeightTpl NoWeight() { return Limits::NumberBad(); }
  740. static const std::string &Type() {
  741. static const std::string *const type = new std::string(
  742. fst::StrCat("minmax", FloatWeightTpl<T>::GetPrecisionString()));
  743. return *type;
  744. }
  745. // Fails for IEEE NaN.
  746. constexpr bool Member() const { return !internal::IsNan(Value()); }
  747. MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
  748. // If one of infinities, or a NaN.
  749. if (!Member() || Value() == Limits::NegInfinity() ||
  750. Value() == Limits::PosInfinity()) {
  751. return *this;
  752. } else {
  753. return MinMaxWeightTpl<T>(std::floor(Value() / delta + 0.5F) * delta);
  754. }
  755. }
  756. constexpr MinMaxWeightTpl<T> Reverse() const { return *this; }
  757. static constexpr uint64_t Properties() {
  758. return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
  759. }
  760. };
  761. // Single-precision min-max weight.
  762. using MinMaxWeight = MinMaxWeightTpl<float>;
  763. // Min.
  764. template <class T>
  765. constexpr MinMaxWeightTpl<T> Plus(const MinMaxWeightTpl<T> &w1,
  766. const MinMaxWeightTpl<T> &w2) {
  767. return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
  768. : w1.Value() < w2.Value() ? w1
  769. : w2;
  770. }
  771. constexpr MinMaxWeightTpl<float> Plus(const MinMaxWeightTpl<float> &w1,
  772. const MinMaxWeightTpl<float> &w2) {
  773. return Plus<float>(w1, w2);
  774. }
  775. constexpr MinMaxWeightTpl<double> Plus(const MinMaxWeightTpl<double> &w1,
  776. const MinMaxWeightTpl<double> &w2) {
  777. return Plus<double>(w1, w2);
  778. }
  779. // Max.
  780. template <class T>
  781. constexpr MinMaxWeightTpl<T> Times(const MinMaxWeightTpl<T> &w1,
  782. const MinMaxWeightTpl<T> &w2) {
  783. return (!w1.Member() || !w2.Member()) ? MinMaxWeightTpl<T>::NoWeight()
  784. : w1.Value() >= w2.Value() ? w1
  785. : w2;
  786. }
  787. constexpr MinMaxWeightTpl<float> Times(const MinMaxWeightTpl<float> &w1,
  788. const MinMaxWeightTpl<float> &w2) {
  789. return Times<float>(w1, w2);
  790. }
  791. constexpr MinMaxWeightTpl<double> Times(const MinMaxWeightTpl<double> &w1,
  792. const MinMaxWeightTpl<double> &w2) {
  793. return Times<double>(w1, w2);
  794. }
  795. // Defined only for special cases.
  796. template <class T>
  797. constexpr MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
  798. const MinMaxWeightTpl<T> &w2,
  799. DivideType typ = DIVIDE_ANY) {
  800. return w1.Value() >= w2.Value() ? w1 : MinMaxWeightTpl<T>::NoWeight();
  801. }
  802. constexpr MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
  803. const MinMaxWeightTpl<float> &w2,
  804. DivideType typ = DIVIDE_ANY) {
  805. return Divide<float>(w1, w2, typ);
  806. }
  807. constexpr MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
  808. const MinMaxWeightTpl<double> &w2,
  809. DivideType typ = DIVIDE_ANY) {
  810. return Divide<double>(w1, w2, typ);
  811. }
  812. // Converts to tropical.
  813. template <>
  814. struct WeightConvert<LogWeight, TropicalWeight> {
  815. constexpr TropicalWeight operator()(const LogWeight &w) const {
  816. return w.Value();
  817. }
  818. };
  819. template <>
  820. struct WeightConvert<Log64Weight, TropicalWeight> {
  821. constexpr TropicalWeight operator()(const Log64Weight &w) const {
  822. return w.Value();
  823. }
  824. };
  825. // Converts to log.
  826. template <>
  827. struct WeightConvert<TropicalWeight, LogWeight> {
  828. constexpr LogWeight operator()(const TropicalWeight &w) const {
  829. return w.Value();
  830. }
  831. };
  832. template <>
  833. struct WeightConvert<RealWeight, LogWeight> {
  834. LogWeight operator()(const RealWeight &w) const { return -log(w.Value()); }
  835. };
  836. template <>
  837. struct WeightConvert<Real64Weight, LogWeight> {
  838. LogWeight operator()(const Real64Weight &w) const { return -log(w.Value()); }
  839. };
  840. template <>
  841. struct WeightConvert<Log64Weight, LogWeight> {
  842. constexpr LogWeight operator()(const Log64Weight &w) const {
  843. return w.Value();
  844. }
  845. };
  846. // Converts to log64.
  847. template <>
  848. struct WeightConvert<TropicalWeight, Log64Weight> {
  849. constexpr Log64Weight operator()(const TropicalWeight &w) const {
  850. return w.Value();
  851. }
  852. };
  853. template <>
  854. struct WeightConvert<RealWeight, Log64Weight> {
  855. Log64Weight operator()(const RealWeight &w) const { return -log(w.Value()); }
  856. };
  857. template <>
  858. struct WeightConvert<Real64Weight, Log64Weight> {
  859. Log64Weight operator()(const Real64Weight &w) const {
  860. return -log(w.Value());
  861. }
  862. };
  863. template <>
  864. struct WeightConvert<LogWeight, Log64Weight> {
  865. constexpr Log64Weight operator()(const LogWeight &w) const {
  866. return w.Value();
  867. }
  868. };
  869. // Converts to real.
  870. template <>
  871. struct WeightConvert<LogWeight, RealWeight> {
  872. RealWeight operator()(const LogWeight &w) const { return exp(-w.Value()); }
  873. };
  874. template <>
  875. struct WeightConvert<Log64Weight, RealWeight> {
  876. RealWeight operator()(const Log64Weight &w) const { return exp(-w.Value()); }
  877. };
  878. template <>
  879. struct WeightConvert<Real64Weight, RealWeight> {
  880. constexpr RealWeight operator()(const Real64Weight &w) const {
  881. return w.Value();
  882. }
  883. };
  884. // Converts to real64
  885. template <>
  886. struct WeightConvert<LogWeight, Real64Weight> {
  887. Real64Weight operator()(const LogWeight &w) const { return exp(-w.Value()); }
  888. };
  889. template <>
  890. struct WeightConvert<Log64Weight, Real64Weight> {
  891. Real64Weight operator()(const Log64Weight &w) const {
  892. return exp(-w.Value());
  893. }
  894. };
  895. template <>
  896. struct WeightConvert<RealWeight, Real64Weight> {
  897. constexpr Real64Weight operator()(const RealWeight &w) const {
  898. return w.Value();
  899. }
  900. };
  901. // This function object returns random integers chosen from [0,
  902. // num_random_weights). The allow_zero argument determines whether Zero() and
  903. // zero divisors should be returned in the random weight generation. This is
  904. // intended primary for testing.
  905. template <class Weight>
  906. class FloatWeightGenerate {
  907. public:
  908. explicit FloatWeightGenerate(
  909. uint64_t seed = std::random_device()(), bool allow_zero = true,
  910. const size_t num_random_weights = kNumRandomWeights)
  911. : rand_(seed),
  912. allow_zero_(allow_zero),
  913. num_random_weights_(num_random_weights) {}
  914. Weight operator()() const {
  915. const int sample = std::uniform_int_distribution<>(
  916. 0, num_random_weights_ + allow_zero_ - 1)(rand_);
  917. if (allow_zero_ && sample == num_random_weights_) return Weight::Zero();
  918. return Weight(sample);
  919. }
  920. private:
  921. mutable std::mt19937_64 rand_;
  922. const bool allow_zero_;
  923. const size_t num_random_weights_;
  924. };
  925. template <class T>
  926. class WeightGenerate<TropicalWeightTpl<T>>
  927. : public FloatWeightGenerate<TropicalWeightTpl<T>> {
  928. public:
  929. using Weight = TropicalWeightTpl<T>;
  930. using Generate = FloatWeightGenerate<Weight>;
  931. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  932. bool allow_zero = true,
  933. size_t num_random_weights = kNumRandomWeights)
  934. : Generate(seed, allow_zero, num_random_weights) {}
  935. Weight operator()() const { return Weight(Generate::operator()()); }
  936. };
  937. template <class T>
  938. class WeightGenerate<LogWeightTpl<T>>
  939. : public FloatWeightGenerate<LogWeightTpl<T>> {
  940. public:
  941. using Weight = LogWeightTpl<T>;
  942. using Generate = FloatWeightGenerate<Weight>;
  943. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  944. bool allow_zero = true,
  945. size_t num_random_weights = kNumRandomWeights)
  946. : Generate(seed, allow_zero, num_random_weights) {}
  947. Weight operator()() const { return Weight(Generate::operator()()); }
  948. };
  949. template <class T>
  950. class WeightGenerate<RealWeightTpl<T>>
  951. : public FloatWeightGenerate<RealWeightTpl<T>> {
  952. public:
  953. using Weight = RealWeightTpl<T>;
  954. using Generate = FloatWeightGenerate<Weight>;
  955. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  956. bool allow_zero = true,
  957. size_t num_random_weights = kNumRandomWeights)
  958. : Generate(seed, allow_zero, num_random_weights) {}
  959. Weight operator()() const { return Weight(Generate::operator()()); }
  960. };
  961. // This function object returns random integers chosen from [0,
  962. // num_random_weights). The boolean 'allow_zero' determines whether Zero() and
  963. // zero divisors should be returned in the random weight generation. This is
  964. // intended primary for testing.
  965. template <class T>
  966. class WeightGenerate<MinMaxWeightTpl<T>> {
  967. public:
  968. using Weight = MinMaxWeightTpl<T>;
  969. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  970. bool allow_zero = true,
  971. size_t num_random_weights = kNumRandomWeights)
  972. : rand_(seed),
  973. allow_zero_(allow_zero),
  974. num_random_weights_(num_random_weights) {}
  975. Weight operator()() const {
  976. const int sample = std::uniform_int_distribution<>(
  977. -num_random_weights_, num_random_weights_ + allow_zero_)(rand_);
  978. if (allow_zero_ && sample == 0) {
  979. return Weight::Zero();
  980. } else if (sample == -num_random_weights_) {
  981. return Weight::One();
  982. } else {
  983. return Weight(sample);
  984. }
  985. }
  986. private:
  987. mutable std::mt19937_64 rand_;
  988. const bool allow_zero_;
  989. const size_t num_random_weights_;
  990. };
  991. } // namespace fst
  992. #endif // FST_FLOAT_WEIGHT_H_