You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1265 lines
47 KiB

  1. // fstext/fstext-utils-inl.h
  2. // Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
  3. // Daniel Povey)
  4. // 2014 Telepoint Global Hosting Service, LLC. (Author: David
  5. // Snyder)
  6. // See ../../COPYING for clarification regarding multiple authors
  7. //
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. //
  12. // http://www.apache.org/licenses/LICENSE-2.0
  13. //
  14. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  15. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  16. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  17. // MERCHANTABLITY OR NON-INFRINGEMENT.
  18. // See the Apache 2 License for the specific language governing permissions and
  19. // limitations under the License.
  20. #ifndef KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
  21. #define KALDI_FSTEXT_FSTEXT_UTILS_INL_H_
  22. #include <algorithm>
  23. #include <cstring>
  24. #include <map>
  25. #include <set>
  26. #include <sstream>
  27. #include <string>
  28. #include <unordered_map>
  29. #include <unordered_set>
  30. #include <utility>
  31. #include <vector>
  32. #include "base/kaldi-common.h"
  33. #include "fstext/determinize-star.h"
  34. #include "fstext/pre-determinize.h"
  35. #include "util/const-integer-set.h"
  36. #include "util/kaldi-io.h"
  37. #include "util/stl-utils.h"
  38. #include "util/text-utils.h"
  39. namespace fst {
  40. template <class Arc>
  41. typename Arc::Label HighestNumberedOutputSymbol(const Fst<Arc>& fst) {
  42. typename Arc::Label ans = 0;
  43. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  44. typename Arc::StateId s = siter.Value();
  45. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  46. const Arc& arc = aiter.Value();
  47. ans = std::max(ans, arc.olabel);
  48. }
  49. }
  50. return ans;
  51. }
  52. template <class Arc>
  53. typename Arc::Label HighestNumberedInputSymbol(const Fst<Arc>& fst) {
  54. typename Arc::Label ans = 0;
  55. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  56. typename Arc::StateId s = siter.Value();
  57. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  58. const Arc& arc = aiter.Value();
  59. ans = std::max(ans, arc.ilabel);
  60. }
  61. }
  62. return ans;
  63. }
  64. template <class Arc>
  65. typename Arc::StateId NumArcs(const ExpandedFst<Arc>& fst) {
  66. typedef typename Arc::StateId StateId;
  67. StateId num_arcs = 0;
  68. for (StateId s = 0; s < fst.NumStates(); s++) num_arcs += fst.NumArcs(s);
  69. return num_arcs;
  70. }
  71. template <class Arc, class I>
  72. void GetOutputSymbols(const Fst<Arc>& fst, bool include_eps,
  73. std::vector<I>* symbols) {
  74. KALDI_ASSERT_IS_INTEGER_TYPE(I);
  75. std::set<I> all_syms;
  76. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  77. typename Arc::StateId s = siter.Value();
  78. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  79. const Arc& arc = aiter.Value();
  80. all_syms.insert(arc.olabel);
  81. }
  82. }
  83. // Remove epsilon, if instructed.
  84. if (!include_eps && !all_syms.empty() && *all_syms.begin() == 0)
  85. all_syms.erase(0);
  86. KALDI_ASSERT(symbols != NULL);
  87. kaldi::CopySetToVector(all_syms, symbols);
  88. }
  89. template <class Arc, class I>
  90. void GetInputSymbols(const Fst<Arc>& fst, bool include_eps,
  91. std::vector<I>* symbols) {
  92. KALDI_ASSERT_IS_INTEGER_TYPE(I);
  93. unordered_set<I> all_syms;
  94. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  95. typename Arc::StateId s = siter.Value();
  96. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  97. const Arc& arc = aiter.Value();
  98. all_syms.insert(arc.ilabel);
  99. }
  100. }
  101. // Remove epsilon, if instructed.
  102. if (!include_eps && all_syms.count(0) != 0) all_syms.erase(0);
  103. KALDI_ASSERT(symbols != NULL);
  104. kaldi::CopySetToVector(all_syms, symbols);
  105. std::sort(symbols->begin(), symbols->end());
  106. }
  107. template <class Arc, class I>
  108. class RemoveSomeInputSymbolsMapper {
  109. public:
  110. Arc operator()(const Arc& arc_in) {
  111. Arc ans = arc_in;
  112. if (to_remove_set_.count(ans.ilabel) != 0)
  113. ans.ilabel = 0; // remove this symbol
  114. return ans;
  115. }
  116. MapFinalAction FinalAction() { return MAP_NO_SUPERFINAL; }
  117. MapSymbolsAction InputSymbolsAction() { return MAP_CLEAR_SYMBOLS; }
  118. MapSymbolsAction OutputSymbolsAction() { return MAP_COPY_SYMBOLS; }
  119. uint64 Properties(uint64 props) const {
  120. // remove the following as we don't know now if any of them are true.
  121. uint64 to_remove = kAcceptor | kNotAcceptor | kIDeterministic |
  122. kNonIDeterministic | kNoEpsilons | kNoIEpsilons |
  123. kILabelSorted | kNotILabelSorted;
  124. return props & ~to_remove;
  125. }
  126. explicit RemoveSomeInputSymbolsMapper(const std::vector<I>& to_remove)
  127. : to_remove_set_(to_remove) {
  128. KALDI_ASSERT_IS_INTEGER_TYPE(I);
  129. assert(to_remove_set_.count(0) == 0); // makes no sense to remove epsilon.
  130. }
  131. private:
  132. kaldi::ConstIntegerSet<I> to_remove_set_;
  133. };
  134. template <class Arc, class I>
  135. using LookaheadFst = ArcMapFst<Arc, Arc, RemoveSomeInputSymbolsMapper<Arc, I> >;
  136. // Lookahead composition is used for optimized online
  137. // composition of FSTs during decoding. See
  138. // nnet3/nnet3-latgen-faster-lookahead.cc. For details of compose filters
  139. // see DefaultLookAhead in fst/compose.h
  140. template <class Arc, class I>
  141. LookaheadFst<Arc, I>* LookaheadComposeFst(const Fst<Arc>& ifst1,
  142. const Fst<Arc>& ifst2,
  143. const std::vector<I>& to_remove) {
  144. fst::CacheOptions cache_opts(true, 1 << 25LL);
  145. fst::CacheOptions cache_opts_map(true, 0);
  146. fst::ArcMapFstOptions arcmap_opts(cache_opts);
  147. RemoveSomeInputSymbolsMapper<Arc, I> mapper(to_remove);
  148. return new LookaheadFst<Arc, I>(ComposeFst<Arc>(ifst1, ifst2, cache_opts),
  149. mapper, arcmap_opts);
  150. }
  151. template <class Arc, class I>
  152. void RemoveSomeInputSymbols(const std::vector<I>& to_remove,
  153. MutableFst<Arc>* fst) {
  154. KALDI_ASSERT_IS_INTEGER_TYPE(I);
  155. RemoveSomeInputSymbolsMapper<Arc, I> mapper(to_remove);
  156. Map(fst, mapper);
  157. }
  158. template <class Arc, class I>
  159. class MapInputSymbolsMapper {
  160. public:
  161. Arc operator()(const Arc& arc_in) {
  162. Arc ans = arc_in;
  163. if (ans.ilabel > 0 && ans.ilabel < static_cast<typename Arc::Label>(
  164. (*symbol_mapping_).size()))
  165. ans.ilabel = (*symbol_mapping_)[ans.ilabel];
  166. return ans;
  167. }
  168. MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  169. MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
  170. MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
  171. uint64 Properties(uint64 props) const { // Not tested.
  172. bool remove_epsilons =
  173. (symbol_mapping_->size() > 0 && (*symbol_mapping_)[0] != 0);
  174. bool add_epsilons = (symbol_mapping_->size() > 1 &&
  175. *std::min_element(symbol_mapping_->begin() + 1,
  176. symbol_mapping_->end()) == 0);
  177. // remove the following as we don't know now if any of them are true.
  178. uint64 props_to_remove = kAcceptor | kNotAcceptor | kIDeterministic |
  179. kNonIDeterministic | kILabelSorted |
  180. kNotILabelSorted;
  181. if (remove_epsilons) props_to_remove |= kEpsilons | kIEpsilons;
  182. if (add_epsilons) props_to_remove |= kNoEpsilons | kNoIEpsilons;
  183. uint64 props_to_add = 0;
  184. if (remove_epsilons && !add_epsilons)
  185. props_to_add |= kNoEpsilons | kNoIEpsilons;
  186. return (props & ~props_to_remove) | props_to_add;
  187. }
  188. // initialize with copy = false only if the "to_remove" argument will not be
  189. // deleted in the lifetime of this object.
  190. MapInputSymbolsMapper(const std::vector<I>& to_remove, bool copy) {
  191. KALDI_ASSERT_IS_INTEGER_TYPE(I);
  192. if (copy)
  193. symbol_mapping_ = new std::vector<I>(to_remove);
  194. else
  195. symbol_mapping_ = &to_remove;
  196. owned = copy;
  197. }
  198. ~MapInputSymbolsMapper() {
  199. if (owned && symbol_mapping_ != NULL) delete symbol_mapping_;
  200. }
  201. private:
  202. bool owned;
  203. const std::vector<I>* symbol_mapping_;
  204. };
  205. template <class Arc, class I>
  206. void MapInputSymbols(const std::vector<I>& symbol_mapping,
  207. MutableFst<Arc>* fst) {
  208. KALDI_ASSERT_IS_INTEGER_TYPE(I);
  209. // false == don't copy the "symbol_mapping", retain pointer--
  210. // safe since short-lived object.
  211. MapInputSymbolsMapper<Arc, I> mapper(symbol_mapping, false);
  212. Map(fst, mapper);
  213. }
  214. template <class Arc, class I>
  215. bool GetLinearSymbolSequence(const Fst<Arc>& fst, std::vector<I>* isymbols_out,
  216. std::vector<I>* osymbols_out,
  217. typename Arc::Weight* tot_weight_out) {
  218. typedef typename Arc::StateId StateId;
  219. typedef typename Arc::Weight Weight;
  220. Weight tot_weight = Weight::One();
  221. std::vector<I> ilabel_seq;
  222. std::vector<I> olabel_seq;
  223. StateId cur_state = fst.Start();
  224. if (cur_state == kNoStateId) { // empty sequence.
  225. if (isymbols_out != NULL) isymbols_out->clear();
  226. if (osymbols_out != NULL) osymbols_out->clear();
  227. if (tot_weight_out != NULL) *tot_weight_out = Weight::Zero();
  228. return true;
  229. }
  230. while (1) {
  231. Weight w = fst.Final(cur_state);
  232. if (w != Weight::Zero()) { // is final..
  233. tot_weight = Times(w, tot_weight);
  234. if (fst.NumArcs(cur_state) != 0) return false;
  235. if (isymbols_out != NULL) *isymbols_out = ilabel_seq;
  236. if (osymbols_out != NULL) *osymbols_out = olabel_seq;
  237. if (tot_weight_out != NULL) *tot_weight_out = tot_weight;
  238. return true;
  239. } else {
  240. if (fst.NumArcs(cur_state) != 1) return false;
  241. ArcIterator<Fst<Arc> > iter(fst, cur_state); // get the only arc.
  242. const Arc& arc = iter.Value();
  243. tot_weight = Times(arc.weight, tot_weight);
  244. if (arc.ilabel != 0) ilabel_seq.push_back(arc.ilabel);
  245. if (arc.olabel != 0) olabel_seq.push_back(arc.olabel);
  246. cur_state = arc.nextstate;
  247. }
  248. }
  249. }
  250. // see fstext-utils.h for comment.
  251. template <class Arc>
  252. void ConvertNbestToVector(const Fst<Arc>& fst,
  253. std::vector<VectorFst<Arc> >* fsts_out) {
  254. typedef typename Arc::Weight Weight;
  255. typedef typename Arc::StateId StateId;
  256. fsts_out->clear();
  257. StateId start_state = fst.Start();
  258. if (start_state == kNoStateId) return; // No output.
  259. size_t n_arcs = fst.NumArcs(start_state);
  260. bool start_is_final = (fst.Final(start_state) != Weight::Zero());
  261. fsts_out->reserve(n_arcs + (start_is_final ? 1 : 0));
  262. if (start_is_final) {
  263. fsts_out->resize(fsts_out->size() + 1);
  264. StateId start_state_out = fsts_out->back().AddState();
  265. fsts_out->back().SetFinal(start_state_out, fst.Final(start_state));
  266. }
  267. for (ArcIterator<Fst<Arc> > start_aiter(fst, start_state);
  268. !start_aiter.Done(); start_aiter.Next()) {
  269. fsts_out->resize(fsts_out->size() + 1);
  270. VectorFst<Arc>& ofst = fsts_out->back();
  271. const Arc& first_arc = start_aiter.Value();
  272. StateId cur_state = start_state, cur_ostate = ofst.AddState();
  273. ofst.SetStart(cur_ostate);
  274. StateId next_ostate = ofst.AddState();
  275. ofst.AddArc(cur_ostate, Arc(first_arc.ilabel, first_arc.olabel,
  276. first_arc.weight, next_ostate));
  277. cur_state = first_arc.nextstate;
  278. cur_ostate = next_ostate;
  279. while (1) {
  280. size_t this_n_arcs = fst.NumArcs(cur_state);
  281. KALDI_ASSERT(this_n_arcs <= 1); // or it violates our assumptions
  282. // about the input.
  283. if (this_n_arcs == 1) {
  284. KALDI_ASSERT(fst.Final(cur_state) == Weight::Zero());
  285. // or problem with ShortestPath.
  286. ArcIterator<Fst<Arc> > aiter(fst, cur_state);
  287. const Arc& arc = aiter.Value();
  288. next_ostate = ofst.AddState();
  289. ofst.AddArc(cur_ostate,
  290. Arc(arc.ilabel, arc.olabel, arc.weight, next_ostate));
  291. cur_state = arc.nextstate;
  292. cur_ostate = next_ostate;
  293. } else {
  294. KALDI_ASSERT(fst.Final(cur_state) != Weight::Zero());
  295. // or problem with ShortestPath.
  296. ofst.SetFinal(cur_ostate, fst.Final(cur_state));
  297. break;
  298. }
  299. }
  300. }
  301. }
  302. // see fstext-utils.sh for comment.
  303. template <class Arc>
  304. void NbestAsFsts(const Fst<Arc>& fst, size_t n,
  305. std::vector<VectorFst<Arc> >* fsts_out) {
  306. KALDI_ASSERT(n > 0);
  307. KALDI_ASSERT(fsts_out != NULL);
  308. VectorFst<Arc> nbest_fst;
  309. ShortestPath(fst, &nbest_fst, n);
  310. ConvertNbestToVector(nbest_fst, fsts_out);
  311. }
  312. template <class Arc, class I>
  313. void MakeLinearAcceptorWithAlternatives(
  314. const std::vector<std::vector<I> >& labels, MutableFst<Arc>* ofst) {
  315. typedef typename Arc::StateId StateId;
  316. typedef typename Arc::Weight Weight;
  317. ofst->DeleteStates();
  318. StateId cur_state = ofst->AddState();
  319. ofst->SetStart(cur_state);
  320. for (size_t i = 0; i < labels.size(); i++) {
  321. KALDI_ASSERT(labels[i].size() != 0);
  322. StateId next_state = ofst->AddState();
  323. for (size_t j = 0; j < labels[i].size(); j++) {
  324. Arc arc(labels[i][j], labels[i][j], Weight::One(), next_state);
  325. ofst->AddArc(cur_state, arc);
  326. }
  327. cur_state = next_state;
  328. }
  329. ofst->SetFinal(cur_state, Weight::One());
  330. }
  331. template <class Arc, class I>
  332. void MakeLinearAcceptor(const std::vector<I>& labels, MutableFst<Arc>* ofst) {
  333. typedef typename Arc::StateId StateId;
  334. typedef typename Arc::Weight Weight;
  335. ofst->DeleteStates();
  336. StateId cur_state = ofst->AddState();
  337. ofst->SetStart(cur_state);
  338. for (size_t i = 0; i < labels.size(); i++) {
  339. StateId next_state = ofst->AddState();
  340. Arc arc(labels[i], labels[i], Weight::One(), next_state);
  341. ofst->AddArc(cur_state, arc);
  342. cur_state = next_state;
  343. }
  344. ofst->SetFinal(cur_state, Weight::One());
  345. }
  346. template <class I>
  347. void GetSymbols(const SymbolTable& symtab, bool include_eps,
  348. std::vector<I>* syms_out) {
  349. KALDI_ASSERT(syms_out != NULL);
  350. syms_out->clear();
  351. for (SymbolTableIterator iter(symtab); !iter.Done(); iter.Next()) {
  352. if (include_eps || iter.Value() != 0) {
  353. syms_out->push_back(iter.Value());
  354. KALDI_ASSERT(syms_out->back() ==
  355. iter.Value()); // an integer-range thing.
  356. }
  357. }
  358. }
  359. template <class Arc>
  360. void SafeDeterminizeWrapper(MutableFst<Arc>* ifst, MutableFst<Arc>* ofst,
  361. float delta) {
  362. typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
  363. std::vector<typename Arc::Label> extra_syms;
  364. PreDeterminize(ifst, (typename Arc::Label)(highest_sym + 1), &extra_syms);
  365. DeterminizeStar(*ifst, ofst, delta);
  366. RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
  367. }
  368. template <class Arc>
  369. void SafeDeterminizeMinimizeWrapper(MutableFst<Arc>* ifst, VectorFst<Arc>* ofst,
  370. float delta) {
  371. typename Arc::Label highest_sym = HighestNumberedInputSymbol(*ifst);
  372. std::vector<typename Arc::Label> extra_syms;
  373. PreDeterminize(ifst, (typename Arc::Label)(highest_sym + 1), &extra_syms);
  374. DeterminizeStar(*ifst, ofst, delta);
  375. RemoveSomeInputSymbols(extra_syms, ofst); // remove the extra symbols.
  376. RemoveEpsLocal(ofst); // this is "safe" and will never hurt.
  377. MinimizeEncoded(ofst, delta);
  378. }
  379. inline void DeterminizeStarInLog(VectorFst<StdArc>* fst, float delta,
  380. bool* debug_ptr, int max_states) {
  381. // DeterminizeStarInLog determinizes 'fst' in the log semiring, using
  382. // the DeterminizeStar algorithm (which also removes epsilons).
  383. ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
  384. VectorFst<LogArc>* fst_log =
  385. new VectorFst<LogArc>; // Want to determinize in log semiring.
  386. Cast(*fst, fst_log);
  387. VectorFst<StdArc> tmp;
  388. *fst = tmp; // make fst empty to free up memory. [actually may make no
  389. // difference..]
  390. VectorFst<LogArc>* fst_det_log = new VectorFst<LogArc>;
  391. DeterminizeStar(*fst_log, fst_det_log, delta, debug_ptr, max_states);
  392. Cast(*fst_det_log, fst);
  393. delete fst_log;
  394. delete fst_det_log;
  395. }
  396. inline void DeterminizeInLog(VectorFst<StdArc>* fst) {
  397. // DeterminizeInLog determinizes 'fst' in the log semiring.
  398. ArcSort(fst, ILabelCompare<StdArc>()); // helps DeterminizeStar to be faster.
  399. VectorFst<LogArc>* fst_log =
  400. new VectorFst<LogArc>; // Want to determinize in log semiring.
  401. Cast(*fst, fst_log);
  402. VectorFst<StdArc> tmp;
  403. *fst = tmp; // make fst empty to free up memory. [actually may make no
  404. // difference..]
  405. VectorFst<LogArc>* fst_det_log = new VectorFst<LogArc>;
  406. Determinize(*fst_log, fst_det_log);
  407. Cast(*fst_det_log, fst);
  408. delete fst_log;
  409. delete fst_det_log;
  410. }
  411. // make it inline to avoid having to put it in a .cc file.
  412. // destructive algorithm (changes ifst as well as ofst).
  413. inline void SafeDeterminizeMinimizeWrapperInLog(VectorFst<StdArc>* ifst,
  414. VectorFst<StdArc>* ofst,
  415. float delta) {
  416. VectorFst<LogArc>* ifst_log =
  417. new VectorFst<LogArc>; // Want to determinize in log semiring.
  418. Cast(*ifst, ifst_log);
  419. VectorFst<LogArc>* ofst_log = new VectorFst<LogArc>;
  420. SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
  421. Cast(*ofst_log, ofst);
  422. delete ifst_log;
  423. delete ofst_log;
  424. RemoveEpsLocal(ofst); // this is "safe" and will never hurt. Do this in
  425. // tropical, which is important.
  426. MinimizeEncoded(ofst, delta); // Non-deterministic minimization will fail in
  427. // log semiring so do it with StdARc.
  428. }
  429. inline void SafeDeterminizeWrapperInLog(VectorFst<StdArc>* ifst,
  430. VectorFst<StdArc>* ofst, float delta) {
  431. VectorFst<LogArc>* ifst_log =
  432. new VectorFst<LogArc>; // Want to determinize in log semiring.
  433. Cast(*ifst, ifst_log);
  434. VectorFst<LogArc>* ofst_log = new VectorFst<LogArc>;
  435. SafeDeterminizeWrapper(ifst_log, ofst_log, delta);
  436. Cast(*ofst_log, ofst);
  437. delete ifst_log;
  438. delete ofst_log;
  439. }
  440. template <class Arc>
  441. void RemoveWeights(MutableFst<Arc>* ifst) {
  442. typedef typename Arc::StateId StateId;
  443. typedef typename Arc::Weight Weight;
  444. for (StateIterator<MutableFst<Arc> > siter(*ifst); !siter.Done();
  445. siter.Next()) {
  446. StateId s = siter.Value();
  447. for (MutableArcIterator<MutableFst<Arc> > aiter(ifst, s); !aiter.Done();
  448. aiter.Next()) {
  449. Arc arc(aiter.Value());
  450. arc.weight = Weight::One();
  451. aiter.SetValue(arc);
  452. }
  453. if (ifst->Final(s) != Weight::Zero()) ifst->SetFinal(s, Weight::One());
  454. }
  455. ifst->SetProperties(kUnweighted, kUnweighted);
  456. }
  457. // Used in PrecedingInputSymbolsAreSame (non-functor version), and
  458. // similar routines.
  459. template <class T>
  460. struct IdentityFunction {
  461. typedef T Arg;
  462. typedef T Result;
  463. T operator()(const T& t) const { return t; }
  464. };
  465. template <class Arc>
  466. bool PrecedingInputSymbolsAreSame(bool start_is_epsilon, const Fst<Arc>& fst) {
  467. IdentityFunction<typename Arc::Label> f;
  468. return PrecedingInputSymbolsAreSameClass(start_is_epsilon, fst, f);
  469. }
  470. template <class Arc, class F> // F is functor type from labels to classes.
  471. bool PrecedingInputSymbolsAreSameClass(bool start_is_epsilon,
  472. const Fst<Arc>& fst, const F& f) {
  473. typedef typename F::Result ClassType;
  474. typedef typename Arc::StateId StateId;
  475. std::vector<ClassType> classes;
  476. ClassType noClass = f(kNoLabel);
  477. if (start_is_epsilon) {
  478. StateId start_state = fst.Start();
  479. if (start_state < 0 || start_state == kNoStateId)
  480. return true; // empty fst-- doesn't matter.
  481. classes.resize(start_state + 1, noClass);
  482. classes[start_state] = 0;
  483. }
  484. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  485. StateId s = siter.Value();
  486. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  487. const Arc& arc = aiter.Value();
  488. if (classes.size() <= arc.nextstate)
  489. classes.resize(arc.nextstate + 1, noClass);
  490. if (classes[arc.nextstate] == noClass)
  491. classes[arc.nextstate] = f(arc.ilabel);
  492. else if (classes[arc.nextstate] != f(arc.ilabel))
  493. return false;
  494. }
  495. }
  496. return true;
  497. }
  498. template <class Arc>
  499. bool FollowingInputSymbolsAreSame(bool end_is_epsilon, const Fst<Arc>& fst) {
  500. IdentityFunction<typename Arc::Label> f;
  501. return FollowingInputSymbolsAreSameClass(end_is_epsilon, fst, f);
  502. }
  503. template <class Arc, class F>
  504. bool FollowingInputSymbolsAreSameClass(bool end_is_epsilon, const Fst<Arc>& fst,
  505. const F& f) {
  506. typedef typename Arc::StateId StateId;
  507. typedef typename Arc::Weight Weight;
  508. typedef typename F::Result ClassType;
  509. const ClassType noClass = f(kNoLabel), epsClass = f(0);
  510. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  511. StateId s = siter.Value();
  512. ClassType c = noClass;
  513. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  514. const Arc& arc = aiter.Value();
  515. if (c == noClass)
  516. c = f(arc.ilabel);
  517. else if (c != f(arc.ilabel))
  518. return false;
  519. }
  520. if (end_is_epsilon && c != noClass && c != epsClass &&
  521. fst.Final(s) != Weight::Zero())
  522. return false;
  523. }
  524. return true;
  525. }
  526. template <class Arc>
  527. void MakePrecedingInputSymbolsSame(bool start_is_epsilon,
  528. MutableFst<Arc>* fst) {
  529. IdentityFunction<typename Arc::Label> f;
  530. MakePrecedingInputSymbolsSameClass(start_is_epsilon, fst, f);
  531. }
  532. template <class Arc, class F>
  533. void MakePrecedingInputSymbolsSameClass(bool start_is_epsilon,
  534. MutableFst<Arc>* fst, const F& f) {
  535. typedef typename F::Result ClassType;
  536. typedef typename Arc::StateId StateId;
  537. typedef typename Arc::Weight Weight;
  538. std::vector<ClassType> classes;
  539. ClassType noClass = f(kNoLabel);
  540. ClassType epsClass = f(0);
  541. if (start_is_epsilon) { // treat having-start-state as epsilon in-transition.
  542. StateId start_state = fst->Start();
  543. if (start_state < 0 || start_state == kNoStateId) // empty FST.
  544. return;
  545. classes.resize(start_state + 1, noClass);
  546. classes[start_state] = epsClass;
  547. }
  548. // Find bad states (states with multiple input-symbols into them).
  549. std::set<StateId> bad_states; // states that we need to change.
  550. for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
  551. StateId s = siter.Value();
  552. for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
  553. const Arc& arc = aiter.Value();
  554. if (classes.size() <= static_cast<size_t>(arc.nextstate))
  555. classes.resize(arc.nextstate + 1, noClass);
  556. if (classes[arc.nextstate] == noClass)
  557. classes[arc.nextstate] = f(arc.ilabel);
  558. else if (classes[arc.nextstate] != f(arc.ilabel))
  559. bad_states.insert(arc.nextstate);
  560. }
  561. }
  562. if (bad_states.empty()) return; // Nothing to do.
  563. kaldi::ConstIntegerSet<StateId> bad_states_ciset(
  564. bad_states); // faster lookup.
  565. // Work out list of arcs we have to change as (state, arc-offset).
  566. // Can't do the actual changes in this pass, since we have to add new
  567. // states which invalidates the iterators.
  568. std::vector<std::pair<StateId, size_t> > arcs_to_change;
  569. for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
  570. StateId s = siter.Value();
  571. for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
  572. const Arc& arc = aiter.Value();
  573. if (arc.ilabel != 0 && bad_states_ciset.count(arc.nextstate) != 0)
  574. arcs_to_change.push_back(std::make_pair(s, aiter.Position()));
  575. }
  576. }
  577. KALDI_ASSERT(!arcs_to_change.empty()); // since !bad_states.empty().
  578. std::map<std::pair<StateId, ClassType>, StateId> state_map;
  579. // state_map is a map from (bad-state, input-symbol-class) to dummy-state.
  580. for (size_t i = 0; i < arcs_to_change.size(); i++) {
  581. StateId s = arcs_to_change[i].first;
  582. ArcIterator<MutableFst<Arc> > aiter(*fst, s);
  583. aiter.Seek(arcs_to_change[i].second);
  584. Arc arc = aiter.Value();
  585. // Transition is non-eps transition to "bad" state. Introduce new state (or
  586. // find existing one).
  587. std::pair<StateId, ClassType> p(arc.nextstate, f(arc.ilabel));
  588. if (state_map.count(p) == 0) {
  589. StateId newstate = state_map[p] = fst->AddState();
  590. fst->AddArc(newstate, Arc(0, 0, Weight::One(), arc.nextstate));
  591. }
  592. StateId dst_state = state_map[p];
  593. arc.nextstate = dst_state;
  594. // Initialize the MutableArcIterator only now, as the call to NewState()
  595. // may have invalidated the first arc iterator.
  596. MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
  597. maiter.Seek(arcs_to_change[i].second);
  598. maiter.SetValue(arc);
  599. }
  600. }
  601. template <class Arc>
  602. void MakeFollowingInputSymbolsSame(bool end_is_epsilon, MutableFst<Arc>* fst) {
  603. IdentityFunction<typename Arc::Label> f;
  604. MakeFollowingInputSymbolsSameClass(end_is_epsilon, fst, f);
  605. }
  606. template <class Arc, class F>
  607. void MakeFollowingInputSymbolsSameClass(bool end_is_epsilon,
  608. MutableFst<Arc>* fst, const F& f) {
  609. typedef typename Arc::StateId StateId;
  610. typedef typename Arc::Weight Weight;
  611. typedef typename F::Result ClassType;
  612. std::vector<StateId> bad_states;
  613. ClassType noClass = f(kNoLabel);
  614. ClassType epsClass = f(0);
  615. for (StateIterator<Fst<Arc> > siter(*fst); !siter.Done(); siter.Next()) {
  616. StateId s = siter.Value();
  617. ClassType c = noClass;
  618. bool bad = false;
  619. for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
  620. const Arc& arc = aiter.Value();
  621. if (c == noClass) {
  622. c = f(arc.ilabel);
  623. } else if (c != f(arc.ilabel)) {
  624. bad = true;
  625. break;
  626. }
  627. }
  628. if (end_is_epsilon && c != noClass && c != epsClass &&
  629. fst->Final(s) != Weight::Zero())
  630. bad = true;
  631. if (bad) bad_states.push_back(s);
  632. }
  633. std::vector<Arc> my_arcs;
  634. for (size_t i = 0; i < bad_states.size(); i++) {
  635. StateId s = bad_states[i];
  636. my_arcs.clear();
  637. for (ArcIterator<MutableFst<Arc> > aiter(*fst, s); !aiter.Done();
  638. aiter.Next())
  639. my_arcs.push_back(aiter.Value());
  640. for (size_t j = 0; j < my_arcs.size(); j++) {
  641. Arc& arc = my_arcs[j];
  642. if (arc.ilabel != 0) {
  643. StateId newstate = fst->AddState();
  644. // Create a new state for each non-eps arc in original FST, out of each
  645. // bad state. Not as optimal as it could be, but does avoid some
  646. // complicated weight-pushing issues in which, to maintain
  647. // stochasticity, we would have to know which semiring we want to
  648. // maintain stochasticity in.
  649. fst->AddArc(newstate, Arc(arc.ilabel, 0, Weight::One(), arc.nextstate));
  650. MutableArcIterator<MutableFst<Arc> > maiter(fst, s);
  651. maiter.Seek(j);
  652. maiter.SetValue(Arc(0, arc.olabel, arc.weight, newstate));
  653. }
  654. }
  655. }
  656. }
  657. template <class Arc>
  658. VectorFst<Arc>* MakeLoopFst(const std::vector<const ExpandedFst<Arc>*>& fsts) {
  659. typedef typename Arc::Weight Weight;
  660. typedef typename Arc::StateId StateId;
  661. typedef typename Arc::Label Label;
  662. VectorFst<Arc>* ans = new VectorFst<Arc>;
  663. StateId loop_state = ans->AddState(); // = 0.
  664. ans->SetStart(loop_state);
  665. ans->SetFinal(loop_state, Weight::One());
  666. // "cache" is used as an optimization when some of the pointers in "fsts"
  667. // may have the same value.
  668. unordered_map<const ExpandedFst<Arc>*, Arc> cache;
  669. for (Label i = 0; i < static_cast<Label>(fsts.size()); i++) {
  670. const ExpandedFst<Arc>* fst = fsts[i];
  671. if (fst == NULL) continue;
  672. { // optimization with cache: helpful if some members of "fsts" may
  673. // contain the same pointer value (e.g. in GetHTransducer).
  674. typename unordered_map<const ExpandedFst<Arc>*, Arc>::iterator iter =
  675. cache.find(fst);
  676. if (iter != cache.end()) {
  677. Arc arc = iter->second;
  678. arc.olabel = i;
  679. ans->AddArc(0, arc);
  680. continue;
  681. }
  682. }
  683. KALDI_ASSERT(fst->Properties(kAcceptor, true) ==
  684. kAcceptor); // expect acceptor.
  685. StateId fst_num_states = fst->NumStates();
  686. StateId fst_start_state = fst->Start();
  687. if (fst_start_state == kNoStateId) continue; // empty fst.
  688. bool share_start_state =
  689. fst->Properties(kInitialAcyclic, true) == kInitialAcyclic &&
  690. fst->NumArcs(fst_start_state) == 1 &&
  691. fst->Final(fst_start_state) == Weight::Zero();
  692. std::vector<StateId> state_map(fst_num_states); // fst state -> ans state
  693. for (StateId s = 0; s < fst_num_states; s++) {
  694. if (s == fst_start_state && share_start_state)
  695. state_map[s] = loop_state;
  696. else
  697. state_map[s] = ans->AddState();
  698. }
  699. if (!share_start_state) {
  700. Arc arc(0, i, Weight::One(), state_map[fst_start_state]);
  701. cache[fst] = arc;
  702. ans->AddArc(0, arc);
  703. }
  704. for (StateId s = 0; s < fst_num_states; s++) {
  705. // Add arcs out of state s.
  706. for (ArcIterator<ExpandedFst<Arc> > aiter(*fst, s); !aiter.Done();
  707. aiter.Next()) {
  708. const Arc& arc = aiter.Value();
  709. Label olabel = (s == fst_start_state && share_start_state ? i : 0);
  710. Arc newarc(arc.ilabel, olabel, arc.weight, state_map[arc.nextstate]);
  711. ans->AddArc(state_map[s], newarc);
  712. if (s == fst_start_state && share_start_state) cache[fst] = newarc;
  713. }
  714. if (fst->Final(s) != Weight::Zero()) {
  715. KALDI_ASSERT(!(s == fst_start_state && share_start_state));
  716. ans->AddArc(state_map[s], Arc(0, 0, fst->Final(s), loop_state));
  717. }
  718. }
  719. }
  720. return ans;
  721. }
  722. template <class Arc>
  723. void ClearSymbols(bool clear_input, bool clear_output, MutableFst<Arc>* fst) {
  724. for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
  725. siter.Next()) {
  726. typename Arc::StateId s = siter.Value();
  727. for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s); !aiter.Done();
  728. aiter.Next()) {
  729. Arc arc = aiter.Value();
  730. bool change = false;
  731. if (clear_input && arc.ilabel != 0) {
  732. arc.ilabel = 0;
  733. change = true;
  734. }
  735. if (clear_output && arc.olabel != 0) {
  736. arc.olabel = 0;
  737. change = true;
  738. }
  739. if (change) {
  740. aiter.SetValue(arc);
  741. }
  742. }
  743. }
  744. }
  745. template <class Arc>
  746. void ApplyProbabilityScale(float scale, MutableFst<Arc>* fst) {
  747. typedef typename Arc::Weight Weight;
  748. typedef typename Arc::StateId StateId;
  749. for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
  750. siter.Next()) {
  751. StateId s = siter.Value();
  752. for (MutableArcIterator<MutableFst<Arc> > aiter(fst, s); !aiter.Done();
  753. aiter.Next()) {
  754. Arc arc = aiter.Value();
  755. arc.weight = Weight(arc.weight.Value() * scale);
  756. aiter.SetValue(arc);
  757. }
  758. if (fst->Final(s) != Weight::Zero())
  759. fst->SetFinal(s, Weight(fst->Final(s).Value() * scale));
  760. }
  761. }
  762. // return arc-offset of self-loop with ilabel (or -1 if none exists).
  763. // if more than one such self-loop, pick first one.
  764. template <class Arc>
  765. ssize_t FindSelfLoopWithILabel(const Fst<Arc>& fst, typename Arc::StateId s) {
  766. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next())
  767. if (aiter.Value().nextstate == s && aiter.Value().ilabel != 0)
  768. return static_cast<ssize_t>(aiter.Position());
  769. return static_cast<ssize_t>(-1);
  770. }
  771. template <class Arc>
  772. bool EqualAlign(const Fst<Arc>& ifst, typename Arc::StateId length,
  773. int rand_seed, MutableFst<Arc>* ofst, int num_retries) {
  774. srand(rand_seed);
  775. KALDI_ASSERT(ofst->NumStates() == 0); // make sure ofst empty.
  776. // make sure all states can reach final-state (or this algorithm may enter
  777. // infinite loop.
  778. KALDI_ASSERT(ifst.Properties(kCoAccessible, true) == kCoAccessible);
  779. typedef typename Arc::StateId StateId;
  780. typedef typename Arc::Weight Weight;
  781. if (ifst.Start() == kNoStateId) {
  782. KALDI_WARN << "Empty input fst.";
  783. return false;
  784. }
  785. // First select path through ifst.
  786. std::vector<StateId> path;
  787. std::vector<size_t> arc_offsets; // arc taken out of each state.
  788. std::vector<int> nof_ilabels;
  789. StateId num_ilabels = 0;
  790. int retry_no = 0;
  791. // Under normal circumstances, this will be one-pass-only process
  792. // Multiple tries might be needed in special cases, typically when
  793. // the number of frames is close to number of transitions from
  794. // the start node to the final node. It usually happens for really
  795. // short utterances
  796. do {
  797. num_ilabels = 0;
  798. arc_offsets.clear();
  799. path.clear();
  800. path.push_back(ifst.Start());
  801. while (1) {
  802. // Select either an arc or final-prob.
  803. StateId s = path.back();
  804. size_t num_arcs = ifst.NumArcs(s);
  805. size_t num_arcs_tot = num_arcs;
  806. if (ifst.Final(s) != Weight::Zero()) num_arcs_tot++;
  807. // kaldi::RandInt is a bit like Rand(), but gets around situations
  808. // where RAND_MAX is very small.
  809. // Change this to Rand() % num_arcs_tot if compile issues arise
  810. size_t arc_offset =
  811. static_cast<size_t>(kaldi::RandInt(0, num_arcs_tot - 1));
  812. if (arc_offset < num_arcs) { // an actual arc.
  813. ArcIterator<Fst<Arc> > aiter(ifst, s);
  814. aiter.Seek(arc_offset);
  815. const Arc& arc = aiter.Value();
  816. if (arc.nextstate == s) {
  817. continue; // don't take this self-loop arc
  818. } else {
  819. arc_offsets.push_back(arc_offset);
  820. path.push_back(arc.nextstate);
  821. if (arc.ilabel != 0) num_ilabels++;
  822. }
  823. } else {
  824. break; // Chose final-prob.
  825. }
  826. }
  827. nof_ilabels.push_back(num_ilabels);
  828. } while ((++retry_no < num_retries) && (num_ilabels > length));
  829. if (num_ilabels > length) {
  830. std::stringstream ilabel_vec;
  831. std::copy(nof_ilabels.begin(), nof_ilabels.end(),
  832. std::ostream_iterator<int>(ilabel_vec, ","));
  833. std::string s = ilabel_vec.str();
  834. s.erase(s.end() - 1);
  835. KALDI_WARN << "EqualAlign: the randomly constructed paths lengths: " << s;
  836. KALDI_WARN << "EqualAlign: utterance has too few frames " << length
  837. << " to align.";
  838. return false; // can't make it shorter by adding self-loops!.
  839. }
  840. StateId num_self_loops = 0;
  841. std::vector<ssize_t> self_loop_offsets(path.size());
  842. for (size_t i = 0; i < path.size(); i++)
  843. if ((self_loop_offsets[i] = FindSelfLoopWithILabel(ifst, path[i])) !=
  844. static_cast<ssize_t>(-1))
  845. num_self_loops++;
  846. if (num_self_loops == 0 && num_ilabels < length) {
  847. KALDI_WARN << "No self-loops on chosen path; cannot match length.";
  848. return false; // no self-loops to make it longer.
  849. }
  850. StateId num_extra = length - num_ilabels; // Number of self-loops we need.
  851. StateId min_num_loops = 0;
  852. if (num_extra != 0)
  853. min_num_loops = num_extra / num_self_loops; // prevent div by zero.
  854. StateId num_with_one_more_loop = num_extra - (min_num_loops * num_self_loops);
  855. KALDI_ASSERT(num_with_one_more_loop < num_self_loops || num_self_loops == 0);
  856. ofst->AddState();
  857. ofst->SetStart(0);
  858. StateId cur_state = 0;
  859. StateId counter = 0; // tell us when we should stop adding one more loop.
  860. for (size_t i = 0; i < path.size(); i++) {
  861. // First, add any self-loops that are necessary.
  862. StateId num_loops = 0;
  863. if (self_loop_offsets[i] != static_cast<ssize_t>(-1)) {
  864. num_loops = min_num_loops + (counter < num_with_one_more_loop ? 1 : 0);
  865. counter++;
  866. }
  867. for (StateId j = 0; j < num_loops; j++) {
  868. ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
  869. aiter.Seek(self_loop_offsets[i]);
  870. Arc arc = aiter.Value();
  871. KALDI_ASSERT(arc.nextstate == path[i] &&
  872. arc.ilabel != 0); // make sure self-loop with ilabel.
  873. StateId next_state = ofst->AddState();
  874. ofst->AddArc(cur_state,
  875. Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
  876. cur_state = next_state;
  877. }
  878. if (i + 1 < path.size()) { // add forward transition.
  879. ArcIterator<Fst<Arc> > aiter(ifst, path[i]);
  880. aiter.Seek(arc_offsets[i]);
  881. Arc arc = aiter.Value();
  882. KALDI_ASSERT(arc.nextstate == path[i + 1]);
  883. StateId next_state = ofst->AddState();
  884. ofst->AddArc(cur_state,
  885. Arc(arc.ilabel, arc.olabel, arc.weight, next_state));
  886. cur_state = next_state;
  887. } else { // add final-prob.
  888. Weight weight = ifst.Final(path[i]);
  889. KALDI_ASSERT(weight != Weight::Zero());
  890. ofst->SetFinal(cur_state, weight);
  891. }
  892. }
  893. return true;
  894. }
  895. // This function identifies two types of useless arcs:
  896. // those where arc A and arc B both go from state X to
  897. // state Y with the same input symbol (remove the one
  898. // with smaller probability, or an arbitrary one if they
  899. // are the same); and those where A is an arc from state X
  900. // to state X, with epsilon input symbol [remove A].
  901. // Only works for tropical (not log) semiring as it uses
  902. // NaturalLess.
  903. template <class Arc>
  904. void RemoveUselessArcs(MutableFst<Arc>* fst) {
  905. typedef typename Arc::Label Label;
  906. typedef typename Arc::StateId StateId;
  907. typedef typename Arc::Weight Weight;
  908. NaturalLess<Weight> nl;
  909. StateId non_coacc_state = kNoStateId;
  910. size_t num_arcs_removed = 0, tot_arcs = 0;
  911. for (StateIterator<MutableFst<Arc> > siter(*fst); !siter.Done();
  912. siter.Next()) {
  913. std::vector<size_t> arcs_to_delete;
  914. std::vector<Arc> arcs;
  915. // pair2arclist lets us look up the arcs
  916. std::map<std::pair<Label, StateId>, std::vector<size_t> > pair2arclist;
  917. StateId state = siter.Value();
  918. for (ArcIterator<MutableFst<Arc> > aiter(*fst, state); !aiter.Done();
  919. aiter.Next()) {
  920. size_t pos = arcs.size();
  921. const Arc& arc = aiter.Value();
  922. arcs.push_back(arc);
  923. pair2arclist[std::make_pair(arc.ilabel, arc.nextstate)].push_back(pos);
  924. }
  925. typename std::map<std::pair<Label, StateId>, std::vector<size_t> >::iterator
  926. iter = pair2arclist.begin(),
  927. end = pair2arclist.end();
  928. for (; iter != end; ++iter) {
  929. const std::vector<size_t>& poslist = iter->second;
  930. if (poslist.size() > 1) { // >1 arc with same ilabel, dest-state
  931. size_t best_pos = poslist[0];
  932. Weight best_weight = arcs[best_pos].weight;
  933. for (size_t j = 1; j < poslist.size(); j++) {
  934. size_t pos = poslist[j];
  935. Weight this_weight = arcs[pos].weight;
  936. if (nl(this_weight,
  937. best_weight)) { // NaturalLess seems to be somehow
  938. // "backwards".
  939. best_weight = this_weight; // found a better one.
  940. best_pos = pos;
  941. }
  942. }
  943. for (size_t j = 0; j < poslist.size(); j++)
  944. if (poslist[j] != best_pos) arcs_to_delete.push_back(poslist[j]);
  945. } else {
  946. KALDI_ASSERT(poslist.size() == 1);
  947. size_t pos = poslist[0];
  948. Arc& arc = arcs[pos];
  949. if (arc.ilabel == 0 && arc.nextstate == state)
  950. arcs_to_delete.push_back(pos);
  951. }
  952. }
  953. tot_arcs += arcs.size();
  954. if (arcs_to_delete.size() != 0) {
  955. num_arcs_removed += arcs_to_delete.size();
  956. if (non_coacc_state == kNoStateId) non_coacc_state = fst->AddState();
  957. MutableArcIterator<MutableFst<Arc> > maiter(fst, state);
  958. for (size_t j = 0; j < arcs_to_delete.size(); j++) {
  959. size_t pos = arcs_to_delete[j];
  960. maiter.Seek(pos);
  961. arcs[pos].nextstate = non_coacc_state;
  962. maiter.SetValue(arcs[pos]);
  963. }
  964. }
  965. }
  966. if (non_coacc_state != kNoStateId) Connect(fst);
  967. KALDI_VLOG(1) << "removed " << num_arcs_removed << " of " << tot_arcs
  968. << "arcs.";
  969. }
  970. template <class Arc>
  971. void PhiCompose(const Fst<Arc>& fst1, const Fst<Arc>& fst2,
  972. typename Arc::Label phi_label, MutableFst<Arc>* ofst) {
  973. KALDI_ASSERT(phi_label !=
  974. kNoLabel); // just use regular compose in this case.
  975. typedef Fst<Arc> F;
  976. typedef PhiMatcher<SortedMatcher<F> > PM;
  977. CacheOptions base_opts;
  978. base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
  979. // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
  980. // The matcher for fst1 doesn't matter; we'll use fst2's matcher.
  981. ComposeFstImplOptions<SortedMatcher<F>, PM> impl_opts(base_opts);
  982. // the false below is something called phi_loop which is something I don't
  983. // fully understand, but I don't think we want it.
  984. // These pointers are taken ownership of, by ComposeFst.
  985. PM* phi_matcher = new PM(fst2, MATCH_INPUT, phi_label, false);
  986. SortedMatcher<F>* sorted_matcher =
  987. new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
  988. // not to use this matcher, as this would mean we would
  989. // not follow phi transitions.
  990. impl_opts.matcher1 = sorted_matcher;
  991. impl_opts.matcher2 = phi_matcher;
  992. *ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
  993. Connect(ofst);
  994. }
  995. template <class Arc>
  996. void PropagateFinalInternal(typename Arc::Label phi_label,
  997. typename Arc::StateId s, MutableFst<Arc>* fst) {
  998. typedef typename Arc::Weight Weight;
  999. if (fst->Final(s) == Weight::Zero()) {
  1000. // search for phi transition. We assume there
  1001. // is just one-- phi nondeterminism is not allowed
  1002. // anyway.
  1003. int num_phis = 0;
  1004. for (ArcIterator<Fst<Arc> > aiter(*fst, s); !aiter.Done(); aiter.Next()) {
  1005. const Arc& arc = aiter.Value();
  1006. if (arc.ilabel == phi_label) {
  1007. num_phis++;
  1008. if (arc.nextstate == s) continue; // don't expect
  1009. // phi loops but ignore them anyway.
  1010. // If this recurses infinitely, it means there
  1011. // are loops of phi transitions, which there should
  1012. // not be in a normal backoff LM. We could make this
  1013. // routine work for this case, but currently there is
  1014. // no need.
  1015. PropagateFinalInternal(phi_label, arc.nextstate, fst);
  1016. if (fst->Final(arc.nextstate) != Weight::Zero())
  1017. fst->SetFinal(s, Times(fst->Final(arc.nextstate), arc.weight));
  1018. }
  1019. KALDI_ASSERT(num_phis <= 1 && "Phi nondeterminism found");
  1020. }
  1021. }
  1022. }
  1023. template <class Arc>
  1024. void PropagateFinal(typename Arc::Label phi_label, MutableFst<Arc>* fst) {
  1025. typedef typename Arc::StateId StateId;
  1026. if (fst->Properties(kIEpsilons, true)) // just warn.
  1027. KALDI_WARN << "PropagateFinal: this may not work as desired "
  1028. "since your FST has input epsilons.";
  1029. StateId num_states = fst->NumStates();
  1030. for (StateId s = 0; s < num_states; s++)
  1031. PropagateFinalInternal(phi_label, s, fst);
  1032. }
  1033. template <class Arc>
  1034. void RhoCompose(const Fst<Arc>& fst1, const Fst<Arc>& fst2,
  1035. typename Arc::Label rho_label, MutableFst<Arc>* ofst) {
  1036. KALDI_ASSERT(rho_label !=
  1037. kNoLabel); // just use regular compose in this case.
  1038. typedef Fst<Arc> F;
  1039. typedef RhoMatcher<SortedMatcher<F> > RM;
  1040. CacheOptions base_opts;
  1041. base_opts.gc_limit = 0; // Cache only the last state for fastest copy.
  1042. // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
  1043. // The matcher for fst1 doesn't matter; we'll use fst2's matcher.
  1044. ComposeFstImplOptions<SortedMatcher<F>, RM> impl_opts(base_opts);
  1045. // the false below is something called rho_loop which is something I don't
  1046. // fully understand, but I don't think we want it.
  1047. // These pointers are taken ownership of, by ComposeFst.
  1048. RM* rho_matcher = new RM(fst2, MATCH_INPUT, rho_label);
  1049. SortedMatcher<F>* sorted_matcher =
  1050. new SortedMatcher<F>(fst1, MATCH_NONE); // tell it
  1051. // not to use this matcher, as this would mean we would
  1052. // not follow rho transitions.
  1053. impl_opts.matcher1 = sorted_matcher;
  1054. impl_opts.matcher2 = rho_matcher;
  1055. *ofst = ComposeFst<Arc>(fst1, fst2, impl_opts);
  1056. Connect(ofst);
  1057. }
  1058. // Declare an override of the template below.
  1059. template <>
  1060. inline bool IsStochasticFst(const Fst<LogArc>& fst, float delta,
  1061. LogArc::Weight* min_sum, LogArc::Weight* max_sum);
  1062. // Will override this for LogArc where NaturalLess will not work.
  1063. template <class Arc>
  1064. inline bool IsStochasticFst(const Fst<Arc>& fst, float delta,
  1065. typename Arc::Weight* min_sum,
  1066. typename Arc::Weight* max_sum) {
  1067. typedef typename Arc::StateId StateId;
  1068. typedef typename Arc::Weight Weight;
  1069. NaturalLess<Weight> nl;
  1070. bool first_time = true;
  1071. bool ans = true;
  1072. if (min_sum) *min_sum = Arc::Weight::One();
  1073. if (max_sum) *max_sum = Arc::Weight::One();
  1074. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  1075. StateId s = siter.Value();
  1076. Weight sum = fst.Final(s);
  1077. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  1078. const Arc& arc = aiter.Value();
  1079. sum = Plus(sum, arc.weight);
  1080. }
  1081. if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
  1082. if (first_time) {
  1083. first_time = false;
  1084. if (max_sum) *max_sum = sum;
  1085. if (min_sum) *min_sum = sum;
  1086. } else {
  1087. if (max_sum && nl(*max_sum, sum)) *max_sum = sum;
  1088. if (min_sum && nl(sum, *min_sum)) *min_sum = sum;
  1089. }
  1090. }
  1091. if (first_time) { // just avoid NaNs if FST was empty.
  1092. if (max_sum) *max_sum = Weight::One();
  1093. if (min_sum) *min_sum = Weight::One();
  1094. }
  1095. return ans;
  1096. }
  1097. // Overriding template for LogArc as NaturalLess does not work there.
  1098. template <>
  1099. inline bool IsStochasticFst(const Fst<LogArc>& fst, float delta,
  1100. LogArc::Weight* min_sum, LogArc::Weight* max_sum) {
  1101. typedef LogArc Arc;
  1102. typedef Arc::StateId StateId;
  1103. typedef Arc::Weight Weight;
  1104. bool first_time = true;
  1105. bool ans = true;
  1106. if (min_sum) *min_sum = LogArc::Weight::One();
  1107. if (max_sum) *max_sum = LogArc::Weight::One();
  1108. for (StateIterator<Fst<Arc> > siter(fst); !siter.Done(); siter.Next()) {
  1109. StateId s = siter.Value();
  1110. Weight sum = fst.Final(s);
  1111. for (ArcIterator<Fst<Arc> > aiter(fst, s); !aiter.Done(); aiter.Next()) {
  1112. const Arc& arc = aiter.Value();
  1113. sum = Plus(sum, arc.weight);
  1114. }
  1115. if (!ApproxEqual(Weight::One(), sum, delta)) ans = false;
  1116. if (first_time) {
  1117. first_time = false;
  1118. if (max_sum) *max_sum = sum;
  1119. if (min_sum) *min_sum = sum;
  1120. } else {
  1121. // note that max and min are reversed from their normal
  1122. // meanings here (max and min w.r.t. the underlying probabilities).
  1123. if (max_sum && sum.Value() < max_sum->Value()) *max_sum = sum;
  1124. if (min_sum && sum.Value() > min_sum->Value()) *min_sum = sum;
  1125. }
  1126. }
  1127. if (first_time) { // just avoid NaNs if FST was empty.
  1128. if (max_sum) *max_sum = Weight::One();
  1129. if (min_sum) *min_sum = Weight::One();
  1130. }
  1131. return ans;
  1132. }
  1133. // Tests whether a tropical FST is stochastic in the log
  1134. // semiring. (casts it and does the check.)
  1135. // This function deals with the generic fst.
  1136. // This version currently supports ConstFst<StdArc> or VectorFst<StdArc>.
  1137. // Otherwise, it will be died with an error.
  1138. inline bool IsStochasticFstInLog(const Fst<StdArc>& fst, float delta,
  1139. StdArc::Weight* min_sum,
  1140. StdArc::Weight* max_sum) {
  1141. bool ans = false;
  1142. LogArc::Weight log_min = LogArc::Weight::One(),
  1143. log_max = LogArc::Weight::Zero();
  1144. if (fst.Type() == "const") {
  1145. ConstFst<LogArc> logfst;
  1146. Cast(dynamic_cast<const ConstFst<StdArc>&>(fst), &logfst);
  1147. ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
  1148. } else if (fst.Type() == "vector") {
  1149. VectorFst<LogArc> logfst;
  1150. Cast(dynamic_cast<const VectorFst<StdArc>&>(fst), &logfst);
  1151. ans = IsStochasticFst(logfst, delta, &log_min, &log_max);
  1152. } else {
  1153. KALDI_ERR << "This version currently supports ConstFst<StdArc> "
  1154. << "or VectorFst<StdArc>";
  1155. }
  1156. if (min_sum) *min_sum = StdArc::Weight(log_min.Value());
  1157. if (max_sum) *max_sum = StdArc::Weight(log_max.Value());
  1158. return ans;
  1159. }
  1160. } // namespace fst.
  1161. #endif // KALDI_FSTEXT_FSTEXT_UTILS_INL_H_