You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

278 lines
10 KiB

  1. // decoder/lattice-faster-online-decoder.cc
  2. // Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
  3. // 2013-2014 Johns Hopkins University (Author: Daniel Povey)
  4. // 2014 Guoguo Chen
  5. // 2014 IMSL, PKU-HKUST (author: Wei Shi)
  6. // 2018 Zhehuai Chen
  7. // See ../../COPYING for clarification regarding multiple authors
  8. //
  9. // Licensed under the Apache License, Version 2.0 (the "License");
  10. // you may not use this file except in compliance with the License.
  11. // You may obtain a copy of the License at
  12. //
  13. // http://www.apache.org/licenses/LICENSE-2.0
  14. //
  15. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  16. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  17. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  18. // MERCHANTABLITY OR NON-INFRINGEMENT.
  19. // See the Apache 2 License for the specific language governing permissions and
  20. // limitations under the License.
  21. // see note at the top of lattice-faster-decoder.cc, about how to maintain this
  22. // file in sync with lattice-faster-decoder.cc
  23. #include <limits>
  24. #include <queue>
  25. #include <unordered_map>
  26. #include <utility>
  27. #include "decoder/lattice-faster-online-decoder.h"
  28. namespace kaldi {
  29. template <typename FST>
  30. bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
  31. bool use_final_probs) const {
  32. Lattice lat1;
  33. {
  34. Lattice raw_lat;
  35. this->GetRawLattice(&raw_lat, use_final_probs);
  36. ShortestPath(raw_lat, &lat1);
  37. }
  38. Lattice lat2;
  39. GetBestPath(&lat2, use_final_probs);
  40. BaseFloat delta = 0.1;
  41. int32 num_paths = 1;
  42. if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
  43. KALDI_WARN << "Best-path test failed";
  44. return false;
  45. } else {
  46. return true;
  47. }
  48. }
  49. // Outputs an FST corresponding to the single best path through the lattice.
  50. template <typename FST>
  51. bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(
  52. Lattice* olat, bool use_final_probs) const {
  53. olat->DeleteStates();
  54. BaseFloat final_graph_cost;
  55. BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
  56. if (iter.Done()) return false; // would have printed warning.
  57. StateId state = olat->AddState();
  58. olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
  59. while (!iter.Done()) {
  60. LatticeArc arc;
  61. iter = TraceBackBestPath(iter, &arc);
  62. arc.nextstate = state;
  63. StateId new_state = olat->AddState();
  64. olat->AddArc(new_state, arc);
  65. state = new_state;
  66. }
  67. olat->SetStart(state);
  68. return true;
  69. }
  70. template <typename FST>
  71. typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator
  72. LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
  73. bool use_final_probs, BaseFloat* final_cost_out) const {
  74. if (this->decoding_finalized_ && !use_final_probs)
  75. KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
  76. << "BestPathEnd() with use_final_probs == false";
  77. KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
  78. "You cannot call BestPathEnd if no frames were decoded.");
  79. unordered_map<Token*, BaseFloat> final_costs_local;
  80. const unordered_map<Token*, BaseFloat>& final_costs =
  81. (this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
  82. if (!this->decoding_finalized_ && use_final_probs)
  83. this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
  84. // Singly linked list of tokens on last frame (access list through "next"
  85. // pointer).
  86. BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
  87. BaseFloat best_final_cost = 0;
  88. Token* best_tok = NULL;
  89. for (Token* tok = this->active_toks_.back().toks; tok != NULL;
  90. tok = tok->next) {
  91. BaseFloat cost = tok->tot_cost, final_cost = 0.0;
  92. if (use_final_probs && !final_costs.empty()) {
  93. // if we are instructed to use final-probs, and any final tokens were
  94. // active on final frame, include the final-prob in the cost of the token.
  95. typename unordered_map<Token*, BaseFloat>::const_iterator iter =
  96. final_costs.find(tok);
  97. if (iter != final_costs.end()) {
  98. final_cost = iter->second;
  99. cost += final_cost;
  100. } else {
  101. cost = std::numeric_limits<BaseFloat>::infinity();
  102. }
  103. }
  104. if (cost < best_cost) {
  105. best_cost = cost;
  106. best_tok = tok;
  107. best_final_cost = final_cost;
  108. }
  109. }
  110. if (best_tok ==
  111. NULL) { // this should not happen, and is likely a code error or
  112. // caused by infinities in likelihoods, but I'm not making
  113. // it a fatal error for now.
  114. KALDI_WARN << "No final token found.";
  115. }
  116. if (final_cost_out) *final_cost_out = best_final_cost;
  117. return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
  118. }
  119. template <typename FST>
  120. typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator
  121. LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(BestPathIterator iter,
  122. LatticeArc* oarc) const {
  123. KALDI_ASSERT(!iter.Done() && oarc != NULL);
  124. Token* tok = static_cast<Token*>(iter.tok);
  125. int32 cur_t = iter.frame, step_t = 0;
  126. if (tok->backpointer != NULL) {
  127. // retrieve the correct forward link(with the best link cost)
  128. BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
  129. ForwardLinkT* link;
  130. for (link = tok->backpointer->links; link != NULL; link = link->next) {
  131. if (link->next_tok == tok) { // this is a link to "tok"
  132. BaseFloat graph_cost = link->graph_cost,
  133. acoustic_cost = link->acoustic_cost;
  134. BaseFloat cost = graph_cost + acoustic_cost;
  135. if (cost < best_cost) {
  136. oarc->ilabel = link->ilabel;
  137. oarc->olabel = link->olabel;
  138. if (link->ilabel != 0) {
  139. KALDI_ASSERT(static_cast<size_t>(cur_t) <
  140. this->cost_offsets_.size());
  141. acoustic_cost -= this->cost_offsets_[cur_t];
  142. step_t = -1;
  143. } else {
  144. step_t = 0;
  145. }
  146. oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
  147. best_cost = cost;
  148. }
  149. }
  150. }
  151. if (link == NULL &&
  152. best_cost ==
  153. std::numeric_limits<BaseFloat>::infinity()) { // Did not find
  154. // correct link.
  155. KALDI_ERR << "Error tracing best-path back (likely "
  156. << "bug in token-pruning algorithm)";
  157. }
  158. } else {
  159. oarc->ilabel = 0;
  160. oarc->olabel = 0;
  161. oarc->weight = LatticeWeight::One(); // zero costs.
  162. }
  163. return BestPathIterator(tok->backpointer, cur_t + step_t);
  164. }
  165. template <typename FST>
  166. bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
  167. Lattice* ofst, bool use_final_probs, BaseFloat beam) const {
  168. typedef LatticeArc Arc;
  169. typedef Arc::StateId StateId;
  170. typedef Arc::Weight Weight;
  171. typedef Arc::Label Label;
  172. // Note: you can't use the old interface (Decode()) if you want to
  173. // get the lattice with use_final_probs = false. You'd have to do
  174. // InitDecoding() and then AdvanceDecoding().
  175. if (this->decoding_finalized_ && !use_final_probs)
  176. KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
  177. << "GetRawLattice() with use_final_probs == false";
  178. unordered_map<Token*, BaseFloat> final_costs_local;
  179. const unordered_map<Token*, BaseFloat>& final_costs =
  180. (this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
  181. if (!this->decoding_finalized_ && use_final_probs)
  182. this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
  183. ofst->DeleteStates();
  184. // num-frames plus one (since frames are one-based, and we have
  185. // an extra frame for the start-state).
  186. int32 num_frames = this->active_toks_.size() - 1;
  187. KALDI_ASSERT(num_frames > 0);
  188. for (int32 f = 0; f <= num_frames; f++) {
  189. if (this->active_toks_[f].toks == NULL) {
  190. KALDI_WARN << "No tokens active on frame " << f
  191. << ": not producing lattice.\n";
  192. return false;
  193. }
  194. }
  195. unordered_map<Token*, StateId> tok_map;
  196. std::queue<std::pair<Token*, int32> > tok_queue;
  197. // First initialize the queue and states. Put the initial state on the queue;
  198. // this is the last token in the list active_toks_[0].toks.
  199. for (Token* tok = this->active_toks_[0].toks; tok != NULL; tok = tok->next) {
  200. if (tok->next == NULL) {
  201. tok_map[tok] = ofst->AddState();
  202. ofst->SetStart(tok_map[tok]);
  203. std::pair<Token*, int32> tok_pair(tok, 0); // #frame = 0
  204. tok_queue.push(tok_pair);
  205. }
  206. }
  207. // Next create states for "good" tokens
  208. while (!tok_queue.empty()) {
  209. std::pair<Token*, int32> cur_tok_pair = tok_queue.front();
  210. tok_queue.pop();
  211. Token* cur_tok = cur_tok_pair.first;
  212. int32 cur_frame = cur_tok_pair.second;
  213. KALDI_ASSERT(cur_frame >= 0 && cur_frame <= this->cost_offsets_.size());
  214. typename unordered_map<Token*, StateId>::const_iterator iter =
  215. tok_map.find(cur_tok);
  216. KALDI_ASSERT(iter != tok_map.end());
  217. StateId cur_state = iter->second;
  218. for (ForwardLinkT* l = cur_tok->links; l != NULL; l = l->next) {
  219. Token* next_tok = l->next_tok;
  220. if (next_tok->extra_cost < beam) {
  221. // so both the current and the next token are good; create the arc
  222. int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
  223. StateId nextstate;
  224. if (tok_map.find(next_tok) == tok_map.end()) {
  225. nextstate = tok_map[next_tok] = ofst->AddState();
  226. tok_queue.push(std::pair<Token*, int32>(next_tok, next_frame));
  227. } else {
  228. nextstate = tok_map[next_tok];
  229. }
  230. BaseFloat cost_offset =
  231. (l->ilabel != 0 ? this->cost_offsets_[cur_frame] : 0);
  232. Arc arc(l->ilabel, l->olabel,
  233. Weight(l->graph_cost, l->acoustic_cost - cost_offset),
  234. nextstate);
  235. ofst->AddArc(cur_state, arc);
  236. }
  237. }
  238. if (cur_frame == num_frames) {
  239. if (use_final_probs && !final_costs.empty()) {
  240. typename unordered_map<Token*, BaseFloat>::const_iterator iter =
  241. final_costs.find(cur_tok);
  242. if (iter != final_costs.end())
  243. ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
  244. } else {
  245. ofst->SetFinal(cur_state, LatticeWeight::One());
  246. }
  247. }
  248. }
  249. return (ofst->NumStates() != 0);
  250. }
  251. // Instantiate the template for the FST types that we'll need.
  252. template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
  253. template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
  254. template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
  255. } // end namespace kaldi.