You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

267 lines
10 KiB

  1. // fstext/lattice-utils-inl.h
  2. // Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
  3. // Daniel Povey)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_FSTEXT_LATTICE_UTILS_INL_H_
  19. #define KALDI_FSTEXT_LATTICE_UTILS_INL_H_
  20. // Do not include this file directly. It is included by lattice-utils.h
  21. #include <utility>
  22. #include <vector>
  23. namespace fst {
  24. /* Convert from FST with arc-type Weight, to one with arc-type
  25. CompactLatticeWeight. Uses FactorFst to identify chains
  26. of states which can be turned into a single output arc. */
  27. template <class Weight, class Int>
  28. void ConvertLattice(
  29. const ExpandedFst<ArcTpl<Weight> >& ifst,
  30. MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >* ofst,
  31. bool invert) {
  32. typedef ArcTpl<Weight> Arc;
  33. typedef typename Arc::StateId StateId;
  34. typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
  35. typedef ArcTpl<CompactWeight> CompactArc;
  36. VectorFst<ArcTpl<Weight> > ffst;
  37. std::vector<std::vector<Int> > labels;
  38. if (invert) { // normal case: want the ilabels as sequences on the arcs of
  39. Factor(ifst, &ffst, &labels); // the output... Factor makes seqs of
  40. // ilabels.
  41. } else {
  42. VectorFst<ArcTpl<Weight> > invfst(ifst);
  43. Invert(&invfst);
  44. Factor(invfst, &ffst, &labels);
  45. }
  46. TopSort(&ffst); // Put the states in ffst in topological order, which is
  47. // easier on the eye when reading the text-form lattices and corresponds to
  48. // what we get when we generate the lattices in the decoder.
  49. ofst->DeleteStates();
  50. // The states will be numbered exactly the same as the original FST.
  51. // Add the states to the new FST.
  52. StateId num_states = ffst.NumStates();
  53. for (StateId s = 0; s < num_states; s++) {
  54. StateId news = ofst->AddState();
  55. assert(news == s);
  56. }
  57. ofst->SetStart(ffst.Start());
  58. for (StateId s = 0; s < num_states; s++) {
  59. Weight final_weight = ffst.Final(s);
  60. if (final_weight != Weight::Zero()) {
  61. CompactWeight final_compact_weight(final_weight, std::vector<Int>());
  62. ofst->SetFinal(s, final_compact_weight);
  63. }
  64. for (ArcIterator<ExpandedFst<Arc> > iter(ffst, s); !iter.Done();
  65. iter.Next()) {
  66. const Arc& arc = iter.Value();
  67. KALDI_PARANOID_ASSERT(arc.weight != Weight::Zero());
  68. // note: zero-weight arcs not allowed anyway so weight should not be zero,
  69. // but no harm in checking.
  70. CompactArc compact_arc(arc.olabel, arc.olabel,
  71. CompactWeight(arc.weight, labels[arc.ilabel]),
  72. arc.nextstate);
  73. ofst->AddArc(s, compact_arc);
  74. }
  75. }
  76. }
  77. template <class Weight, class Int>
  78. void ConvertLattice(
  79. const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >& ifst,
  80. MutableFst<ArcTpl<Weight> >* ofst, bool invert) {
  81. typedef ArcTpl<Weight> Arc;
  82. typedef typename Arc::StateId StateId;
  83. typedef typename Arc::Label Label;
  84. typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight;
  85. typedef ArcTpl<CompactWeight> CompactArc;
  86. ofst->DeleteStates();
  87. // make the states in the new FST have the same numbers as
  88. // the original ones, and add chains of states as necessary
  89. // to encode the string-valued weights.
  90. StateId num_states = ifst.NumStates();
  91. for (StateId s = 0; s < num_states; s++) {
  92. StateId news = ofst->AddState();
  93. assert(news == s);
  94. }
  95. ofst->SetStart(ifst.Start());
  96. for (StateId s = 0; s < num_states; s++) {
  97. CompactWeight final_weight = ifst.Final(s);
  98. if (final_weight != CompactWeight::Zero()) {
  99. StateId cur_state = s;
  100. size_t string_length = final_weight.String().size();
  101. for (size_t n = 0; n < string_length; n++) {
  102. StateId next_state = ofst->AddState();
  103. Label ilabel = 0;
  104. Arc arc(ilabel, final_weight.String()[n],
  105. (n == 0 ? final_weight.Weight() : Weight::One()), next_state);
  106. if (invert) std::swap(arc.ilabel, arc.olabel);
  107. ofst->AddArc(cur_state, arc);
  108. cur_state = next_state;
  109. }
  110. ofst->SetFinal(cur_state,
  111. string_length > 0 ? Weight::One() : final_weight.Weight());
  112. }
  113. for (ArcIterator<ExpandedFst<CompactArc> > iter(ifst, s); !iter.Done();
  114. iter.Next()) {
  115. const CompactArc& arc = iter.Value();
  116. size_t string_length = arc.weight.String().size();
  117. StateId cur_state = s;
  118. // for all but the last element in the string--
  119. // add a temporary state.
  120. for (size_t n = 0; n + 1 < string_length; n++) {
  121. StateId next_state = ofst->AddState();
  122. Label ilabel = (n == 0 ? arc.ilabel : 0),
  123. olabel = static_cast<Label>(arc.weight.String()[n]);
  124. Weight weight = (n == 0 ? arc.weight.Weight() : Weight::One());
  125. Arc new_arc(ilabel, olabel, weight, next_state);
  126. if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
  127. ofst->AddArc(cur_state, new_arc);
  128. cur_state = next_state;
  129. }
  130. Label ilabel = (string_length <= 1 ? arc.ilabel : 0),
  131. olabel = (string_length > 0 ? arc.weight.String()[string_length - 1]
  132. : 0);
  133. Weight weight =
  134. (string_length <= 1 ? arc.weight.Weight() : Weight::One());
  135. Arc new_arc(ilabel, olabel, weight, arc.nextstate);
  136. if (invert) std::swap(new_arc.ilabel, new_arc.olabel);
  137. ofst->AddArc(cur_state, new_arc);
  138. }
  139. }
  140. }
  141. // This function converts lattices between float and double;
  142. // it works for both CompactLatticeWeight and LatticeWeight.
  143. template <class WeightIn, class WeightOut>
  144. void ConvertLattice(const ExpandedFst<ArcTpl<WeightIn> >& ifst,
  145. MutableFst<ArcTpl<WeightOut> >* ofst) {
  146. typedef ArcTpl<WeightIn> ArcIn;
  147. typedef ArcTpl<WeightOut> ArcOut;
  148. typedef typename ArcIn::StateId StateId;
  149. ofst->DeleteStates();
  150. // The states will be numbered exactly the same as the original FST.
  151. // Add the states to the new FST.
  152. StateId num_states = ifst.NumStates();
  153. for (StateId s = 0; s < num_states; s++) {
  154. StateId news = ofst->AddState();
  155. assert(news == s);
  156. }
  157. ofst->SetStart(ifst.Start());
  158. for (StateId s = 0; s < num_states; s++) {
  159. WeightIn final_iweight = ifst.Final(s);
  160. if (final_iweight != WeightIn::Zero()) {
  161. WeightOut final_oweight;
  162. ConvertLatticeWeight(final_iweight, &final_oweight);
  163. ofst->SetFinal(s, final_oweight);
  164. }
  165. for (ArcIterator<ExpandedFst<ArcIn> > iter(ifst, s); !iter.Done();
  166. iter.Next()) {
  167. ArcIn arc = iter.Value();
  168. KALDI_PARANOID_ASSERT(arc.weight != WeightIn::Zero());
  169. ArcOut oarc;
  170. ConvertLatticeWeight(arc.weight, &oarc.weight);
  171. oarc.ilabel = arc.ilabel;
  172. oarc.olabel = arc.olabel;
  173. oarc.nextstate = arc.nextstate;
  174. ofst->AddArc(s, oarc);
  175. }
  176. }
  177. }
  178. template <class Weight, class ScaleFloat>
  179. void ScaleLattice(const std::vector<std::vector<ScaleFloat> >& scale,
  180. MutableFst<ArcTpl<Weight> >* fst) {
  181. assert(scale.size() == 2 && scale[0].size() == 2 && scale[1].size() == 2);
  182. if (scale == DefaultLatticeScale()) // nothing to do.
  183. return;
  184. typedef ArcTpl<Weight> Arc;
  185. typedef MutableFst<Arc> Fst;
  186. typedef typename Arc::StateId StateId;
  187. StateId num_states = fst->NumStates();
  188. for (StateId s = 0; s < num_states; s++) {
  189. for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
  190. Arc arc = aiter.Value();
  191. arc.weight = Weight(ScaleTupleWeight(arc.weight, scale));
  192. aiter.SetValue(arc);
  193. }
  194. Weight final_weight = fst->Final(s);
  195. if (final_weight != Weight::Zero())
  196. fst->SetFinal(s, Weight(ScaleTupleWeight(final_weight, scale)));
  197. }
  198. }
  199. template <class Weight, class Int>
  200. void RemoveAlignmentsFromCompactLattice(
  201. MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >* fst) {
  202. typedef CompactLatticeWeightTpl<Weight, Int> W;
  203. typedef ArcTpl<W> Arc;
  204. typedef MutableFst<Arc> Fst;
  205. typedef typename Arc::StateId StateId;
  206. StateId num_states = fst->NumStates();
  207. for (StateId s = 0; s < num_states; s++) {
  208. for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
  209. Arc arc = aiter.Value();
  210. arc.weight = W(arc.weight.Weight(), std::vector<Int>());
  211. aiter.SetValue(arc);
  212. }
  213. W final_weight = fst->Final(s);
  214. if (final_weight != W::Zero())
  215. fst->SetFinal(s, W(final_weight.Weight(), std::vector<Int>()));
  216. }
  217. }
  218. template <class Weight, class Int>
  219. bool CompactLatticeHasAlignment(
  220. const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >& fst) {
  221. typedef CompactLatticeWeightTpl<Weight, Int> W;
  222. typedef ArcTpl<W> Arc;
  223. typedef ExpandedFst<Arc> Fst;
  224. typedef typename Arc::StateId StateId;
  225. StateId num_states = fst.NumStates();
  226. for (StateId s = 0; s < num_states; s++) {
  227. for (ArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) {
  228. const Arc& arc = aiter.Value();
  229. if (!arc.weight.String().empty()) return true;
  230. }
  231. W final_weight = fst.Final(s);
  232. if (!final_weight.String().empty()) return true;
  233. }
  234. return false;
  235. }
  236. template <class Real>
  237. void ConvertFstToLattice(const ExpandedFst<ArcTpl<TropicalWeight> >& ifst,
  238. MutableFst<ArcTpl<LatticeWeightTpl<Real> > >* ofst) {
  239. int32 num_states_cache = 50000;
  240. fst::CacheOptions cache_opts(true, num_states_cache);
  241. fst::MapFstOptions mapfst_opts(cache_opts);
  242. StdToLatticeMapper<Real> mapper;
  243. MapFst<StdArc, ArcTpl<LatticeWeightTpl<Real> >, StdToLatticeMapper<Real> >
  244. map_fst(ifst, mapper, mapfst_opts);
  245. *ofst = map_fst;
  246. }
  247. } // namespace fst
  248. #endif // KALDI_FSTEXT_LATTICE_UTILS_INL_H_