You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

352 lines
10 KiB

  1. // base/kaldi-math.h
  2. // Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian;
  3. // Jan Silovsky; Saarland University
  4. //
  5. // See ../../COPYING for clarification regarding multiple authors
  6. //
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. //
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. //
  13. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  15. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  16. // MERCHANTABLITY OR NON-INFRINGEMENT.
  17. // See the Apache 2 License for the specific language governing permissions and
  18. // limitations under the License.
  19. #ifndef KALDI_BASE_KALDI_MATH_H_
  20. #define KALDI_BASE_KALDI_MATH_H_ 1
  21. #ifdef _MSC_VER
  22. #include <float.h>
  23. #endif
  24. #include <cmath>
  25. #include <limits>
  26. #include <vector>
  27. #include "base/kaldi-common.h"
  28. #include "base/kaldi-types.h"
  29. #ifndef DBL_EPSILON
  30. #define DBL_EPSILON 2.2204460492503131e-16
  31. #endif
  32. #ifndef FLT_EPSILON
  33. #define FLT_EPSILON 1.19209290e-7f
  34. #endif
  35. #ifndef M_PI
  36. #define M_PI 3.1415926535897932384626433832795
  37. #endif
  38. #ifndef M_SQRT2
  39. #define M_SQRT2 1.4142135623730950488016887
  40. #endif
  41. #ifndef M_2PI
  42. #define M_2PI 6.283185307179586476925286766559005
  43. #endif
  44. #ifndef M_SQRT1_2
  45. #define M_SQRT1_2 0.7071067811865475244008443621048490
  46. #endif
  47. #ifndef M_LOG_2PI
  48. #define M_LOG_2PI 1.8378770664093454835606594728112
  49. #endif
  50. #ifndef M_LN2
  51. #define M_LN2 0.693147180559945309417232121458
  52. #endif
  53. #ifndef M_LN10
  54. #define M_LN10 2.302585092994045684017991454684
  55. #endif
  56. #define KALDI_ISNAN std::isnan
  57. #define KALDI_ISINF std::isinf
  58. #define KALDI_ISFINITE(x) std::isfinite(x)
  59. #if !defined(KALDI_SQR)
  60. #define KALDI_SQR(x) ((x) * (x))
  61. #endif
  62. namespace kaldi {
  63. #if !defined(_MSC_VER) || (_MSC_VER >= 1900)
  64. inline double Exp(double x) { return exp(x); }
  65. #ifndef KALDI_NO_EXPF
  66. inline float Exp(float x) { return expf(x); }
  67. #else
  68. inline float Exp(float x) { return exp(static_cast<double>(x)); }
  69. #endif // KALDI_NO_EXPF
  70. #else
  71. inline double Exp(double x) { return exp(x); }
  72. #if !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
  73. // Microsoft CL v18.0 buggy 64-bit implementation of
  74. // expf() incorrectly returns -inf for exp(-inf).
  75. inline float Exp(float x) { return exp(static_cast<double>(x)); }
  76. #else
  77. inline float Exp(float x) { return expf(x); }
  78. #endif // !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
  79. #endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)
  80. inline double Log(double x) { return log(x); }
  81. inline float Log(float x) { return logf(x); }
  82. #if !defined(_MSC_VER) || (_MSC_VER >= 1700)
  83. inline double Log1p(double x) { return log1p(x); }
  84. inline float Log1p(float x) { return log1pf(x); }
  85. #else
  86. inline double Log1p(double x) {
  87. const double cutoff = 1.0e-08;
  88. if (x < cutoff)
  89. return x - 0.5 * x * x;
  90. else
  91. return Log(1.0 + x);
  92. }
  93. inline float Log1p(float x) {
  94. const float cutoff = 1.0e-07;
  95. if (x < cutoff)
  96. return x - 0.5 * x * x;
  97. else
  98. return Log(1.0 + x);
  99. }
  100. #endif
  101. static const double kMinLogDiffDouble = Log(DBL_EPSILON); // negative!
  102. static const float kMinLogDiffFloat = Log(FLT_EPSILON); // negative!
  103. // -infinity
  104. const float kLogZeroFloat = -std::numeric_limits<float>::infinity();
  105. const double kLogZeroDouble = -std::numeric_limits<double>::infinity();
  106. const BaseFloat kLogZeroBaseFloat = -std::numeric_limits<BaseFloat>::infinity();
  107. // Returns a random integer between 0 and RAND_MAX, inclusive
  108. int Rand(struct RandomState* state = NULL);
  109. // State for thread-safe random number generator
  110. struct RandomState {
  111. RandomState();
  112. unsigned seed;
  113. };
  114. // Returns a random integer between first and last inclusive.
  115. int32 RandInt(int32 first, int32 last, struct RandomState* state = NULL);
  116. // Returns true with probability "prob",
  117. bool WithProb(BaseFloat prob, struct RandomState* state = NULL);
  118. // with 0 <= prob <= 1 [we check this].
  119. // Internally calls Rand(). This function is carefully implemented so
  120. // that it should work even if prob is very small.
  121. /// Returns a random number strictly between 0 and 1.
  122. inline float RandUniform(struct RandomState* state = NULL) {
  123. return static_cast<float>((Rand(state) + 1.0) / (RAND_MAX + 2.0));
  124. }
  125. inline float RandGauss(struct RandomState* state = NULL) {
  126. return static_cast<float>(sqrtf(-2 * Log(RandUniform(state))) *
  127. cosf(2 * M_PI * RandUniform(state)));
  128. }
  129. // Returns poisson-distributed random number. Uses Knuth's algorithm.
  130. // Take care: this takes time proportional
  131. // to lambda. Faster algorithms exist but are more complex.
  132. int32 RandPoisson(float lambda, struct RandomState* state = NULL);
  133. // Returns a pair of gaussian random numbers. Uses Box-Muller transform
  134. void RandGauss2(float* a, float* b, RandomState* state = NULL);
  135. void RandGauss2(double* a, double* b, RandomState* state = NULL);
  136. // Also see Vector<float,double>::RandCategorical().
  137. // This is a randomized pruning mechanism that preserves expectations,
  138. // that we typically use to prune posteriors.
  139. template <class Float>
  140. inline Float RandPrune(Float post, BaseFloat prune_thresh,
  141. struct RandomState* state = NULL) {
  142. KALDI_ASSERT(prune_thresh >= 0.0);
  143. if (post == 0.0 || std::abs(post) >= prune_thresh) return post;
  144. return (post >= 0 ? 1.0 : -1.0) *
  145. (RandUniform(state) <= fabs(post) / prune_thresh ? prune_thresh : 0.0);
  146. }
  147. // returns log(exp(x) + exp(y)).
  148. inline double LogAdd(double x, double y) {
  149. double diff;
  150. if (x < y) {
  151. diff = x - y;
  152. x = y;
  153. } else {
  154. diff = y - x;
  155. }
  156. // diff is negative. x is now the larger one.
  157. if (diff >= kMinLogDiffDouble) {
  158. double res;
  159. res = x + Log1p(Exp(diff));
  160. return res;
  161. } else {
  162. return x; // return the larger one.
  163. }
  164. }
  165. // returns log(exp(x) + exp(y)).
  166. inline float LogAdd(float x, float y) {
  167. float diff;
  168. if (x < y) {
  169. diff = x - y;
  170. x = y;
  171. } else {
  172. diff = y - x;
  173. }
  174. // diff is negative. x is now the larger one.
  175. if (diff >= kMinLogDiffFloat) {
  176. float res;
  177. res = x + Log1p(Exp(diff));
  178. return res;
  179. } else {
  180. return x; // return the larger one.
  181. }
  182. }
  183. // returns log(exp(x) - exp(y)).
  184. inline double LogSub(double x, double y) {
  185. if (y >= x) { // Throws exception if y>=x.
  186. if (y == x)
  187. return kLogZeroDouble;
  188. else
  189. KALDI_ERR << "Cannot subtract a larger from a smaller number.";
  190. }
  191. double diff = y - x; // Will be negative.
  192. double res = x + Log(1.0 - Exp(diff));
  193. // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
  194. if (KALDI_ISNAN(res)) return kLogZeroDouble;
  195. return res;
  196. }
  197. // returns log(exp(x) - exp(y)).
  198. inline float LogSub(float x, float y) {
  199. if (y >= x) { // Throws exception if y>=x.
  200. if (y == x)
  201. return kLogZeroDouble;
  202. else
  203. KALDI_ERR << "Cannot subtract a larger from a smaller number.";
  204. }
  205. float diff = y - x; // Will be negative.
  206. float res = x + Log(1.0f - Exp(diff));
  207. // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
  208. if (KALDI_ISNAN(res)) return kLogZeroFloat;
  209. return res;
  210. }
  211. /// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
  212. static inline bool ApproxEqual(float a, float b,
  213. float relative_tolerance = 0.001) {
  214. // a==b handles infinities.
  215. if (a == b) return true;
  216. float diff = std::abs(a - b);
  217. if (diff == std::numeric_limits<float>::infinity() || diff != diff)
  218. return false; // diff is +inf or nan.
  219. return (diff <= relative_tolerance * (std::abs(a) + std::abs(b)));
  220. }
  221. /// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
  222. static inline void AssertEqual(float a, float b,
  223. float relative_tolerance = 0.001) {
  224. // a==b handles infinities.
  225. KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance));
  226. }
  227. // RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0.
  228. int32 RoundUpToNearestPowerOfTwo(int32 n);
  229. /// Returns a / b, rounding towards negative infinity in all cases.
  230. static inline int32 DivideRoundingDown(int32 a, int32 b) {
  231. KALDI_ASSERT(b != 0);
  232. if (a * b >= 0)
  233. return a / b;
  234. else if (a < 0)
  235. return (a - b + 1) / b;
  236. else
  237. return (a - b - 1) / b;
  238. }
  239. template <class I>
  240. I Gcd(I m, I n) {
  241. if (m == 0 || n == 0) {
  242. if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
  243. KALDI_ERR << "Undefined GCD since m = 0, n = 0.";
  244. }
  245. return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
  246. // return absolute value of whichever is nonzero
  247. }
  248. // could use compile-time assertion
  249. // but involves messing with complex template stuff.
  250. KALDI_ASSERT(std::numeric_limits<I>::is_integer);
  251. while (1) {
  252. m %= n;
  253. if (m == 0) return (n > 0 ? n : -n);
  254. n %= m;
  255. if (n == 0) return (m > 0 ? m : -m);
  256. }
  257. }
  258. /// Returns the least common multiple of two integers. Will
  259. /// crash unless the inputs are positive.
  260. template <class I>
  261. I Lcm(I m, I n) {
  262. KALDI_ASSERT(m > 0 && n > 0);
  263. I gcd = Gcd(m, n);
  264. return gcd * (m / gcd) * (n / gcd);
  265. }
  266. template <class I>
  267. void Factorize(I m, std::vector<I>* factors) {
  268. // Splits a number into its prime factors, in sorted order from
  269. // least to greatest, with duplication. A very inefficient
  270. // algorithm, which is mainly intended for use in the
  271. // mixed-radix FFT computation (where we assume most factors
  272. // are small).
  273. KALDI_ASSERT(factors != NULL);
  274. KALDI_ASSERT(m >= 1); // Doesn't work for zero or negative numbers.
  275. factors->clear();
  276. I small_factors[10] = {2, 3, 5, 7, 11, 13, 17, 19, 23, 29};
  277. // First try small factors.
  278. for (I i = 0; i < 10; i++) {
  279. if (m == 1) return; // We're done.
  280. while (m % small_factors[i] == 0) {
  281. m /= small_factors[i];
  282. factors->push_back(small_factors[i]);
  283. }
  284. }
  285. // Next try all odd numbers starting from 31.
  286. for (I j = 31;; j += 2) {
  287. if (m == 1) return;
  288. while (m % j == 0) {
  289. m /= j;
  290. factors->push_back(j);
  291. }
  292. }
  293. }
  294. inline double Hypot(double x, double y) { return hypot(x, y); }
  295. inline float Hypot(float x, float y) { return hypotf(x, y); }
  296. } // namespace kaldi
  297. #endif // KALDI_BASE_KALDI_MATH_H_