You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

892 lines
31 KiB

  1. // fstext/lattice-weight.h
  2. // Copyright 2009-2012 Microsoft Corporation
  3. // Johns Hopkins University (author: Daniel Povey)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_
  19. #define KALDI_FSTEXT_LATTICE_WEIGHT_H_
  20. #include <algorithm>
  21. #include <limits>
  22. #include <string>
  23. #include <vector>
  24. #include "base/kaldi-types.h"
  25. #include "base/kaldi-common.h"
  26. #include "fst/fstlib.h"
  27. using namespace kaldi_type;
  28. namespace fst {
  29. // Declare weight type for lattice... will import to namespace kaldi. has two
  30. // members, value1_ and value2_, of type BaseFloat (normally equals float). It
  31. // is basically the same as the tropical semiring on value1_+value2_, except it
  32. // keeps track of a and b separately. More precisely, it is equivalent to the
  33. // lexicographic semiring on (value1_+value2_), (value1_-value2_)
  34. template <class FloatType>
  35. class LatticeWeightTpl;
  36. template <class FloatType>
  37. inline std::ostream& operator<<(std::ostream& strm,
  38. const LatticeWeightTpl<FloatType>& w);
  39. template <class FloatType>
  40. inline std::istream& operator>>(std::istream& strm,
  41. LatticeWeightTpl<FloatType>& w);
  42. template <class FloatType>
  43. class LatticeWeightTpl {
  44. public:
  45. typedef FloatType T; // normally float.
  46. typedef LatticeWeightTpl ReverseWeight;
  47. inline T Value1() const { return value1_; }
  48. inline T Value2() const { return value2_; }
  49. inline void SetValue1(T f) { value1_ = f; }
  50. inline void SetValue2(T f) { value2_ = f; }
  51. LatticeWeightTpl() : value1_{}, value2_{} {}
  52. LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}
  53. LatticeWeightTpl(const LatticeWeightTpl& other)
  54. : value1_(other.value1_), value2_(other.value2_) {}
  55. LatticeWeightTpl& operator=(const LatticeWeightTpl& w) {
  56. value1_ = w.value1_;
  57. value2_ = w.value2_;
  58. return *this;
  59. }
  60. LatticeWeightTpl<FloatType> Reverse() const { return *this; }
  61. static const LatticeWeightTpl Zero() {
  62. return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
  63. std::numeric_limits<T>::infinity());
  64. }
  65. static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); }
  66. static const std::string& Type() {
  67. static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8");
  68. return type;
  69. }
  70. static const LatticeWeightTpl NoWeight() {
  71. return LatticeWeightTpl(std::numeric_limits<FloatType>::quiet_NaN(),
  72. std::numeric_limits<FloatType>::quiet_NaN());
  73. }
  74. bool Member() const {
  75. // value1_ == value1_ tests for NaN.
  76. // also test for no -inf, and either both or neither
  77. // must be +inf, and
  78. if (value1_ != value1_ || value2_ != value2_) return false; // NaN
  79. if (value1_ == -std::numeric_limits<T>::infinity() ||
  80. value2_ == -std::numeric_limits<T>::infinity())
  81. return false; // -infty not allowed
  82. if (value1_ == std::numeric_limits<T>::infinity() ||
  83. value2_ == std::numeric_limits<T>::infinity()) {
  84. if (value1_ != std::numeric_limits<T>::infinity() ||
  85. value2_ != std::numeric_limits<T>::infinity())
  86. return false; // both must be +infty;
  87. // this is necessary so that the semiring has only one zero.
  88. }
  89. return true;
  90. }
  91. LatticeWeightTpl Quantize(float delta = kDelta) const {
  92. if (value1_ + value2_ == -std::numeric_limits<T>::infinity()) {
  93. return LatticeWeightTpl(-std::numeric_limits<T>::infinity(),
  94. -std::numeric_limits<T>::infinity());
  95. } else if (value1_ + value2_ == std::numeric_limits<T>::infinity()) {
  96. return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
  97. std::numeric_limits<T>::infinity());
  98. } else if (value1_ + value2_ != value1_ + value2_) { // NaN
  99. return LatticeWeightTpl(value1_ + value2_, value1_ + value2_);
  100. } else {
  101. return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta,
  102. floor(value2_ / delta + 0.5F) * delta);
  103. }
  104. }
  105. static constexpr uint64 Properties() {
  106. return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
  107. }
  108. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  109. // not Kaldi-style, I/O.
  110. std::istream& Read(std::istream& strm) {
  111. // Always read/write as float, even if T is double,
  112. // so we can use OpenFst-style read/write and still maintain
  113. // compatibility when compiling with different FloatTypes
  114. ReadType(strm, &value1_);
  115. ReadType(strm, &value2_);
  116. return strm;
  117. }
  118. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  119. // not Kaldi-style, I/O.
  120. std::ostream& Write(std::ostream& strm) const {
  121. WriteType(strm, value1_);
  122. WriteType(strm, value2_);
  123. return strm;
  124. }
  125. size_t Hash() const {
  126. size_t ans;
  127. union {
  128. T f;
  129. size_t s;
  130. } u;
  131. u.s = 0;
  132. u.f = value1_;
  133. ans = u.s;
  134. u.f = value2_;
  135. ans += u.s;
  136. return ans;
  137. }
  138. protected:
  139. inline static void WriteFloatType(std::ostream& strm, const T& f) {
  140. if (f == std::numeric_limits<T>::infinity())
  141. strm << "Infinity";
  142. else if (f == -std::numeric_limits<T>::infinity())
  143. strm << "-Infinity";
  144. else if (f != f)
  145. strm << "BadNumber";
  146. else
  147. strm << f;
  148. }
  149. // Internal helper function, used in ReadNoParen.
  150. inline static void ReadFloatType(std::istream& strm, T& f) { // NOLINT
  151. std::string s;
  152. strm >> s;
  153. if (s == "Infinity") {
  154. f = std::numeric_limits<T>::infinity();
  155. } else if (s == "-Infinity") {
  156. f = -std::numeric_limits<T>::infinity();
  157. } else if (s == "BadNumber") {
  158. f = std::numeric_limits<T>::quiet_NaN();
  159. } else {
  160. char* p;
  161. f = strtod(s.c_str(), &p);
  162. if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit);
  163. }
  164. }
  165. // Reads LatticeWeight when there are no parentheses around pair terms...
  166. // currently the only form supported.
  167. inline std::istream& ReadNoParen(std::istream& strm, char separator) {
  168. int c;
  169. do {
  170. c = strm.get();
  171. } while (isspace(c));
  172. std::string s1;
  173. while (c != separator) {
  174. if (c == EOF) {
  175. strm.clear(std::ios::badbit);
  176. return strm;
  177. }
  178. s1 += c;
  179. c = strm.get();
  180. }
  181. std::istringstream strm1(s1);
  182. ReadFloatType(strm1, value1_); // ReadFloatType is class member function
  183. // read second element
  184. ReadFloatType(strm, value2_);
  185. return strm;
  186. }
  187. friend std::istream& operator>>
  188. <FloatType>(std::istream&, LatticeWeightTpl<FloatType>&);
  189. friend std::ostream& operator<< <FloatType>(
  190. std::ostream&, const LatticeWeightTpl<FloatType>&);
  191. private:
  192. T value1_;
  193. T value2_;
  194. };
  195. /* ScaleTupleWeight is a function defined for LatticeWeightTpl and
  196. CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2
  197. matrix. Used, for example, in applying acoustic scaling.
  198. */
  199. template <class FloatType, class ScaleFloatType>
  200. inline LatticeWeightTpl<FloatType> ScaleTupleWeight(
  201. const LatticeWeightTpl<FloatType>& w,
  202. const std::vector<std::vector<ScaleFloatType> >& scale) {
  203. // Without the next special case we'd get NaNs from infinity * 0
  204. if (w.Value1() == std::numeric_limits<FloatType>::infinity())
  205. return LatticeWeightTpl<FloatType>::Zero();
  206. return LatticeWeightTpl<FloatType>(
  207. scale[0][0] * w.Value1() + scale[0][1] * w.Value2(),
  208. scale[1][0] * w.Value1() + scale[1][1] * w.Value2());
  209. }
  210. /* For testing purposes and in case it's ever useful, we define a similar
  211. function to apply to LexicographicWeight and the like, templated on
  212. TropicalWeight<float> etc.; we use PairWeight which is the base class of
  213. LexicographicWeight.
  214. */
  215. template <class FloatType, class ScaleFloatType>
  216. inline PairWeight<TropicalWeightTpl<FloatType>, TropicalWeightTpl<FloatType> >
  217. ScaleTupleWeight(const PairWeight<TropicalWeightTpl<FloatType>,
  218. TropicalWeightTpl<FloatType> >& w,
  219. const std::vector<std::vector<ScaleFloatType> >& scale) {
  220. typedef TropicalWeightTpl<FloatType> BaseType;
  221. typedef PairWeight<BaseType, BaseType> PairType;
  222. const BaseType zero = BaseType::Zero();
  223. // Without the next special case we'd get NaNs from infinity * 0
  224. if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero);
  225. FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value();
  226. return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2),
  227. BaseType(scale[1][0] * f1 + scale[1][1] * f2));
  228. }
  229. template <class FloatType>
  230. inline bool operator==(const LatticeWeightTpl<FloatType>& wa,
  231. const LatticeWeightTpl<FloatType>& wb) {
  232. // Volatile qualifier thwarts over-aggressive compiler optimizations
  233. // that lead to problems esp. with NaturalLess().
  234. volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
  235. vb2 = wb.Value2();
  236. return (va1 == vb1 && va2 == vb2);
  237. }
  238. template <class FloatType>
  239. inline bool operator!=(const LatticeWeightTpl<FloatType>& wa,
  240. const LatticeWeightTpl<FloatType>& wb) {
  241. // Volatile qualifier thwarts over-aggressive compiler optimizations
  242. // that lead to problems esp. with NaturalLess().
  243. volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
  244. vb2 = wb.Value2();
  245. return (va1 != vb1 || va2 != vb2);
  246. }
  247. // We define a Compare function LatticeWeightTpl even though it's
  248. // not required by the semiring standard-- it's just more efficient
  249. // to do it this way rather than using the NaturalLess template.
  250. /// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
  251. template <class FloatType>
  252. inline int Compare(const LatticeWeightTpl<FloatType>& w1,
  253. const LatticeWeightTpl<FloatType>& w2) {
  254. FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2();
  255. if (f1 < f2) { // having smaller cost means you're larger
  256. return 1;
  257. } else if (f1 > f2) { // in the semiring [higher probability]
  258. return -1;
  259. } else if (w1.Value1() < w2.Value1()) {
  260. // mathematically we should be comparing (w1.value1_-w1.value2_ <
  261. // w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ =
  262. // w2.value1_+w2.value2_ to both sides and divide by two, and we get the
  263. // simpler equivalent form w1.value1_ < w2.value1_.
  264. return 1;
  265. } else if (w1.Value1() > w2.Value1()) {
  266. return -1;
  267. } else {
  268. return 0;
  269. }
  270. }
  271. template <class FloatType>
  272. inline LatticeWeightTpl<FloatType> Plus(const LatticeWeightTpl<FloatType>& w1,
  273. const LatticeWeightTpl<FloatType>& w2) {
  274. return (Compare(w1, w2) >= 0 ? w1 : w2);
  275. }
  276. // For efficiency, override the NaturalLess template class.
  277. template <class FloatType>
  278. class NaturalLess<LatticeWeightTpl<FloatType> > {
  279. public:
  280. typedef LatticeWeightTpl<FloatType> Weight;
  281. NaturalLess() {}
  282. bool operator()(const Weight& w1, const Weight& w2) const {
  283. // NaturalLess is a negative order (opposite to normal ordering).
  284. // This operator () corresponds to "<" in the negative order, which
  285. // corresponds to the ">" in the normal order.
  286. return (Compare(w1, w2) == 1);
  287. }
  288. };
  289. template <>
  290. class NaturalLess<LatticeWeightTpl<float> > {
  291. public:
  292. typedef LatticeWeightTpl<float> Weight;
  293. NaturalLess() {}
  294. bool operator()(const Weight& w1, const Weight& w2) const {
  295. // NaturalLess is a negative order (opposite to normal ordering).
  296. // This operator () corresponds to "<" in the negative order, which
  297. // corresponds to the ">" in the normal order.
  298. return (Compare(w1, w2) == 1);
  299. }
  300. };
  301. template <>
  302. class NaturalLess<LatticeWeightTpl<double> > {
  303. public:
  304. typedef LatticeWeightTpl<double> Weight;
  305. NaturalLess() {}
  306. bool operator()(const Weight& w1, const Weight& w2) const {
  307. // NaturalLess is a negative order (opposite to normal ordering).
  308. // This operator () corresponds to "<" in the negative order, which
  309. // corresponds to the ">" in the normal order.
  310. return (Compare(w1, w2) == 1);
  311. }
  312. };
  313. template <class FloatType>
  314. inline LatticeWeightTpl<FloatType> Times(
  315. const LatticeWeightTpl<FloatType>& w1,
  316. const LatticeWeightTpl<FloatType>& w2) {
  317. return LatticeWeightTpl<FloatType>(w1.Value1() + w2.Value1(),
  318. w1.Value2() + w2.Value2());
  319. }
  320. // divide w1 by w2 (on left/right/any doesn't matter as
  321. // commutative).
  322. template <class FloatType>
  323. inline LatticeWeightTpl<FloatType> Divide(const LatticeWeightTpl<FloatType>& w1,
  324. const LatticeWeightTpl<FloatType>& w2,
  325. DivideType typ = DIVIDE_ANY) {
  326. typedef FloatType T;
  327. T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2();
  328. if (a != a || b != b || a == -std::numeric_limits<T>::infinity() ||
  329. b == -std::numeric_limits<T>::infinity()) {
  330. KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. "
  331. << "[dividing by zero?] Returning zero";
  332. return LatticeWeightTpl<T>::Zero();
  333. }
  334. if (a == std::numeric_limits<T>::infinity() ||
  335. b == std::numeric_limits<T>::infinity())
  336. return LatticeWeightTpl<T>::Zero(); // not a valid number if only one is
  337. // infinite.
  338. return LatticeWeightTpl<T>(a, b);
  339. }
  340. template <class FloatType>
  341. inline bool ApproxEqual(const LatticeWeightTpl<FloatType>& w1,
  342. const LatticeWeightTpl<FloatType>& w2,
  343. float delta = kDelta) {
  344. if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2())
  345. return true; // handles Zero().
  346. return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <=
  347. delta);
  348. }
  349. template <class FloatType>
  350. inline std::ostream& operator<<(std::ostream& strm,
  351. const LatticeWeightTpl<FloatType>& w) {
  352. LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value1());
  353. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  354. strm << FLAGS_fst_weight_separator[0]; // comma by default;
  355. // may or may not be settable from Kaldi programs.
  356. LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value2());
  357. return strm;
  358. }
  359. template <class FloatType>
  360. inline std::istream& operator>>(std::istream& strm,
  361. LatticeWeightTpl<FloatType>& w1) {
  362. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  363. // separator defaults to ','
  364. return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]);
  365. }
  366. // CompactLattice will be an acceptor (accepting the words/output-symbols),
  367. // with the weights and input-symbol-seqs on the arcs.
  368. // There must be a total order on W. We assume for the sake of efficiency
  369. // that there is a function
  370. // Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and
  371. // zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 :
  372. // w2).
  373. template <class WeightType, class IntType>
  374. class CompactLatticeWeightTpl {
  375. public:
  376. typedef WeightType W;
  377. typedef CompactLatticeWeightTpl<WeightType, IntType> ReverseWeight;
  378. // Plus is like LexicographicWeight on the pair (weight_, string_), but where
  379. // we use standard lexicographic order on string_ [this is not the same as
  380. // NaturalLess on the StringWeight equivalent, which does not define a
  381. // total order].
  382. // Times, Divide obvious... (support both left & right division..)
  383. // CommonDivisor would need to be coded separately.
  384. CompactLatticeWeightTpl() {}
  385. CompactLatticeWeightTpl(const WeightType& w, const std::vector<IntType>& s)
  386. : weight_(w), string_(s) {}
  387. CompactLatticeWeightTpl& operator=(
  388. const CompactLatticeWeightTpl<WeightType, IntType>& w) {
  389. weight_ = w.weight_;
  390. string_ = w.string_;
  391. return *this;
  392. }
  393. const W& Weight() const { return weight_; }
  394. const std::vector<IntType>& String() const { return string_; }
  395. void SetWeight(const W& w) { weight_ = w; }
  396. void SetString(const std::vector<IntType>& s) { string_ = s; }
  397. static const CompactLatticeWeightTpl<WeightType, IntType> Zero() {
  398. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::Zero(),
  399. std::vector<IntType>());
  400. }
  401. static const CompactLatticeWeightTpl<WeightType, IntType> One() {
  402. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::One(),
  403. std::vector<IntType>());
  404. }
  405. inline static std::string GetIntSizeString() {
  406. char buf[2];
  407. buf[0] = '0' + sizeof(IntType);
  408. buf[1] = '\0';
  409. return buf;
  410. }
  411. static const std::string& Type() {
  412. static const std::string type =
  413. "compact" + WeightType::Type() + GetIntSizeString();
  414. return type;
  415. }
  416. static const CompactLatticeWeightTpl<WeightType, IntType> NoWeight() {
  417. return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::NoWeight(),
  418. std::vector<IntType>());
  419. }
  420. CompactLatticeWeightTpl<WeightType, IntType> Reverse() const {
  421. size_t s = string_.size();
  422. std::vector<IntType> v(s);
  423. for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1];
  424. return CompactLatticeWeightTpl<WeightType, IntType>(weight_, v);
  425. }
  426. bool Member() const {
  427. // a semiring has only one zero, this is the important property
  428. // we're trying to maintain here. So force string_ to be empty if
  429. // w_ == zero.
  430. if (!weight_.Member()) return false;
  431. if (weight_ == WeightType::Zero())
  432. return string_.empty();
  433. else
  434. return true;
  435. }
  436. CompactLatticeWeightTpl Quantize(float delta = kDelta) const {
  437. return CompactLatticeWeightTpl(weight_.Quantize(delta), string_);
  438. }
  439. static constexpr uint64 Properties() {
  440. return kLeftSemiring | kRightSemiring | kPath | kIdempotent;
  441. }
  442. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  443. // not Kaldi-style, I/O.
  444. std::istream& Read(std::istream& strm) {
  445. weight_.Read(strm);
  446. if (strm.fail()) {
  447. return strm;
  448. }
  449. int32 sz;
  450. ReadType(strm, &sz);
  451. if (strm.fail()) {
  452. return strm;
  453. }
  454. if (sz < 0) {
  455. KALDI_WARN << "Negative string size! Read failure";
  456. strm.clear(std::ios::badbit);
  457. return strm;
  458. }
  459. string_.resize(sz);
  460. for (int32 i = 0; i < sz; i++) {
  461. ReadType(strm, &(string_[i]));
  462. }
  463. return strm;
  464. }
  465. // This is used in OpenFst for binary I/O. This is OpenFst-style,
  466. // not Kaldi-style, I/O.
  467. std::ostream& Write(std::ostream& strm) const {
  468. weight_.Write(strm);
  469. if (strm.fail()) {
  470. return strm;
  471. }
  472. int32 sz = static_cast<int32>(string_.size());
  473. WriteType(strm, sz);
  474. for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]);
  475. return strm;
  476. }
  477. size_t Hash() const {
  478. size_t ans = weight_.Hash();
  479. // any weird numbers here are largish primes
  480. size_t sz = string_.size(), mult = 6967;
  481. for (size_t i = 0; i < sz; i++) {
  482. ans += string_[i] * mult;
  483. mult *= 7499;
  484. }
  485. return ans;
  486. }
  487. private:
  488. W weight_;
  489. std::vector<IntType> string_;
  490. };
  491. template <class WeightType, class IntType>
  492. inline bool operator==(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  493. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  494. return (w1.Weight() == w2.Weight() && w1.String() == w2.String());
  495. }
  496. template <class WeightType, class IntType>
  497. inline bool operator!=(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  498. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  499. return (w1.Weight() != w2.Weight() || w1.String() != w2.String());
  500. }
  501. template <class WeightType, class IntType>
  502. inline bool ApproxEqual(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  503. const CompactLatticeWeightTpl<WeightType, IntType>& w2,
  504. float delta = kDelta) {
  505. return (ApproxEqual(w1.Weight(), w2.Weight(), delta) &&
  506. w1.String() == w2.String());
  507. }
  508. // Compare is not part of the standard for weight types, but used internally for
  509. // efficiency. The comparison here first compares the weight; if this is the
  510. // same, it compares the string. The comparison on strings is: first compare
  511. // the length, if this is the same, use lexicographical order. We can't just
  512. // use the lexicographical order because this would destroy the distributive
  513. // property of multiplication over addition, taking into account that addition
  514. // uses Compare. The string element of "Compare" isn't super-important in
  515. // practical terms; it's only needed to ensure that Plus always give consistent
  516. // answers and is symmetric. It's essentially for tie-breaking, but we need to
  517. // make sure all the semiring axioms are satisfied otherwise OpenFst might
  518. // break.
  519. template <class WeightType, class IntType>
  520. inline int Compare(const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  521. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  522. int c1 = Compare(w1.Weight(), w2.Weight());
  523. if (c1 != 0) return c1;
  524. int l1 = w1.String().size(), l2 = w2.String().size();
  525. // Use opposite order on the string lengths, so that if the costs are the
  526. // same, the shorter string wins.
  527. if (l1 > l2)
  528. return -1;
  529. else if (l1 < l2)
  530. return 1;
  531. for (int i = 0; i < l1; i++) {
  532. if (w1.String()[i] < w2.String()[i])
  533. return -1;
  534. else if (w1.String()[i] > w2.String()[i])
  535. return 1;
  536. }
  537. return 0;
  538. }
  539. // For efficiency, override the NaturalLess template class.
  540. template <class FloatType, class IntType>
  541. class NaturalLess<
  542. CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> > {
  543. public:
  544. typedef CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> Weight;
  545. NaturalLess() {}
  546. bool operator()(const Weight& w1, const Weight& w2) const {
  547. // NaturalLess is a negative order (opposite to normal ordering).
  548. // This operator () corresponds to "<" in the negative order, which
  549. // corresponds to the ">" in the normal order.
  550. return (Compare(w1, w2) == 1);
  551. }
  552. };
  553. template <>
  554. class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> > {
  555. public:
  556. typedef CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> Weight;
  557. NaturalLess() {}
  558. bool operator()(const Weight& w1, const Weight& w2) const {
  559. // NaturalLess is a negative order (opposite to normal ordering).
  560. // This operator () corresponds to "<" in the negative order, which
  561. // corresponds to the ">" in the normal order.
  562. return (Compare(w1, w2) == 1);
  563. }
  564. };
  565. template <>
  566. class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> > {
  567. public:
  568. typedef CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> Weight;
  569. NaturalLess() {}
  570. bool operator()(const Weight& w1, const Weight& w2) const {
  571. // NaturalLess is a negative order (opposite to normal ordering).
  572. // This operator () corresponds to "<" in the negative order, which
  573. // corresponds to the ">" in the normal order.
  574. return (Compare(w1, w2) == 1);
  575. }
  576. };
  577. // Make sure Compare is defined for TropicalWeight, so everything works
  578. // if we substitute LatticeWeight for TropicalWeight.
  579. inline int Compare(const TropicalWeight& w1, const TropicalWeight& w2) {
  580. float f1 = w1.Value(), f2 = w2.Value();
  581. if (f1 == f2)
  582. return 0;
  583. else if (f1 > f2)
  584. return -1;
  585. else
  586. return 1;
  587. }
  588. template <class WeightType, class IntType>
  589. inline CompactLatticeWeightTpl<WeightType, IntType> Plus(
  590. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  591. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  592. return (Compare(w1, w2) >= 0 ? w1 : w2);
  593. }
  594. template <class WeightType, class IntType>
  595. inline CompactLatticeWeightTpl<WeightType, IntType> Times(
  596. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  597. const CompactLatticeWeightTpl<WeightType, IntType>& w2) {
  598. WeightType w = Times(w1.Weight(), w2.Weight());
  599. if (w == WeightType::Zero()) {
  600. return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
  601. // special case to ensure zero is unique
  602. } else {
  603. std::vector<IntType> v;
  604. v.resize(w1.String().size() + w2.String().size());
  605. typename std::vector<IntType>::iterator iter = v.begin();
  606. iter = std::copy(w1.String().begin(), w1.String().end(),
  607. iter); // returns end of first range.
  608. std::copy(w2.String().begin(), w2.String().end(), iter);
  609. return CompactLatticeWeightTpl<WeightType, IntType>(w, v);
  610. }
  611. }
  612. template <class WeightType, class IntType>
  613. inline CompactLatticeWeightTpl<WeightType, IntType> Divide(
  614. const CompactLatticeWeightTpl<WeightType, IntType>& w1,
  615. const CompactLatticeWeightTpl<WeightType, IntType>& w2,
  616. DivideType div = DIVIDE_ANY) {
  617. if (w1.Weight() == WeightType::Zero()) {
  618. if (w2.Weight() != WeightType::Zero()) {
  619. return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
  620. } else {
  621. KALDI_ERR << "Division by zero [0/0]";
  622. }
  623. } else if (w2.Weight() == WeightType::Zero()) {
  624. KALDI_ERR << "Error: division by zero";
  625. }
  626. WeightType w = Divide(w1.Weight(), w2.Weight());
  627. const std::vector<IntType> v1 = w1.String(), v2 = w2.String();
  628. if (v2.size() > v1.size()) {
  629. KALDI_ERR << "Cannot divide, length mismatch";
  630. }
  631. typename std::vector<IntType>::const_iterator v1b = v1.begin(),
  632. v1e = v1.end(),
  633. v2b = v2.begin(),
  634. v2e = v2.end();
  635. if (div == DIVIDE_LEFT) {
  636. if (!std::equal(v2b, v2e,
  637. v1b)) { // v2 must be identical to first part of v1.
  638. KALDI_ERR << "Cannot divide, data mismatch";
  639. }
  640. return CompactLatticeWeightTpl<WeightType, IntType>(
  641. w, std::vector<IntType>(v1b + (v2e - v2b),
  642. v1e)); // return last part of v1.
  643. } else if (div == DIVIDE_RIGHT) {
  644. if (!std::equal(
  645. v2b, v2e,
  646. v1e - (v2e - v2b))) { // v2 must be identical to last part of v1.
  647. KALDI_ERR << "Cannot divide, data mismatch";
  648. }
  649. return CompactLatticeWeightTpl<WeightType, IntType>(
  650. w, std::vector<IntType>(
  651. v1b, v1e - (v2e - v2b))); // return first part of v1.
  652. } else {
  653. KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY";
  654. }
  655. return CompactLatticeWeightTpl<WeightType,
  656. IntType>::Zero(); // keep compiler happy.
  657. }
  658. template <class WeightType, class IntType>
  659. inline std::ostream& operator<<(
  660. std::ostream& strm, const CompactLatticeWeightTpl<WeightType, IntType>& w) {
  661. strm << w.Weight();
  662. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  663. strm << FLAGS_fst_weight_separator[0]; // comma by default.
  664. for (size_t i = 0; i < w.String().size(); i++) {
  665. strm << w.String()[i];
  666. if (i + 1 < w.String().size())
  667. strm << kStringSeparator; // '_'; defined in string-weight.h in OpenFst
  668. // code.
  669. }
  670. return strm;
  671. }
  672. template <class WeightType, class IntType>
  673. inline std::istream& operator>>(
  674. std::istream& strm, CompactLatticeWeightTpl<WeightType, IntType>& w) {
  675. std::string s;
  676. strm >> s;
  677. if (strm.fail()) {
  678. return strm;
  679. }
  680. CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
  681. size_t pos = s.find_last_of(FLAGS_fst_weight_separator); // normally ","
  682. if (pos == std::string::npos) {
  683. strm.clear(std::ios::badbit);
  684. return strm;
  685. }
  686. // get parts of str before and after the separator (default: ',');
  687. std::string s1(s, 0, pos), s2(s, pos + 1);
  688. std::istringstream strm1(s1);
  689. WeightType weight;
  690. strm1 >> weight;
  691. w.SetWeight(weight);
  692. if (strm1.fail() || !strm1.eof()) {
  693. strm.clear(std::ios::badbit);
  694. return strm;
  695. }
  696. // read string part.
  697. std::vector<IntType> string;
  698. const char* c = s2.c_str();
  699. while (*c != '\0') {
  700. if (*c == kStringSeparator) // '_'
  701. c++;
  702. char* c2;
  703. int64_t i = strtol(c, &c2, 10);
  704. if (c2 == c || static_cast<int64_t>(static_cast<IntType>(i)) != i) {
  705. strm.clear(std::ios::badbit);
  706. return strm;
  707. }
  708. c = c2;
  709. string.push_back(static_cast<IntType>(i));
  710. }
  711. w.SetString(string);
  712. return strm;
  713. }
  714. template <class BaseWeightType, class IntType>
  715. class CompactLatticeWeightCommonDivisorTpl {
  716. public:
  717. typedef CompactLatticeWeightTpl<BaseWeightType, IntType> Weight;
  718. Weight operator()(const Weight& w1, const Weight& w2) const {
  719. // First find longest common prefix of the strings.
  720. typename std::vector<IntType>::const_iterator s1b = w1.String().begin(),
  721. s1e = w1.String().end(),
  722. s2b = w2.String().begin(),
  723. s2e = w2.String().end();
  724. while (s1b < s1e && s2b < s2e && *s1b == *s2b) {
  725. s1b++;
  726. s2b++;
  727. }
  728. return Weight(Plus(w1.Weight(), w2.Weight()),
  729. std::vector<IntType>(w1.String().begin(), s1b));
  730. }
  731. };
  732. /** Scales the pair (a, b) of floating-point weights inside a
  733. CompactLatticeWeight by premultiplying it (viewed as a vector)
  734. by a 2x2 matrix "scale".
  735. Assumes there is a ScaleTupleWeight function that applies to "Weight";
  736. this currently only works if Weight equals LatticeWeightTpl<FloatType>
  737. for some FloatType.
  738. */
  739. template <class Weight, class IntType, class ScaleFloatType>
  740. inline CompactLatticeWeightTpl<Weight, IntType> ScaleTupleWeight(
  741. const CompactLatticeWeightTpl<Weight, IntType>& w,
  742. const std::vector<std::vector<ScaleFloatType> >& scale) {
  743. return CompactLatticeWeightTpl<Weight, IntType>(
  744. Weight(ScaleTupleWeight(w.Weight(), scale)), w.String());
  745. }
  746. /** Define some ConvertLatticeWeight functions that are used in various lattice
  747. conversions... make them all templates, some with no arguments, since some
  748. must be templates.*/
  749. template <class Float1, class Float2>
  750. inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1>& w_in,
  751. LatticeWeightTpl<Float2>* w_out) {
  752. w_out->SetValue1(w_in.Value1());
  753. w_out->SetValue2(w_in.Value2());
  754. }
  755. template <class Float1, class Float2, class Int>
  756. inline void ConvertLatticeWeight(
  757. const CompactLatticeWeightTpl<LatticeWeightTpl<Float1>, Int>& w_in,
  758. CompactLatticeWeightTpl<LatticeWeightTpl<Float2>, Int>* w_out) {
  759. LatticeWeightTpl<Float2> weight2(w_in.Weight().Value1(),
  760. w_in.Weight().Value2());
  761. w_out->SetWeight(weight2);
  762. w_out->SetString(w_in.String());
  763. }
  764. // to convert from Lattice to standard FST
  765. template <class Float1, class Float2>
  766. inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1>& w_in,
  767. TropicalWeightTpl<Float2>* w_out) {
  768. TropicalWeightTpl<Float2> w1(w_in.Value1());
  769. TropicalWeightTpl<Float2> w2(w_in.Value2());
  770. *w_out = Times(w1, w2);
  771. }
  772. template <class Float>
  773. inline double ConvertToCost(const LatticeWeightTpl<Float>& w) {
  774. return static_cast<double>(w.Value1()) + static_cast<double>(w.Value2());
  775. }
  776. template <class Float, class Int>
  777. inline double ConvertToCost(
  778. const CompactLatticeWeightTpl<LatticeWeightTpl<Float>, Int>& w) {
  779. return static_cast<double>(w.Weight().Value1()) +
  780. static_cast<double>(w.Weight().Value2());
  781. }
  782. template <class Float>
  783. inline double ConvertToCost(const TropicalWeightTpl<Float>& w) {
  784. return w.Value();
  785. }
  786. } // namespace fst
  787. #endif // KALDI_FSTEXT_LATTICE_WEIGHT_H_