You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

318 lines
12 KiB

  1. // fstext/remove-eps-local-inl.h
  2. // Copyright 2009-2011 Microsoft Corporation
  3. // 2014 Johns Hopkins University (author: Daniel Povey
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
  19. #define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
  20. #include <vector>
  21. namespace fst {
  22. template <class Weight>
  23. struct ReweightPlusDefault {
  24. inline Weight operator()(const Weight& a, const Weight& b) {
  25. return Plus(a, b);
  26. }
  27. };
  28. struct ReweightPlusLogArc {
  29. inline TropicalWeight operator()(const TropicalWeight& a,
  30. const TropicalWeight& b) {
  31. LogWeight a_log(a.Value()), b_log(b.Value());
  32. return TropicalWeight(Plus(a_log, b_log).Value());
  33. }
  34. };
  35. template <class Arc,
  36. class ReweightPlus = ReweightPlusDefault<typename Arc::Weight> >
  37. class RemoveEpsLocalClass {
  38. typedef typename Arc::StateId StateId;
  39. typedef typename Arc::Label Label;
  40. typedef typename Arc::Weight Weight;
  41. public:
  42. explicit RemoveEpsLocalClass(MutableFst<Arc>* fst) : fst_(fst) {
  43. if (fst_->Start() == kNoStateId) return; // empty.
  44. non_coacc_state_ = fst_->AddState();
  45. InitNumArcs();
  46. StateId num_states = fst_->NumStates();
  47. for (StateId s = 0; s < num_states; s++)
  48. for (size_t pos = 0; pos < fst_->NumArcs(s); pos++) RemoveEps(s, pos);
  49. assert(CheckNumArcs());
  50. Connect(fst); // remove inaccessible states.
  51. }
  52. private:
  53. MutableFst<Arc>* fst_;
  54. StateId non_coacc_state_; // use this to delete arcs: make it nextstate
  55. std::vector<StateId> num_arcs_in_; // The number of arcs into the state, plus
  56. // one if it's the start state.
  57. std::vector<StateId> num_arcs_out_; // The number of arcs out of the state,
  58. // plus one if it's a final state.
  59. ReweightPlus reweight_plus_;
  60. bool CanCombineArcs(const Arc& a, const Arc& b, Arc* c) {
  61. if (a.ilabel != 0 && b.ilabel != 0) return false;
  62. if (a.olabel != 0 && b.olabel != 0) return false;
  63. c->weight = Times(a.weight, b.weight);
  64. c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
  65. c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
  66. c->nextstate = b.nextstate;
  67. return true;
  68. }
  69. static bool CanCombineFinal(const Arc& a, Weight final_prob,
  70. Weight* final_prob_out) {
  71. if (a.ilabel != 0 || a.olabel != 0) {
  72. return false;
  73. } else {
  74. *final_prob_out = Times(a.weight, final_prob);
  75. return true;
  76. }
  77. }
  78. void InitNumArcs() { // init num transitions in/out of each state.
  79. StateId num_states = fst_->NumStates();
  80. num_arcs_in_.resize(num_states);
  81. num_arcs_out_.resize(num_states);
  82. num_arcs_in_[fst_->Start()]++; // count start as trans in.
  83. for (StateId s = 0; s < num_states; s++) {
  84. if (fst_->Final(s) != Weight::Zero())
  85. num_arcs_out_[s]++; // count final as transition.
  86. for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
  87. aiter.Next()) {
  88. num_arcs_in_[aiter.Value().nextstate]++;
  89. num_arcs_out_[s]++;
  90. }
  91. }
  92. }
  93. bool CheckNumArcs() { // check num arcs in/out of each state, at end. Debug.
  94. num_arcs_in_[fst_->Start()]--; // count start as trans in.
  95. StateId num_states = fst_->NumStates();
  96. for (StateId s = 0; s < num_states; s++) {
  97. if (s == non_coacc_state_) continue;
  98. if (fst_->Final(s) != Weight::Zero())
  99. num_arcs_out_[s]--; // count final as transition.
  100. for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
  101. aiter.Next()) {
  102. if (aiter.Value().nextstate == non_coacc_state_) continue;
  103. num_arcs_in_[aiter.Value().nextstate]--;
  104. num_arcs_out_[s]--;
  105. }
  106. }
  107. for (StateId s = 0; s < num_states; s++) {
  108. assert(num_arcs_in_[s] == 0);
  109. assert(num_arcs_out_[s] == 0);
  110. }
  111. return true; // always does this. so we can assert it w/o warnings.
  112. }
  113. inline void GetArc(StateId s, size_t pos, Arc* arc) const {
  114. ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
  115. aiter.Seek(pos);
  116. *arc = aiter.Value();
  117. }
  118. inline void SetArc(StateId s, size_t pos, const Arc& arc) {
  119. MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
  120. aiter.Seek(pos);
  121. aiter.SetValue(arc);
  122. }
  123. void Reweight(StateId s, size_t pos, Weight reweight) {
  124. // Reweight is called from RemoveEpsPattern1; it is a step we
  125. // do to preserve stochasticity. This function multiplies the
  126. // arc at (s, pos) by reweight and divides all the arcs [+final-prob]
  127. // out of the next state by the same. This is only valid if
  128. // the next state has only one arc in and is not the start state.
  129. assert(reweight != Weight::Zero());
  130. MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
  131. aiter.Seek(pos);
  132. Arc arc = aiter.Value();
  133. assert(num_arcs_in_[arc.nextstate] == 1);
  134. arc.weight = Times(arc.weight, reweight);
  135. aiter.SetValue(arc);
  136. for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
  137. !aiter_next.Done(); aiter_next.Next()) {
  138. Arc nextarc = aiter_next.Value();
  139. if (nextarc.nextstate != non_coacc_state_) {
  140. nextarc.weight = Divide(nextarc.weight, reweight, DIVIDE_LEFT);
  141. aiter_next.SetValue(nextarc);
  142. }
  143. }
  144. Weight final = fst_->Final(arc.nextstate);
  145. if (final != Weight::Zero()) {
  146. fst_->SetFinal(arc.nextstate, Divide(final, reweight, DIVIDE_LEFT));
  147. }
  148. }
  149. // RemoveEpsPattern1 applies where this arc, which is not a
  150. // self-loop, enters a state which has only one input transition
  151. // [and is not the start state], and has multiple output
  152. // transitions [counting being the final-state as a final-transition].
  153. void RemoveEpsPattern1(StateId s, size_t pos, Arc arc) {
  154. const StateId nextstate = arc.nextstate;
  155. Weight total_removed = Weight::Zero(),
  156. total_kept = Weight::Zero(); // totals out of nextstate.
  157. std::vector<Arc> arcs_to_add; // to add to state s.
  158. for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
  159. !aiter_next.Done(); aiter_next.Next()) {
  160. Arc nextarc = aiter_next.Value();
  161. if (nextarc.nextstate == non_coacc_state_) continue; // deleted.
  162. Arc combined;
  163. if (CanCombineArcs(arc, nextarc, &combined)) {
  164. total_removed = reweight_plus_(total_removed, nextarc.weight);
  165. num_arcs_out_[nextstate]--;
  166. num_arcs_in_[nextarc.nextstate]--;
  167. nextarc.nextstate = non_coacc_state_;
  168. aiter_next.SetValue(nextarc);
  169. arcs_to_add.push_back(combined);
  170. } else {
  171. total_kept = reweight_plus_(total_kept, nextarc.weight);
  172. }
  173. }
  174. { // now final-state.
  175. Weight next_final = fst_->Final(nextstate);
  176. if (next_final != Weight::Zero()) {
  177. Weight new_final;
  178. if (CanCombineFinal(arc, next_final, &new_final)) {
  179. total_removed = reweight_plus_(total_removed, next_final);
  180. if (fst_->Final(s) == Weight::Zero())
  181. num_arcs_out_[s]++; // final is counted as arc.
  182. fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
  183. num_arcs_out_[nextstate]--;
  184. fst_->SetFinal(nextstate, Weight::Zero());
  185. } else {
  186. total_kept = reweight_plus_(total_kept, next_final);
  187. }
  188. }
  189. }
  190. if (total_removed != Weight::Zero()) { // did something...
  191. if (total_kept == Weight::Zero()) { // removed everything: remove arc.
  192. num_arcs_out_[s]--;
  193. num_arcs_in_[arc.nextstate]--;
  194. arc.nextstate = non_coacc_state_;
  195. SetArc(s, pos, arc);
  196. } else {
  197. // Have to reweight.
  198. Weight total = reweight_plus_(total_removed, total_kept);
  199. Weight reweight = Divide(total_kept, total, DIVIDE_LEFT); // <=1
  200. Reweight(s, pos, reweight);
  201. }
  202. }
  203. // Now add the arcs we were going to add.
  204. for (size_t i = 0; i < arcs_to_add.size(); i++) {
  205. num_arcs_out_[s]++;
  206. num_arcs_in_[arcs_to_add[i].nextstate]++;
  207. fst_->AddArc(s, arcs_to_add[i]);
  208. }
  209. }
  210. void RemoveEpsPattern2(StateId s, size_t pos, Arc arc) {
  211. // Pattern 2 is where "nextstate" has only one arc out, counting
  212. // being-the-final-state as an arc, but possibly multiple arcs in.
  213. // Also, nextstate != s.
  214. const StateId nextstate = arc.nextstate;
  215. bool can_delete_next = (num_arcs_in_[nextstate] == 1); // if
  216. // we combine, can delete the corresponding out-arc/final-prob
  217. // of nextstate.
  218. bool delete_arc = false; // set to true if this arc to be deleted.
  219. Weight next_final = fst_->Final(arc.nextstate);
  220. if (next_final !=
  221. Weight::Zero()) { // nextstate has no actual arcs out, only final-prob.
  222. Weight new_final;
  223. if (CanCombineFinal(arc, next_final, &new_final)) {
  224. if (fst_->Final(s) == Weight::Zero())
  225. num_arcs_out_[s]++; // final is counted as arc.
  226. fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
  227. delete_arc = true; // will delete "arc".
  228. if (can_delete_next) {
  229. num_arcs_out_[nextstate]--;
  230. fst_->SetFinal(nextstate, Weight::Zero());
  231. }
  232. }
  233. } else { // has an arc but no final prob.
  234. MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
  235. assert(!aiter_next.Done());
  236. while (aiter_next.Value().nextstate == non_coacc_state_) {
  237. aiter_next.Next();
  238. assert(!aiter_next.Done());
  239. }
  240. // now aiter_next points to a real arc out of nextstate.
  241. Arc nextarc = aiter_next.Value();
  242. Arc combined;
  243. if (CanCombineArcs(arc, nextarc, &combined)) {
  244. delete_arc = true;
  245. if (can_delete_next) { // do it before we invalidate iterators
  246. num_arcs_out_[nextstate]--;
  247. num_arcs_in_[nextarc.nextstate]--;
  248. nextarc.nextstate = non_coacc_state_;
  249. aiter_next.SetValue(nextarc);
  250. }
  251. num_arcs_out_[s]++;
  252. num_arcs_in_[combined.nextstate]++;
  253. fst_->AddArc(s, combined);
  254. }
  255. }
  256. if (delete_arc) {
  257. num_arcs_out_[s]--;
  258. num_arcs_in_[nextstate]--;
  259. arc.nextstate = non_coacc_state_;
  260. SetArc(s, pos, arc);
  261. }
  262. }
  263. void RemoveEps(StateId s, size_t pos) {
  264. // Tries to do local epsilon-removal for arc sequences starting with this
  265. // arc
  266. Arc arc;
  267. GetArc(s, pos, &arc);
  268. StateId nextstate = arc.nextstate;
  269. if (nextstate == non_coacc_state_) return; // deleted arc.
  270. if (nextstate == s) return; // don't handle self-loops: too complex.
  271. if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
  272. RemoveEpsPattern1(s, pos, arc);
  273. } else if (num_arcs_out_[nextstate] == 1) {
  274. RemoveEpsPattern2(s, pos, arc);
  275. }
  276. }
  277. };
  278. template <class Arc>
  279. void RemoveEpsLocal(MutableFst<Arc>* fst) {
  280. RemoveEpsLocalClass<Arc> c(fst); // work gets done in initializer.
  281. }
  282. void RemoveEpsLocalSpecial(MutableFst<StdArc>* fst) {
  283. // work gets done in initializer.
  284. RemoveEpsLocalClass<StdArc, ReweightPlusLogArc> c(fst);
  285. }
  286. } // end namespace fst.
  287. #endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_