You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

892 lines
31 KiB

  1. // fstext/lattice-weight.h
  2. // Copyright 2009-2012 Microsoft Corporation
  3. // Johns Hopkins University (author: Daniel Povey)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_
  19. #define KALDI_FSTEXT_LATTICE_WEIGHT_H_
  20. #include <algorithm>
  21. #include <limits>
  22. #include <string>
  23. #include <vector>
  24. #include "base/kaldi-types.h"
  25. #include "base/kaldi-common.h"
  26. #include "fst/fstlib.h"
  27. namespace fst {
  28. // Declare weight type for lattice... will import to namespace kaldi. has two
  29. // members, value1_ and value2_, of type BaseFloat (normally equals float). It
  30. // is basically the same as the tropical semiring on value1_+value2_, except it
  31. // keeps track of a and b separately. More precisely, it is equivalent to the
  32. // lexicographic semiring on (value1_+value2_), (value1_-value2_)
  33. template <class FloatType>
  34. class LatticeWeightTpl;
  35. template <class FloatType>
  36. inline std::ostream& operator<<(std::ostream& strm,
  37. const LatticeWeightTpl<FloatType>& w);
  38. template <class FloatType>
  39. inline std::istream& operator>>(std::istream& strm,
  40. LatticeWeightTpl<FloatType>& w);
  41. template <class FloatType>
  42. class LatticeWeightTpl {
  43. public:
  44. typedef FloatType T; // normally float.
  45. typedef LatticeWeightTpl ReverseWeight;
  46. inline T Value1() const { return value1_; }
  47. inline T Value2() const { return value2_; }
  48. inline void SetValue1(T f) { value1_ = f; }
  49. inline void SetValue2(T f) { value2_ = f; }
  50. LatticeWeightTpl() : value1_{}, value2_{} {}
  51. LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}
  52. LatticeWeightTpl(const LatticeWeightTpl& other)
  53. : value1_(other.value1_), value2_(other.value2_) {}
  54. LatticeWeightTpl& operator=(const LatticeWeightTpl& w) {
  55. value1_ = w.value1_;
  56. value2_ = w.value2_;
  57. return *this;
  58. }
  59. LatticeWeightTpl<FloatType> Reverse() const { return *this; }
  60. static const LatticeWeightTpl Zero() {
  61. return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
  62. std::numeric_limits<T>::infinity());
  63. }
  64. static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); }
  65. static const std::string& Type() {
  66. static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8");
  67. return type;
  68. }
  69. static const LatticeWeightTpl NoWeight() {
  70. return LatticeWeightTpl(std::numeric_limits<FloatType>::quiet_NaN(),
  71. std::numeric_limits<FloatType>::quiet_NaN());
  72. }
  73. bool Member() const {
  74. // value1_ == value1_ tests for NaN.
  75. // also test for no -inf, and either both or neither
  76. // must be +inf, and
  77. if (value1_ != value1_ || value2_ != value2_) return false; // NaN
  78. if (value1_ == -std::numeric_limits<T>::infinity() ||
  79. value2_ == -std::numeric_limits<T>::infinity())
  80. return false; // -infty not allowed
  81. if (value1_ == std::numeric_limits<T>::infinity() ||
  82. value2_ == std::numeric_limits<T>::infinity()) {
  83. if (value1_ != std::numeric_limits<T>::infinity() ||
  84. value2_ != std::numeric_limits<T>::infinity())
  85. return false; // both must be +infty;
  86. // this is necessary so that the semiring has only one zero.
  87. }
  88. return true;
  89. }
  90. LatticeWeightTpl Quantize(float delta = kDelta) const {
  91. if (value1_ + value2_ == -std::numeric_limits<T>::infinity()) {
  92. return LatticeWeightTpl(-std::numeric_limits<T>::infinity(),
  93. -std::numeric_limits<T>::infinity());
  94. } else if (value1_ + value2_ == std::numeric_limits<T>::infinity()) {
  95. return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
  96. std::numeric_limits<T>::infinity());
  97. } else if (value1_ + value2_ != value1_ + value2_) { // NaN
  98. return LatticeWeightTpl(value1_ + value2_, value1_ + value2_);
  99. } else {
  100. return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta,
  101. floor(value2_ / delta + 0.5F) * delta);
  102. }
  103. }
  104. static constexpr uint64 Properties() {
  105. return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
  106. }
  107. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  108. // not Kaldi-style, I/O.
  109. std::istream& Read(std::istream& strm) {
  110. // Always read/write as float, even if T is double,
  111. // so we can use OpenFst-style read/write and still maintain
  112. // compatibility when compiling with different FloatTypes
  113. ReadType(strm, &value1_);
  114. ReadType(strm, &value2_);
  115. return strm;
  116. }
  117. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  118. // not Kaldi-style, I/O.
  119. std::ostream& Write(std::ostream& strm) const {
  120. WriteType(strm, value1_);
  121. WriteType(strm, value2_);
  122. return strm;
  123. }
  124. size_t Hash() const {
  125. size_t ans;
  126. union {
  127. T f;
  128. size_t s;
  129. } u;
  130. u.s = 0;
  131. u.f = value1_;
  132. ans = u.s;
  133. u.f = value2_;
  134. ans += u.s;
  135. return ans;
  136. }
  137. protected:
  138. inline static void WriteFloatType(std::ostream& strm, const T& f) {
  139. if (f == std::numeric_limits<T>::infinity())
  140. strm << "Infinity";
  141. else if (f == -std::numeric_limits<T>::infinity())
  142. strm << "-Infinity";
  143. else if (f != f)
  144. strm << "BadNumber";
  145. else
  146. strm << f;
  147. }
  148. // Internal helper function, used in ReadNoParen.
  149. inline static void ReadFloatType(std::istream& strm, T& f) { // NOLINT
  150. std::string s;
  151. strm >> s;
  152. if (s == "Infinity") {
  153. f = std::numeric_limits<T>::infinity();
  154. } else if (s == "-Infinity") {
  155. f = -std::numeric_limits<T>::infinity();
  156. } else if (s == "BadNumber") {
  157. f = std::numeric_limits<T>::quiet_NaN();
  158. } else {
  159. char* p;
  160. f = strtod(s.c_str(), &p);
  161. if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit);
  162. }
  163. }
  164. // Reads LatticeWeight when there are no parentheses around pair terms...
  165. // currently the only form supported.
  166. inline std::istream& ReadNoParen(std::istream& strm, char separator) {
  167. int c;
  168. do {
  169. c = strm.get();
  170. } while (isspace(c));
  171. std::string s1;
  172. while (c != separator) {
  173. if (c == EOF) {
  174. strm.clear(std::ios::badbit);
  175. return strm;
  176. }
  177. s1 += c;
  178. c = strm.get();
  179. }
  180. std::istringstream strm1(s1);
  181. ReadFloatType(strm1, value1_); // ReadFloatType is class member function
  182. // read second element
  183. ReadFloatType(strm, value2_);
  184. return strm;
  185. }
  186. friend std::istream& operator>>
  187. <FloatType>(std::istream&, LatticeWeightTpl<FloatType>&);
  188. friend std::ostream& operator<< <FloatType>(
  189. std::ostream&, const LatticeWeightTpl<FloatType>&);
  190. private:
  191. T value1_;
  192. T value2_;
  193. };
  194. /* ScaleTupleWeight is a function defined for LatticeWeightTpl and
  195. CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2
  196. matrix. Used, for example, in applying acoustic scaling.
  197. */
  198. template <class FloatType, class ScaleFloatType>
  199. inline LatticeWeightTpl<FloatType> ScaleTupleWeight(
  200. const LatticeWeightTpl<FloatType>& w,
  201. const std::vector<std::vector<ScaleFloatType> >& scale) {
  202. // Without the next special case we'd get NaNs from infinity * 0
  203. if (w.Value1() == std::numeric_limits<FloatType>::infinity())
  204. return LatticeWeightTpl<FloatType>::Zero();
  205. return LatticeWeightTpl<FloatType>(
  206. scale[0][0] * w.Value1() + scale[0][1] * w.Value2(),
  207. scale[1][0] * w.Value1() + scale[1][1] * w.Value2());
  208. }
  209. /* For testing purposes and in case it's ever useful, we define a similar
  210. function to apply to LexicographicWeight and the like, templated on
  211. TropicalWeight<float> etc.; we use PairWeight which is the base class of
  212. LexicographicWeight.
  213. */
  214. template <class FloatType, class ScaleFloatType>
  215. inline PairWeight<TropicalWeightTpl<FloatType>, TropicalWeightTpl<FloatType> >
  216. ScaleTupleWeight(const PairWeight<TropicalWeightTpl<FloatType>,
  217. TropicalWeightTpl<FloatType> >& w,
  218. const std::vector<std::vector<ScaleFloatType> >& scale) {
  219. typedef TropicalWeightTpl<FloatType> BaseType;
  220. typedef PairWeight<BaseType, BaseType> PairType;
  221. const BaseType zero = BaseType::Zero();
  222. // Without the next special case we'd get NaNs from infinity * 0
  223. if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero);
  224. FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value();
  225. return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2),
  226. BaseType(scale[1][0] * f1 + scale[1][1] * f2));
  227. }
  228. template <class FloatType>
  229. inline bool operator==(const LatticeWeightTpl<FloatType>& wa,
  230. const LatticeWeightTpl<FloatType>& wb) {
  231. // Volatile qualifier thwarts over-aggressive compiler optimizations
  232. // that lead to problems esp. with NaturalLess().
  233. volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
  234. vb2 = wb.Value2();
  235. return (va1 == vb1 && va2 == vb2);
  236. }
  237. template <class FloatType>
  238. inline bool operator!=(const LatticeWeightTpl<FloatType>& wa,
  239. const LatticeWeightTpl<FloatType>& wb) {
  240. // Volatile qualifier thwarts over-aggressive compiler optimizations
  241. // that lead to problems esp. with NaturalLess().
  242. volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
  243. vb2 = wb.Value2();
  244. return (va1 != vb1 || va2 != vb2);
  245. }
  246. // We define a Compare function LatticeWeightTpl even though it's
  247. // not required by the semiring standard-- it's just more efficient
  248. // to do it this way rather than using the NaturalLess template.
  249. /// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
  250. template <class FloatType>
  251. inline int Compare(const LatticeWeightTpl<FloatType>& w1,
  252. const LatticeWeightTpl<FloatType>& w2) {
  253. FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2();
  254. if (f1 < f2) { // having smaller cost means you're larger
  255. return 1;
  256. } else if (f1 > f2) { // in the semiring [higher probability]
  257. return -1;
  258. } else if (w1.Value1() < w2.Value1()) {
  259. // mathematically we should be comparing (w1.value1_-w1.value2_ <
  260. // w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ =
  261. // w2.value1_+w2.value2_ to both sides and divide by two, and we get the
  262. // simpler equivalent form w1.value1_ < w2.value1_.
  263. return 1;
  264. } else if (w1.Value1() > w2.Value1()) {
  265. return -1;
  266. } else {
  267. return 0;
  268. }
  269. }
  270. template <class FloatType>
  271. inline LatticeWeightTpl<FloatType> Plus(const LatticeWeightTpl<FloatType>& w1,
  272. const LatticeWeightTpl<FloatType>& w2) {
  273. return (Compare(w1, w2) >= 0 ? w1 : w2);
  274. }
  275. // For efficiency, override the NaturalLess template class.
  276. template <class FloatType>
  277. class NaturalLess<LatticeWeightTpl<FloatType> > {
  278. public:
  279. typedef LatticeWeightTpl<FloatType> Weight;
  280. NaturalLess() {}
  281. bool operator()(const Weight& w1, const Weight& w2) const {
  282. // NaturalLess is a negative order (opposite to normal ordering).
  283. // This operator () corresponds to "<" in the negative order, which
  284. // corresponds to the ">" in the normal order.
  285. return (Compare(w1, w2) == 1);
  286. }
  287. };
  288. template <>
  289. class NaturalLess<LatticeWeightTpl<float> > {
  290. public:
  291. typedef LatticeWeightTpl<float> Weight;
  292. NaturalLess() {}
  293. bool operator()(const Weight& w1, const Weight& w2) const {
  294. // NaturalLess is a negative order (opposite to normal ordering).
  295. // This operator () corresponds to "<" in the negative order, which
  296. // corresponds to the ">" in the normal order.
  297. return (Compare(w1, w2) == 1);
  298. }
  299. };
  300. template <>
  301. class NaturalLess<LatticeWeightTpl<double> > {
  302. public:
  303. typedef LatticeWeightTpl<double> Weight;
  304. NaturalLess() {}
  305. bool operator()(const Weight& w1, const Weight& w2) const {
  306. // NaturalLess is a negative order (opposite to normal ordering).
  307. // This operator () corresponds to "<" in the negative order, which
  308. // corresponds to the ">" in the normal order.
  309. return (Compare(w1, w2) == 1);
  310. }
  311. };
  312. template <class FloatType>
  313. inline LatticeWeightTpl<FloatType> Times(
  314. const LatticeWeightTpl<FloatType>& w1,
  315. const LatticeWeightTpl<FloatType>& w2) {
  316. return LatticeWeightTpl<FloatType>(w1.Value1() + w2.Value1(),
  317. w1.Value2() + w2.Value2());
  318. }
  319. // divide w1 by w2 (on left/right/any doesn't matter as
  320. // commutative).
  321. template <class FloatType>
  322. inline LatticeWeightTpl<FloatType> Divide(const LatticeWeightTpl<FloatType>& w1,
  323. const LatticeWeightTpl<FloatType>& w2,
  324. DivideType typ = DIVIDE_ANY) {
  325. typedef FloatType T;
  326. T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2();
  327. if (a != a || b != b || a == -std::numeric_limits<T>::infinity() ||
  328. b == -std::numeric_limits<T>::infinity()) {
  329. KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. "
  330. << "[dividing by zero?] Returning zero";
  331. return LatticeWeightTpl<T>::Zero();
  332. }
  333. if (a == std::numeric_limits<T>::infinity() ||
  334. b == std::numeric_limits<T>::infinity())
  335. return LatticeWeightTpl<T>::Zero(); // not a valid number if only one is
  336. // infinite.
  337. return LatticeWeightTpl<T>(a, b);
  338. }
  339. template <class FloatType>
  340. inline bool ApproxEqual(const LatticeWeightTpl<FloatType>& w1,
  341. const LatticeWeightTpl<FloatType>& w2,
  342. float delta = kDelta) {
  343. if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2())
  344. return true; // handles Zero().
  345. return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <=
  346. delta);
  347. }
  348. template <class FloatType>
  349. inline std::ostream& operator<<(std::ostream& strm,
  350. const LatticeWeightTpl<FloatType>& w) {
  351. LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value1());
  352. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  353. strm << FLAGS_fst_weight_separator[0]; // comma by default;
  354. // may or may not be settable from Kaldi programs.
  355. LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value2());
  356. return strm;
  357. }
  358. template <class FloatType>
  359. inline std::istream& operator>>(std::istream& strm,
  360. LatticeWeightTpl<FloatType>& w1) {
  361. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  362. // separator defaults to ','
  363. return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]);
  364. }
  365. // CompactLattice will be an acceptor (accepting the words/output-symbols),
  366. // with the weights and input-symbol-seqs on the arcs.
  367. // There must be a total order on W. We assume for the sake of efficiency
  368. // that there is a function
  369. // Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and
  370. // zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 :
  371. // w2).
  372. template <class WeightType, class IntType>
  373. class CompactLatticeWeightTpl {
  374. public:
  375. typedef WeightType W;
  376. typedef CompactLatticeWeightTpl<WeightType, IntType> ReverseWeight;
  377. // Plus is like LexicographicWeight on the pair (weight_, string_), but where
  378. // we use standard lexicographic order on string_ [this is not the same as
  379. // NaturalLess on the StringWeight equivalent, which does not define a
  380. // total order].
  381. // Times, Divide obvious... (support both left & right division..)
  382. // CommonDivisor would need to be coded separately.
  383. CompactLatticeWeightTpl() {}
  384. CompactLatticeWeightTpl(const WeightType& w, const std::vector<IntType>& s)
  385. : weight_(w), string_(s) {}
  386. CompactLatticeWeightTpl& operator=(
  387. const CompactLatticeWeightTpl<WeightType, IntType>& w) {
  388. weight_ = w.weight_;
  389. string_ = w.string_;
  390. return *this;
  391. }
  392. const W& Weight() const { return weight_; }
  393. const std::vector<IntType>& String() const { return string_; }
  394. void SetWeight(const W& w) { weight_ = w; }
  395. void SetString(const std::vector<IntType>& s) { string_ = s; }
  396. static const CompactLatticeWeightTpl<WeightType, IntType> Zero() {
  397. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::Zero(),
  398. std::vector<IntType>());
  399. }
  400. static const CompactLatticeWeightTpl<WeightType, IntType> One() {
  401. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::One(),
  402. std::vector<IntType>());
  403. }
  404. inline static std::string GetIntSizeString() {
  405. char buf[2];
  406. buf[0] = '0' + sizeof(IntType);
  407. buf[1] = '\0';
  408. return buf;
  409. }
  410. static const std::string& Type() {
  411. static const std::string type =
  412. "compact" + WeightType::Type() + GetIntSizeString();
  413. return type;
  414. }
  415. static const CompactLatticeWeightTpl<WeightType, IntType> NoWeight() {
  416. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::NoWeight(),
  417. std::vector<IntType>());
  418. }
  419. CompactLatticeWeightTpl<WeightType, IntType> Reverse() const {
  420. size_t s = string_.size();
  421. std::vector<IntType> v(s);
  422. for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1];
  423. return CompactLatticeWeightTpl<WeightType, IntType>(weight_, v);
  424. }
  425. bool Member() const {
  426. // a semiring has only one zero, this is the important property
  427. // we're trying to maintain here. So force string_ to be empty if
  428. // w_ == zero.
  429. if (!weight_.Member()) return false;
  430. if (weight_ == WeightType::Zero())
  431. return string_.empty();
  432. else
  433. return true;
  434. }
  435. CompactLatticeWeightTpl Quantize(float delta = kDelta) const {
  436. return CompactLatticeWeightTpl(weight_.Quantize(delta), string_);
  437. }
  438. static constexpr uint64 Properties() {
  439. return kLeftSemiring | kRightSemiring | kPath | kIdempotent;
  440. }
  441. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  442. // not Kaldi-style, I/O.
  443. std::istream& Read(std::istream& strm) {
  444. weight_.Read(strm);
  445. if (strm.fail()) {
  446. return strm;
  447. }
  448. int32 sz;
  449. ReadType(strm, &sz);
  450. if (strm.fail()) {
  451. return strm;
  452. }
  453. if (sz < 0) {
  454. KALDI_WARN << "Negative string size! Read failure";
  455. strm.clear(std::ios::badbit);
  456. return strm;
  457. }
  458. string_.resize(sz);
  459. for (int32 i = 0; i < sz; i++) {
  460. ReadType(strm, &(string_[i]));
  461. }
  462. return strm;
  463. }
  464. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  465. // not Kaldi-style, I/O.
  466. std::ostream& Write(std::ostream& strm) const {
  467. weight_.Write(strm);
  468. if (strm.fail()) {
  469. return strm;
  470. }
  471. int32 sz = static_cast<int32>(string_.size());
  472. WriteType(strm, sz);
  473. for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]);
  474. return strm;
  475. }
  476. size_t Hash() const {
  477. size_t ans = weight_.Hash();
  478. // any weird numbers here are largish primes
  479. size_t sz = string_.size(), mult = 6967;
  480. for (size_t i = 0; i < sz; i++) {
  481. ans += string_[i] * mult;
  482. mult *= 7499;
  483. }
  484. return ans;
  485. }
  486. private:
  487. W weight_;
  488. std::vector<IntType> string_;
  489. };
  490. template <class WeightType, class IntType>
  491. inline bool operator==(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  492. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  493. return (w1.Weight() == w2.Weight() && w1.String() == w2.String());
  494. }
  495. template <class WeightType, class IntType>
  496. inline bool operator!=(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  497. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  498. return (w1.Weight() != w2.Weight() || w1.String() != w2.String());
  499. }
  500. template <class WeightType, class IntType>
  501. inline bool ApproxEqual(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  502. const CompactLatticeWeightTpl<WeightType, IntType>& w2,
  503. float delta = kDelta) {
  504. return (ApproxEqual(w1.Weight(), w2.Weight(), delta) &&
  505. w1.String() == w2.String());
  506. }
  507. // Compare is not part of the standard for weight types, but used internally for
  508. // efficiency. The comparison here first compares the weight; if this is the
  509. // same, it compares the string. The comparison on strings is: first compare
  510. // the length, if this is the same, use lexicographical order. We can't just
  511. // use the lexicographical order because this would destroy the distributive
  512. // property of multiplication over addition, taking into account that addition
  513. // uses Compare. The string element of "Compare" isn't super-important in
  514. // practical terms; it's only needed to ensure that Plus always give consistent
  515. // answers and is symmetric. It's essentially for tie-breaking, but we need to
  516. // make sure all the semiring axioms are satisfied otherwise OpenFst might
  517. // break.
  518. template <class WeightType, class IntType>
  519. inline int Compare(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  520. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  521. int c1 = Compare(w1.Weight(), w2.Weight());
  522. if (c1 != 0) return c1;
  523. int l1 = w1.String().size(), l2 = w2.String().size();
  524. // Use opposite order on the string lengths, so that if the costs are the
  525. // same, the shorter string wins.
  526. if (l1 > l2)
  527. return -1;
  528. else if (l1 < l2)
  529. return 1;
  530. for (int i = 0; i < l1; i++) {
  531. if (w1.String()[i] < w2.String()[i])
  532. return -1;
  533. else if (w1.String()[i] > w2.String()[i])
  534. return 1;
  535. }
  536. return 0;
  537. }
  538. // For efficiency, override the NaturalLess template class.
  539. template <class FloatType, class IntType>
  540. class NaturalLess<
  541. CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> > {
  542. public:
  543. typedef CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> Weight;
  544. NaturalLess() {}
  545. bool operator()(const Weight& w1, const Weight& w2) const {
  546. // NaturalLess is a negative order (opposite to normal ordering).
  547. // This operator () corresponds to "<" in the negative order, which
  548. // corresponds to the ">" in the normal order.
  549. return (Compare(w1, w2) == 1);
  550. }
  551. };
  552. template <>
  553. class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> > {
  554. public:
  555. typedef CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> Weight;
  556. NaturalLess() {}
  557. bool operator()(const Weight& w1, const Weight& w2) const {
  558. // NaturalLess is a negative order (opposite to normal ordering).
  559. // This operator () corresponds to "<" in the negative order, which
  560. // corresponds to the ">" in the normal order.
  561. return (Compare(w1, w2) == 1);
  562. }
  563. };
  564. template <>
  565. class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> > {
  566. public:
  567. typedef CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> Weight;
  568. NaturalLess() {}
  569. bool operator()(const Weight& w1, const Weight& w2) const {
  570. // NaturalLess is a negative order (opposite to normal ordering).
  571. // This operator () corresponds to "<" in the negative order, which
  572. // corresponds to the ">" in the normal order.
  573. return (Compare(w1, w2) == 1);
  574. }
  575. };
  576. // Make sure Compare is defined for TropicalWeight, so everything works
  577. // if we substitute LatticeWeight for TropicalWeight.
  578. inline int Compare(const TropicalWeight& w1, const TropicalWeight& w2) {
  579. float f1 = w1.Value(), f2 = w2.Value();
  580. if (f1 == f2)
  581. return 0;
  582. else if (f1 > f2)
  583. return -1;
  584. else
  585. return 1;
  586. }
  587. template <class WeightType, class IntType>
  588. inline CompactLatticeWeightTpl<WeightType, IntType> Plus(
  589. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  590. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  591. return (Compare(w1, w2) >= 0 ? w1 : w2);
  592. }
  593. template <class WeightType, class IntType>
  594. inline CompactLatticeWeightTpl<WeightType, IntType> Times(
  595. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  596. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  597. WeightType w = Times(w1.Weight(), w2.Weight());
  598. if (w == WeightType::Zero()) {
  599. return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
  600. // special case to ensure zero is unique
  601. } else {
  602. std::vector<IntType> v;
  603. v.resize(w1.String().size() + w2.String().size());
  604. typename std::vector<IntType>::iterator iter = v.begin();
  605. iter = std::copy(w1.String().begin(), w1.String().end(),
  606. iter); // returns end of first range.
  607. std::copy(w2.String().begin(), w2.String().end(), iter);
  608. return CompactLatticeWeightTpl<WeightType, IntType>(w, v);
  609. }
  610. }
  611. template <class WeightType, class IntType>
  612. inline CompactLatticeWeightTpl<WeightType, IntType> Divide(
  613. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  614. const CompactLatticeWeightTpl<WeightType, IntType>& w2,
  615. DivideType div = DIVIDE_ANY) {
  616. if (w1.Weight() == WeightType::Zero()) {
  617. if (w2.Weight() != WeightType::Zero()) {
  618. return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
  619. } else {
  620. KALDI_ERR << "Division by zero [0/0]";
  621. }
  622. } else if (w2.Weight() == WeightType::Zero()) {
  623. KALDI_ERR << "Error: division by zero";
  624. }
  625. WeightType w = Divide(w1.Weight(), w2.Weight());
  626. const std::vector<IntType> v1 = w1.String(), v2 = w2.String();
  627. if (v2.size() > v1.size()) {
  628. KALDI_ERR << "Cannot divide, length mismatch";
  629. }
  630. typename std::vector<IntType>::const_iterator v1b = v1.begin(),
  631. v1e = v1.end(),
  632. v2b = v2.begin(),
  633. v2e = v2.end();
  634. if (div == DIVIDE_LEFT) {
  635. if (!std::equal(v2b, v2e,
  636. v1b)) { // v2 must be identical to first part of v1.
  637. KALDI_ERR << "Cannot divide, data mismatch";
  638. }
  639. return CompactLatticeWeightTpl<WeightType, IntType>(
  640. w, std::vector<IntType>(v1b + (v2e - v2b),
  641. v1e)); // return last part of v1.
  642. } else if (div == DIVIDE_RIGHT) {
  643. if (!std::equal(
  644. v2b, v2e,
  645. v1e - (v2e - v2b))) { // v2 must be identical to last part of v1.
  646. KALDI_ERR << "Cannot divide, data mismatch";
  647. }
  648. return CompactLatticeWeightTpl<WeightType, IntType>(
  649. w, std::vector<IntType>(
  650. v1b, v1e - (v2e - v2b))); // return first part of v1.
  651. } else {
  652. KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY";
  653. }
  654. return CompactLatticeWeightTpl<WeightType,
  655. IntType>::Zero(); // keep compiler happy.
  656. }
  657. template <class WeightType, class IntType>
  658. inline std::ostream& operator<<(
  659. std::ostream& strm, const CompactLatticeWeightTpl<WeightType, IntType>& w) {
  660. strm << w.Weight();
  661. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  662. strm << FLAGS_fst_weight_separator[0]; // comma by default.
  663. for (size_t i = 0; i < w.String().size(); i++) {
  664. strm << w.String()[i];
  665. if (i + 1 < w.String().size())
  666. strm << kStringSeparator; // '_'; defined in string-weight.h in OpenFst
  667. // code.
  668. }
  669. return strm;
  670. }
  671. template <class WeightType, class IntType>
  672. inline std::istream& operator>>(
  673. std::istream& strm, CompactLatticeWeightTpl<WeightType, IntType>& w) {
  674. std::string s;
  675. strm >> s;
  676. if (strm.fail()) {
  677. return strm;
  678. }
  679. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  680. size_t pos = s.find_last_of(FLAGS_fst_weight_separator); // normally ","
  681. if (pos == std::string::npos) {
  682. strm.clear(std::ios::badbit);
  683. return strm;
  684. }
  685. // get parts of str before and after the separator (default: ',');
  686. std::string s1(s, 0, pos), s2(s, pos + 1);
  687. std::istringstream strm1(s1);
  688. WeightType weight;
  689. strm1 >> weight;
  690. w.SetWeight(weight);
  691. if (strm1.fail() || !strm1.eof()) {
  692. strm.clear(std::ios::badbit);
  693. return strm;
  694. }
  695. // read string part.
  696. std::vector<IntType> string;
  697. const char* c = s2.c_str();
  698. while (*c != '\0') {
  699. if (*c == kStringSeparator) // '_'
  700. c++;
  701. char* c2;
  702. int64_t i = strtol(c, &c2, 10);
  703. if (c2 == c || static_cast<int64_t>(static_cast<IntType>(i)) != i) {
  704. strm.clear(std::ios::badbit);
  705. return strm;
  706. }
  707. c = c2;
  708. string.push_back(static_cast<IntType>(i));
  709. }
  710. w.SetString(string);
  711. return strm;
  712. }
  713. template <class BaseWeightType, class IntType>
  714. class CompactLatticeWeightCommonDivisorTpl {
  715. public:
  716. typedef CompactLatticeWeightTpl<BaseWeightType, IntType> Weight;
  717. Weight operator()(const Weight& w1, const Weight& w2) const {
  718. // First find longest common prefix of the strings.
  719. typename std::vector<IntType>::const_iterator s1b = w1.String().begin(),
  720. s1e = w1.String().end(),
  721. s2b = w2.String().begin(),
  722. s2e = w2.String().end();
  723. while (s1b < s1e && s2b < s2e && *s1b == *s2b) {
  724. s1b++;
  725. s2b++;
  726. }
  727. return Weight(Plus(w1.Weight(), w2.Weight()),
  728. std::vector<IntType>(w1.String().begin(), s1b));
  729. }
  730. };
  731. /** Scales the pair (a, b) of floating-point weights inside a
  732. CompactLatticeWeight by premultiplying it (viewed as a vector)
  733. by a 2x2 matrix "scale".
  734. Assumes there is a ScaleTupleWeight function that applies to "Weight";
  735. this currently only works if Weight equals LatticeWeightTpl<FloatType>
  736. for some FloatType.
  737. */
  738. template <class Weight, class IntType, class ScaleFloatType>
  739. inline CompactLatticeWeightTpl<Weight, IntType> ScaleTupleWeight(
  740. const CompactLatticeWeightTpl<Weight, IntType>& w,
  741. const std::vector<std::vector<ScaleFloatType> >& scale) {
  742. return CompactLatticeWeightTpl<Weight, IntType>(
  743. Weight(ScaleTupleWeight(w.Weight(), scale)), w.String());
  744. }
  745. /** Define some ConvertLatticeWeight functions that are used in various lattice
  746. conversions... make them all templates, some with no arguments, since some
  747. must be templates.*/
  748. template <class Float1, class Float2>
  749. inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1>& w_in,
  750. LatticeWeightTpl<Float2>* w_out) {
  751. w_out->SetValue1(w_in.Value1());
  752. w_out->SetValue2(w_in.Value2());
  753. }
  754. template <class Float1, class Float2, class Int>
  755. inline void ConvertLatticeWeight(
  756. const CompactLatticeWeightTpl<LatticeWeightTpl<Float1>, Int>& w_in,
  757. CompactLatticeWeightTpl<LatticeWeightTpl<Float2>, Int>* w_out) {
  758. LatticeWeightTpl<Float2> weight2(w_in.Weight().Value1(),
  759. w_in.Weight().Value2());
  760. w_out->SetWeight(weight2);
  761. w_out->SetString(w_in.String());
  762. }
  763. // to convert from Lattice to standard FST
  764. template <class Float1, class Float2>
  765. inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1>& w_in,
  766. TropicalWeightTpl<Float2>* w_out) {
  767. TropicalWeightTpl<Float2> w1(w_in.Value1());
  768. TropicalWeightTpl<Float2> w2(w_in.Value2());
  769. *w_out = Times(w1, w2);
  770. }
  771. template <class Float>
  772. inline double ConvertToCost(const LatticeWeightTpl<Float>& w) {
  773. return static_cast<double>(w.Value1()) + static_cast<double>(w.Value2());
  774. }
  775. template <class Float, class Int>
  776. inline double ConvertToCost(
  777. const CompactLatticeWeightTpl<LatticeWeightTpl<Float>, Int>& w) {
  778. return static_cast<double>(w.Weight().Value1()) +
  779. static_cast<double>(w.Weight().Value2());
  780. }
  781. template <class Float>
  782. inline double ConvertToCost(const TropicalWeightTpl<Float>& w) {
  783. return w.Value();
  784. }
  785. } // namespace fst
  786. #endif // KALDI_FSTEXT_LATTICE_WEIGHT_H_