You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

621 lines
19 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // LogWeight along with sign information that represents the value X in the
  19. // linear domain as <sign(X), -ln(|X|)>
  20. //
  21. // The sign is a TropicalWeight:
  22. // positive, TropicalWeight.Value() > 0.0, recommended value 1.0
  23. // negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
  24. #ifndef FST_SIGNED_LOG_WEIGHT_H_
  25. #define FST_SIGNED_LOG_WEIGHT_H_
  26. #include <climits>
  27. #include <cmath>
  28. #include <cstddef>
  29. #include <cstdint>
  30. #include <cstdlib>
  31. #include <random>
  32. #include <string>
  33. #include <fst/log.h>
  34. #include <fst/float-weight.h>
  35. #include <fst/pair-weight.h>
  36. #include <fst/product-weight.h>
  37. #include <fst/util.h>
  38. #include <fst/weight.h>
  39. namespace fst {
  40. template <class T>
  41. class SignedLogWeightTpl : public PairWeight<TropicalWeight, LogWeightTpl<T>> {
  42. public:
  43. using W1 = TropicalWeight;
  44. using W2 = LogWeightTpl<T>;
  45. using ReverseWeight = SignedLogWeightTpl;
  46. using PairWeight<W1, W2>::Value1;
  47. using PairWeight<W1, W2>::Value2;
  48. SignedLogWeightTpl() noexcept : PairWeight<W1, W2>() {}
  49. // Conversion from plain LogWeightTpl.
  50. // NOLINTNEXTLINE(google-explicit-constructor)
  51. SignedLogWeightTpl(const W2 &w2) : PairWeight<W1, W2>(W1(1.0), w2) {}
  52. explicit SignedLogWeightTpl(const PairWeight<W1, W2> &weight)
  53. : PairWeight<W1, W2>(weight) {}
  54. SignedLogWeightTpl(const W1 &w1, const W2 &w2) : PairWeight<W1, W2>(w1, w2) {}
  55. static const SignedLogWeightTpl &Zero() {
  56. static const SignedLogWeightTpl zero(W1(1.0), W2::Zero());
  57. return zero;
  58. }
  59. static const SignedLogWeightTpl &One() {
  60. static const SignedLogWeightTpl one(W1(1.0), W2::One());
  61. return one;
  62. }
  63. static const SignedLogWeightTpl &NoWeight() {
  64. static const SignedLogWeightTpl no_weight(W1(1.0), W2::NoWeight());
  65. return no_weight;
  66. }
  67. static const std::string &Type() {
  68. static const std::string *const type =
  69. new std::string("signed_log_" + W1::Type() + "_" + W2::Type());
  70. return *type;
  71. }
  72. bool IsPositive() const { return Value1().Value() > 0; }
  73. SignedLogWeightTpl Quantize(float delta = kDelta) const {
  74. return SignedLogWeightTpl(PairWeight<W1, W2>::Quantize(delta));
  75. }
  76. ReverseWeight Reverse() const {
  77. return SignedLogWeightTpl(PairWeight<W1, W2>::Reverse());
  78. }
  79. bool Member() const { return PairWeight<W1, W2>::Member(); }
  80. // Neither idempotent nor path.
  81. static constexpr uint64_t Properties() {
  82. return kLeftSemiring | kRightSemiring | kCommutative;
  83. }
  84. size_t Hash() const {
  85. size_t h1;
  86. if (Value2() == W2::Zero() || IsPositive()) {
  87. h1 = TropicalWeight(1.0).Hash();
  88. } else {
  89. h1 = TropicalWeight(-1.0).Hash();
  90. }
  91. size_t h2 = Value2().Hash();
  92. static constexpr int lshift = 5;
  93. static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5;
  94. return h1 << lshift ^ h1 >> rshift ^ h2;
  95. }
  96. };
  97. template <class T>
  98. inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
  99. const SignedLogWeightTpl<T> &w2) {
  100. using W1 = TropicalWeight;
  101. using W2 = LogWeightTpl<T>;
  102. if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
  103. const auto s1 = w1.IsPositive();
  104. const auto s2 = w2.IsPositive();
  105. const bool equal = (s1 == s2);
  106. const auto f1 = w1.Value2().Value();
  107. const auto f2 = w2.Value2().Value();
  108. if (f1 == FloatLimits<T>::PosInfinity()) {
  109. return w2;
  110. } else if (f2 == FloatLimits<T>::PosInfinity()) {
  111. return w1;
  112. } else if (f1 == f2) {
  113. if (equal) {
  114. return SignedLogWeightTpl<T>(W1(w1.Value1()), W2(f2 - M_LN2));
  115. } else {
  116. return SignedLogWeightTpl<T>::Zero();
  117. }
  118. } else if (f1 > f2) {
  119. if (equal) {
  120. return SignedLogWeightTpl<T>(W1(w1.Value1()),
  121. W2(f2 - internal::LogPosExp(f1 - f2)));
  122. } else {
  123. return SignedLogWeightTpl<T>(W1(w2.Value1()),
  124. W2((f2 - internal::LogNegExp(f1 - f2))));
  125. }
  126. } else {
  127. if (equal) {
  128. return SignedLogWeightTpl<T>(W1(w2.Value1()),
  129. W2((f1 - internal::LogPosExp(f2 - f1))));
  130. } else {
  131. return SignedLogWeightTpl<T>(W1(w1.Value1()),
  132. W2((f1 - internal::LogNegExp(f2 - f1))));
  133. }
  134. }
  135. }
  136. template <class T>
  137. inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
  138. const SignedLogWeightTpl<T> &w2) {
  139. SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
  140. return Plus(w1, minus_w2);
  141. }
  142. template <class T>
  143. inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
  144. const SignedLogWeightTpl<T> &w2) {
  145. using W2 = LogWeightTpl<T>;
  146. if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
  147. const auto s1 = w1.IsPositive();
  148. const auto s2 = w2.IsPositive();
  149. const auto f1 = w1.Value2().Value();
  150. const auto f2 = w2.Value2().Value();
  151. if (s1 == s2) {
  152. return SignedLogWeightTpl<T>(TropicalWeight(1.0), W2(f1 + f2));
  153. } else {
  154. return SignedLogWeightTpl<T>(TropicalWeight(-1.0), W2(f1 + f2));
  155. }
  156. }
  157. template <class T>
  158. inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
  159. const SignedLogWeightTpl<T> &w2,
  160. DivideType typ = DIVIDE_ANY) {
  161. using W2 = LogWeightTpl<T>;
  162. if (!w1.Member() || !w2.Member()) return SignedLogWeightTpl<T>::NoWeight();
  163. const auto s1 = w1.IsPositive();
  164. const auto s2 = w2.IsPositive();
  165. const auto f1 = w1.Value2().Value();
  166. const auto f2 = w2.Value2().Value();
  167. if (f2 == FloatLimits<T>::PosInfinity()) {
  168. return SignedLogWeightTpl<T>(TropicalWeight(1.0),
  169. W2(FloatLimits<T>::NumberBad()));
  170. } else if (f1 == FloatLimits<T>::PosInfinity()) {
  171. return SignedLogWeightTpl<T>(TropicalWeight(1.0),
  172. W2(FloatLimits<T>::PosInfinity()));
  173. } else if (s1 == s2) {
  174. return SignedLogWeightTpl<T>(TropicalWeight(1.0), W2(f1 - f2));
  175. } else {
  176. return SignedLogWeightTpl<T>(TropicalWeight(-1.0), W2(f1 - f2));
  177. }
  178. }
  179. template <class T>
  180. inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
  181. const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
  182. using W2 = LogWeightTpl<T>;
  183. if (w1.IsPositive() == w2.IsPositive()) {
  184. return ApproxEqual(w1.Value2(), w2.Value2(), delta);
  185. } else {
  186. return ApproxEqual(w1.Value2(), W2::Zero(), delta) &&
  187. ApproxEqual(w2.Value2(), W2::Zero(), delta);
  188. }
  189. }
  190. template <class T>
  191. inline bool operator==(const SignedLogWeightTpl<T> &w1,
  192. const SignedLogWeightTpl<T> &w2) {
  193. using W2 = LogWeightTpl<T>;
  194. if (w1.IsPositive() == w2.IsPositive()) {
  195. return w1.Value2() == w2.Value2();
  196. } else {
  197. return w1.Value2() == W2::Zero() && w2.Value2() == W2::Zero();
  198. }
  199. }
  200. template <class T>
  201. inline bool operator!=(const SignedLogWeightTpl<T> &w1,
  202. const SignedLogWeightTpl<T> &w2) {
  203. return !(w1 == w2);
  204. }
  205. // All functions and operators with a LogWeightTpl arg need to be
  206. // explicitly specified since the implicit constructor will not be
  207. // tried in conjunction with function overloading.
  208. template <class T>
  209. inline SignedLogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
  210. const SignedLogWeightTpl<T> &w2) {
  211. return Plus(SignedLogWeightTpl<T>(w1), w2);
  212. }
  213. template <class T>
  214. inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
  215. const LogWeightTpl<T> &w2) {
  216. return Plus(w1, SignedLogWeightTpl<T>(w2));
  217. }
  218. template <class T>
  219. inline SignedLogWeightTpl<T> Minus(const LogWeightTpl<T> &w1,
  220. const SignedLogWeightTpl<T> &w2) {
  221. return Minus(SignedLogWeightTpl<T>(w1), w2);
  222. }
  223. template <class T>
  224. inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
  225. const LogWeightTpl<T> &w2) {
  226. return Minus(w1, SignedLogWeightTpl<T>(w2));
  227. }
  228. template <class T>
  229. inline SignedLogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
  230. const SignedLogWeightTpl<T> &w2) {
  231. return Times(SignedLogWeightTpl<T>(w1), w2);
  232. }
  233. template <class T>
  234. inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
  235. const LogWeightTpl<T> &w2) {
  236. return Times(w1, SignedLogWeightTpl<T>(w2));
  237. }
  238. template <class T>
  239. inline SignedLogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
  240. const SignedLogWeightTpl<T> &w2,
  241. DivideType typ = DIVIDE_ANY) {
  242. return Divide(SignedLogWeightTpl<T>(w1), w2, typ);
  243. }
  244. template <class T>
  245. inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
  246. const LogWeightTpl<T> &w2,
  247. DivideType typ = DIVIDE_ANY) {
  248. return Divide(w1, SignedLogWeightTpl<T>(w2), typ);
  249. }
  250. template <class T>
  251. inline bool ApproxEqual(const LogWeightTpl<T> &w1,
  252. const SignedLogWeightTpl<T> &w2, float delta = kDelta) {
  253. return ApproxEqual(LogWeightTpl<T>(w1), w2, delta);
  254. }
  255. template <class T>
  256. inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
  257. const LogWeightTpl<T> &w2, float delta = kDelta) {
  258. return ApproxEqual(w1, LogWeightTpl<T>(w2), delta);
  259. }
  260. template <class T>
  261. inline bool operator==(const LogWeightTpl<T> &w1,
  262. const SignedLogWeightTpl<T> &w2) {
  263. return SignedLogWeightTpl<T>(w1) == w2;
  264. }
  265. template <class T>
  266. inline bool operator==(const SignedLogWeightTpl<T> &w1,
  267. const LogWeightTpl<T> &w2) {
  268. return w1 == SignedLogWeightTpl<T>(w2);
  269. }
  270. template <class T>
  271. inline bool operator!=(const LogWeightTpl<T> &w1,
  272. const SignedLogWeightTpl<T> &w2) {
  273. return SignedLogWeightTpl<T>(w1) != w2;
  274. }
  275. template <class T>
  276. inline bool operator!=(const SignedLogWeightTpl<T> &w1,
  277. const LogWeightTpl<T> &w2) {
  278. return w1 != SignedLogWeightTpl<T>(w2);
  279. }
  280. // Single-precision signed-log weight.
  281. using SignedLogWeight = SignedLogWeightTpl<float>;
  282. // Double-precision signed-log weight.
  283. using SignedLog64Weight = SignedLogWeightTpl<double>;
  284. template <class W1, class W2>
  285. bool SignedLogConvertCheck(W1 weight) {
  286. if (weight.Value1().Value() < 0.0) {
  287. FSTERROR() << "WeightConvert: Can't convert weight " << weight << " from "
  288. << W1::Type() << " to " << W2::Type();
  289. return false;
  290. }
  291. return true;
  292. }
  293. // Specialization using the Kahan compensated summation
  294. template <class T>
  295. class Adder<SignedLogWeightTpl<T>> {
  296. public:
  297. using Weight = SignedLogWeightTpl<T>;
  298. using W1 = TropicalWeight;
  299. using W2 = LogWeightTpl<T>;
  300. explicit Adder(Weight w = Weight::Zero())
  301. : ssum_(w.IsPositive()), sum_(w.Value2().Value()), c_(0.0) {}
  302. Weight Add(const Weight &w) {
  303. const auto sw = w.IsPositive();
  304. const auto f = w.Value2().Value();
  305. const bool equal = (ssum_ == sw);
  306. if (!Sum().Member() || f == FloatLimits<T>::PosInfinity()) {
  307. return Sum();
  308. } else if (!w.Member() || sum_ == FloatLimits<T>::PosInfinity()) {
  309. sum_ = f;
  310. ssum_ = sw;
  311. c_ = 0.0;
  312. } else if (f == sum_) {
  313. if (equal) {
  314. sum_ = internal::KahanLogSum(sum_, f, &c_);
  315. } else {
  316. sum_ = FloatLimits<T>::PosInfinity();
  317. ssum_ = true;
  318. c_ = 0.0;
  319. }
  320. } else if (f > sum_) {
  321. if (equal) {
  322. sum_ = internal::KahanLogSum(sum_, f, &c_);
  323. } else {
  324. sum_ = internal::KahanLogDiff(sum_, f, &c_);
  325. }
  326. } else {
  327. if (equal) {
  328. sum_ = internal::KahanLogSum(f, sum_, &c_);
  329. } else {
  330. sum_ = internal::KahanLogDiff(f, sum_, &c_);
  331. ssum_ = sw;
  332. }
  333. }
  334. return Sum();
  335. }
  336. Weight Sum() const { return Weight(W1(ssum_ ? 1.0 : -1.0), W2(sum_)); }
  337. void Reset(Weight w = Weight::Zero()) {
  338. ssum_ = w.IsPositive();
  339. sum_ = w.Value2().Value();
  340. c_ = 0.0;
  341. }
  342. private:
  343. bool ssum_; // true iff sign of sum is positive
  344. double sum_; // unsigned sum
  345. double c_; // Kahan compensation
  346. };
  347. // Converts to tropical.
  348. template <>
  349. struct WeightConvert<SignedLogWeight, TropicalWeight> {
  350. TropicalWeight operator()(const SignedLogWeight &weight) const {
  351. if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(weight)) {
  352. return TropicalWeight::NoWeight();
  353. }
  354. return TropicalWeight(weight.Value2().Value());
  355. }
  356. };
  357. template <>
  358. struct WeightConvert<SignedLog64Weight, TropicalWeight> {
  359. TropicalWeight operator()(const SignedLog64Weight &weight) const {
  360. if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(weight)) {
  361. return TropicalWeight::NoWeight();
  362. }
  363. return TropicalWeight(weight.Value2().Value());
  364. }
  365. };
  366. // Converts to log.
  367. template <>
  368. struct WeightConvert<SignedLogWeight, LogWeight> {
  369. LogWeight operator()(const SignedLogWeight &weight) const {
  370. if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(weight)) {
  371. return LogWeight::NoWeight();
  372. }
  373. return LogWeight(weight.Value2().Value());
  374. }
  375. };
  376. template <>
  377. struct WeightConvert<SignedLog64Weight, LogWeight> {
  378. LogWeight operator()(const SignedLog64Weight &weight) const {
  379. if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(weight)) {
  380. return LogWeight::NoWeight();
  381. }
  382. return LogWeight(weight.Value2().Value());
  383. }
  384. };
  385. // Converts to log64.
  386. template <>
  387. struct WeightConvert<SignedLogWeight, Log64Weight> {
  388. Log64Weight operator()(const SignedLogWeight &weight) const {
  389. if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(weight)) {
  390. return Log64Weight::NoWeight();
  391. }
  392. return Log64Weight(weight.Value2().Value());
  393. }
  394. };
  395. template <>
  396. struct WeightConvert<SignedLog64Weight, Log64Weight> {
  397. Log64Weight operator()(const SignedLog64Weight &weight) const {
  398. if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(weight)) {
  399. return Log64Weight::NoWeight();
  400. }
  401. return Log64Weight(weight.Value2().Value());
  402. }
  403. };
  404. // Converts to real.
  405. template <>
  406. struct WeightConvert<SignedLogWeight, RealWeight> {
  407. RealWeight operator()(const SignedLogWeight &weight) const {
  408. return RealWeight(weight.Value1().Value() * exp(-weight.Value2().Value()));
  409. }
  410. };
  411. template <>
  412. struct WeightConvert<SignedLog64Weight, RealWeight> {
  413. RealWeight operator()(const SignedLog64Weight &weight) const {
  414. return RealWeight(weight.Value1().Value() * exp(-weight.Value2().Value()));
  415. }
  416. };
  417. // Converts to real64.
  418. template <>
  419. struct WeightConvert<SignedLogWeight, Real64Weight> {
  420. Real64Weight operator()(const SignedLogWeight &weight) const {
  421. return Real64Weight(weight.Value1().Value() *
  422. exp(-weight.Value2().Value()));
  423. }
  424. };
  425. template <>
  426. struct WeightConvert<SignedLog64Weight, Real64Weight> {
  427. Real64Weight operator()(const SignedLog64Weight &weight) const {
  428. return Real64Weight(weight.Value1().Value() *
  429. exp(-weight.Value2().Value()));
  430. }
  431. };
  432. // Converts to signed log.
  433. template <>
  434. struct WeightConvert<TropicalWeight, SignedLogWeight> {
  435. SignedLogWeight operator()(const TropicalWeight &weight) const {
  436. return SignedLogWeight(1.0, weight.Value());
  437. }
  438. };
  439. template <>
  440. struct WeightConvert<LogWeight, SignedLogWeight> {
  441. SignedLogWeight operator()(const LogWeight &weight) const {
  442. return SignedLogWeight(1.0, weight.Value());
  443. }
  444. };
  445. template <>
  446. struct WeightConvert<Log64Weight, SignedLogWeight> {
  447. SignedLogWeight operator()(const Log64Weight &weight) const {
  448. return SignedLogWeight(1.0, weight.Value());
  449. }
  450. };
  451. template <>
  452. struct WeightConvert<RealWeight, SignedLogWeight> {
  453. SignedLogWeight operator()(const RealWeight &weight) const {
  454. return SignedLogWeight(weight.Value() >= 0 ? 1.0 : -1.0,
  455. -log(std::abs(weight.Value())));
  456. }
  457. };
  458. template <>
  459. struct WeightConvert<Real64Weight, SignedLogWeight> {
  460. SignedLogWeight operator()(const Real64Weight &weight) const {
  461. return SignedLogWeight(weight.Value() >= 0 ? 1.0 : -1.0,
  462. -log(std::abs(weight.Value())));
  463. }
  464. };
  465. template <>
  466. struct WeightConvert<SignedLog64Weight, SignedLogWeight> {
  467. SignedLogWeight operator()(const SignedLog64Weight &weight) const {
  468. return SignedLogWeight(weight.Value1(), weight.Value2().Value());
  469. }
  470. };
  471. // Converts to signed log64.
  472. template <>
  473. struct WeightConvert<TropicalWeight, SignedLog64Weight> {
  474. SignedLog64Weight operator()(const TropicalWeight &weight) const {
  475. return SignedLog64Weight(1.0, weight.Value());
  476. }
  477. };
  478. template <>
  479. struct WeightConvert<LogWeight, SignedLog64Weight> {
  480. SignedLog64Weight operator()(const LogWeight &weight) const {
  481. return SignedLog64Weight(1.0, weight.Value());
  482. }
  483. };
  484. template <>
  485. struct WeightConvert<Log64Weight, SignedLog64Weight> {
  486. SignedLog64Weight operator()(const Log64Weight &weight) const {
  487. return SignedLog64Weight(1.0, weight.Value());
  488. }
  489. };
  490. template <>
  491. struct WeightConvert<RealWeight, SignedLog64Weight> {
  492. SignedLog64Weight operator()(const RealWeight &weight) const {
  493. return SignedLog64Weight(weight.Value() >= 0 ? 1.0 : -1.0,
  494. -log(std::abs(weight.Value())));
  495. }
  496. };
  497. template <>
  498. struct WeightConvert<Real64Weight, SignedLog64Weight> {
  499. SignedLog64Weight operator()(const Real64Weight &weight) const {
  500. return SignedLog64Weight(weight.Value() >= 0 ? 1.0 : -1.0,
  501. -log(std::abs(weight.Value())));
  502. }
  503. };
  504. template <>
  505. struct WeightConvert<SignedLogWeight, SignedLog64Weight> {
  506. SignedLog64Weight operator()(const SignedLogWeight &weight) const {
  507. return SignedLog64Weight(weight.Value1(), weight.Value2().Value());
  508. }
  509. };
  510. // This function object returns SignedLogWeightTpl<T>'s that are random integers
  511. // chosen from [0, num_random_weights) times a random sign. This is intended
  512. // primarily for testing.
  513. template <class T>
  514. class WeightGenerate<SignedLogWeightTpl<T>> {
  515. public:
  516. using Weight = SignedLogWeightTpl<T>;
  517. using W1 = typename Weight::W1;
  518. using W2 = typename Weight::W2;
  519. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  520. bool allow_zero = true,
  521. size_t num_random_weights = kNumRandomWeights)
  522. : rand_(seed),
  523. allow_zero_(allow_zero),
  524. num_random_weights_(num_random_weights) {}
  525. Weight operator()() const {
  526. static constexpr W1 negative(-1.0);
  527. static constexpr W1 positive(+1.0);
  528. const bool sign = std::bernoulli_distribution(.5)(rand_);
  529. const int sample = std::uniform_int_distribution<>(
  530. 0, num_random_weights_ + allow_zero_ - 1)(rand_);
  531. if (allow_zero_ && sample == num_random_weights_) {
  532. return Weight(sign ? positive : negative, W2::Zero());
  533. }
  534. return Weight(sign ? positive : negative, W2(sample));
  535. }
  536. private:
  537. mutable std::mt19937_64 rand_;
  538. const bool allow_zero_;
  539. const size_t num_random_weights_;
  540. };
  541. } // namespace fst
  542. #endif // FST_SIGNED_LOG_WEIGHT_H_