You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2033 lines
80 KiB

  1. // lat/lattice-functions.cc
  2. // Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
  3. // 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao
  4. // Weng;
  5. // Bagher BabaAli
  6. // 2013 Cisco Systems (author: Neha Agrawal) [code modified
  7. // from original code in ../gmmbin/gmm-rescore-lattice.cc]
  8. // 2014 Guoguo Chen
  9. // See ../../COPYING for clarification regarding multiple authors
  10. //
  11. // Licensed under the Apache License, Version 2.0 (the "License");
  12. // you may not use this file except in compliance with the License.
  13. // You may obtain a copy of the License at
  14. //
  15. // http://www.apache.org/licenses/LICENSE-2.0
  16. //
  17. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  18. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  19. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  20. // MERCHANTABLITY OR NON-INFRINGEMENT.
  21. // See the Apache 2 License for the specific language governing permissions and
  22. // limitations under the License.
  23. #include "lat/lattice-functions.h"
  24. // #include "hmm/transition-model.h"
  25. // #include "util/stl-utils.h"
  26. #include "base/kaldi-math.h"
  27. // #include "hmm/hmm-utils.h"
  28. namespace kaldi {
  29. using std::map;
  30. using std::vector;
  31. // void GetPerFrameAcousticCosts(const Lattice &nbest,
  32. // Vector<BaseFloat> *per_frame_loglikes) {
  33. // using namespace fst;
  34. // typedef Lattice::Arc::Weight Weight;
  35. // vector<BaseFloat> loglikes;
  36. //
  37. // int32 cur_state = nbest.Start();
  38. // int32 prev_frame = -1;
  39. // BaseFloat eps_acwt = 0.0;
  40. // while(1) {
  41. // Weight w = nbest.Final(cur_state);
  42. // if (w != Weight::Zero()) {
  43. // KALDI_ASSERT(nbest.NumArcs(cur_state) == 0);
  44. // if (per_frame_loglikes != NULL) {
  45. // SubVector<BaseFloat> subvec(&(loglikes[0]), loglikes.size());
  46. // Vector<BaseFloat> vec(subvec);
  47. // *per_frame_loglikes = vec;
  48. // }
  49. // break;
  50. // } else {
  51. // KALDI_ASSERT(nbest.NumArcs(cur_state) == 1);
  52. // fst::ArcIterator<Lattice> iter(nbest, cur_state);
  53. // const Lattice::Arc &arc = iter.Value();
  54. // BaseFloat acwt = arc.weight.Value2();
  55. // if (arc.ilabel != 0) {
  56. // if (eps_acwt > 0) {
  57. // acwt += eps_acwt;
  58. // eps_acwt = 0.0;
  59. // }
  60. // loglikes.push_back(acwt);
  61. // prev_frame++;
  62. // } else if (acwt == acwt){
  63. // if (prev_frame > -1) {
  64. // loglikes[prev_frame] += acwt;
  65. // } else {
  66. // eps_acwt += acwt;
  67. // }
  68. // }
  69. // cur_state = arc.nextstate;
  70. // }
  71. // }
  72. // }
  73. //
  74. // int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
  75. // if (!lat.Properties(fst::kTopSorted, true))
  76. // KALDI_ERR << "Input lattice must be topologically sorted.";
  77. // KALDI_ASSERT(lat.Start() == 0);
  78. // int32 num_states = lat.NumStates();
  79. // times->clear();
  80. // times->resize(num_states, -1);
  81. // (*times)[0] = 0;
  82. // for (int32 state = 0; state < num_states; state++) {
  83. // int32 cur_time = (*times)[state];
  84. // for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
  85. // aiter.Next()) {
  86. // const LatticeArc &arc = aiter.Value();
  87. //
  88. // if (arc.ilabel != 0) { // Non-epsilon input label on arc
  89. // // next time instance
  90. // if ((*times)[arc.nextstate] == -1) {
  91. // (*times)[arc.nextstate] = cur_time + 1;
  92. // } else {
  93. // KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1);
  94. // }
  95. // } else { // epsilon input label on arc
  96. // // Same time instance
  97. // if ((*times)[arc.nextstate] == -1)
  98. // (*times)[arc.nextstate] = cur_time;
  99. // else
  100. // KALDI_ASSERT((*times)[arc.nextstate] == cur_time);
  101. // }
  102. // }
  103. // }
  104. // return (*std::max_element(times->begin(), times->end()));
  105. // }
  106. //
  107. // int32 CompactLatticeStateTimes(const CompactLattice &lat,
  108. // vector<int32> *times) {
  109. // if (!lat.Properties(fst::kTopSorted, true))
  110. // KALDI_ERR << "Input lattice must be topologically sorted.";
  111. // KALDI_ASSERT(lat.Start() == 0);
  112. // int32 num_states = lat.NumStates();
  113. // times->clear();
  114. // times->resize(num_states, -1);
  115. // (*times)[0] = 0;
  116. // int32 utt_len = -1;
  117. // for (int32 state = 0; state < num_states; state++) {
  118. // int32 cur_time = (*times)[state];
  119. // for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done();
  120. // aiter.Next()) {
  121. // const CompactLatticeArc &arc = aiter.Value();
  122. // int32 arc_len = static_cast<int32>(arc.weight.String().size());
  123. // if ((*times)[arc.nextstate] == -1)
  124. // (*times)[arc.nextstate] = cur_time + arc_len;
  125. // else
  126. // KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len);
  127. // }
  128. // if (lat.Final(state) != CompactLatticeWeight::Zero()) {
  129. // int32 this_utt_len = (*times)[state] +
  130. // lat.Final(state).String().size(); if (utt_len == -1) utt_len =
  131. // this_utt_len; else {
  132. // if (this_utt_len != utt_len) {
  133. // KALDI_WARN << "Utterance does not "
  134. // "seem to have a consistent length.";
  135. // utt_len = std::max(utt_len, this_utt_len);
  136. // }
  137. // }
  138. // }
  139. // }
  140. // if (utt_len == -1) {
  141. // KALDI_WARN << "Utterance does not have a final-state.";
  142. // return 0;
  143. // }
  144. // return utt_len;
  145. // }
  146. //
  147. // bool ComputeCompactLatticeAlphas(const CompactLattice &clat,
  148. // vector<double> *alpha) {
  149. // using namespace fst;
  150. //
  151. // // typedef the arc, weight types
  152. // typedef CompactLattice::Arc Arc;
  153. // typedef Arc::Weight Weight;
  154. // typedef Arc::StateId StateId;
  155. //
  156. // //Make sure the lattice is topologically sorted.
  157. // if (clat.Properties(fst::kTopSorted, true) == 0) {
  158. // KALDI_WARN << "Input lattice must be topologically sorted.";
  159. // return false;
  160. // }
  161. // if (clat.Start() != 0) {
  162. // KALDI_WARN << "Input lattice must start from state 0.";
  163. // return false;
  164. // }
  165. //
  166. // int32 num_states = clat.NumStates();
  167. // (*alpha).resize(0);
  168. // (*alpha).resize(num_states, kLogZeroDouble);
  169. //
  170. // // Now propagate alphas forward. Note that we don't acount the weight of
  171. // the
  172. // // final state to alpha[final_state] -- we acount it to beta[final_state];
  173. // (*alpha)[0] = 0.0;
  174. // for (StateId s = 0; s < num_states; s++) {
  175. // double this_alpha = (*alpha)[s];
  176. // for (ArcIterator<CompactLattice> aiter(clat, s);
  177. // !aiter.Done(); aiter.Next()) {
  178. // const Arc &arc = aiter.Value();
  179. // double arc_like = -(arc.weight.Weight().Value1() +
  180. // arc.weight.Weight().Value2());
  181. // (*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate],
  182. // this_alpha + arc_like);
  183. // }
  184. // }
  185. //
  186. // return true;
  187. // }
  188. //
  189. // bool ComputeCompactLatticeBetas(const CompactLattice &clat,
  190. // vector<double> *beta) {
  191. // using namespace fst;
  192. //
  193. // // typedef the arc, weight types
  194. // typedef CompactLattice::Arc Arc;
  195. // typedef Arc::Weight Weight;
  196. // typedef Arc::StateId StateId;
  197. //
  198. // // Make sure the lattice is topologically sorted.
  199. // if (clat.Properties(fst::kTopSorted, true) == 0) {
  200. // KALDI_WARN << "Input lattice must be topologically sorted.";
  201. // return false;
  202. // }
  203. // if (clat.Start() != 0) {
  204. // KALDI_WARN << "Input lattice must start from state 0.";
  205. // return false;
  206. // }
  207. //
  208. // int32 num_states = clat.NumStates();
  209. // (*beta).resize(0);
  210. // (*beta).resize(num_states, kLogZeroDouble);
  211. //
  212. // // Now propagate betas backward. Note that beta[final_state] contains the
  213. // // weight of the final state in the lattice -- compare that with alpha.
  214. // for (StateId s = num_states-1; s >= 0; s--) {
  215. // Weight f = clat.Final(s);
  216. // double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
  217. // for (ArcIterator<CompactLattice> aiter(clat, s);
  218. // !aiter.Done(); aiter.Next()) {
  219. // const Arc &arc = aiter.Value();
  220. // double arc_like = -(arc.weight.Weight().Value1() +
  221. // arc.weight.Weight().Value2());
  222. // double arc_beta = (*beta)[arc.nextstate] + arc_like;
  223. // this_beta = LogAdd(this_beta, arc_beta);
  224. // }
  225. // (*beta)[s] = this_beta;
  226. // }
  227. //
  228. // return true;
  229. // }
  230. template <class LatType> // could be Lattice or CompactLattice
  231. bool PruneLattice(BaseFloat beam, LatType* lat) {
  232. typedef typename LatType::Arc Arc;
  233. typedef typename Arc::Weight Weight;
  234. typedef typename Arc::StateId StateId;
  235. KALDI_ASSERT(beam > 0.0);
  236. if (!lat->Properties(fst::kTopSorted, true)) {
  237. if (fst::TopSort(lat) == false) {
  238. KALDI_WARN << "Cycles detected in lattice";
  239. return false;
  240. }
  241. }
  242. // We assume states before "start" are not reachable, since
  243. // the lattice is topologically sorted.
  244. int32 start = lat->Start();
  245. int32 num_states = lat->NumStates();
  246. if (num_states == 0) return false;
  247. std::vector<double> forward_cost(
  248. num_states,
  249. std::numeric_limits<double>::infinity()); // viterbi forward.
  250. forward_cost[start] = 0.0; // lattice can't have cycles so couldn't be
  251. // less than this.
  252. double best_final_cost = std::numeric_limits<double>::infinity();
  253. // Update the forward probs.
  254. // Thanks to Jing Zheng for finding a bug here.
  255. for (int32 state = 0; state < num_states; state++) {
  256. double this_forward_cost = forward_cost[state];
  257. for (fst::ArcIterator<LatType> aiter(*lat, state); !aiter.Done();
  258. aiter.Next()) {
  259. const Arc& arc(aiter.Value());
  260. StateId nextstate = arc.nextstate;
  261. KALDI_ASSERT(nextstate > state && nextstate < num_states);
  262. double next_forward_cost = this_forward_cost + ConvertToCost(arc.weight);
  263. if (forward_cost[nextstate] > next_forward_cost)
  264. forward_cost[nextstate] = next_forward_cost;
  265. }
  266. Weight final_weight = lat->Final(state);
  267. double this_final_cost = this_forward_cost + ConvertToCost(final_weight);
  268. if (this_final_cost < best_final_cost) best_final_cost = this_final_cost;
  269. }
  270. int32 bad_state = lat->AddState(); // this state is not final.
  271. double cutoff = best_final_cost + beam;
  272. // Go backwards updating the backward probs (which share memory with the
  273. // forward probs), and pruning arcs and deleting final-probs. We prune arcs
  274. // by making them point to the non-final state "bad_state". We'll then use
  275. // Trim() to remove unnecessary arcs and states. [this is just easier than
  276. // doing it ourselves.]
  277. std::vector<double>& backward_cost(forward_cost);
  278. for (int32 state = num_states - 1; state >= 0; state--) {
  279. double this_forward_cost = forward_cost[state];
  280. double this_backward_cost = ConvertToCost(lat->Final(state));
  281. if (this_backward_cost + this_forward_cost > cutoff &&
  282. this_backward_cost != std::numeric_limits<double>::infinity())
  283. lat->SetFinal(state, Weight::Zero());
  284. for (fst::MutableArcIterator<LatType> aiter(lat, state); !aiter.Done();
  285. aiter.Next()) {
  286. Arc arc(aiter.Value());
  287. StateId nextstate = arc.nextstate;
  288. KALDI_ASSERT(nextstate > state && nextstate < num_states);
  289. double arc_cost = ConvertToCost(arc.weight),
  290. arc_backward_cost = arc_cost + backward_cost[nextstate],
  291. this_fb_cost = this_forward_cost + arc_backward_cost;
  292. if (arc_backward_cost < this_backward_cost)
  293. this_backward_cost = arc_backward_cost;
  294. if (this_fb_cost > cutoff) { // Prune the arc.
  295. arc.nextstate = bad_state;
  296. aiter.SetValue(arc);
  297. }
  298. }
  299. backward_cost[state] = this_backward_cost;
  300. }
  301. fst::Connect(lat);
  302. return (lat->NumStates() > 0);
  303. }
  304. // instantiate the template for lattice and CompactLattice.
  305. template bool PruneLattice(BaseFloat beam, Lattice* lat);
  306. template bool PruneLattice(BaseFloat beam, CompactLattice* lat);
  307. // BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
  308. // double *acoustic_like_sum) {
  309. // // Note, Posterior is defined as follows: Indexed [frame], then a list
  310. // // of (transition-id, posterior-probability) pairs.
  311. // // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > >
  312. // Posterior; using namespace fst; typedef Lattice::Arc Arc; typedef
  313. // Arc::Weight Weight; typedef Arc::StateId StateId;
  314. //
  315. // if (acoustic_like_sum) *acoustic_like_sum = 0.0;
  316. //
  317. // // Make sure the lattice is topologically sorted.
  318. // if (lat.Properties(fst::kTopSorted, true) == 0)
  319. // KALDI_ERR << "Input lattice must be topologically sorted.";
  320. // KALDI_ASSERT(lat.Start() == 0);
  321. //
  322. // int32 num_states = lat.NumStates();
  323. // vector<int32> state_times;
  324. // int32 max_time = LatticeStateTimes(lat, &state_times);
  325. // std::vector<double> alpha(num_states, kLogZeroDouble);
  326. // std::vector<double> &beta(alpha); // we re-use the same memory for
  327. // // this, but it's semantically distinct so we name it differently.
  328. // double tot_forward_prob = kLogZeroDouble;
  329. //
  330. // post->clear();
  331. // post->resize(max_time);
  332. //
  333. // alpha[0] = 0.0;
  334. // // Propagate alphas forward.
  335. // for (StateId s = 0; s < num_states; s++) {
  336. // double this_alpha = alpha[s];
  337. // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
  338. // const Arc &arc = aiter.Value();
  339. // double arc_like = -ConvertToCost(arc.weight);
  340. // alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha +
  341. // arc_like);
  342. // }
  343. // Weight f = lat.Final(s);
  344. // if (f != Weight::Zero()) {
  345. // double final_like = this_alpha - (f.Value1() + f.Value2());
  346. // tot_forward_prob = LogAdd(tot_forward_prob, final_like);
  347. // KALDI_ASSERT(state_times[s] == max_time &&
  348. // "Lattice is inconsistent (final-prob not at max_time)");
  349. // }
  350. // }
  351. // for (StateId s = num_states-1; s >= 0; s--) {
  352. // Weight f = lat.Final(s);
  353. // double this_beta = -(f.Value1() + f.Value2());
  354. // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
  355. // const Arc &arc = aiter.Value();
  356. // double arc_like = -ConvertToCost(arc.weight),
  357. // arc_beta = beta[arc.nextstate] + arc_like;
  358. // this_beta = LogAdd(this_beta, arc_beta);
  359. // int32 transition_id = arc.ilabel;
  360. //
  361. // // The following "if" is an optimization to avoid un-needed exp().
  362. // if (transition_id != 0 || acoustic_like_sum != NULL) {
  363. // double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
  364. //
  365. // if (transition_id != 0) // Arc has a transition-id on it [not
  366. // epsilon]
  367. // (*post)[state_times[s]].push_back(std::make_pair(transition_id,
  368. // static_cast<kaldi::BaseFloat>(posterior)));
  369. // if (acoustic_like_sum != NULL)
  370. // *acoustic_like_sum -= posterior * arc.weight.Value2();
  371. // }
  372. // }
  373. // if (acoustic_like_sum != NULL && f != Weight::Zero()) {
  374. // double final_logprob = - ConvertToCost(f),
  375. // posterior = Exp(alpha[s] + final_logprob - tot_forward_prob);
  376. // *acoustic_like_sum -= posterior * f.Value2();
  377. // }
  378. // beta[s] = this_beta;
  379. // }
  380. // double tot_backward_prob = beta[0];
  381. // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
  382. // KALDI_WARN << "Total forward probability over lattice = " <<
  383. // tot_forward_prob
  384. // << ", while total backward probability = " <<
  385. // tot_backward_prob;
  386. // }
  387. // // Now combine any posteriors with the same transition-id.
  388. // for (int32 t = 0; t < max_time; t++)
  389. // MergePairVectorSumming(&((*post)[t]));
  390. // return tot_backward_prob;
  391. // }
  392. //
  393. //
  394. // void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
  395. // const vector<int32> &silence_phones,
  396. // vector< std::set<int32> > *active_phones) {
  397. // KALDI_ASSERT(IsSortedAndUniq(silence_phones));
  398. // vector<int32> state_times;
  399. // int32 num_states = lat.NumStates();
  400. // int32 max_time = LatticeStateTimes(lat, &state_times);
  401. // active_phones->clear();
  402. // active_phones->resize(max_time);
  403. // for (int32 state = 0; state < num_states; state++) {
  404. // int32 cur_time = state_times[state];
  405. // for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
  406. // aiter.Next()) {
  407. // const LatticeArc &arc = aiter.Value();
  408. // if (arc.ilabel != 0) { // Non-epsilon arc
  409. // int32 phone = trans.TransitionIdToPhone(arc.ilabel);
  410. // if (!std::binary_search(silence_phones.begin(),
  411. // silence_phones.end(), phone))
  412. // (*active_phones)[cur_time].insert(phone);
  413. // }
  414. // } // end looping over arcs
  415. // } // end looping over states
  416. // }
  417. //
  418. // void ConvertLatticeToPhones(const TransitionModel &trans,
  419. // Lattice *lat) {
  420. // typedef LatticeArc Arc;
  421. // int32 num_states = lat->NumStates();
  422. // for (int32 state = 0; state < num_states; state++) {
  423. // for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
  424. // aiter.Next()) {
  425. // Arc arc(aiter.Value());
  426. // arc.olabel = 0; // remove any word.
  427. // if ((arc.ilabel != 0) // has a transition-id on input..
  428. // && (trans.TransitionIdToHmmState(arc.ilabel) == 0)
  429. // && (!trans.IsSelfLoop(arc.ilabel))) {
  430. // // && trans.IsFinal(arc.ilabel)) // there is one of these per
  431. // phone...
  432. // arc.olabel = trans.TransitionIdToPhone(arc.ilabel);
  433. // }
  434. // aiter.SetValue(arc);
  435. // } // end looping over arcs
  436. // } // end looping over states
  437. // }
  438. //
  439. //
  440. // static inline double LogAddOrMax(bool viterbi, double a, double b) {
  441. // if (viterbi)
  442. // return std::max(a, b);
  443. // else
  444. // return LogAdd(a, b);
  445. // }
  446. //
  447. // template<typename LatticeType>
  448. // double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
  449. // bool viterbi,
  450. // vector<double> *alpha,
  451. // vector<double> *beta) {
  452. // typedef typename LatticeType::Arc Arc;
  453. // typedef typename Arc::Weight Weight;
  454. // typedef typename Arc::StateId StateId;
  455. //
  456. // StateId num_states = lat.NumStates();
  457. // KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted);
  458. // KALDI_ASSERT(lat.Start() == 0);
  459. // alpha->clear();
  460. // beta->clear();
  461. // alpha->resize(num_states, kLogZeroDouble);
  462. // beta->resize(num_states, kLogZeroDouble);
  463. //
  464. // double tot_forward_prob = kLogZeroDouble;
  465. // (*alpha)[0] = 0.0;
  466. // // Propagate alphas forward.
  467. // for (StateId s = 0; s < num_states; s++) {
  468. // double this_alpha = (*alpha)[s];
  469. // for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
  470. // aiter.Next()) {
  471. // const Arc &arc = aiter.Value();
  472. // double arc_like = -ConvertToCost(arc.weight);
  473. // (*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate],
  474. // this_alpha + arc_like);
  475. // }
  476. // Weight f = lat.Final(s);
  477. // if (f != Weight::Zero()) {
  478. // double final_like = this_alpha - ConvertToCost(f);
  479. // tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like);
  480. // }
  481. // }
  482. // for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed.
  483. // double this_beta = -ConvertToCost(lat.Final(s));
  484. // for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
  485. // aiter.Next()) {
  486. // const Arc &arc = aiter.Value();
  487. // double arc_like = -ConvertToCost(arc.weight),
  488. // arc_beta = (*beta)[arc.nextstate] + arc_like;
  489. // this_beta = LogAddOrMax(viterbi, this_beta, arc_beta);
  490. // }
  491. // (*beta)[s] = this_beta;
  492. // }
  493. // double tot_backward_prob = (*beta)[lat.Start()];
  494. // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
  495. // KALDI_WARN << "Total forward probability over lattice = " <<
  496. // tot_forward_prob
  497. // << ", while total backward probability = " <<
  498. // tot_backward_prob;
  499. // }
  500. // // Split the difference when returning... they should be the same.
  501. // return 0.5 * (tot_backward_prob + tot_forward_prob);
  502. // }
  503. //
  504. // // instantiate the template for Lattice and CompactLattice
  505. // template
  506. // double ComputeLatticeAlphasAndBetas(const Lattice &lat,
  507. // bool viterbi,
  508. // vector<double> *alpha,
  509. // vector<double> *beta);
  510. //
  511. // template
  512. // double ComputeLatticeAlphasAndBetas(const CompactLattice &lat,
  513. // bool viterbi,
  514. // vector<double> *alpha,
  515. // vector<double> *beta);
  516. //
  517. //
  518. //
  519. // /// This is used in CompactLatticeLimitDepth.
  520. // struct LatticeArcRecord {
  521. // BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc,
  522. // // minus the overall best-cost of the lattice.
  523. // CompactLatticeArc::StateId state; // state in the lattice.
  524. // size_t arc; // arc index within the state.
  525. // bool operator < (const LatticeArcRecord &other) const {
  526. // return logprob < other.logprob;
  527. // }
  528. // };
  529. //
  530. // void CompactLatticeLimitDepth(int32 max_depth_per_frame,
  531. // CompactLattice *clat) {
  532. // typedef CompactLatticeArc Arc;
  533. // typedef Arc::Weight Weight;
  534. // typedef Arc::StateId StateId;
  535. //
  536. // if (clat->Start() == fst::kNoStateId) {
  537. // KALDI_WARN << "Limiting depth of empty lattice.";
  538. // return;
  539. // }
  540. // if (clat->Properties(fst::kTopSorted, true) == 0) {
  541. // if (!TopSort(clat))
  542. // KALDI_ERR << "Topological sorting of lattice failed.";
  543. // }
  544. //
  545. // vector<int32> state_times;
  546. // int32 T = CompactLatticeStateTimes(*clat, &state_times);
  547. //
  548. // // The alpha and beta quantities here are "viterbi" alphas and beta.
  549. // std::vector<double> alpha;
  550. // std::vector<double> beta;
  551. // bool viterbi = true;
  552. // double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi,
  553. // &alpha, &beta);
  554. //
  555. // std::vector<std::vector<LatticeArcRecord> > arc_records(T);
  556. //
  557. // StateId num_states = clat->NumStates();
  558. // for (StateId s = 0; s < num_states; s++) {
  559. // for (fst::ArcIterator<CompactLattice> aiter(*clat, s); !aiter.Done();
  560. // aiter.Next()) {
  561. // const Arc &arc = aiter.Value();
  562. // LatticeArcRecord arc_record;
  563. // arc_record.state = s;
  564. // arc_record.arc = aiter.Position();
  565. // arc_record.logprob =
  566. // (alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight))
  567. // - best_prob;
  568. // KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative.
  569. // int32 num_frames = arc.weight.String().size(), start_t =
  570. // state_times[s]; for (int32 t = start_t; t < start_t + num_frames; t++)
  571. // {
  572. // KALDI_ASSERT(t < T);
  573. // arc_records[t].push_back(arc_record);
  574. // }
  575. // }
  576. // }
  577. // StateId dead_state = clat->AddState(); // A non-coaccesible state which we
  578. // use
  579. // // to remove arcs (make them end
  580. // // there).
  581. // size_t max_depth = max_depth_per_frame;
  582. // for (int32 t = 0; t < T; t++) {
  583. // size_t size = arc_records[t].size();
  584. // if (size > max_depth) {
  585. // // we sort from worst to best, so we keep the later-numbered ones,
  586. // // and delete the lower-numbered ones.
  587. // size_t cutoff = size - max_depth;
  588. // std::nth_element(arc_records[t].begin(),
  589. // arc_records[t].begin() + cutoff,
  590. // arc_records[t].end());
  591. // for (size_t index = 0; index < cutoff; index++) {
  592. // LatticeArcRecord record(arc_records[t][index]);
  593. // fst::MutableArcIterator<CompactLattice> aiter(clat, record.state);
  594. // aiter.Seek(record.arc);
  595. // Arc arc = aiter.Value();
  596. // if (arc.nextstate != dead_state) { // not already killed.
  597. // arc.nextstate = dead_state;
  598. // aiter.SetValue(arc);
  599. // }
  600. // }
  601. // }
  602. // }
  603. // Connect(clat);
  604. // TopSortCompactLatticeIfNeeded(clat);
  605. // }
  606. //
  607. //
  608. // void TopSortCompactLatticeIfNeeded(CompactLattice *clat) {
  609. // if (clat->Properties(fst::kTopSorted, true) == 0) {
  610. // if (fst::TopSort(clat) == false) {
  611. // KALDI_ERR << "Topological sorting failed";
  612. // }
  613. // }
  614. // }
  615. //
  616. // void TopSortLatticeIfNeeded(Lattice *lat) {
  617. // if (lat->Properties(fst::kTopSorted, true) == 0) {
  618. // if (fst::TopSort(lat) == false) {
  619. // KALDI_ERR << "Topological sorting failed";
  620. // }
  621. // }
  622. // }
  623. //
  624. //
  625. // /// Returns the depth of the lattice, defined as the average number of
  626. // /// arcs crossing any given frame. Returns 1 for empty lattices.
  627. // /// Requires that input is topologically sorted.
  628. // BaseFloat CompactLatticeDepth(const CompactLattice &clat,
  629. // int32 *num_frames) {
  630. // typedef CompactLattice::Arc::StateId StateId;
  631. // if (clat.Properties(fst::kTopSorted, true) == 0) {
  632. // KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically
  633. // "
  634. // << "sorted.";
  635. // }
  636. // if (clat.Start() == fst::kNoStateId) {
  637. // *num_frames = 0;
  638. // return 1.0;
  639. // }
  640. // size_t num_arc_frames = 0;
  641. // int32 t;
  642. // {
  643. // vector<int32> state_times;
  644. // t = CompactLatticeStateTimes(clat, &state_times);
  645. // }
  646. // if (num_frames != NULL)
  647. // *num_frames = t;
  648. // for (StateId s = 0; s < clat.NumStates(); s++) {
  649. // for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
  650. // aiter.Next()) {
  651. // const CompactLatticeArc &arc = aiter.Value();
  652. // num_arc_frames += arc.weight.String().size();
  653. // }
  654. // num_arc_frames += clat.Final(s).String().size();
  655. // }
  656. // return num_arc_frames / static_cast<BaseFloat>(t);
  657. // }
  658. //
  659. //
  660. // void CompactLatticeDepthPerFrame(const CompactLattice &clat,
  661. // std::vector<int32> *depth_per_frame) {
  662. // typedef CompactLattice::Arc::StateId StateId;
  663. // if (clat.Properties(fst::kTopSorted, true) == 0) {
  664. // KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not "
  665. // << "topologically sorted.";
  666. // }
  667. // if (clat.Start() == fst::kNoStateId) {
  668. // depth_per_frame->clear();
  669. // return;
  670. // }
  671. // vector<int32> state_times;
  672. // int32 T = CompactLatticeStateTimes(clat, &state_times);
  673. //
  674. // depth_per_frame->clear();
  675. // if (T <= 0) {
  676. // return;
  677. // } else {
  678. // depth_per_frame->resize(T, 0);
  679. // for (StateId s = 0; s < clat.NumStates(); s++) {
  680. // int32 start_time = state_times[s];
  681. // for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
  682. // aiter.Next()) {
  683. // const CompactLatticeArc &arc = aiter.Value();
  684. // int32 len = arc.weight.String().size();
  685. // for (int32 t = start_time; t < start_time + len; t++) {
  686. // KALDI_ASSERT(t < T);
  687. // (*depth_per_frame)[t]++;
  688. // }
  689. // }
  690. // int32 final_len = clat.Final(s).String().size();
  691. // for (int32 t = start_time; t < start_time + final_len; t++) {
  692. // KALDI_ASSERT(t < T);
  693. // (*depth_per_frame)[t]++;
  694. // }
  695. // }
  696. // }
  697. // }
  698. //
  699. //
  700. //
  701. // void ConvertCompactLatticeToPhones(const TransitionModel &trans,
  702. // CompactLattice *clat) {
  703. // typedef CompactLatticeArc Arc;
  704. // typedef Arc::Weight Weight;
  705. // int32 num_states = clat->NumStates();
  706. // for (int32 state = 0; state < num_states; state++) {
  707. // for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
  708. // !aiter.Done();
  709. // aiter.Next()) {
  710. // Arc arc(aiter.Value());
  711. // std::vector<int32> phone_seq;
  712. // const std::vector<int32> &tid_seq = arc.weight.String();
  713. // for (std::vector<int32>::const_iterator iter = tid_seq.begin();
  714. // iter != tid_seq.end(); ++iter) {
  715. // if (trans.IsFinal(*iter))// note: there is one of these per phone...
  716. // phone_seq.push_back(trans.TransitionIdToPhone(*iter));
  717. // }
  718. // arc.weight.SetString(phone_seq);
  719. // aiter.SetValue(arc);
  720. // } // end looping over arcs
  721. // Weight f = clat->Final(state);
  722. // if (f != Weight::Zero()) {
  723. // std::vector<int32> phone_seq;
  724. // const std::vector<int32> &tid_seq = f.String();
  725. // for (std::vector<int32>::const_iterator iter = tid_seq.begin();
  726. // iter != tid_seq.end(); ++iter) {
  727. // if (trans.IsFinal(*iter))// note: there is one of these per phone...
  728. // phone_seq.push_back(trans.TransitionIdToPhone(*iter));
  729. // }
  730. // f.SetString(phone_seq);
  731. // clat->SetFinal(state, f);
  732. // }
  733. // } // end looping over states
  734. // }
  735. //
  736. // bool LatticeBoost(const TransitionModel &trans,
  737. // const std::vector<int32> &alignment,
  738. // const std::vector<int32> &silence_phones,
  739. // BaseFloat b,
  740. // BaseFloat max_silence_error,
  741. // Lattice *lat) {
  742. // TopSortLatticeIfNeeded(lat);
  743. //
  744. // // get all stored properties (test==false means don't test if not known).
  745. // uint64 props = lat->Properties(fst::kFstProperties,
  746. // false);
  747. //
  748. // KALDI_ASSERT(IsSortedAndUniq(silence_phones));
  749. // KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
  750. // vector<int32> state_times;
  751. // int32 num_states = lat->NumStates();
  752. // int32 num_frames = LatticeStateTimes(*lat, &state_times);
  753. // KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size()));
  754. // for (int32 state = 0; state < num_states; state++) {
  755. // int32 cur_time = state_times[state];
  756. // for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
  757. // aiter.Next()) {
  758. // LatticeArc arc = aiter.Value();
  759. // if (arc.ilabel != 0) { // Non-epsilon arc
  760. // if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
  761. // KALDI_WARN << "Lattice has out-of-range transition-ids: "
  762. // << "lattice/model mismatch?";
  763. // return false;
  764. // }
  765. // int32 phone = trans.TransitionIdToPhone(arc.ilabel),
  766. // ref_phone = trans.TransitionIdToPhone(alignment[cur_time]);
  767. // BaseFloat frame_error;
  768. // if (phone == ref_phone) {
  769. // frame_error = 0.0;
  770. // } else { // an error...
  771. // if (std::binary_search(silence_phones.begin(),
  772. // silence_phones.end(), phone))
  773. // frame_error = max_silence_error;
  774. // else
  775. // frame_error = 1.0;
  776. // }
  777. // BaseFloat delta_cost = -b * frame_error; // negative cost if
  778. // // frame is wrong, to boost likelihood of arcs with errors on them.
  779. // // Add this cost to the graph part.
  780. // arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
  781. // aiter.SetValue(arc);
  782. // }
  783. // }
  784. // }
  785. // // All we changed is the weights, so any properties that were
  786. // // known before, are still known, except for whether or not the
  787. // // lattice was weighted.
  788. // lat->SetProperties(props,
  789. // ~(fst::kWeighted|fst::kUnweighted));
  790. //
  791. // return true;
  792. // }
  793. //
  794. //
  795. //
  796. // BaseFloat LatticeForwardBackwardMpeVariants(
  797. // const TransitionModel &trans,
  798. // const std::vector<int32> &silence_phones,
  799. // const Lattice &lat,
  800. // const std::vector<int32> &num_ali,
  801. // std::string criterion,
  802. // bool one_silence_class,
  803. // Posterior *post) {
  804. // using namespace fst;
  805. // typedef Lattice::Arc Arc;
  806. // typedef Arc::Weight Weight;
  807. // typedef Arc::StateId StateId;
  808. //
  809. // KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr");
  810. // bool is_mpfe = (criterion == "mpfe");
  811. //
  812. // if (lat.Properties(fst::kTopSorted, true) == 0)
  813. // KALDI_ERR << "Input lattice must be topologically sorted.";
  814. // KALDI_ASSERT(lat.Start() == 0);
  815. //
  816. // int32 num_states = lat.NumStates();
  817. // vector<int32> state_times;
  818. // int32 max_time = LatticeStateTimes(lat, &state_times);
  819. // KALDI_ASSERT(max_time == static_cast<int32>(num_ali.size()));
  820. // std::vector<double> alpha(num_states, kLogZeroDouble),
  821. // alpha_smbr(num_states, 0), //forward variable for sMBR
  822. // beta(num_states, kLogZeroDouble),
  823. // beta_smbr(num_states, 0); //backward variable for sMBR
  824. //
  825. // double tot_forward_prob = kLogZeroDouble;
  826. // double tot_forward_score = 0;
  827. //
  828. // post->clear();
  829. // post->resize(max_time);
  830. //
  831. // alpha[0] = 0.0;
  832. // // First Pass Forward,
  833. // for (StateId s = 0; s < num_states; s++) {
  834. // double this_alpha = alpha[s];
  835. // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
  836. // const Arc &arc = aiter.Value();
  837. // double arc_like = -ConvertToCost(arc.weight);
  838. // alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha +
  839. // arc_like);
  840. // }
  841. // Weight f = lat.Final(s);
  842. // if (f != Weight::Zero()) {
  843. // double final_like = this_alpha - (f.Value1() + f.Value2());
  844. // tot_forward_prob = LogAdd(tot_forward_prob, final_like);
  845. // KALDI_ASSERT(state_times[s] == max_time &&
  846. // "Lattice is inconsistent (final-prob not at max_time)");
  847. // }
  848. // }
  849. // // First Pass Backward,
  850. // for (StateId s = num_states-1; s >= 0; s--) {
  851. // Weight f = lat.Final(s);
  852. // double this_beta = -(f.Value1() + f.Value2());
  853. // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
  854. // const Arc &arc = aiter.Value();
  855. // double arc_like = -ConvertToCost(arc.weight),
  856. // arc_beta = beta[arc.nextstate] + arc_like;
  857. // this_beta = LogAdd(this_beta, arc_beta);
  858. // }
  859. // beta[s] = this_beta;
  860. // }
  861. // // First Pass Forward-Backward Check
  862. // double tot_backward_prob = beta[0];
  863. // // may loose the condition somehow here 1e-6 (was 1e-8)
  864. // if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) {
  865. // KALDI_ERR << "Total forward probability over lattice = " <<
  866. // tot_forward_prob
  867. // << ", while total backward probability = " <<
  868. // tot_backward_prob;
  869. // }
  870. //
  871. // alpha_smbr[0] = 0.0;
  872. // // Second Pass Forward, calculate forward for MPFE/SMBR
  873. // for (StateId s = 0; s < num_states; s++) {
  874. // double this_alpha = alpha[s];
  875. // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
  876. // const Arc &arc = aiter.Value();
  877. // double arc_like = -ConvertToCost(arc.weight);
  878. // double frame_acc = 0.0;
  879. // if (arc.ilabel != 0) {
  880. // int32 cur_time = state_times[s];
  881. // int32 phone = trans.TransitionIdToPhone(arc.ilabel),
  882. // ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
  883. // bool phone_is_sil = std::binary_search(silence_phones.begin(),
  884. // silence_phones.end(),
  885. // phone),
  886. // ref_phone_is_sil = std::binary_search(silence_phones.begin(),
  887. // silence_phones.end(),
  888. // ref_phone),
  889. // both_sil = phone_is_sil && ref_phone_is_sil;
  890. // if (!is_mpfe) { // smbr.
  891. // int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
  892. // ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
  893. // if (!one_silence_class) // old behavior
  894. // frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
  895. // else
  896. // frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
  897. // } else {
  898. // if (!one_silence_class) // old behavior
  899. // frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
  900. // else
  901. // frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
  902. // }
  903. // }
  904. // double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]);
  905. // alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc);
  906. // }
  907. // Weight f = lat.Final(s);
  908. // if (f != Weight::Zero()) {
  909. // double final_like = this_alpha - (f.Value1() + f.Value2());
  910. // double arc_scale = Exp(final_like - tot_forward_prob);
  911. // tot_forward_score += arc_scale * alpha_smbr[s];
  912. // KALDI_ASSERT(state_times[s] == max_time &&
  913. // "Lattice is inconsistent (final-prob not at max_time)");
  914. // }
  915. // }
  916. // // Second Pass Backward, collect Mpe style posteriors
  917. // for (StateId s = num_states-1; s >= 0; s--) {
  918. // for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
  919. // const Arc &arc = aiter.Value();
  920. // double arc_like = -ConvertToCost(arc.weight),
  921. // arc_beta = beta[arc.nextstate] + arc_like;
  922. // double frame_acc = 0.0;
  923. // int32 transition_id = arc.ilabel;
  924. // if (arc.ilabel != 0) {
  925. // int32 cur_time = state_times[s];
  926. // int32 phone = trans.TransitionIdToPhone(arc.ilabel),
  927. // ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
  928. // bool phone_is_sil = std::binary_search(silence_phones.begin(),
  929. // silence_phones.end(), phone),
  930. // ref_phone_is_sil = std::binary_search(silence_phones.begin(),
  931. // silence_phones.end(),
  932. // ref_phone),
  933. // both_sil = phone_is_sil && ref_phone_is_sil;
  934. // if (!is_mpfe) { // smbr.
  935. // int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
  936. // ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
  937. // if (!one_silence_class) // old behavior
  938. // frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
  939. // else
  940. // frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
  941. // } else {
  942. // if (!one_silence_class) // old behavior
  943. // frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
  944. // else
  945. // frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
  946. // }
  947. // }
  948. // double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]);
  949. // // check arc_scale NAN,
  950. // // this is to prevent partial paths in Lattices
  951. // // i.e., paths don't survive to the final state
  952. // if (KALDI_ISNAN(arc_scale)) arc_scale = 0;
  953. // beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc);
  954. //
  955. // if (transition_id != 0) { // Arc has a transition-id on it [not
  956. // epsilon]
  957. // double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
  958. // double acc_diff = alpha_smbr[s] + frame_acc +
  959. // beta_smbr[arc.nextstate]
  960. // - tot_forward_score;
  961. // double posterior_smbr = posterior * acc_diff;
  962. // (*post)[state_times[s]].push_back(std::make_pair(transition_id,
  963. // static_cast<BaseFloat>(posterior_smbr)));
  964. // }
  965. // }
  966. // }
  967. //
  968. // //Second Pass Forward Backward check
  969. // double tot_backward_score = beta_smbr[0]; // Initial state id == 0
  970. // // may loose the condition somehow here 1e-5/1e-4
  971. // if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) {
  972. // KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
  973. // << ", while total backward score = " << tot_backward_score;
  974. // }
  975. //
  976. // // Output the computed posteriors
  977. // for (int32 t = 0; t < max_time; t++)
  978. // MergePairVectorSumming(&((*post)[t]));
  979. // return tot_forward_score;
  980. // }
  981. //
  982. // bool CompactLatticeToWordAlignment(const CompactLattice &clat,
  983. // std::vector<int32> *words,
  984. // std::vector<int32> *begin_times,
  985. // std::vector<int32> *lengths) {
  986. // words->clear();
  987. // begin_times->clear();
  988. // lengths->clear();
  989. // typedef CompactLattice::Arc Arc;
  990. // typedef Arc::Label Label;
  991. // typedef CompactLattice::StateId StateId;
  992. // typedef CompactLattice::Weight Weight;
  993. // using namespace fst;
  994. // StateId state = clat.Start();
  995. // int32 cur_time = 0;
  996. // if (state == kNoStateId) {
  997. // KALDI_WARN << "Empty lattice.";
  998. // return false;
  999. // }
  1000. // while (1) {
  1001. // Weight final = clat.Final(state);
  1002. // size_t num_arcs = clat.NumArcs(state);
  1003. // if (final != Weight::Zero()) {
  1004. // if (num_arcs != 0) {
  1005. // KALDI_WARN << "Lattice is not linear.";
  1006. // return false;
  1007. // }
  1008. // if (! final.String().empty()) {
  1009. // KALDI_WARN << "Lattice has alignments on final-weight: probably "
  1010. // "was not word-aligned (alignments will be approximate)";
  1011. // }
  1012. // return true;
  1013. // } else {
  1014. // if (num_arcs != 1) {
  1015. // KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
  1016. // return false;
  1017. // }
  1018. // fst::ArcIterator<CompactLattice> aiter(clat, state);
  1019. // const Arc &arc = aiter.Value();
  1020. // Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
  1021. // // Also note: word_id may be zero; we output it anyway.
  1022. // int32 length = arc.weight.String().size();
  1023. // words->push_back(word_id);
  1024. // begin_times->push_back(cur_time);
  1025. // lengths->push_back(length);
  1026. // cur_time += length;
  1027. // state = arc.nextstate;
  1028. // }
  1029. // }
  1030. // }
  1031. //
  1032. //
  1033. // bool CompactLatticeToWordProns(
  1034. // const TransitionModel &tmodel,
  1035. // const CompactLattice &clat,
  1036. // std::vector<int32> *words,
  1037. // std::vector<int32> *begin_times,
  1038. // std::vector<int32> *lengths,
  1039. // std::vector<std::vector<int32> > *prons,
  1040. // std::vector<std::vector<int32> > *phone_lengths) {
  1041. // words->clear();
  1042. // begin_times->clear();
  1043. // lengths->clear();
  1044. // prons->clear();
  1045. // phone_lengths->clear();
  1046. // typedef CompactLattice::Arc Arc;
  1047. // typedef Arc::Label Label;
  1048. // typedef CompactLattice::StateId StateId;
  1049. // typedef CompactLattice::Weight Weight;
  1050. // using namespace fst;
  1051. // StateId state = clat.Start();
  1052. // int32 cur_time = 0;
  1053. // if (state == kNoStateId) {
  1054. // KALDI_WARN << "Empty lattice.";
  1055. // return false;
  1056. // }
  1057. // while (1) {
  1058. // Weight final = clat.Final(state);
  1059. // size_t num_arcs = clat.NumArcs(state);
  1060. // if (final != Weight::Zero()) {
  1061. // if (num_arcs != 0) {
  1062. // KALDI_WARN << "Lattice is not linear.";
  1063. // return false;
  1064. // }
  1065. // if (! final.String().empty()) {
  1066. // KALDI_WARN << "Lattice has alignments on final-weight: probably "
  1067. // "was not word-aligned (alignments will be approximate)";
  1068. // }
  1069. // return true;
  1070. // } else {
  1071. // if (num_arcs != 1) {
  1072. // KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
  1073. // return false;
  1074. // }
  1075. // fst::ArcIterator<CompactLattice> aiter(clat, state);
  1076. // const Arc &arc = aiter.Value();
  1077. // Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
  1078. // // Also note: word_id may be zero; we output it anyway.
  1079. // int32 length = arc.weight.String().size();
  1080. // words->push_back(word_id);
  1081. // begin_times->push_back(cur_time);
  1082. // lengths->push_back(length);
  1083. // const std::vector<int32> &arc_alignment = arc.weight.String();
  1084. // std::vector<std::vector<int32> > split_alignment;
  1085. // SplitToPhones(tmodel, arc_alignment, &split_alignment);
  1086. // std::vector<int32> phones(split_alignment.size());
  1087. // std::vector<int32> plengths(split_alignment.size());
  1088. // for (size_t i = 0; i < split_alignment.size(); i++) {
  1089. // KALDI_ASSERT(!split_alignment[i].empty());
  1090. // phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]);
  1091. // plengths[i] = split_alignment[i].size();
  1092. // }
  1093. // prons->push_back(phones);
  1094. // phone_lengths->push_back(plengths);
  1095. //
  1096. // cur_time += length;
  1097. // state = arc.nextstate;
  1098. // }
  1099. // }
  1100. // }
  1101. //
  1102. //
  1103. //
  1104. // void CompactLatticeShortestPath(const CompactLattice &clat,
  1105. // CompactLattice *shortest_path) {
  1106. // using namespace fst;
  1107. // if (clat.Properties(fst::kTopSorted, true) == 0) {
  1108. // CompactLattice clat_copy(clat);
  1109. // if (!TopSort(&clat_copy))
  1110. // KALDI_ERR << "Was not able to topologically sort lattice (cycles
  1111. // found?)";
  1112. // CompactLatticeShortestPath(clat_copy, shortest_path);
  1113. // return;
  1114. // }
  1115. // // Now we can assume it's topologically sorted.
  1116. // shortest_path->DeleteStates();
  1117. // if (clat.Start() == kNoStateId) return;
  1118. // typedef CompactLatticeArc Arc;
  1119. // typedef Arc::StateId StateId;
  1120. // typedef CompactLatticeWeight Weight;
  1121. // vector<std::pair<double, StateId> > best_cost_and_pred(clat.NumStates() +
  1122. // 1); StateId superfinal = clat.NumStates(); for (StateId s = 0; s <=
  1123. // clat.NumStates(); s++) {
  1124. // best_cost_and_pred[s].first = std::numeric_limits<double>::infinity();
  1125. // best_cost_and_pred[s].second = fst::kNoStateId;
  1126. // }
  1127. // best_cost_and_pred[clat.Start()].first = 0;
  1128. // for (StateId s = 0; s < clat.NumStates(); s++) {
  1129. // double my_cost = best_cost_and_pred[s].first;
  1130. // for (ArcIterator<CompactLattice> aiter(clat, s);
  1131. // !aiter.Done();
  1132. // aiter.Next()) {
  1133. // const Arc &arc = aiter.Value();
  1134. // double arc_cost = ConvertToCost(arc.weight),
  1135. // next_cost = my_cost + arc_cost;
  1136. // if (next_cost < best_cost_and_pred[arc.nextstate].first) {
  1137. // best_cost_and_pred[arc.nextstate].first = next_cost;
  1138. // best_cost_and_pred[arc.nextstate].second = s;
  1139. // }
  1140. // }
  1141. // double final_cost = ConvertToCost(clat.Final(s)),
  1142. // tot_final = my_cost + final_cost;
  1143. // if (tot_final < best_cost_and_pred[superfinal].first) {
  1144. // best_cost_and_pred[superfinal].first = tot_final;
  1145. // best_cost_and_pred[superfinal].second = s;
  1146. // }
  1147. // }
  1148. // std::vector<StateId> states; // states on best path.
  1149. // StateId cur_state = superfinal, start_state = clat.Start();
  1150. // while (cur_state != start_state) {
  1151. // StateId prev_state = best_cost_and_pred[cur_state].second;
  1152. // if (prev_state == kNoStateId) {
  1153. // KALDI_WARN << "Failure in best-path algorithm for lattice (infinite
  1154. // costs?)"; return; // return empty best-path.
  1155. // }
  1156. // states.push_back(prev_state);
  1157. // KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles");
  1158. // cur_state = prev_state;
  1159. // }
  1160. // std::reverse(states.begin(), states.end());
  1161. // for (size_t i = 0; i < states.size(); i++)
  1162. // shortest_path->AddState();
  1163. // for (StateId s = 0; static_cast<size_t>(s) < states.size(); s++) {
  1164. // if (s == 0) shortest_path->SetStart(s);
  1165. // if (static_cast<size_t>(s + 1) < states.size()) { // transition to next
  1166. // state.
  1167. // bool have_arc = false;
  1168. // Arc cur_arc;
  1169. // for (ArcIterator<CompactLattice> aiter(clat, states[s]);
  1170. // !aiter.Done();
  1171. // aiter.Next()) {
  1172. // const Arc &arc = aiter.Value();
  1173. // if (arc.nextstate == states[s+1]) {
  1174. // if (!have_arc ||
  1175. // ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) {
  1176. // cur_arc = arc;
  1177. // have_arc = true;
  1178. // }
  1179. // }
  1180. // }
  1181. // KALDI_ASSERT(have_arc && "Code error.");
  1182. // shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel,
  1183. // cur_arc.weight, s+1));
  1184. // } else { // final-prob.
  1185. // shortest_path->SetFinal(s, clat.Final(states[s]));
  1186. // }
  1187. // }
  1188. // }
  1189. //
  1190. //
  1191. // void ExpandCompactLattice(const CompactLattice &clat,
  1192. // double epsilon,
  1193. // CompactLattice *expand_clat) {
  1194. // using namespace fst;
  1195. // typedef CompactLattice::Arc Arc;
  1196. // typedef Arc::Weight Weight;
  1197. // typedef Arc::StateId StateId;
  1198. // typedef std::pair<StateId, StateId> StatePair;
  1199. // typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
  1200. // typedef MapType::iterator IterType;
  1201. //
  1202. // if (clat.Start() == kNoStateId) return;
  1203. // // Make sure the input lattice is topologically sorted.
  1204. // if (clat.Properties(kTopSorted, true) == 0) {
  1205. // CompactLattice clat_copy(clat);
  1206. // KALDI_LOG << "Topsort this lattice.";
  1207. // if (!TopSort(&clat_copy))
  1208. // KALDI_ERR << "Was not able to topologically sort lattice (cycles
  1209. // found?)";
  1210. // ExpandCompactLattice(clat_copy, epsilon, expand_clat);
  1211. // return;
  1212. // }
  1213. //
  1214. // // Compute backward logprobs betas for the expanded lattice.
  1215. // // Note: the backward logprobs in the original lattice <clat> and the
  1216. // // expanded lattice <expand_clat> are the same.
  1217. // int32 num_states = clat.NumStates();
  1218. // std::vector<double> beta(num_states, kLogZeroDouble);
  1219. // ComputeCompactLatticeBetas(clat, &beta);
  1220. // double tot_backward_logprob = beta[0];
  1221. // std::vector<double> alpha;
  1222. // alpha.push_back(0.0);
  1223. // expand_clat->DeleteStates();
  1224. // MapType state_map; // Map from state pair (orig_state, copy_state) to
  1225. // // copy_state, where orig_state is a state in the original lattice, and
  1226. // // copy_state is its corresponding one in the expanded lattice.
  1227. // unordered_map<StateId, StateId> states; // Map from orig_state to its
  1228. // // copy_state for states with incoming arcs' posteriors <= epsilon.
  1229. // std::queue<StatePair> state_queue;
  1230. //
  1231. // // Set start state in the expanded lattice.
  1232. // StateId start_state = expand_clat->AddState();
  1233. // expand_clat->SetStart(start_state);
  1234. // StatePair start_pair(clat.Start(), start_state);
  1235. // state_queue.push(start_pair);
  1236. // std::pair<IterType, bool> result =
  1237. // state_map.insert(std::make_pair(start_pair, start_state));
  1238. // KALDI_ASSERT(result.second == true);
  1239. //
  1240. // // Expand <clat> and update forward logprobs alphas in <expand_clat>.
  1241. // while (!state_queue.empty()) {
  1242. // StatePair s = state_queue.front();
  1243. // StateId s1 = s.first,
  1244. // s2 = s.second;
  1245. // state_queue.pop();
  1246. //
  1247. // Weight f = clat.Final(s1);
  1248. // if (f != Weight::Zero()) {
  1249. // KALDI_ASSERT(state_map.find(s) != state_map.end());
  1250. // expand_clat->SetFinal(state_map[s], f);
  1251. // }
  1252. //
  1253. // for (ArcIterator<CompactLattice> aiter(clat, s1);
  1254. // !aiter.Done(); aiter.Next()) {
  1255. // const Arc &arc = aiter.Value();
  1256. // StateId orig_state = arc.nextstate;
  1257. // double arc_like = -ConvertToCost(arc.weight),
  1258. // this_alpha = alpha[s2] + arc_like,
  1259. // arc_post = Exp(this_alpha + beta[orig_state] -
  1260. // tot_backward_logprob);
  1261. // // Generate the expanded lattice.
  1262. // StateId copy_state;
  1263. // if (arc_post > epsilon) {
  1264. // copy_state = expand_clat->AddState();
  1265. // StatePair next_pair(orig_state, copy_state);
  1266. // std::pair<IterType, bool> result =
  1267. // state_map.insert(std::make_pair(next_pair, copy_state));
  1268. // KALDI_ASSERT(result.second == true);
  1269. // state_queue.push(next_pair);
  1270. // } else {
  1271. // unordered_map<StateId, StateId>::iterator iter =
  1272. // states.find(orig_state); if (iter == states.end() ) { // The
  1273. // counterpart state of orig_state
  1274. // // has not been created in <expand_clat>
  1275. // yet.
  1276. // copy_state = expand_clat->AddState();
  1277. // StatePair next_pair(orig_state, copy_state);
  1278. // std::pair<IterType, bool> result =
  1279. // state_map.insert(std::make_pair(next_pair, copy_state));
  1280. // KALDI_ASSERT(result.second == true);
  1281. // state_queue.push(next_pair);
  1282. // states[orig_state] = copy_state;
  1283. // } else {
  1284. // copy_state = iter->second;
  1285. // }
  1286. // }
  1287. // // Create an arc from state_map[s] to copy_state in the expanded
  1288. // lattice. expand_clat->AddArc(state_map[s], Arc(arc.ilabel, arc.olabel,
  1289. // arc.weight,
  1290. // copy_state));
  1291. // // Compute forward logprobs alpha for the expanded lattice.
  1292. // if ((alpha.size() - 1) < copy_state) { // The first time to compute
  1293. // alpha
  1294. // // for copy_state in
  1295. // <expand_clat>.
  1296. // alpha.push_back(this_alpha);
  1297. // } else { // Accumulate alpha.
  1298. // alpha[copy_state] = LogAdd(alpha[copy_state], this_alpha);
  1299. // }
  1300. // }
  1301. // } // end while
  1302. // }
  1303. //
  1304. //
  1305. // void CompactLatticeBestCostsAndTracebacks(
  1306. // const CompactLattice &clat,
  1307. // CostTraceType *forward_best_cost_and_pred,
  1308. // CostTraceType *backward_best_cost_and_pred) {
  1309. //
  1310. // // typedef the arc, weight types
  1311. // typedef CompactLatticeArc Arc;
  1312. // typedef Arc::Weight Weight;
  1313. // typedef Arc::StateId StateId;
  1314. //
  1315. // forward_best_cost_and_pred->clear();
  1316. // backward_best_cost_and_pred->clear();
  1317. // forward_best_cost_and_pred->resize(clat.NumStates());
  1318. // backward_best_cost_and_pred->resize(clat.NumStates());
  1319. // // Initialize the cost and predecessor state for each state.
  1320. // for (StateId s = 0; s < clat.NumStates(); s++) {
  1321. // (*forward_best_cost_and_pred)[s].first =
  1322. // std::numeric_limits<double>::infinity();
  1323. // (*backward_best_cost_and_pred)[s].first =
  1324. // std::numeric_limits<double>::infinity();
  1325. // (*forward_best_cost_and_pred)[s].second = fst::kNoStateId;
  1326. // (*backward_best_cost_and_pred)[s].second = fst::kNoStateId;
  1327. // }
  1328. //
  1329. // StateId start_state = clat.Start();
  1330. // (*forward_best_cost_and_pred)[start_state].first = 0;
  1331. // // Transverse the lattice forwardly to compute the best cost from the start
  1332. // // state to each state and the best predecessor state of each state.
  1333. // for (StateId s = 0; s < clat.NumStates(); s++) {
  1334. // double cur_cost = (*forward_best_cost_and_pred)[s].first;
  1335. // for (fst::ArcIterator<CompactLattice> aiter(clat, s);
  1336. // !aiter.Done(); aiter.Next()) {
  1337. // const Arc &arc = aiter.Value();
  1338. // double next_cost = cur_cost + ConvertToCost(arc.weight);
  1339. // if (next_cost < (*forward_best_cost_and_pred)[arc.nextstate].first) {
  1340. // (*forward_best_cost_and_pred)[arc.nextstate].first = next_cost;
  1341. // (*forward_best_cost_and_pred)[arc.nextstate].second = s;
  1342. // }
  1343. // }
  1344. // }
  1345. // // Transverse the lattice backwardly to compute the best cost from a final
  1346. // // state to each state and the best predecessor state of each state.
  1347. // for (StateId s = clat.NumStates() - 1; s >= 0; s--) {
  1348. // double this_cost = ConvertToCost(clat.Final(s));
  1349. // for (fst::ArcIterator<CompactLattice> aiter(clat, s);
  1350. // !aiter.Done(); aiter.Next()) {
  1351. // const Arc &arc = aiter.Value();
  1352. // double next_cost = (*backward_best_cost_and_pred)[arc.nextstate].first
  1353. // +
  1354. // ConvertToCost(arc.weight);
  1355. // if (next_cost < this_cost) {
  1356. // this_cost = next_cost;
  1357. // (*backward_best_cost_and_pred)[s].second = arc.nextstate;
  1358. // }
  1359. // }
  1360. // (*backward_best_cost_and_pred)[s].first = this_cost;
  1361. // }
  1362. // }
  1363. //
  1364. //
  1365. // void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
  1366. // CompactLattice *clat) {
  1367. // if (clat->Start() == fst::kNoStateId) return;
  1368. // // Make sure the input lattice is topologically sorted.
  1369. // if (clat->Properties(fst::kTopSorted, true) == 0) {
  1370. // KALDI_LOG << "Topsort this lattice.";
  1371. // if (!TopSort(clat))
  1372. // KALDI_ERR << "Was not able to topologically sort lattice (cycles
  1373. // found?)";
  1374. // AddNnlmScoreToCompactLattice(nnlm_scores, clat);
  1375. // return;
  1376. // }
  1377. //
  1378. // // typedef the arc, weight types
  1379. // typedef CompactLatticeArc Arc;
  1380. // typedef Arc::Weight Weight;
  1381. // typedef Arc::StateId StateId;
  1382. // typedef std::pair<int32, int32> StatePair;
  1383. //
  1384. // int32 num_states = clat->NumStates();
  1385. // unordered_map<StatePair, bool, PairHasher<int32> > final_state_check;
  1386. // for (StateId s = 0; s < num_states; s++) {
  1387. // for (fst::MutableArcIterator<CompactLattice> aiter(clat, s);
  1388. // !aiter.Done(); aiter.Next()) {
  1389. // Arc arc(aiter.Value());
  1390. // StatePair arc_index = std::make_pair(static_cast<int32>(s),
  1391. // static_cast<int32>(arc.nextstate));
  1392. // MapT::const_iterator it = nnlm_scores.find(arc_index);
  1393. // double nnlm_score;
  1394. // if (it != nnlm_scores.end())
  1395. // nnlm_score = it->second;
  1396. // else
  1397. // KALDI_ERR << "Some arc does not have neural language model score.";
  1398. // if (arc.ilabel != 0) { // if there is a word on this arc
  1399. // LatticeWeight weight = arc.weight.Weight();
  1400. // // Add associated neural LM score to each arc.
  1401. // weight.SetValue1(weight.Value1() + nnlm_score);
  1402. // arc.weight.SetWeight(weight);
  1403. // aiter.SetValue(arc);
  1404. // }
  1405. // Weight clat_final = clat->Final(arc.nextstate);
  1406. // StatePair final_pair = std::make_pair(arc.nextstate, arc.nextstate);
  1407. // // Add neural LM scores to each final state only once.
  1408. // if (clat_final != CompactLatticeWeight::Zero() &&
  1409. // final_state_check.find(final_pair) == final_state_check.end()) {
  1410. // MapT::const_iterator final_it = nnlm_scores.find(final_pair);
  1411. // double final_nnlm_score = 0.0;
  1412. // if (final_it != nnlm_scores.end())
  1413. // final_nnlm_score = final_it->second;
  1414. // // Add neural LM scores to the final weight.
  1415. // Weight final_weight(LatticeWeight(clat_final.Weight().Value1() +
  1416. // final_nnlm_score,
  1417. // clat_final.Weight().Value2()),
  1418. // clat_final.String());
  1419. // clat->SetFinal(arc.nextstate, final_weight);
  1420. // final_state_check[final_pair] = true;
  1421. // }
  1422. // } // end looping over arcs
  1423. // } // end looping over states
  1424. // }
  1425. //
  1426. // void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
  1427. // CompactLattice *clat) {
  1428. // typedef CompactLatticeArc Arc;
  1429. // int32 num_states = clat->NumStates();
  1430. //
  1431. // //scan the lattice
  1432. // for (int32 state = 0; state < num_states; state++) {
  1433. // for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
  1434. // !aiter.Done(); aiter.Next()) {
  1435. //
  1436. // Arc arc(aiter.Value());
  1437. //
  1438. // if (arc.ilabel != 0) { // if there is a word on this arc
  1439. // LatticeWeight weight = arc.weight.Weight();
  1440. // // add word insertion penalty to lattice
  1441. // weight.SetValue1( weight.Value1() + word_ins_penalty);
  1442. // arc.weight.SetWeight(weight);
  1443. // aiter.SetValue(arc);
  1444. // }
  1445. // } // end looping over arcs
  1446. // } // end looping over states
  1447. // }
  1448. //
  1449. // struct ClatRescoreTuple {
  1450. // ClatRescoreTuple(int32 state, int32 arc, int32 tid):
  1451. // state_id(state), arc_id(arc), tid(tid) { }
  1452. // int32 state_id;
  1453. // int32 arc_id;
  1454. // int32 tid;
  1455. // };
  1456. //
  1457. // /** RescoreCompactLatticeInternal is the internal code for both
  1458. // RescoreCompactLattice and RescoreCompatLatticeSpeedup. For
  1459. // RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will
  1460. // be 1.0.
  1461. // */
  1462. // bool RescoreCompactLatticeInternal(
  1463. // const TransitionModel *tmodel,
  1464. // BaseFloat speedup_factor,
  1465. // DecodableInterface *decodable,
  1466. // CompactLattice *clat) {
  1467. // KALDI_ASSERT(speedup_factor >= 1.0);
  1468. // if (clat->NumStates() == 0) {
  1469. // KALDI_WARN << "Rescoring empty lattice";
  1470. // return false;
  1471. // }
  1472. // if (!clat->Properties(fst::kTopSorted, true)) {
  1473. // if (fst::TopSort(clat) == false) {
  1474. // KALDI_WARN << "Cycles detected in lattice.";
  1475. // return false;
  1476. // }
  1477. // }
  1478. // std::vector<int32> state_times;
  1479. // int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times);
  1480. //
  1481. // std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len);
  1482. //
  1483. // int32 num_states = clat->NumStates();
  1484. // KALDI_ASSERT(num_states == state_times.size());
  1485. // for (size_t state = 0; state < num_states; state++) {
  1486. // KALDI_ASSERT(state_times[state] >= 0);
  1487. // int32 t = state_times[state];
  1488. // int32 arc_id = 0;
  1489. // for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
  1490. // !aiter.Done(); aiter.Next(), arc_id++) {
  1491. // CompactLatticeArc arc = aiter.Value();
  1492. // std::vector<int32> arc_string = arc.weight.String();
  1493. //
  1494. // for (size_t offset = 0; offset < arc_string.size(); offset++) {
  1495. // if (t < utt_len) { // end state may be past this..
  1496. // int32 tid = arc_string[offset];
  1497. // time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id,
  1498. // tid));
  1499. // } else {
  1500. // if (t != utt_len) {
  1501. // KALDI_WARN << "There appears to be lattice/feature mismatch, "
  1502. // << "aborting.";
  1503. // return false;
  1504. // }
  1505. // }
  1506. // }
  1507. // }
  1508. // if (clat->Final(state) != CompactLatticeWeight::Zero()) {
  1509. // arc_id = -1;
  1510. // std::vector<int32> arc_string = clat->Final(state).String();
  1511. // for (size_t offset = 0; offset < arc_string.size(); offset++) {
  1512. // KALDI_ASSERT(t + offset < utt_len); // already checked in
  1513. // // CompactLatticeStateTimes, so would be code error.
  1514. // time_to_state[t+offset].push_back(
  1515. // ClatRescoreTuple(state, arc_id, arc_string[offset]));
  1516. // }
  1517. // }
  1518. // }
  1519. //
  1520. // for (int32 t = 0; t < utt_len; t++) {
  1521. // if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
  1522. // KALDI_WARN << "Features are too short for lattice: utt-len is "
  1523. // << utt_len << ", " << t << " is last frame";
  1524. // return false;
  1525. // }
  1526. // // frame_scale is the scale we put on the computed acoustic probs for
  1527. // this
  1528. // // frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not
  1529. // doing
  1530. // // the "speedup" code). For frames with multiple pdf-ids it will be one.
  1531. // // For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
  1532. // // with probability 1.0 / speedup_factor, and zero otherwise. If it is
  1533. // zero,
  1534. // // we can avoid computing the probabilities.
  1535. // BaseFloat frame_scale = 1.0;
  1536. // KALDI_ASSERT(!time_to_state[t].empty());
  1537. // if (tmodel != NULL) {
  1538. // int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid);
  1539. // bool frame_has_multiple_pdfs = false;
  1540. // for (size_t i = 1; i < time_to_state[t].size(); i++) {
  1541. // if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) {
  1542. // frame_has_multiple_pdfs = true;
  1543. // break;
  1544. // }
  1545. // }
  1546. // if (frame_has_multiple_pdfs) {
  1547. // frame_scale = 1.0;
  1548. // } else {
  1549. // if (WithProb(1.0 / speedup_factor)) {
  1550. // frame_scale = speedup_factor;
  1551. // } else {
  1552. // frame_scale = 0.0;
  1553. // }
  1554. // }
  1555. // if (frame_scale == 0.0)
  1556. // continue; // the code below would be pointless.
  1557. // }
  1558. //
  1559. // for (size_t i = 0; i < time_to_state[t].size(); i++) {
  1560. // int32 state = time_to_state[t][i].state_id;
  1561. // int32 arc_id = time_to_state[t][i].arc_id;
  1562. // int32 tid = time_to_state[t][i].tid;
  1563. //
  1564. // if (arc_id == -1) { // Final state
  1565. // // Access the trans_id
  1566. // CompactLatticeWeight curr_clat_weight = clat->Final(state);
  1567. //
  1568. // // Calculate likelihood
  1569. // BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
  1570. // // update weight
  1571. // CompactLatticeWeight new_clat_weight = curr_clat_weight;
  1572. // LatticeWeight new_lat_weight = new_clat_weight.Weight();
  1573. // new_lat_weight.SetValue2(-log_like +
  1574. // curr_clat_weight.Weight().Value2());
  1575. // new_clat_weight.SetWeight(new_lat_weight);
  1576. // clat->SetFinal(state, new_clat_weight);
  1577. // } else {
  1578. // fst::MutableArcIterator<CompactLattice> aiter(clat, state);
  1579. //
  1580. // aiter.Seek(arc_id);
  1581. // CompactLatticeArc arc = aiter.Value();
  1582. //
  1583. // // Calculate likelihood
  1584. // BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
  1585. // // update weight
  1586. // LatticeWeight new_weight = arc.weight.Weight();
  1587. // new_weight.SetValue2(-log_like + arc.weight.Weight().Value2());
  1588. // arc.weight.SetWeight(new_weight);
  1589. // aiter.SetValue(arc);
  1590. // }
  1591. // }
  1592. // }
  1593. // return true;
  1594. // }
  1595. //
  1596. //
  1597. // bool RescoreCompactLatticeSpeedup(
  1598. // const TransitionModel &tmodel,
  1599. // BaseFloat speedup_factor,
  1600. // DecodableInterface *decodable,
  1601. // CompactLattice *clat) {
  1602. // return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable,
  1603. // clat);
  1604. // }
  1605. //
  1606. // bool RescoreCompactLattice(DecodableInterface *decodable,
  1607. // CompactLattice *clat) {
  1608. // return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat);
  1609. // }
  1610. //
  1611. //
  1612. // bool RescoreLattice(DecodableInterface *decodable,
  1613. // Lattice *lat) {
  1614. // if (lat->NumStates() == 0) {
  1615. // KALDI_WARN << "Rescoring empty lattice";
  1616. // return false;
  1617. // }
  1618. // if (!lat->Properties(fst::kTopSorted, true)) {
  1619. // if (fst::TopSort(lat) == false) {
  1620. // KALDI_WARN << "Cycles detected in lattice.";
  1621. // return false;
  1622. // }
  1623. // }
  1624. // std::vector<int32> state_times;
  1625. // int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times);
  1626. //
  1627. // std::vector<std::vector<int32> > time_to_state(utt_len );
  1628. //
  1629. // int32 num_states = lat->NumStates();
  1630. // KALDI_ASSERT(num_states == state_times.size());
  1631. // for (size_t state = 0; state < num_states; state++) {
  1632. // int32 t = state_times[state];
  1633. // // Don't check t >= 0 because non-accessible states could have t = -1.
  1634. // KALDI_ASSERT(t <= utt_len);
  1635. // if (t >= 0 && t < utt_len)
  1636. // time_to_state[t].push_back(state);
  1637. // }
  1638. //
  1639. // for (int32 t = 0; t < utt_len; t++) {
  1640. // if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
  1641. // KALDI_WARN << "Features are too short for lattice: utt-len is "
  1642. // << utt_len << ", " << t << " is last frame";
  1643. // return false;
  1644. // }
  1645. // for (size_t i = 0; i < time_to_state[t].size(); i++) {
  1646. // int32 state = time_to_state[t][i];
  1647. // for (fst::MutableArcIterator<Lattice> aiter(lat, state);
  1648. // !aiter.Done(); aiter.Next()) {
  1649. // LatticeArc arc = aiter.Value();
  1650. // if (arc.ilabel != 0) {
  1651. // int32 trans_id = arc.ilabel; // Note: it doesn't necessarily
  1652. // // have to be a transition-id, just whatever the Decodable
  1653. // // object is expecting, but it's normally a transition-id.
  1654. //
  1655. // BaseFloat log_like = decodable->LogLikelihood(t, trans_id);
  1656. // arc.weight.SetValue2(-log_like + arc.weight.Value2());
  1657. // aiter.SetValue(arc);
  1658. // }
  1659. // }
  1660. // }
  1661. // }
  1662. // return true;
  1663. // }
  1664. //
  1665. //
  1666. // BaseFloat LatticeForwardBackwardMmi(
  1667. // const TransitionModel &tmodel,
  1668. // const Lattice &lat,
  1669. // const std::vector<int32> &num_ali,
  1670. // bool drop_frames,
  1671. // bool convert_to_pdf_ids,
  1672. // bool cancel,
  1673. // Posterior *post) {
  1674. // // First compute the MMI posteriors.
  1675. //
  1676. // Posterior den_post;
  1677. // BaseFloat ans = LatticeForwardBackward(lat,
  1678. // &den_post,
  1679. // NULL);
  1680. //
  1681. // Posterior num_post;
  1682. // AlignmentToPosterior(num_ali, &num_post);
  1683. //
  1684. // // Now negate the MMI posteriors and add the numerator
  1685. // // posteriors.
  1686. // ScalePosterior(-1.0, &den_post);
  1687. //
  1688. // if (convert_to_pdf_ids) {
  1689. // Posterior num_tmp;
  1690. // ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp);
  1691. // num_tmp.swap(num_post);
  1692. // Posterior den_tmp;
  1693. // ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp);
  1694. // den_tmp.swap(den_post);
  1695. // }
  1696. //
  1697. // MergePosteriors(num_post, den_post,
  1698. // cancel, drop_frames, post);
  1699. //
  1700. // return ans;
  1701. // }
  1702. //
  1703. //
  1704. // int32 LongestSentenceLength(const Lattice &lat) {
  1705. // typedef Lattice::Arc Arc;
  1706. // typedef Arc::Label Label;
  1707. // typedef Arc::StateId StateId;
  1708. //
  1709. // if (lat.Properties(fst::kTopSorted, true) == 0) {
  1710. // Lattice lat_copy(lat);
  1711. // if (!TopSort(&lat_copy))
  1712. // KALDI_ERR << "Was not able to topologically sort lattice (cycles
  1713. // found?)";
  1714. // return LongestSentenceLength(lat_copy);
  1715. // }
  1716. // std::vector<int32> max_length(lat.NumStates(), 0);
  1717. // int32 lattice_max_length = 0;
  1718. // for (StateId s = 0; s < lat.NumStates(); s++) {
  1719. // int32 this_max_length = max_length[s];
  1720. // for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done();
  1721. // aiter.Next()) {
  1722. // const Arc &arc = aiter.Value();
  1723. // bool arc_has_word = (arc.olabel != 0);
  1724. // StateId nextstate = arc.nextstate;
  1725. // KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
  1726. // if (arc_has_word) {
  1727. // // A lattice should ideally not have cycles anyway; a cycle with a
  1728. // word
  1729. // // on is something very bad.
  1730. // KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on.");
  1731. // max_length[nextstate] = std::max(max_length[nextstate],
  1732. // this_max_length + 1);
  1733. // } else {
  1734. // max_length[nextstate] = std::max(max_length[nextstate],
  1735. // this_max_length);
  1736. // }
  1737. // }
  1738. // if (lat.Final(s) != LatticeWeight::Zero())
  1739. // lattice_max_length = std::max(lattice_max_length, max_length[s]);
  1740. // }
  1741. // return lattice_max_length;
  1742. // }
  1743. //
  1744. // int32 LongestSentenceLength(const CompactLattice &clat) {
  1745. // typedef CompactLattice::Arc Arc;
  1746. // typedef Arc::Label Label;
  1747. // typedef Arc::StateId StateId;
  1748. //
  1749. // if (clat.Properties(fst::kTopSorted, true) == 0) {
  1750. // CompactLattice clat_copy(clat);
  1751. // if (!TopSort(&clat_copy))
  1752. // KALDI_ERR << "Was not able to topologically sort lattice (cycles
  1753. // found?)";
  1754. // return LongestSentenceLength(clat_copy);
  1755. // }
  1756. // std::vector<int32> max_length(clat.NumStates(), 0);
  1757. // int32 lattice_max_length = 0;
  1758. // for (StateId s = 0; s < clat.NumStates(); s++) {
  1759. // int32 this_max_length = max_length[s];
  1760. // for (fst::ArcIterator<CompactLattice> aiter(clat, s);
  1761. // !aiter.Done(); aiter.Next()) {
  1762. // const Arc &arc = aiter.Value();
  1763. // bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel.
  1764. // // also note: for normal CompactLattice, e.g. as produced by
  1765. // // determinization, all arcs will have nonzero labels, but the user
  1766. // might
  1767. // // decide to remplace some of the labels with zero for some reason, and
  1768. // we
  1769. // // want to support this.
  1770. // StateId nextstate = arc.nextstate;
  1771. // KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
  1772. // KALDI_ASSERT(nextstate > s && "CompactLattice has cycles");
  1773. // if (arc_has_word)
  1774. // max_length[nextstate] = std::max(max_length[nextstate],
  1775. // this_max_length + 1);
  1776. // else
  1777. // max_length[nextstate] = std::max(max_length[nextstate],
  1778. // this_max_length);
  1779. // }
  1780. // if (clat.Final(s) != CompactLatticeWeight::Zero())
  1781. // lattice_max_length = std::max(lattice_max_length, max_length[s]);
  1782. // }
  1783. // return lattice_max_length;
  1784. // }
  1785. //
  1786. // void ComposeCompactLatticeDeterministic(
  1787. // const CompactLattice& clat,
  1788. // fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
  1789. // CompactLattice* composed_clat) {
  1790. // // StdFst::Arc and CompactLatticeArc has the same StateId type.
  1791. // typedef fst::StdArc::StateId StateId;
  1792. // typedef fst::StdArc::Weight Weight1;
  1793. // typedef CompactLatticeArc::Weight Weight2;
  1794. // typedef std::pair<StateId, StateId> StatePair;
  1795. // typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
  1796. // typedef MapType::iterator IterType;
  1797. //
  1798. // // Empties the output FST.
  1799. // KALDI_ASSERT(composed_clat != NULL);
  1800. // composed_clat->DeleteStates();
  1801. //
  1802. // MapType state_map;
  1803. // std::queue<StatePair> state_queue;
  1804. //
  1805. // // Sets start state in <composed_clat>.
  1806. // StateId start_state = composed_clat->AddState();
  1807. // StatePair start_pair(clat.Start(), det_fst->Start());
  1808. // composed_clat->SetStart(start_state);
  1809. // state_queue.push(start_pair);
  1810. // std::pair<IterType, bool> result =
  1811. // state_map.insert(std::make_pair(start_pair, start_state));
  1812. // KALDI_ASSERT(result.second == true);
  1813. //
  1814. // // Starts composition here.
  1815. // while (!state_queue.empty()) {
  1816. // // Gets the first state in the queue.
  1817. // StatePair s = state_queue.front();
  1818. // StateId s1 = s.first;
  1819. // StateId s2 = s.second;
  1820. // state_queue.pop();
  1821. //
  1822. //
  1823. // Weight2 clat_final = clat.Final(s1);
  1824. // if (clat_final.Weight().Value1() !=
  1825. // std::numeric_limits<BaseFloat>::infinity()) {
  1826. // // Test for whether the final-prob of state s1 was zero.
  1827. // Weight1 det_fst_final = det_fst->Final(s2);
  1828. // if (det_fst_final.Value() !=
  1829. // std::numeric_limits<BaseFloat>::infinity()) {
  1830. // // Test for whether the final-prob of state s2 was zero. If neither
  1831. // // source-state final prob was zero, then we should create final
  1832. // state
  1833. // // in fst_composed. We compute the product manually since this is
  1834. // more
  1835. // // efficient.
  1836. // Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() +
  1837. // det_fst_final.Value(),
  1838. // clat_final.Weight().Value2()),
  1839. // clat_final.String());
  1840. // // we can assume final_weight is not Zero(), since neither of
  1841. // // the sources was zero.
  1842. // KALDI_ASSERT(state_map.find(s) != state_map.end());
  1843. // composed_clat->SetFinal(state_map[s], final_weight);
  1844. // }
  1845. // }
  1846. //
  1847. // // Loops over pair of edges at s1 and s2.
  1848. // for (fst::ArcIterator<CompactLattice> aiter(clat, s1);
  1849. // !aiter.Done(); aiter.Next()) {
  1850. // const CompactLatticeArc& arc1 = aiter.Value();
  1851. // fst::StdArc arc2;
  1852. // StateId next_state1 = arc1.nextstate, next_state2;
  1853. // bool matched = false;
  1854. //
  1855. // if (arc1.olabel == 0) {
  1856. // // If the symbol on <arc1> is <epsilon>, we transit to the next state
  1857. // // for <clat>, but keep <det_fst> at the current state.
  1858. // matched = true;
  1859. // next_state2 = s2;
  1860. // } else {
  1861. // // Otherwise try to find the matched arc in <det_fst>.
  1862. // matched = det_fst->GetArc(s2, arc1.olabel, &arc2);
  1863. // if (matched) {
  1864. // next_state2 = arc2.nextstate;
  1865. // }
  1866. // }
  1867. //
  1868. // // If matched arc is found in <det_fst>, then we have to add new arcs
  1869. // to
  1870. // // <composed_clat>.
  1871. // if (matched) {
  1872. // StatePair next_state_pair(next_state1, next_state2);
  1873. // IterType siter = state_map.find(next_state_pair);
  1874. // StateId next_state;
  1875. //
  1876. // // Adds composed state to <state_map>.
  1877. // if (siter == state_map.end()) {
  1878. // // If the composed state has not been created yet, create it.
  1879. // next_state = composed_clat->AddState();
  1880. // std::pair<const StatePair, StateId> next_state_map(next_state_pair,
  1881. // next_state);
  1882. // std::pair<IterType, bool> result =
  1883. // state_map.insert(next_state_map); KALDI_ASSERT(result.second);
  1884. // state_queue.push(next_state_pair);
  1885. // } else {
  1886. // // If the composed state is already in <state_map>, we can directly
  1887. // // use that.
  1888. // next_state = siter->second;
  1889. // }
  1890. //
  1891. // // Adds arc to <composed_clat>.
  1892. // if (arc1.olabel == 0) {
  1893. // composed_clat->AddArc(state_map[s],
  1894. // CompactLatticeArc(arc1.ilabel, 0,
  1895. // arc1.weight, next_state));
  1896. // } else {
  1897. // Weight2 composed_weight(
  1898. // LatticeWeight(arc1.weight.Weight().Value1() +
  1899. // arc2.weight.Value(),
  1900. // arc1.weight.Weight().Value2()),
  1901. // arc1.weight.String());
  1902. // composed_clat->AddArc(state_map[s],
  1903. // CompactLatticeArc(arc1.ilabel, arc2.olabel,
  1904. // composed_weight,
  1905. // next_state));
  1906. // }
  1907. // }
  1908. // }
  1909. // }
  1910. // fst::Connect(composed_clat);
  1911. // }
  1912. //
  1913. //
  1914. // void ComputeAcousticScoresMap(
  1915. // const Lattice &lat,
  1916. // unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
  1917. // PairHasher<int32> > *acoustic_scores)
  1918. // {
  1919. // // typedef the arc, weight types
  1920. // typedef Lattice::Arc Arc;
  1921. // typedef Arc::Weight LatticeWeight;
  1922. // typedef Arc::StateId StateId;
  1923. //
  1924. // acoustic_scores->clear();
  1925. //
  1926. // std::vector<int32> state_times;
  1927. // LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted
  1928. //
  1929. // KALDI_ASSERT(lat.Start() == 0);
  1930. //
  1931. // for (StateId s = 0; s < lat.NumStates(); s++) {
  1932. // int32 t = state_times[s];
  1933. // for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done();
  1934. // aiter.Next()) {
  1935. // const Arc &arc = aiter.Value();
  1936. // const LatticeWeight &weight = arc.weight;
  1937. //
  1938. // int32 tid = arc.ilabel;
  1939. //
  1940. // if (tid != 0) {
  1941. // unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
  1942. // PairHasher<int32> >::iterator it =
  1943. // acoustic_scores->find(std::make_pair(t, tid));
  1944. // if (it == acoustic_scores->end()) {
  1945. // acoustic_scores->insert(std::make_pair(std::make_pair(t, tid),
  1946. // std::make_pair(weight.Value2(),
  1947. // 1)));
  1948. // } else {
  1949. // if (it->second.second == 2
  1950. // && it->second.first / it->second.second != weight.Value2()) {
  1951. // KALDI_VLOG(2) << "Transitions on the same frame have different "
  1952. // << "acoustic costs for tid " << tid << "; "
  1953. // << it->second.first / it->second.second
  1954. // << " vs " << weight.Value2();
  1955. // }
  1956. // it->second.first += weight.Value2();
  1957. // it->second.second++;
  1958. // }
  1959. // } else {
  1960. // // Arcs with epsilon input label (tid) must have 0 acoustic cost
  1961. // KALDI_ASSERT(weight.Value2() == 0);
  1962. // }
  1963. // }
  1964. //
  1965. // LatticeWeight f = lat.Final(s);
  1966. // if (f != LatticeWeight::Zero()) {
  1967. // // Final acoustic cost must be 0 as we are reading from
  1968. // // non-determinized, non-compact lattice
  1969. // KALDI_ASSERT(f.Value2() == 0.0);
  1970. // }
  1971. // }
  1972. // }
  1973. //
  1974. // void ReplaceAcousticScoresFromMap(
  1975. // const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
  1976. // PairHasher<int32> > &acoustic_scores,
  1977. // Lattice *lat) {
  1978. // // typedef the arc, weight types
  1979. // typedef Lattice::Arc Arc;
  1980. // typedef Arc::Weight LatticeWeight;
  1981. // typedef Arc::StateId StateId;
  1982. //
  1983. // TopSortLatticeIfNeeded(lat);
  1984. //
  1985. // std::vector<int32> state_times;
  1986. // LatticeStateTimes(*lat, &state_times);
  1987. //
  1988. // KALDI_ASSERT(lat->Start() == 0);
  1989. //
  1990. // for (StateId s = 0; s < lat->NumStates(); s++) {
  1991. // int32 t = state_times[s];
  1992. // for (fst::MutableArcIterator<Lattice> aiter(lat, s);
  1993. // !aiter.Done(); aiter.Next()) {
  1994. // Arc arc(aiter.Value());
  1995. //
  1996. // int32 tid = arc.ilabel;
  1997. // if (tid != 0) {
  1998. // unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
  1999. // PairHasher<int32> >::const_iterator it =
  2000. // acoustic_scores.find(std::make_pair(t, tid));
  2001. // if (it == acoustic_scores.end()) {
  2002. // KALDI_ERR << "Could not find tid " << tid << " at time " << t
  2003. // << " in the acoustic scores map.";
  2004. // } else {
  2005. // arc.weight.SetValue2(it->second.first / it->second.second);
  2006. // }
  2007. // } else {
  2008. // // For epsilon arcs, set acoustic cost to 0.0
  2009. // arc.weight.SetValue2(0.0);
  2010. // }
  2011. // aiter.SetValue(arc);
  2012. // }
  2013. //
  2014. // LatticeWeight f = lat->Final(s);
  2015. // if (f != LatticeWeight::Zero()) {
  2016. // // Set final acoustic cost to 0.0
  2017. // f.SetValue2(0.0);
  2018. // lat->SetFinal(s, f);
  2019. // }
  2020. // }
  2021. // }
  2022. } // namespace kaldi