You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

892 lines
31 KiB

  1. // fstext/lattice-weight.h
  2. // Copyright 2009-2012 Microsoft Corporation
  3. // Johns Hopkins University (author: Daniel Povey)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_
  19. #define KALDI_FSTEXT_LATTICE_WEIGHT_H_
  20. #include <algorithm>
  21. #include <limits>
  22. #include <string>
  23. #include <vector>
  24. #include "base/kaldi-common.h"
  25. //#include "fst/fstlib.h"
  26. namespace fst {
  27. // Declare weight type for lattice... will import to namespace kaldi. has two
  28. // members, value1_ and value2_, of type BaseFloat (normally equals float). It
  29. // is basically the same as the tropical semiring on value1_+value2_, except it
  30. // keeps track of a and b separately. More precisely, it is equivalent to the
  31. // lexicographic semiring on (value1_+value2_), (value1_-value2_)
  32. template <class FloatType>
  33. class LatticeWeightTpl;
  34. template <class FloatType>
  35. inline std::ostream& operator<<(std::ostream& strm,
  36. const LatticeWeightTpl<FloatType>& w);
  37. template <class FloatType>
  38. inline std::istream& operator>>(std::istream& strm,
  39. LatticeWeightTpl<FloatType>& w);
  40. template <class FloatType>
  41. class LatticeWeightTpl {
  42. public:
  43. typedef FloatType T; // normally float.
  44. typedef LatticeWeightTpl ReverseWeight;
  45. inline T Value1() const { return value1_; }
  46. inline T Value2() const { return value2_; }
  47. inline void SetValue1(T f) { value1_ = f; }
  48. inline void SetValue2(T f) { value2_ = f; }
  49. LatticeWeightTpl() : value1_{}, value2_{} {}
  50. LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}
  51. LatticeWeightTpl(const LatticeWeightTpl& other)
  52. : value1_(other.value1_), value2_(other.value2_) {}
  53. LatticeWeightTpl& operator=(const LatticeWeightTpl& w) {
  54. value1_ = w.value1_;
  55. value2_ = w.value2_;
  56. return *this;
  57. }
  58. LatticeWeightTpl<FloatType> Reverse() const { return *this; }
  59. static const LatticeWeightTpl Zero() {
  60. return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
  61. std::numeric_limits<T>::infinity());
  62. }
  63. static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); }
  64. static const std::string& Type() {
  65. static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8");
  66. return type;
  67. }
  68. static const LatticeWeightTpl NoWeight() {
  69. return LatticeWeightTpl(std::numeric_limits<FloatType>::quiet_NaN(),
  70. std::numeric_limits<FloatType>::quiet_NaN());
  71. }
  72. bool Member() const {
  73. // value1_ == value1_ tests for NaN.
  74. // also test for no -inf, and either both or neither
  75. // must be +inf, and
  76. if (value1_ != value1_ || value2_ != value2_) return false; // NaN
  77. if (value1_ == -std::numeric_limits<T>::infinity() ||
  78. value2_ == -std::numeric_limits<T>::infinity())
  79. return false; // -infty not allowed
  80. if (value1_ == std::numeric_limits<T>::infinity() ||
  81. value2_ == std::numeric_limits<T>::infinity()) {
  82. if (value1_ != std::numeric_limits<T>::infinity() ||
  83. value2_ != std::numeric_limits<T>::infinity())
  84. return false; // both must be +infty;
  85. // this is necessary so that the semiring has only one zero.
  86. }
  87. return true;
  88. }
  89. LatticeWeightTpl Quantize(float delta = kDelta) const {
  90. if (value1_ + value2_ == -std::numeric_limits<T>::infinity()) {
  91. return LatticeWeightTpl(-std::numeric_limits<T>::infinity(),
  92. -std::numeric_limits<T>::infinity());
  93. } else if (value1_ + value2_ == std::numeric_limits<T>::infinity()) {
  94. return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
  95. std::numeric_limits<T>::infinity());
  96. } else if (value1_ + value2_ != value1_ + value2_) { // NaN
  97. return LatticeWeightTpl(value1_ + value2_, value1_ + value2_);
  98. } else {
  99. return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta,
  100. floor(value2_ / delta + 0.5F) * delta);
  101. }
  102. }
  103. static constexpr uint64 Properties() {
  104. return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
  105. }
  106. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  107. // not Kaldi-style, I/O.
  108. std::istream& Read(std::istream& strm) {
  109. // Always read/write as float, even if T is double,
  110. // so we can use OpenFst-style read/write and still maintain
  111. // compatibility when compiling with different FloatTypes
  112. ReadType(strm, &value1_);
  113. ReadType(strm, &value2_);
  114. return strm;
  115. }
  116. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  117. // not Kaldi-style, I/O.
  118. std::ostream& Write(std::ostream& strm) const {
  119. WriteType(strm, value1_);
  120. WriteType(strm, value2_);
  121. return strm;
  122. }
  123. size_t Hash() const {
  124. size_t ans;
  125. union {
  126. T f;
  127. size_t s;
  128. } u;
  129. u.s = 0;
  130. u.f = value1_;
  131. ans = u.s;
  132. u.f = value2_;
  133. ans += u.s;
  134. return ans;
  135. }
  136. protected:
  137. inline static void WriteFloatType(std::ostream& strm, const T& f) {
  138. if (f == std::numeric_limits<T>::infinity())
  139. strm << "Infinity";
  140. else if (f == -std::numeric_limits<T>::infinity())
  141. strm << "-Infinity";
  142. else if (f != f)
  143. strm << "BadNumber";
  144. else
  145. strm << f;
  146. }
  147. // Internal helper function, used in ReadNoParen.
  148. inline static void ReadFloatType(std::istream& strm, T& f) { // NOLINT
  149. std::string s;
  150. strm >> s;
  151. if (s == "Infinity") {
  152. f = std::numeric_limits<T>::infinity();
  153. } else if (s == "-Infinity") {
  154. f = -std::numeric_limits<T>::infinity();
  155. } else if (s == "BadNumber") {
  156. f = std::numeric_limits<T>::quiet_NaN();
  157. } else {
  158. char* p;
  159. f = strtod(s.c_str(), &p);
  160. if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit);
  161. }
  162. }
  163. // Reads LatticeWeight when there are no parentheses around pair terms...
  164. // currently the only form supported.
  165. inline std::istream& ReadNoParen(std::istream& strm, char separator) {
  166. int c;
  167. do {
  168. c = strm.get();
  169. } while (isspace(c));
  170. std::string s1;
  171. while (c != separator) {
  172. if (c == EOF) {
  173. strm.clear(std::ios::badbit);
  174. return strm;
  175. }
  176. s1 += c;
  177. c = strm.get();
  178. }
  179. std::istringstream strm1(s1);
  180. ReadFloatType(strm1, value1_); // ReadFloatType is class member function
  181. // read second element
  182. ReadFloatType(strm, value2_);
  183. return strm;
  184. }
  185. friend std::istream& operator>>
  186. <FloatType>(std::istream&, LatticeWeightTpl<FloatType>&);
  187. friend std::ostream& operator<< <FloatType>(
  188. std::ostream&, const LatticeWeightTpl<FloatType>&);
  189. private:
  190. T value1_;
  191. T value2_;
  192. };
  193. /* ScaleTupleWeight is a function defined for LatticeWeightTpl and
  194. CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2
  195. matrix. Used, for example, in applying acoustic scaling.
  196. */
  197. template <class FloatType, class ScaleFloatType>
  198. inline LatticeWeightTpl<FloatType> ScaleTupleWeight(
  199. const LatticeWeightTpl<FloatType>& w,
  200. const std::vector<std::vector<ScaleFloatType> >& scale) {
  201. // Without the next special case we'd get NaNs from infinity * 0
  202. if (w.Value1() == std::numeric_limits<FloatType>::infinity())
  203. return LatticeWeightTpl<FloatType>::Zero();
  204. return LatticeWeightTpl<FloatType>(
  205. scale[0][0] * w.Value1() + scale[0][1] * w.Value2(),
  206. scale[1][0] * w.Value1() + scale[1][1] * w.Value2());
  207. }
  208. /* For testing purposes and in case it's ever useful, we define a similar
  209. function to apply to LexicographicWeight and the like, templated on
  210. TropicalWeight<float> etc.; we use PairWeight which is the base class of
  211. LexicographicWeight.
  212. */
  213. template <class FloatType, class ScaleFloatType>
  214. inline PairWeight<TropicalWeightTpl<FloatType>, TropicalWeightTpl<FloatType> >
  215. ScaleTupleWeight(const PairWeight<TropicalWeightTpl<FloatType>,
  216. TropicalWeightTpl<FloatType> >& w,
  217. const std::vector<std::vector<ScaleFloatType> >& scale) {
  218. typedef TropicalWeightTpl<FloatType> BaseType;
  219. typedef PairWeight<BaseType, BaseType> PairType;
  220. const BaseType zero = BaseType::Zero();
  221. // Without the next special case we'd get NaNs from infinity * 0
  222. if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero);
  223. FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value();
  224. return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2),
  225. BaseType(scale[1][0] * f1 + scale[1][1] * f2));
  226. }
  227. template <class FloatType>
  228. inline bool operator==(const LatticeWeightTpl<FloatType>& wa,
  229. const LatticeWeightTpl<FloatType>& wb) {
  230. // Volatile qualifier thwarts over-aggressive compiler optimizations
  231. // that lead to problems esp. with NaturalLess().
  232. volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
  233. vb2 = wb.Value2();
  234. return (va1 == vb1 && va2 == vb2);
  235. }
  236. template <class FloatType>
  237. inline bool operator!=(const LatticeWeightTpl<FloatType>& wa,
  238. const LatticeWeightTpl<FloatType>& wb) {
  239. // Volatile qualifier thwarts over-aggressive compiler optimizations
  240. // that lead to problems esp. with NaturalLess().
  241. volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
  242. vb2 = wb.Value2();
  243. return (va1 != vb1 || va2 != vb2);
  244. }
  245. // We define a Compare function LatticeWeightTpl even though it's
  246. // not required by the semiring standard-- it's just more efficient
  247. // to do it this way rather than using the NaturalLess template.
  248. /// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
  249. template <class FloatType>
  250. inline int Compare(const LatticeWeightTpl<FloatType>& w1,
  251. const LatticeWeightTpl<FloatType>& w2) {
  252. FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2();
  253. if (f1 < f2) { // having smaller cost means you're larger
  254. return 1;
  255. } else if (f1 > f2) { // in the semiring [higher probability]
  256. return -1;
  257. } else if (w1.Value1() < w2.Value1()) {
  258. // mathematically we should be comparing (w1.value1_-w1.value2_ <
  259. // w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ =
  260. // w2.value1_+w2.value2_ to both sides and divide by two, and we get the
  261. // simpler equivalent form w1.value1_ < w2.value1_.
  262. return 1;
  263. } else if (w1.Value1() > w2.Value1()) {
  264. return -1;
  265. } else {
  266. return 0;
  267. }
  268. }
  269. template <class FloatType>
  270. inline LatticeWeightTpl<FloatType> Plus(const LatticeWeightTpl<FloatType>& w1,
  271. const LatticeWeightTpl<FloatType>& w2) {
  272. return (Compare(w1, w2) >= 0 ? w1 : w2);
  273. }
  274. // For efficiency, override the NaturalLess template class.
  275. template <class FloatType>
  276. class NaturalLess<LatticeWeightTpl<FloatType> > {
  277. public:
  278. typedef LatticeWeightTpl<FloatType> Weight;
  279. NaturalLess() {}
  280. bool operator()(const Weight& w1, const Weight& w2) const {
  281. // NaturalLess is a negative order (opposite to normal ordering).
  282. // This operator () corresponds to "<" in the negative order, which
  283. // corresponds to the ">" in the normal order.
  284. return (Compare(w1, w2) == 1);
  285. }
  286. };
  287. template <>
  288. class NaturalLess<LatticeWeightTpl<float> > {
  289. public:
  290. typedef LatticeWeightTpl<float> Weight;
  291. NaturalLess() {}
  292. bool operator()(const Weight& w1, const Weight& w2) const {
  293. // NaturalLess is a negative order (opposite to normal ordering).
  294. // This operator () corresponds to "<" in the negative order, which
  295. // corresponds to the ">" in the normal order.
  296. return (Compare(w1, w2) == 1);
  297. }
  298. };
  299. template <>
  300. class NaturalLess<LatticeWeightTpl<double> > {
  301. public:
  302. typedef LatticeWeightTpl<double> Weight;
  303. NaturalLess() {}
  304. bool operator()(const Weight& w1, const Weight& w2) const {
  305. // NaturalLess is a negative order (opposite to normal ordering).
  306. // This operator () corresponds to "<" in the negative order, which
  307. // corresponds to the ">" in the normal order.
  308. return (Compare(w1, w2) == 1);
  309. }
  310. };
  311. template <class FloatType>
  312. inline LatticeWeightTpl<FloatType> Times(
  313. const LatticeWeightTpl<FloatType>& w1,
  314. const LatticeWeightTpl<FloatType>& w2) {
  315. return LatticeWeightTpl<FloatType>(w1.Value1() + w2.Value1(),
  316. w1.Value2() + w2.Value2());
  317. }
  318. // divide w1 by w2 (on left/right/any doesn't matter as
  319. // commutative).
  320. template <class FloatType>
  321. inline LatticeWeightTpl<FloatType> Divide(const LatticeWeightTpl<FloatType>& w1,
  322. const LatticeWeightTpl<FloatType>& w2,
  323. DivideType typ = DIVIDE_ANY) {
  324. typedef FloatType T;
  325. T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2();
  326. if (a != a || b != b || a == -std::numeric_limits<T>::infinity() ||
  327. b == -std::numeric_limits<T>::infinity()) {
  328. KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. "
  329. << "[dividing by zero?] Returning zero";
  330. return LatticeWeightTpl<T>::Zero();
  331. }
  332. if (a == std::numeric_limits<T>::infinity() ||
  333. b == std::numeric_limits<T>::infinity())
  334. return LatticeWeightTpl<T>::Zero(); // not a valid number if only one is
  335. // infinite.
  336. return LatticeWeightTpl<T>(a, b);
  337. }
  338. template <class FloatType>
  339. inline bool ApproxEqual(const LatticeWeightTpl<FloatType>& w1,
  340. const LatticeWeightTpl<FloatType>& w2,
  341. float delta = kDelta) {
  342. if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2())
  343. return true; // handles Zero().
  344. return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <=
  345. delta);
  346. }
  347. template <class FloatType>
  348. inline std::ostream& operator<<(std::ostream& strm,
  349. const LatticeWeightTpl<FloatType>& w) {
  350. LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value1());
  351. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  352. strm << FLAGS_fst_weight_separator[0]; // comma by default;
  353. // may or may not be settable from Kaldi programs.
  354. LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value2());
  355. return strm;
  356. }
  357. template <class FloatType>
  358. inline std::istream& operator>>(std::istream& strm,
  359. LatticeWeightTpl<FloatType>& w1) {
  360. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  361. // separator defaults to ','
  362. return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]);
  363. }
  364. // CompactLattice will be an acceptor (accepting the words/output-symbols),
  365. // with the weights and input-symbol-seqs on the arcs.
  366. // There must be a total order on W. We assume for the sake of efficiency
  367. // that there is a function
  368. // Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and
  369. // zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 :
  370. // w2).
  371. template <class WeightType, class IntType>
  372. class CompactLatticeWeightTpl {
  373. public:
  374. typedef WeightType W;
  375. typedef CompactLatticeWeightTpl<WeightType, IntType> ReverseWeight;
  376. // Plus is like LexicographicWeight on the pair (weight_, string_), but where
  377. // we use standard lexicographic order on string_ [this is not the same as
  378. // NaturalLess on the StringWeight equivalent, which does not define a
  379. // total order].
  380. // Times, Divide obvious... (support both left & right division..)
  381. // CommonDivisor would need to be coded separately.
  382. CompactLatticeWeightTpl() {}
  383. CompactLatticeWeightTpl(const WeightType& w, const std::vector<IntType>& s)
  384. : weight_(w), string_(s) {}
  385. CompactLatticeWeightTpl& operator=(
  386. const CompactLatticeWeightTpl<WeightType, IntType>& w) {
  387. weight_ = w.weight_;
  388. string_ = w.string_;
  389. return *this;
  390. }
  391. const W& Weight() const { return weight_; }
  392. const std::vector<IntType>& String() const { return string_; }
  393. void SetWeight(const W& w) { weight_ = w; }
  394. void SetString(const std::vector<IntType>& s) { string_ = s; }
  395. static const CompactLatticeWeightTpl<WeightType, IntType> Zero() {
  396. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::Zero(),
  397. std::vector<IntType>());
  398. }
  399. static const CompactLatticeWeightTpl<WeightType, IntType> One() {
  400. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::One(),
  401. std::vector<IntType>());
  402. }
  403. inline static std::string GetIntSizeString() {
  404. char buf[2];
  405. buf[0] = '0' + sizeof(IntType);
  406. buf[1] = '\0';
  407. return buf;
  408. }
  409. static const std::string& Type() {
  410. static const std::string type =
  411. "compact" + WeightType::Type() + GetIntSizeString();
  412. return type;
  413. }
  414. static const CompactLatticeWeightTpl<WeightType, IntType> NoWeight() {
  415. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::NoWeight(),
  416. std::vector<IntType>());
  417. }
  418. CompactLatticeWeightTpl<WeightType, IntType> Reverse() const {
  419. size_t s = string_.size();
  420. std::vector<IntType> v(s);
  421. for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1];
  422. return CompactLatticeWeightTpl<WeightType, IntType>(weight_, v);
  423. }
  424. bool Member() const {
  425. // a semiring has only one zero, this is the important property
  426. // we're trying to maintain here. So force string_ to be empty if
  427. // w_ == zero.
  428. if (!weight_.Member()) return false;
  429. if (weight_ == WeightType::Zero())
  430. return string_.empty();
  431. else
  432. return true;
  433. }
  434. CompactLatticeWeightTpl Quantize(float delta = kDelta) const {
  435. return CompactLatticeWeightTpl(weight_.Quantize(delta), string_);
  436. }
  437. static constexpr uint64 Properties() {
  438. return kLeftSemiring | kRightSemiring | kPath | kIdempotent;
  439. }
  440. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  441. // not Kaldi-style, I/O.
  442. std::istream& Read(std::istream& strm) {
  443. weight_.Read(strm);
  444. if (strm.fail()) {
  445. return strm;
  446. }
  447. int32 sz;
  448. ReadType(strm, &sz);
  449. if (strm.fail()) {
  450. return strm;
  451. }
  452. if (sz < 0) {
  453. KALDI_WARN << "Negative string size! Read failure";
  454. strm.clear(std::ios::badbit);
  455. return strm;
  456. }
  457. string_.resize(sz);
  458. for (int32 i = 0; i < sz; i++) {
  459. ReadType(strm, &(string_[i]));
  460. }
  461. return strm;
  462. }
  463. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  464. // not Kaldi-style, I/O.
  465. std::ostream& Write(std::ostream& strm) const {
  466. weight_.Write(strm);
  467. if (strm.fail()) {
  468. return strm;
  469. }
  470. int32 sz = static_cast<int32>(string_.size());
  471. WriteType(strm, sz);
  472. for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]);
  473. return strm;
  474. }
  475. size_t Hash() const {
  476. size_t ans = weight_.Hash();
  477. // any weird numbers here are largish primes
  478. size_t sz = string_.size(), mult = 6967;
  479. for (size_t i = 0; i < sz; i++) {
  480. ans += string_[i] * mult;
  481. mult *= 7499;
  482. }
  483. return ans;
  484. }
  485. private:
  486. W weight_;
  487. std::vector<IntType> string_;
  488. };
  489. template <class WeightType, class IntType>
  490. inline bool operator==(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  491. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  492. return (w1.Weight() == w2.Weight() && w1.String() == w2.String());
  493. }
  494. template <class WeightType, class IntType>
  495. inline bool operator!=(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  496. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  497. return (w1.Weight() != w2.Weight() || w1.String() != w2.String());
  498. }
  499. template <class WeightType, class IntType>
  500. inline bool ApproxEqual(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  501. const CompactLatticeWeightTpl<WeightType, IntType>& w2,
  502. float delta = kDelta) {
  503. return (ApproxEqual(w1.Weight(), w2.Weight(), delta) &&
  504. w1.String() == w2.String());
  505. }
  506. // Compare is not part of the standard for weight types, but used internally for
  507. // efficiency. The comparison here first compares the weight; if this is the
  508. // same, it compares the string. The comparison on strings is: first compare
  509. // the length, if this is the same, use lexicographical order. We can't just
  510. // use the lexicographical order because this would destroy the distributive
  511. // property of multiplication over addition, taking into account that addition
  512. // uses Compare. The string element of "Compare" isn't super-important in
  513. // practical terms; it's only needed to ensure that Plus always give consistent
  514. // answers and is symmetric. It's essentially for tie-breaking, but we need to
  515. // make sure all the semiring axioms are satisfied otherwise OpenFst might
  516. // break.
  517. template <class WeightType, class IntType>
  518. inline int Compare(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  519. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  520. int c1 = Compare(w1.Weight(), w2.Weight());
  521. if (c1 != 0) return c1;
  522. int l1 = w1.String().size(), l2 = w2.String().size();
  523. // Use opposite order on the string lengths, so that if the costs are the
  524. // same, the shorter string wins.
  525. if (l1 > l2)
  526. return -1;
  527. else if (l1 < l2)
  528. return 1;
  529. for (int i = 0; i < l1; i++) {
  530. if (w1.String()[i] < w2.String()[i])
  531. return -1;
  532. else if (w1.String()[i] > w2.String()[i])
  533. return 1;
  534. }
  535. return 0;
  536. }
  537. // For efficiency, override the NaturalLess template class.
  538. template <class FloatType, class IntType>
  539. class NaturalLess<
  540. CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> > {
  541. public:
  542. typedef CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> Weight;
  543. NaturalLess() {}
  544. bool operator()(const Weight& w1, const Weight& w2) const {
  545. // NaturalLess is a negative order (opposite to normal ordering).
  546. // This operator () corresponds to "<" in the negative order, which
  547. // corresponds to the ">" in the normal order.
  548. return (Compare(w1, w2) == 1);
  549. }
  550. };
  551. template <>
  552. class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> > {
  553. public:
  554. typedef CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> Weight;
  555. NaturalLess() {}
  556. bool operator()(const Weight& w1, const Weight& w2) const {
  557. // NaturalLess is a negative order (opposite to normal ordering).
  558. // This operator () corresponds to "<" in the negative order, which
  559. // corresponds to the ">" in the normal order.
  560. return (Compare(w1, w2) == 1);
  561. }
  562. };
  563. template <>
  564. class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> > {
  565. public:
  566. typedef CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> Weight;
  567. NaturalLess() {}
  568. bool operator()(const Weight& w1, const Weight& w2) const {
  569. // NaturalLess is a negative order (opposite to normal ordering).
  570. // This operator () corresponds to "<" in the negative order, which
  571. // corresponds to the ">" in the normal order.
  572. return (Compare(w1, w2) == 1);
  573. }
  574. };
  575. // Make sure Compare is defined for TropicalWeight, so everything works
  576. // if we substitute LatticeWeight for TropicalWeight.
  577. inline int Compare(const TropicalWeight& w1, const TropicalWeight& w2) {
  578. float f1 = w1.Value(), f2 = w2.Value();
  579. if (f1 == f2)
  580. return 0;
  581. else if (f1 > f2)
  582. return -1;
  583. else
  584. return 1;
  585. }
  586. template <class WeightType, class IntType>
  587. inline CompactLatticeWeightTpl<WeightType, IntType> Plus(
  588. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  589. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  590. return (Compare(w1, w2) >= 0 ? w1 : w2);
  591. }
  592. template <class WeightType, class IntType>
  593. inline CompactLatticeWeightTpl<WeightType, IntType> Times(
  594. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  595. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  596. WeightType w = Times(w1.Weight(), w2.Weight());
  597. if (w == WeightType::Zero()) {
  598. return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
  599. // special case to ensure zero is unique
  600. } else {
  601. std::vector<IntType> v;
  602. v.resize(w1.String().size() + w2.String().size());
  603. typename std::vector<IntType>::iterator iter = v.begin();
  604. iter = std::copy(w1.String().begin(), w1.String().end(),
  605. iter); // returns end of first range.
  606. std::copy(w2.String().begin(), w2.String().end(), iter);
  607. return CompactLatticeWeightTpl<WeightType, IntType>(w, v);
  608. }
  609. }
  610. template <class WeightType, class IntType>
  611. inline CompactLatticeWeightTpl<WeightType, IntType> Divide(
  612. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  613. const CompactLatticeWeightTpl<WeightType, IntType>& w2,
  614. DivideType div = DIVIDE_ANY) {
  615. if (w1.Weight() == WeightType::Zero()) {
  616. if (w2.Weight() != WeightType::Zero()) {
  617. return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
  618. } else {
  619. KALDI_ERR << "Division by zero [0/0]";
  620. }
  621. } else if (w2.Weight() == WeightType::Zero()) {
  622. KALDI_ERR << "Error: division by zero";
  623. }
  624. WeightType w = Divide(w1.Weight(), w2.Weight());
  625. const std::vector<IntType> v1 = w1.String(), v2 = w2.String();
  626. if (v2.size() > v1.size()) {
  627. KALDI_ERR << "Cannot divide, length mismatch";
  628. }
  629. typename std::vector<IntType>::const_iterator v1b = v1.begin(),
  630. v1e = v1.end(),
  631. v2b = v2.begin(),
  632. v2e = v2.end();
  633. if (div == DIVIDE_LEFT) {
  634. if (!std::equal(v2b, v2e,
  635. v1b)) { // v2 must be identical to first part of v1.
  636. KALDI_ERR << "Cannot divide, data mismatch";
  637. }
  638. return CompactLatticeWeightTpl<WeightType, IntType>(
  639. w, std::vector<IntType>(v1b + (v2e - v2b),
  640. v1e)); // return last part of v1.
  641. } else if (div == DIVIDE_RIGHT) {
  642. if (!std::equal(
  643. v2b, v2e,
  644. v1e - (v2e - v2b))) { // v2 must be identical to last part of v1.
  645. KALDI_ERR << "Cannot divide, data mismatch";
  646. }
  647. return CompactLatticeWeightTpl<WeightType, IntType>(
  648. w, std::vector<IntType>(
  649. v1b, v1e - (v2e - v2b))); // return first part of v1.
  650. } else {
  651. KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY";
  652. }
  653. return CompactLatticeWeightTpl<WeightType,
  654. IntType>::Zero(); // keep compiler happy.
  655. }
  656. template <class WeightType, class IntType>
  657. inline std::ostream& operator<<(
  658. std::ostream& strm, const CompactLatticeWeightTpl<WeightType, IntType>& w) {
  659. strm << w.Weight();
  660. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  661. strm << FLAGS_fst_weight_separator[0]; // comma by default.
  662. for (size_t i = 0; i < w.String().size(); i++) {
  663. strm << w.String()[i];
  664. if (i + 1 < w.String().size())
  665. strm << kStringSeparator; // '_'; defined in string-weight.h in OpenFst
  666. // code.
  667. }
  668. return strm;
  669. }
  670. template <class WeightType, class IntType>
  671. inline std::istream& operator>>(
  672. std::istream& strm, CompactLatticeWeightTpl<WeightType, IntType>& w) {
  673. std::string s;
  674. strm >> s;
  675. if (strm.fail()) {
  676. return strm;
  677. }
  678. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  679. size_t pos = s.find_last_of(FLAGS_fst_weight_separator); // normally ","
  680. if (pos == std::string::npos) {
  681. strm.clear(std::ios::badbit);
  682. return strm;
  683. }
  684. // get parts of str before and after the separator (default: ',');
  685. std::string s1(s, 0, pos), s2(s, pos + 1);
  686. std::istringstream strm1(s1);
  687. WeightType weight;
  688. strm1 >> weight;
  689. w.SetWeight(weight);
  690. if (strm1.fail() || !strm1.eof()) {
  691. strm.clear(std::ios::badbit);
  692. return strm;
  693. }
  694. // read string part.
  695. std::vector<IntType> string;
  696. const char* c = s2.c_str();
  697. while (*c != '\0') {
  698. if (*c == kStringSeparator) // '_'
  699. c++;
  700. char* c2;
  701. int64_t i = strtol(c, &c2, 10);
  702. if (c2 == c || static_cast<int64_t>(static_cast<IntType>(i)) != i) {
  703. strm.clear(std::ios::badbit);
  704. return strm;
  705. }
  706. c = c2;
  707. string.push_back(static_cast<IntType>(i));
  708. }
  709. w.SetString(string);
  710. return strm;
  711. }
  712. template <class BaseWeightType, class IntType>
  713. class CompactLatticeWeightCommonDivisorTpl {
  714. public:
  715. typedef CompactLatticeWeightTpl<BaseWeightType, IntType> Weight;
  716. Weight operator()(const Weight& w1, const Weight& w2) const {
  717. // First find longest common prefix of the strings.
  718. typename std::vector<IntType>::const_iterator s1b = w1.String().begin(),
  719. s1e = w1.String().end(),
  720. s2b = w2.String().begin(),
  721. s2e = w2.String().end();
  722. while (s1b < s1e && s2b < s2e && *s1b == *s2b) {
  723. s1b++;
  724. s2b++;
  725. }
  726. return Weight(Plus(w1.Weight(), w2.Weight()),
  727. std::vector<IntType>(w1.String().begin(), s1b));
  728. }
  729. };
  730. /** Scales the pair (a, b) of floating-point weights inside a
  731. CompactLatticeWeight by premultiplying it (viewed as a vector)
  732. by a 2x2 matrix "scale".
  733. Assumes there is a ScaleTupleWeight function that applies to "Weight";
  734. this currently only works if Weight equals LatticeWeightTpl<FloatType>
  735. for some FloatType.
  736. */
  737. template <class Weight, class IntType, class ScaleFloatType>
  738. inline CompactLatticeWeightTpl<Weight, IntType> ScaleTupleWeight(
  739. const CompactLatticeWeightTpl<Weight, IntType>& w,
  740. const std::vector<std::vector<ScaleFloatType> >& scale) {
  741. return CompactLatticeWeightTpl<Weight, IntType>(
  742. Weight(ScaleTupleWeight(w.Weight(), scale)), w.String());
  743. }
  744. /** Define some ConvertLatticeWeight functions that are used in various lattice
  745. conversions... make them all templates, some with no arguments, since some
  746. must be templates.*/
  747. template <class Float1, class Float2>
  748. inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1>& w_in,
  749. LatticeWeightTpl<Float2>* w_out) {
  750. w_out->SetValue1(w_in.Value1());
  751. w_out->SetValue2(w_in.Value2());
  752. }
  753. template <class Float1, class Float2, class Int>
  754. inline void ConvertLatticeWeight(
  755. const CompactLatticeWeightTpl<LatticeWeightTpl<Float1>, Int>& w_in,
  756. CompactLatticeWeightTpl<LatticeWeightTpl<Float2>, Int>* w_out) {
  757. LatticeWeightTpl<Float2> weight2(w_in.Weight().Value1(),
  758. w_in.Weight().Value2());
  759. w_out->SetWeight(weight2);
  760. w_out->SetString(w_in.String());
  761. }
  762. // to convert from Lattice to standard FST
  763. template <class Float1, class Float2>
  764. inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1>& w_in,
  765. TropicalWeightTpl<Float2>* w_out) {
  766. TropicalWeightTpl<Float2> w1(w_in.Value1());
  767. TropicalWeightTpl<Float2> w2(w_in.Value2());
  768. *w_out = Times(w1, w2);
  769. }
  770. template <class Float>
  771. inline double ConvertToCost(const LatticeWeightTpl<Float>& w) {
  772. return static_cast<double>(w.Value1()) + static_cast<double>(w.Value2());
  773. }
  774. template <class Float, class Int>
  775. inline double ConvertToCost(
  776. const CompactLatticeWeightTpl<LatticeWeightTpl<Float>, Int>& w) {
  777. return static_cast<double>(w.Weight().Value1()) +
  778. static_cast<double>(w.Weight().Value2());
  779. }
  780. template <class Float>
  781. inline double ConvertToCost(const TropicalWeightTpl<Float>& w) {
  782. return w.Value();
  783. }
  784. } // namespace fst
  785. #endif // KALDI_FSTEXT_LATTICE_WEIGHT_H_