You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
5.2 KiB

  1. // base/kaldi-math.cc
  2. // Copyright 2009-2011 Microsoft Corporation; Yanmin Qian;
  3. // Saarland University; Jan Silovsky
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #include "base/kaldi-math.h"
  19. #ifndef _MSC_VER
  20. #include <stdlib.h>
  21. #include <unistd.h>
  22. #endif
  23. #include <mutex>
  24. #include <string>
  25. namespace kaldi {
  26. // These routines are tested in matrix/matrix-test.cc
  27. int32 RoundUpToNearestPowerOfTwo(int32 n) {
  28. KALDI_ASSERT(n > 0);
  29. n--;
  30. n |= n >> 1;
  31. n |= n >> 2;
  32. n |= n >> 4;
  33. n |= n >> 8;
  34. n |= n >> 16;
  35. return n + 1;
  36. }
  37. static std::mutex _RandMutex;
  38. int Rand(struct RandomState* state) {
  39. #if !defined(_POSIX_THREAD_SAFE_FUNCTIONS)
  40. // On Windows and Cygwin, just call Rand()
  41. return rand();
  42. #else
  43. if (state) {
  44. return rand_r(&(state->seed));
  45. } else {
  46. std::lock_guard<std::mutex> lock(_RandMutex);
  47. return rand();
  48. }
  49. #endif
  50. }
  51. RandomState::RandomState() {
  52. // we initialize it as Rand() + 27437 instead of just Rand(), because on some
  53. // systems, e.g. at the very least Mac OSX Yosemite and later, it seems to be
  54. // the case that rand_r when initialized with rand() will give you the exact
  55. // same sequence of numbers that rand() will give if you keep calling rand()
  56. // after that initial call. This can cause problems with repeated sequences.
  57. // For example if you initialize two RandomState structs one after the other
  58. // without calling rand() in between, they would give you the same sequence
  59. // offset by one (if we didn't have the "+ 27437" in the code). 27437 is just
  60. // a randomly chosen prime number.
  61. seed = unsigned(Rand()) + 27437;
  62. }
  63. bool WithProb(BaseFloat prob, struct RandomState* state) {
  64. KALDI_ASSERT(prob >= 0 && prob <= 1.1); // prob should be <= 1.0,
  65. // but we allow slightly larger values that could arise from roundoff in
  66. // previous calculations.
  67. KALDI_COMPILE_TIME_ASSERT(RAND_MAX > 128 * 128);
  68. if (prob == 0) {
  69. return false;
  70. } else if (prob == 1.0) {
  71. return true;
  72. } else if (prob * RAND_MAX < 128.0) {
  73. // prob is very small but nonzero, and the "main algorithm"
  74. // wouldn't work that well. So: with probability 1/128, we
  75. // return WithProb (prob * 128), else return false.
  76. if (Rand(state) < RAND_MAX / 128) { // with probability 128...
  77. // Note: we know that prob * 128.0 < 1.0, because
  78. // we asserted RAND_MAX > 128 * 128.
  79. return WithProb(prob * 128.0);
  80. } else {
  81. return false;
  82. }
  83. } else {
  84. return (Rand(state) < ((RAND_MAX + static_cast<BaseFloat>(1.0)) * prob));
  85. }
  86. }
  87. int32 RandInt(int32 min_val, int32 max_val, struct RandomState* state) {
  88. // This is not exact.
  89. KALDI_ASSERT(max_val >= min_val);
  90. if (max_val == min_val) return min_val;
  91. #ifdef _MSC_VER
  92. // RAND_MAX is quite small on Windows -> may need to handle larger numbers.
  93. if (RAND_MAX > (max_val - min_val) * 8) {
  94. // *8 to avoid large inaccuracies in probability, from the modulus...
  95. return min_val +
  96. ((unsigned int)Rand(state) % (unsigned int)(max_val + 1 - min_val));
  97. } else {
  98. if ((unsigned int)(RAND_MAX * RAND_MAX) >
  99. (unsigned int)((max_val + 1 - min_val) * 8)) {
  100. // *8 to avoid inaccuracies in probability, from the modulus...
  101. return min_val + ((unsigned int)((Rand(state) + RAND_MAX * Rand(state))) %
  102. (unsigned int)(max_val + 1 - min_val));
  103. } else {
  104. KALDI_ERR << "rand_int failed because we do not support such large "
  105. "random numbers. (Extend this function).";
  106. }
  107. }
  108. #else
  109. return min_val + (static_cast<int32>(Rand(state)) %
  110. static_cast<int32>(max_val + 1 - min_val));
  111. #endif
  112. }
  113. // Returns poisson-distributed random number.
  114. // Take care: this takes time proportional
  115. // to lambda. Faster algorithms exist but are more complex.
  116. int32 RandPoisson(float lambda, struct RandomState* state) {
  117. // Knuth's algorithm.
  118. KALDI_ASSERT(lambda >= 0);
  119. float L = expf(-lambda), p = 1.0;
  120. int32 k = 0;
  121. do {
  122. k++;
  123. float u = RandUniform(state);
  124. p *= u;
  125. } while (p > L);
  126. return k - 1;
  127. }
  128. void RandGauss2(float* a, float* b, RandomState* state) {
  129. KALDI_ASSERT(a);
  130. KALDI_ASSERT(b);
  131. float u1 = RandUniform(state);
  132. float u2 = RandUniform(state);
  133. u1 = sqrtf(-2.0f * logf(u1));
  134. u2 = 2.0f * M_PI * u2;
  135. *a = u1 * cosf(u2);
  136. *b = u1 * sinf(u2);
  137. }
  138. void RandGauss2(double* a, double* b, RandomState* state) {
  139. KALDI_ASSERT(a);
  140. KALDI_ASSERT(b);
  141. float a_float, b_float;
  142. // Just because we're using doubles doesn't mean we need super-high-quality
  143. // random numbers, so we just use the floating-point version internally.
  144. RandGauss2(&a_float, &b_float, state);
  145. *a = a_float;
  146. *b = b_float;
  147. }
  148. } // end namespace kaldi