You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1204 lines
45 KiB

  1. // fstext/determinize-star-inl.h
  2. // Copyright 2009-2011 Microsoft Corporation; Jan Silovsky
  3. // 2015 Hainan Xu
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
  19. #define KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
  20. // Do not include this file directly. It is included by determinize-star.h
  21. #include <algorithm>
  22. #include <climits>
  23. #include <deque>
  24. #include <limits>
  25. #include <string>
  26. #include <unordered_map>
  27. #include <utility>
  28. #include <vector>
  29. using std::unordered_map;
  30. #include "base/kaldi-error.h"
  31. namespace fst {
  32. // This class maps back and forth from/to integer id's to sequences of strings.
  33. // used in determinization algorithm.
  34. template <class Label, class StringId>
  35. class StringRepository {
  36. // Label and StringId are both integer types, possibly the same.
  37. // This is a utility that maps back and forth between a vector<Label> and
  38. // StringId representation of sequences of Labels. It is to save memory, and
  39. // to save compute. We treat sequences of length zero and one separately, for
  40. // efficiency.
  41. public:
  42. class VectorKey { // Hash function object.
  43. public:
  44. size_t operator()(const std::vector<Label>* vec) const {
  45. assert(vec != NULL);
  46. size_t hash = 0, factor = 1;
  47. for (typename std::vector<Label>::const_iterator it = vec->begin();
  48. it != vec->end(); it++) {
  49. hash += factor * (*it);
  50. factor *= 103333; // just an arbitrary prime number.
  51. }
  52. return hash;
  53. }
  54. };
  55. class VectorEqual { // Equality-operator function object.
  56. public:
  57. size_t operator()(const std::vector<Label>* vec1,
  58. const std::vector<Label>* vec2) const {
  59. return (*vec1 == *vec2);
  60. }
  61. };
  62. typedef unordered_map<const std::vector<Label>*, StringId, VectorKey,
  63. VectorEqual>
  64. MapType;
  65. StringId IdOfEmpty() { return no_symbol; }
  66. StringId IdOfLabel(Label l) {
  67. if (l >= 0 && l <= (Label)single_symbol_range) {
  68. return l + single_symbol_start;
  69. } else {
  70. // l is out of the allowed range so we have to treat it as a sequence of
  71. // length one. Should be v. rare.
  72. std::vector<Label> v;
  73. v.push_back(l);
  74. return IdOfSeqInternal(v);
  75. }
  76. }
  77. StringId IdOfSeq(
  78. const std::vector<Label>& v) { // also works for sizes 0 and 1.
  79. size_t sz = v.size();
  80. if (sz == 0)
  81. return no_symbol;
  82. else if (v.size() == 1)
  83. return IdOfLabel(v[0]);
  84. else
  85. return IdOfSeqInternal(v);
  86. }
  87. inline bool IsEmptyString(StringId id) { return id == no_symbol; }
  88. void SeqOfId(StringId id, std::vector<Label>* v) {
  89. if (id == no_symbol) {
  90. v->clear();
  91. } else if (id >= single_symbol_start) {
  92. v->resize(1);
  93. (*v)[0] = id - single_symbol_start;
  94. } else {
  95. assert(static_cast<size_t>(id) < vec_.size());
  96. *v = *(vec_[id]);
  97. }
  98. }
  99. StringId RemovePrefix(StringId id, size_t prefix_len) {
  100. if (prefix_len == 0) {
  101. return id;
  102. } else {
  103. std::vector<Label> v;
  104. SeqOfId(id, &v);
  105. size_t sz = v.size();
  106. assert(sz >= prefix_len);
  107. std::vector<Label> v_noprefix(sz - prefix_len);
  108. for (size_t i = 0; i < sz - prefix_len; i++)
  109. v_noprefix[i] = v[i + prefix_len];
  110. return IdOfSeq(v_noprefix);
  111. }
  112. }
  113. StringRepository() {
  114. // The following are really just constants but don't want to complicate
  115. // compilation so make them class variables. Due to the brokenness of
  116. // <limits>, they can't be accessed as constants.
  117. string_end = (std::numeric_limits<StringId>::max() / 2) -
  118. 1; // all hash values must be <= this.
  119. no_symbol = (std::numeric_limits<StringId>::max() /
  120. 2); // reserved for empty sequence.
  121. single_symbol_start = (std::numeric_limits<StringId>::max() / 2) + 1;
  122. single_symbol_range =
  123. std::numeric_limits<StringId>::max() - single_symbol_start;
  124. }
  125. void Destroy() {
  126. for (typename std::vector<std::vector<Label>*>::iterator iter =
  127. vec_.begin();
  128. iter != vec_.end(); ++iter)
  129. delete *iter;
  130. std::vector<std::vector<Label>*> tmp_vec;
  131. tmp_vec.swap(vec_);
  132. MapType tmp_map;
  133. tmp_map.swap(map_);
  134. }
  135. ~StringRepository() { Destroy(); }
  136. private:
  137. KALDI_DISALLOW_COPY_AND_ASSIGN(StringRepository);
  138. StringId IdOfSeqInternal(const std::vector<Label>& v) {
  139. typename MapType::iterator iter = map_.find(&v);
  140. if (iter != map_.end()) {
  141. return iter->second;
  142. } else { // must add it to map.
  143. StringId this_id = (StringId)vec_.size();
  144. std::vector<Label>* v_new = new std::vector<Label>(v);
  145. vec_.push_back(v_new);
  146. map_[v_new] = this_id;
  147. assert(this_id < string_end); // or we used up the labels.
  148. return this_id;
  149. }
  150. }
  151. std::vector<std::vector<Label>*> vec_;
  152. MapType map_;
  153. static const StringId string_start =
  154. (StringId)0; // This must not change. It's assumed.
  155. StringId string_end; // = (numeric_limits<StringId>::max() / 2) - 1; // all
  156. // hash values must be <= this.
  157. StringId no_symbol; // = (numeric_limits<StringId>::max() / 2); // reserved
  158. // for empty sequence.
  159. StringId
  160. single_symbol_start; // = (numeric_limits<StringId>::max() / 2) + 1;
  161. StringId single_symbol_range; // = numeric_limits<StringId>::max() -
  162. // single_symbol_start;
  163. };
  164. template <class F>
  165. class DeterminizerStar {
  166. typedef typename F::Arc Arc;
  167. public:
  168. // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1
  169. // correspondence between our states and the states in ofst. If destroy ==
  170. // true, release memory as we go (but we cannot output again).
  171. void Output(MutableFst<GallicArc<Arc> >* ofst, bool destroy = true);
  172. // Output to standard FST. We will create extra states to handle sequences of
  173. // symbols on the output. If destroy == true, release memory as we go (but we
  174. // cannot output again).
  175. void Output(MutableFst<Arc>* ofst, bool destroy = true);
  176. // Initializer. After initializing the object you will typically call
  177. // Determinize() and then one of the Output functions.
  178. DeterminizerStar(const Fst<Arc>& ifst, float delta = kDelta,
  179. int max_states = -1, bool allow_partial = false)
  180. : ifst_(ifst.Copy()),
  181. delta_(delta),
  182. max_states_(max_states),
  183. determinized_(false),
  184. allow_partial_(allow_partial),
  185. is_partial_(false),
  186. equal_(delta),
  187. hash_(ifst.Properties(kExpanded, false)
  188. ? down_cast<const ExpandedFst<Arc>*, const Fst<Arc> >(&ifst)
  189. ->NumStates() /
  190. 2 +
  191. 3
  192. : 20,
  193. hasher_, equal_),
  194. epsilon_closure_(ifst_, max_states, &repository_, delta) {}
  195. void Determinize(bool* debug_ptr) {
  196. assert(!determinized_);
  197. // This determinizes the input fst but leaves it in the "special format"
  198. // in "output_arcs_".
  199. InputStateId start_id = ifst_->Start();
  200. if (start_id == kNoStateId) {
  201. determinized_ = true;
  202. return; // Nothing to do.
  203. } else { // Insert start state into hash and queue.
  204. Element elem;
  205. elem.state = start_id;
  206. elem.weight = Weight::One();
  207. elem.string = repository_.IdOfEmpty(); // Id of empty sequence.
  208. std::vector<Element> vec;
  209. vec.push_back(elem);
  210. OutputStateId cur_id = SubsetToStateId(vec);
  211. assert(cur_id == 0 && "Do not call Determinize twice.");
  212. }
  213. while (!Q_.empty()) {
  214. std::pair<std::vector<Element>*, OutputStateId> cur_pair = Q_.front();
  215. Q_.pop_front();
  216. ProcessSubset(cur_pair);
  217. if (debug_ptr && *debug_ptr) Debug(); // will exit.
  218. if (max_states_ > 0 && output_arcs_.size() > max_states_) {
  219. if (allow_partial_ == false) {
  220. KALDI_ERR << "Determinization aborted since passed " << max_states_
  221. << " states";
  222. } else {
  223. KALDI_WARN << "Determinization terminated since passed "
  224. << max_states_
  225. << " states, partial results will be generated";
  226. is_partial_ = true;
  227. break;
  228. }
  229. }
  230. }
  231. determinized_ = true;
  232. }
  233. bool IsPartial() { return is_partial_; }
  234. // frees all except output_arcs_, which contains the important info
  235. // we need to output.
  236. void FreeMostMemory() {
  237. if (ifst_) {
  238. delete ifst_;
  239. ifst_ = NULL;
  240. }
  241. for (typename SubsetHash::iterator iter = hash_.begin();
  242. iter != hash_.end(); ++iter)
  243. delete iter->first;
  244. SubsetHash tmp;
  245. tmp.swap(hash_);
  246. }
  247. ~DeterminizerStar() { FreeMostMemory(); }
  248. private:
  249. typedef typename Arc::Label Label;
  250. typedef typename Arc::Weight Weight;
  251. typedef typename Arc::StateId InputStateId;
  252. typedef typename Arc::StateId
  253. OutputStateId; // same as above but distinguish states in output Fst.
  254. typedef typename Arc::Label StringId; // Id type used in the StringRepository
  255. typedef StringRepository<Label, StringId> StringRepositoryType;
  256. // Element of a subset [of original states]
  257. struct Element {
  258. InputStateId state;
  259. StringId string;
  260. Weight weight;
  261. bool operator!=(const Element& other) const {
  262. return (state != other.state || string != other.string ||
  263. weight != other.weight);
  264. }
  265. };
  266. // Arcs in the format we temporarily create in this class (a representation,
  267. // essentially of a Gallic Fst).
  268. struct TempArc {
  269. Label ilabel;
  270. StringId ostring; // Look it up in the StringRepository, it's a sequence of
  271. // Labels.
  272. OutputStateId nextstate; // or kNoState for final weights.
  273. Weight weight;
  274. };
  275. // Hashing function used in hash of subsets.
  276. // A subset is a pointer to vector<Element>.
  277. // The Elements are in sorted order on state id, and without repeated states.
  278. // Because the order of Elements is fixed, we can use a hashing function that
  279. // is order-dependent. However the weights are not included in the hashing
  280. // function-- we hash subsets that differ only in weight to the same key. This
  281. // is not optimal in terms of the O(N) performance but typically if we have a
  282. // lot of determinized states that differ only in weight then the input
  283. // probably was pathological in some way, or even non-determinizable.
  284. // We don't quantize the weights, in order to avoid inexactness in simple
  285. // cases.
  286. // Instead we apply the delta when comparing subsets for equality, and allow a
  287. // small difference.
  288. class SubsetKey {
  289. public:
  290. size_t operator()(const std::vector<Element>* subset)
  291. const { // hashes only the state and string.
  292. size_t hash = 0, factor = 1;
  293. for (typename std::vector<Element>::const_iterator iter = subset->begin();
  294. iter != subset->end(); ++iter) {
  295. hash *= factor;
  296. hash += iter->state + 103333 * iter->string;
  297. factor *= 23531; // these numbers are primes.
  298. }
  299. return hash;
  300. }
  301. };
  302. // This is the equality operator on subsets. It checks for exact match on
  303. // state-id and string, and approximate match on weights.
  304. class SubsetEqual {
  305. public:
  306. bool operator()(const std::vector<Element>* s1,
  307. const std::vector<Element>* s2) const {
  308. size_t sz = s1->size();
  309. assert(sz >= 0);
  310. if (sz != s2->size()) return false;
  311. typename std::vector<Element>::const_iterator iter1 = s1->begin(),
  312. iter1_end = s1->end(),
  313. iter2 = s2->begin();
  314. for (; iter1 < iter1_end; ++iter1, ++iter2) {
  315. if (iter1->state != iter2->state || iter1->string != iter2->string ||
  316. !ApproxEqual(iter1->weight, iter2->weight, delta_))
  317. return false;
  318. }
  319. return true;
  320. }
  321. float delta_;
  322. explicit SubsetEqual(float delta) : delta_(delta) {}
  323. SubsetEqual() : delta_(kDelta) {}
  324. };
  325. // Operator that says whether two Elements have the same states.
  326. // Used only for debug.
  327. class SubsetEqualStates {
  328. public:
  329. bool operator()(const std::vector<Element>* s1,
  330. const std::vector<Element>* s2) const {
  331. size_t sz = s1->size();
  332. assert(sz >= 0);
  333. if (sz != s2->size()) return false;
  334. typename std::vector<Element>::const_iterator iter1 = s1->begin(),
  335. iter1_end = s1->end(),
  336. iter2 = s2->begin();
  337. for (; iter1 < iter1_end; ++iter1, ++iter2) {
  338. if (iter1->state != iter2->state) return false;
  339. }
  340. return true;
  341. }
  342. };
  343. // Define the hash type we use to store subsets.
  344. typedef unordered_map<const std::vector<Element>*, OutputStateId, SubsetKey,
  345. SubsetEqual>
  346. SubsetHash;
  347. class EpsilonClosure {
  348. public:
  349. EpsilonClosure(const Fst<Arc>* ifst, int max_states,
  350. StringRepository<Label, StringId>* repository, float delta)
  351. : ifst_(ifst),
  352. max_states_(max_states),
  353. repository_(repository),
  354. delta_(delta) {}
  355. // This function computes epsilon closure of subset of states by following
  356. // epsilon links. Called by ProcessSubset. Has no side effects except on the
  357. // repository.
  358. void GetEpsilonClosure(const std::vector<Element>& input_subset,
  359. std::vector<Element>* output_subset);
  360. private:
  361. struct EpsilonClosureInfo {
  362. EpsilonClosureInfo() {}
  363. EpsilonClosureInfo(const Element& e, const Weight& w, bool i)
  364. : element(e), weight_to_process(w), in_queue(i) {}
  365. // the weight in the Element struct is the total current weight
  366. // that has been processed already
  367. Element element;
  368. // this stores the weight that we haven't processed (propagated)
  369. Weight weight_to_process;
  370. // whether "this" struct is in the queue
  371. // we store the info here so that we don't have to look it up every time
  372. bool in_queue;
  373. bool operator<(const EpsilonClosureInfo& other) const {
  374. return this->element.state < other.element.state;
  375. }
  376. };
  377. // to further speed up EpsilonClosure() computation, we have 2 queues
  378. // the 2nd queue is used when we first iterate over the input set -
  379. // if queue_2_.empty() then we directly set output_set equal to input_set
  380. // and return immediately
  381. // Since Epsilon arcs are relatively rare, this way we could efficiently
  382. // detect the epsilon-free case, without having to waste our computation
  383. // e.g. allocating the EpsilonClosureInfo structure; this also lets us do a
  384. // level-by-level traversal, which could avoid some (unfortunately not all)
  385. // duplicate computation if epsilons form a DAG that is not a tree
  386. //
  387. // We put the queues here for better efficiency for memory allocation
  388. std::deque<typename Arc::StateId> queue_;
  389. std::vector<Element> queue_2_;
  390. // the following 2 structures together form our *virtual "map"*
  391. // basically we need a map from state_id to EpsilonClosureInfo that operates
  392. // in O(1) time, while still takes relatively small mem, and this does it
  393. // well for efficiency we don't clear id_to_index_ of its outdated
  394. // information As a result each time we do a look-up, we need to check if
  395. // (ecinfo_[id_to_index_[id]].element.state == id) Yet this is still faster
  396. // than using a std::map<StateId, EpsilonClosureInfo>
  397. std::vector<int> id_to_index_;
  398. // unlike id_to_index_, we clear the content of ecinfo_ each time we call
  399. // EpsilonClosure(). This needed because we need an efficient way to
  400. // traverse the virtual map - it is just too costly to traverse the
  401. // id_to_index_ vector.
  402. std::vector<EpsilonClosureInfo> ecinfo_;
  403. // Add one element (elem) into cur_subset
  404. // it also adds the necessary stuff to queue_, set the correct weight
  405. void AddOneElement(const Element& elem, const Weight& unprocessed_weight);
  406. // Sub-routine that we call in EpsilonClosure()
  407. // It takes the current "unprocessed_weight" and propagate it to the
  408. // states accessible from elem.state by an epsilon arc
  409. // and add the results to cur_subset.
  410. // save_to_queue_2 is set true when we iterate over the initial subset
  411. // - then we save it to queue_2 s.t. if it's empty, we directly return
  412. // the input set
  413. void ExpandOneElement(const Element& elem, bool sorted,
  414. const Weight& unprocessed_weight,
  415. bool save_to_queue_2 = false);
  416. // no pointers below would take the ownership
  417. const Fst<Arc>* ifst_;
  418. int max_states_;
  419. StringRepository<Label, StringId>* repository_;
  420. float delta_;
  421. };
  422. // This function works out the final-weight of the determinized state.
  423. // called by ProcessSubset.
  424. // Has no side effects except on the variable repository_, and output_arcs_.
  425. void ProcessFinal(const std::vector<Element>& closed_subset,
  426. OutputStateId state) {
  427. // processes final-weights for this subset.
  428. bool is_final = false;
  429. StringId final_string = 0; // = 0 to keep compiler happy.
  430. Weight final_weight =
  431. Weight::One(); // This value will never be accessed, and
  432. // we just set it to avoid spurious compiler warnings. We avoid setting it
  433. // to Zero() because floating-point infinities can sometimes generate
  434. // interrupts and slow things down.
  435. typename std::vector<Element>::const_iterator iter = closed_subset.begin(),
  436. end = closed_subset.end();
  437. for (; iter != end; ++iter) {
  438. const Element& elem = *iter;
  439. Weight this_final_weight = ifst_->Final(elem.state);
  440. if (this_final_weight != Weight::Zero()) {
  441. if (!is_final) { // first final-weight
  442. final_string = elem.string;
  443. final_weight = Times(elem.weight, this_final_weight);
  444. is_final = true;
  445. } else { // already have one.
  446. if (final_string != elem.string) {
  447. KALDI_ERR << "FST was not functional -> not determinizable";
  448. }
  449. final_weight =
  450. Plus(final_weight, Times(elem.weight, this_final_weight));
  451. }
  452. }
  453. }
  454. if (is_final) {
  455. // store final weights in TempArc structure, just like a transition.
  456. TempArc temp_arc;
  457. temp_arc.ilabel = 0;
  458. temp_arc.nextstate =
  459. kNoStateId; // special marker meaning "final weight".
  460. temp_arc.ostring = final_string;
  461. temp_arc.weight = final_weight;
  462. output_arcs_[state].push_back(temp_arc);
  463. }
  464. }
  465. // ProcessTransition is called from "ProcessTransitions". Broken out for
  466. // clarity. Has side effects on output_arcs_, and (via SubsetToStateId), Q_
  467. // and hash_.
  468. void ProcessTransition(OutputStateId state, Label ilabel,
  469. std::vector<Element>* subset);
  470. // "less than" operator for pair<Label, Element>. Used in
  471. // ProcessTransitions. Lexicographical order, with comparing the state only
  472. // for "Element".
  473. class PairComparator {
  474. public:
  475. inline bool operator()(const std::pair<Label, Element>& p1,
  476. const std::pair<Label, Element>& p2) {
  477. if (p1.first < p2.first) {
  478. return true;
  479. } else if (p1.first > p2.first) {
  480. return false;
  481. } else {
  482. return p1.second.state < p2.second.state;
  483. }
  484. }
  485. };
  486. // ProcessTransitions handles transitions out of this subset of states.
  487. // Ignores epsilon transitions (epsilon closure already handled that).
  488. // Does not consider final states. Breaks the transitions up by ilabel,
  489. // and creates a new transition in determinized FST, for each ilabel.
  490. // Does this by creating a big vector of pairs <Label, Element> and then
  491. // sorting them using a lexicographical ordering, and calling
  492. // ProcessTransition for each range with the same ilabel. Side effects on
  493. // repository, and (via ProcessTransition) on Q_, hash_, and output_arcs_.
  494. void ProcessTransitions(const std::vector<Element>& closed_subset,
  495. OutputStateId state) {
  496. std::vector<std::pair<Label, Element> > all_elems;
  497. { // Push back into "all_elems", elements corresponding to all
  498. // non-epsilon-input transitions
  499. // out of all states in "closed_subset".
  500. typename std::vector<Element>::const_iterator iter =
  501. closed_subset.begin(),
  502. end = closed_subset.end();
  503. for (; iter != end; ++iter) {
  504. const Element& elem = *iter;
  505. for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
  506. aiter.Next()) {
  507. const Arc& arc = aiter.Value();
  508. if (arc.ilabel !=
  509. 0) { // Non-epsilon transition -- ignore epsilons here.
  510. std::pair<Label, Element> this_pr;
  511. this_pr.first = arc.ilabel;
  512. Element& next_elem(this_pr.second);
  513. next_elem.state = arc.nextstate;
  514. next_elem.weight = Times(elem.weight, arc.weight);
  515. if (arc.olabel == 0) { // output epsilon-- this is simple case so
  516. // handle separately for efficiency
  517. next_elem.string = elem.string;
  518. } else {
  519. std::vector<Label> seq;
  520. repository_.SeqOfId(elem.string, &seq);
  521. seq.push_back(arc.olabel);
  522. next_elem.string = repository_.IdOfSeq(seq);
  523. }
  524. all_elems.push_back(this_pr);
  525. }
  526. }
  527. }
  528. }
  529. PairComparator pc;
  530. std::sort(all_elems.begin(), all_elems.end(), pc);
  531. // now sorted first on input label, then on state.
  532. typedef typename std::vector<std::pair<Label, Element> >::const_iterator
  533. PairIter;
  534. PairIter cur = all_elems.begin(), end = all_elems.end();
  535. std::vector<Element> this_subset;
  536. while (cur != end) {
  537. // Process ranges that share the same input symbol.
  538. Label ilabel = cur->first;
  539. this_subset.clear();
  540. while (cur != end && cur->first == ilabel) {
  541. this_subset.push_back(cur->second);
  542. cur++;
  543. }
  544. // We now have a subset for this ilabel.
  545. ProcessTransition(state, ilabel, &this_subset);
  546. }
  547. }
  548. // SubsetToStateId converts a subset (vector of Elements) to a StateId in the
  549. // output fst. This is a hash lookup; if no such state exists, it adds a new
  550. // state to the hash and adds a new pair to the queue. Side effects on hash_
  551. // and Q_, and on output_arcs_ [just affects the size].
  552. OutputStateId SubsetToStateId(
  553. const std::vector<Element>& subset) { // may add the subset to the queue.
  554. typedef typename SubsetHash::iterator IterType;
  555. IterType iter = hash_.find(&subset);
  556. if (iter == hash_.end()) { // was not there.
  557. std::vector<Element>* new_subset = new std::vector<Element>(subset);
  558. OutputStateId new_state_id = (OutputStateId)output_arcs_.size();
  559. bool ans =
  560. hash_
  561. .insert(std::pair<const std::vector<Element>*, OutputStateId>(
  562. new_subset, new_state_id))
  563. .second;
  564. assert(ans);
  565. output_arcs_.push_back(std::vector<TempArc>());
  566. if (allow_partial_ == false) {
  567. // If --allow-partial is not requested, we do the old way.
  568. Q_.push_front(std::pair<std::vector<Element>*, OutputStateId>(
  569. new_subset, new_state_id));
  570. } else {
  571. // If --allow-partial is requested, we do breadth first search. This
  572. // ensures that when we return partial results, we return the states
  573. // that are reachable by the fewest steps from the start state.
  574. Q_.push_back(std::pair<std::vector<Element>*, OutputStateId>(
  575. new_subset, new_state_id));
  576. }
  577. return new_state_id;
  578. } else {
  579. return iter->second; // the OutputStateId.
  580. }
  581. }
  582. // ProcessSubset does the processing of a determinized state, i.e. it creates
  583. // transitions out of it and adds new determinized states to the queue if
  584. // necessary. The first stage is "EpsilonClosure" (follow epsilons to get a
  585. // possibly larger set of (states, weights)). After that we ignore epsilons.
  586. // We process the final-weight of the state, and then handle transitions out
  587. // (this may add more determinized states to the queue).
  588. void ProcessSubset(
  589. const std::pair<std::vector<Element>*, OutputStateId>& pair) {
  590. const std::vector<Element>* subset = pair.first;
  591. OutputStateId state = pair.second;
  592. std::vector<Element> closed_subset; // subset after epsilon closure.
  593. epsilon_closure_.GetEpsilonClosure(*subset, &closed_subset);
  594. // Now follow non-epsilon arcs [and also process final states]
  595. ProcessFinal(closed_subset, state);
  596. // Now handle transitions out of these states.
  597. ProcessTransitions(closed_subset, state);
  598. }
  599. void Debug();
  600. KALDI_DISALLOW_COPY_AND_ASSIGN(DeterminizerStar);
  601. std::deque<std::pair<std::vector<Element>*, OutputStateId> >
  602. Q_; // queue of subsets to be processed.
  603. std::vector<std::vector<TempArc> >
  604. output_arcs_; // essentially an FST in our format.
  605. const Fst<Arc>* ifst_;
  606. float delta_;
  607. int max_states_;
  608. bool determinized_; // used to check usage.
  609. bool allow_partial_; // output paritial results or not
  610. bool is_partial_; // if we get partial results or not
  611. SubsetKey hasher_; // object that computes keys-- has no data members.
  612. SubsetEqual
  613. equal_; // object that compares subsets-- only data member is delta_.
  614. SubsetHash hash_; // hash from Subset to StateId in final Fst.
  615. StringRepository<Label, StringId>
  616. repository_; // associate integer id's with sequences of labels.
  617. EpsilonClosure epsilon_closure_;
  618. };
  619. template <class F>
  620. bool DeterminizeStar(F& ifst, // NOLINT
  621. MutableFst<typename F::Arc>* ofst, float delta,
  622. bool* debug_ptr, int max_states, bool allow_partial) {
  623. ofst->SetOutputSymbols(ifst.OutputSymbols());
  624. ofst->SetInputSymbols(ifst.InputSymbols());
  625. DeterminizerStar<F> det(ifst, delta, max_states, allow_partial);
  626. det.Determinize(debug_ptr);
  627. det.Output(ofst);
  628. return det.IsPartial();
  629. }
  630. template <class F>
  631. bool DeterminizeStar(F& ifst, // NOLINT
  632. MutableFst<GallicArc<typename F::Arc> >* ofst, float delta,
  633. bool* debug_ptr, int max_states, bool allow_partial) {
  634. ofst->SetOutputSymbols(ifst.InputSymbols());
  635. ofst->SetInputSymbols(ifst.InputSymbols());
  636. DeterminizerStar<F> det(ifst, delta, max_states, allow_partial);
  637. det.Determinize(debug_ptr);
  638. det.Output(ofst);
  639. return det.IsPartial();
  640. }
  641. template <class F>
  642. void DeterminizerStar<F>::EpsilonClosure::GetEpsilonClosure(
  643. const std::vector<Element>& input_subset,
  644. std::vector<Element>* output_subset) {
  645. ecinfo_.resize(0);
  646. size_t size = input_subset.size();
  647. // find whether input fst is known to be sorted in input label.
  648. bool sorted =
  649. ((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
  650. // size is still the input_subset.size()
  651. for (size_t i = 0; i < size; i++) {
  652. ExpandOneElement(input_subset[i], sorted, input_subset[i].weight, true);
  653. }
  654. size_t s = queue_2_.size();
  655. if (s == 0) {
  656. *output_subset = input_subset;
  657. return;
  658. } else {
  659. // queue_2 not empty. Need to create the vector<info>
  660. for (size_t i = 0; i < size; i++) {
  661. // the weight has not been processed yet,
  662. // so put all of them in the "weight_to_process"
  663. ecinfo_.push_back(
  664. EpsilonClosureInfo(input_subset[i], input_subset[i].weight, false));
  665. ecinfo_.back().element.weight = Weight::Zero(); // clear the weight
  666. if (id_to_index_.size() < input_subset[i].state + 1) {
  667. id_to_index_.resize(2 * input_subset[i].state + 1, -1);
  668. }
  669. id_to_index_[input_subset[i].state] = ecinfo_.size() - 1;
  670. }
  671. }
  672. {
  673. Element elem;
  674. elem.weight = Weight::Zero();
  675. for (size_t i = 0; i < s; i++) {
  676. elem.state = queue_2_[i].state;
  677. elem.string = queue_2_[i].string;
  678. AddOneElement(elem, queue_2_[i].weight);
  679. }
  680. queue_2_.resize(0);
  681. }
  682. int counter = 0; // relates to max-states option, used for test.
  683. while (!queue_.empty()) {
  684. InputStateId id = queue_.front();
  685. // no need to check validity of the index
  686. // since anything in the queue we are sure they're in the "virtual set"
  687. int index = id_to_index_[id];
  688. EpsilonClosureInfo& info = ecinfo_[index];
  689. Element& elem = info.element;
  690. Weight unprocessed_weight = info.weight_to_process;
  691. elem.weight = Plus(elem.weight, unprocessed_weight);
  692. info.weight_to_process = Weight::Zero();
  693. info.in_queue = false;
  694. queue_.pop_front();
  695. if (max_states_ > 0 && counter++ > max_states_) {
  696. KALDI_ERR << "Determinization aborted since looped more than "
  697. << max_states_ << " times during epsilon closure";
  698. }
  699. // generally we need to be careful about iterator-invalidation problem
  700. // here we pass a reference (elem), which could be an issue.
  701. // In the beginning of ExpandOneElement, we make a copy of elem.string
  702. // to avoid that issue
  703. ExpandOneElement(elem, sorted, unprocessed_weight);
  704. }
  705. {
  706. // this sorting is based on StateId
  707. sort(ecinfo_.begin(), ecinfo_.end());
  708. output_subset->clear();
  709. size = ecinfo_.size();
  710. output_subset->reserve(size);
  711. for (size_t i = 0; i < size; i++) {
  712. EpsilonClosureInfo& info = ecinfo_[i];
  713. if (info.weight_to_process != Weight::Zero()) {
  714. info.element.weight = Plus(info.element.weight, info.weight_to_process);
  715. }
  716. output_subset->push_back(info.element);
  717. }
  718. }
  719. }
  720. template <class F>
  721. void DeterminizerStar<F>::EpsilonClosure::AddOneElement(
  722. const Element& elem, const Weight& unprocessed_weight) {
  723. // first we try to find the element info in the ecinfo_ vector
  724. int index = -1;
  725. if (elem.state < id_to_index_.size()) {
  726. index = id_to_index_[elem.state];
  727. }
  728. if (index != -1) {
  729. if (index >= ecinfo_.size()) {
  730. index = -1;
  731. } else if (ecinfo_[index].element.state != elem.state) {
  732. // since ecinfo_ might store outdated information, we need to check
  733. index = -1;
  734. }
  735. }
  736. if (index == -1) {
  737. // was no such StateId: insert and add to queue.
  738. ecinfo_.push_back(EpsilonClosureInfo(elem, unprocessed_weight, true));
  739. size_t size = id_to_index_.size();
  740. if (size < elem.state + 1) {
  741. // double the size to reduce memory operations
  742. id_to_index_.resize(2 * elem.state + 1, -1);
  743. }
  744. id_to_index_[elem.state] = ecinfo_.size() - 1;
  745. queue_.push_back(elem.state);
  746. } else { // one is already there. Add weights.
  747. EpsilonClosureInfo& info = ecinfo_[index];
  748. if (info.element.string != elem.string) {
  749. // Non-functional FST.
  750. std::ostringstream ss;
  751. ss << "FST was not functional -> not determinizable.";
  752. { // Print some debugging information. Can be helpful to debug
  753. // the inputs when FSTs are mysteriously non-functional.
  754. std::vector<Label> tmp_seq;
  755. repository_->SeqOfId(info.element.string, &tmp_seq);
  756. ss << "\nFirst string:";
  757. for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i];
  758. ss << "\nSecond string:";
  759. repository_->SeqOfId(elem.string, &tmp_seq);
  760. for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i];
  761. }
  762. KALDI_ERR << ss.str();
  763. }
  764. info.weight_to_process = Plus(info.weight_to_process, unprocessed_weight);
  765. if (!info.in_queue) {
  766. // this is because the code in "else" below: the
  767. // iter->second.weight_to_process might not be Zero()
  768. Weight weight = Plus(info.element.weight, info.weight_to_process);
  769. // What is done below is, we propagate the weight (by adding them
  770. // to the queue only when the change is big enough;
  771. // otherwise we just store the weight, until before returning
  772. // we add the element.weight and weight_to_process together
  773. if (!ApproxEqual(weight, info.element.weight, delta_)) {
  774. // add extra part of weight to queue.
  775. info.in_queue = true;
  776. queue_.push_back(elem.state);
  777. }
  778. }
  779. }
  780. }
  781. template <class F>
  782. void DeterminizerStar<F>::EpsilonClosure::ExpandOneElement(
  783. const Element& elem, bool sorted, const Weight& unprocessed_weight,
  784. bool save_to_queue_2) {
  785. StringId str =
  786. elem.string; // copy it here because there is an iterator-
  787. // - invalidation problem (it really happens for some FSTs)
  788. // now we are going to propagate the "unprocessed_weight"
  789. for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done();
  790. aiter.Next()) {
  791. const Arc& arc = aiter.Value();
  792. if (sorted && arc.ilabel > 0) {
  793. break;
  794. // Break from the loop: due to sorting there will be no
  795. // more transitions with epsilons as input labels.
  796. }
  797. if (arc.ilabel != 0) {
  798. continue; // we only process epsilons here
  799. }
  800. Element next_elem;
  801. next_elem.state = arc.nextstate;
  802. next_elem.weight = Weight::Zero();
  803. Weight next_unprocessed_weight = Times(unprocessed_weight, arc.weight);
  804. // now must append strings
  805. if (arc.olabel == 0) {
  806. next_elem.string = str;
  807. } else {
  808. std::vector<Label> seq;
  809. repository_->SeqOfId(str, &seq);
  810. if (arc.olabel != 0) seq.push_back(arc.olabel);
  811. next_elem.string = repository_->IdOfSeq(seq);
  812. }
  813. if (save_to_queue_2) {
  814. next_elem.weight = next_unprocessed_weight;
  815. queue_2_.push_back(next_elem);
  816. } else {
  817. AddOneElement(next_elem, next_unprocessed_weight);
  818. }
  819. }
  820. }
  821. template <class F>
  822. void DeterminizerStar<F>::Output(MutableFst<GallicArc<Arc> >* ofst,
  823. bool destroy) {
  824. assert(determinized_);
  825. if (destroy) determinized_ = false;
  826. typedef GallicWeight<Label, Weight> ThisGallicWeight;
  827. typedef typename Arc::StateId StateId;
  828. if (destroy) FreeMostMemory();
  829. StateId nStates = static_cast<StateId>(output_arcs_.size());
  830. ofst->DeleteStates();
  831. ofst->SetStart(kNoStateId);
  832. if (nStates == 0) {
  833. return;
  834. }
  835. for (StateId s = 0; s < nStates; s++) {
  836. OutputStateId news = ofst->AddState();
  837. assert(news == s);
  838. }
  839. ofst->SetStart(0);
  840. // now process transitions.
  841. for (StateId this_state = 0; this_state < nStates; this_state++) {
  842. std::vector<TempArc>& this_vec(output_arcs_[this_state]);
  843. typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
  844. end = this_vec.end();
  845. for (; iter != end; ++iter) {
  846. const TempArc& temp_arc(*iter);
  847. GallicArc<Arc> new_arc;
  848. std::vector<Label> seq;
  849. repository_.SeqOfId(temp_arc.ostring, &seq);
  850. StringWeight<Label, STRING_LEFT> string_weight;
  851. for (size_t i = 0; i < seq.size(); i++) string_weight.PushBack(seq[i]);
  852. ThisGallicWeight gallic_weight(string_weight, temp_arc.weight);
  853. if (temp_arc.nextstate == kNoStateId) { // is really final weight.
  854. ofst->SetFinal(this_state, gallic_weight);
  855. } else { // is really an arc.
  856. new_arc.nextstate = temp_arc.nextstate;
  857. new_arc.ilabel = temp_arc.ilabel;
  858. new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
  859. new_arc.weight = gallic_weight; // includes string and weight.
  860. ofst->AddArc(this_state, new_arc);
  861. }
  862. }
  863. // Free up memory. Do this inside the loop as ofst is also allocating
  864. // memory
  865. if (destroy) {
  866. std::vector<TempArc> temp;
  867. temp.swap(this_vec);
  868. }
  869. }
  870. if (destroy) {
  871. std::vector<std::vector<TempArc> > temp;
  872. temp.swap(output_arcs_);
  873. }
  874. }
  875. template <class F>
  876. void DeterminizerStar<F>::Output(MutableFst<Arc>* ofst, bool destroy) {
  877. assert(determinized_);
  878. if (destroy) determinized_ = false;
  879. // Outputs to standard fst.
  880. OutputStateId num_states = static_cast<OutputStateId>(output_arcs_.size());
  881. if (destroy) FreeMostMemory();
  882. ofst->DeleteStates();
  883. if (num_states == 0) {
  884. ofst->SetStart(kNoStateId);
  885. return;
  886. }
  887. // Add basic states-- but will add extra ones to account for strings on
  888. // output.
  889. for (OutputStateId s = 0; s < num_states; s++) {
  890. OutputStateId news = ofst->AddState();
  891. assert(news == s);
  892. }
  893. ofst->SetStart(0);
  894. for (OutputStateId this_state = 0; this_state < num_states; this_state++) {
  895. std::vector<TempArc>& this_vec(output_arcs_[this_state]);
  896. typename std::vector<TempArc>::const_iterator iter = this_vec.begin(),
  897. end = this_vec.end();
  898. for (; iter != end; ++iter) {
  899. const TempArc& temp_arc(*iter);
  900. std::vector<Label> seq;
  901. repository_.SeqOfId(temp_arc.ostring, &seq);
  902. if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
  903. // Make a sequence of states going to a final state, with the strings as
  904. // labels. Put the weight on the first arc.
  905. OutputStateId cur_state = this_state;
  906. for (size_t i = 0; i < seq.size(); i++) {
  907. OutputStateId next_state = ofst->AddState();
  908. Arc arc;
  909. arc.nextstate = next_state;
  910. arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
  911. arc.ilabel = 0; // epsilon.
  912. arc.olabel = seq[i];
  913. ofst->AddArc(cur_state, arc);
  914. cur_state = next_state;
  915. }
  916. ofst->SetFinal(cur_state,
  917. (seq.size() == 0 ? temp_arc.weight : Weight::One()));
  918. } else { // Really an arc.
  919. OutputStateId cur_state = this_state;
  920. // Have to be careful with this integer comparison (i+1 < seq.size())
  921. // because unsigned. i < seq.size()-1 could fail for zero-length
  922. // sequences.
  923. for (size_t i = 0; i + 1 < seq.size(); i++) {
  924. // for all but the last element of seq, create new state.
  925. OutputStateId next_state = ofst->AddState();
  926. Arc arc;
  927. arc.nextstate = next_state;
  928. arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
  929. arc.ilabel = (i == 0 ? temp_arc.ilabel
  930. : 0); // put ilabel on first element of seq.
  931. arc.olabel = seq[i];
  932. ofst->AddArc(cur_state, arc);
  933. cur_state = next_state;
  934. }
  935. // Add the final arc in the sequence.
  936. Arc arc;
  937. arc.nextstate = temp_arc.nextstate;
  938. arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
  939. arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
  940. arc.olabel = (seq.size() > 0 ? seq.back() : 0);
  941. ofst->AddArc(cur_state, arc);
  942. }
  943. }
  944. // Free up memory. Do this inside the loop as ofst is also allocating
  945. // memory
  946. if (destroy) {
  947. std::vector<TempArc> temp;
  948. temp.swap(this_vec);
  949. }
  950. }
  951. if (destroy) {
  952. std::vector<std::vector<TempArc> > temp;
  953. temp.swap(output_arcs_);
  954. repository_.Destroy();
  955. }
  956. }
  957. template <class F>
  958. void DeterminizerStar<F>::ProcessTransition(OutputStateId state, Label ilabel,
  959. std::vector<Element>* subset) {
  960. // At input, "subset" may contain duplicates for a given dest state (but in
  961. // sorted order). This function removes duplicates from "subset", normalizes
  962. // it, and adds a transition to the dest. state (possibly affecting Q_ and
  963. // hash_, if state did not exist).
  964. typedef typename std::vector<Element>::iterator IterType;
  965. { // This block makes the subset have one unique Element per state, adding
  966. // the weights.
  967. IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
  968. size_t num_out = 0;
  969. // Merge elements with same state-id
  970. while (cur_in != end) { // while we have more elements to process.
  971. // At this point, cur_out points to location of next place we want to put
  972. // an element, cur_in points to location of next element we want to
  973. // process.
  974. if (cur_in != cur_out) *cur_out = *cur_in;
  975. cur_in++;
  976. while (cur_in != end &&
  977. cur_in->state == cur_out->state) { // merge elements.
  978. if (cur_in->string != cur_out->string) {
  979. KALDI_ERR << "FST was not functional -> not determinizable";
  980. }
  981. cur_out->weight = Plus(cur_out->weight, cur_in->weight);
  982. cur_in++;
  983. }
  984. cur_out++;
  985. num_out++;
  986. }
  987. subset->resize(num_out);
  988. }
  989. StringId common_str;
  990. Weight tot_weight;
  991. { // This block computes common_str and tot_weight (essentially: the common
  992. // divisor)
  993. // and removes them from the elements.
  994. std::vector<Label> seq;
  995. IterType begin = subset->begin(), iter, end = subset->end();
  996. { // This block computes "seq", which is the common prefix, and
  997. // "common_str",
  998. // which is the StringId version of "seq".
  999. std::vector<Label> tmp_seq;
  1000. for (iter = begin; iter != end; ++iter) {
  1001. if (iter == begin) {
  1002. repository_.SeqOfId(iter->string, &seq);
  1003. } else {
  1004. repository_.SeqOfId(iter->string, &tmp_seq);
  1005. if (tmp_seq.size() < seq.size())
  1006. seq.resize(tmp_seq.size()); // size of shortest one.
  1007. for (size_t i = 0; i < seq.size();
  1008. i++) // seq.size() is the shorter one at this point.
  1009. if (tmp_seq[i] != seq[i]) seq.resize(i);
  1010. }
  1011. if (seq.size() == 0) break; // will not get any prefix.
  1012. }
  1013. common_str = repository_.IdOfSeq(seq);
  1014. }
  1015. { // This block computes "tot_weight".
  1016. iter = begin;
  1017. tot_weight = iter->weight;
  1018. for (++iter; iter != end; ++iter)
  1019. tot_weight = Plus(tot_weight, iter->weight);
  1020. }
  1021. // Now divide out common stuff from elements.
  1022. size_t prefix_len = seq.size();
  1023. for (iter = begin; iter != end; ++iter) {
  1024. iter->weight = Divide(iter->weight, tot_weight);
  1025. iter->string = repository_.RemovePrefix(iter->string, prefix_len);
  1026. }
  1027. }
  1028. // Now add an arc to the state that the subset represents.
  1029. // We may create a new state id for this (in SubsetToStateId).
  1030. TempArc temp_arc;
  1031. temp_arc.ilabel = ilabel;
  1032. temp_arc.nextstate =
  1033. SubsetToStateId(*subset); // may or may not really add the subset.
  1034. temp_arc.ostring = common_str;
  1035. temp_arc.weight = tot_weight;
  1036. output_arcs_[state].push_back(temp_arc); // record the arc.
  1037. }
  1038. template <class F>
  1039. void DeterminizerStar<F>::Debug() {
  1040. // this function called if you send a signal
  1041. // SIGUSR1 to the process (and it's caught by the handler in
  1042. // fstdeterminizestar). It prints out some traceback
  1043. // info and exits.
  1044. KALDI_WARN << "Debug function called (probably SIGUSR1 caught)";
  1045. // free up memory from the hash as we need a little memory
  1046. {
  1047. SubsetHash hash_tmp;
  1048. std::swap(hash_tmp, hash_);
  1049. }
  1050. if (output_arcs_.size() <= 2) {
  1051. KALDI_ERR << "Nothing to trace back";
  1052. }
  1053. size_t max_state = output_arcs_.size() - 2; // don't take the last
  1054. // one as we might be halfway into constructing it.
  1055. std::vector<OutputStateId> predecessor(max_state + 1, kNoStateId);
  1056. for (size_t i = 0; i < max_state; i++) {
  1057. for (size_t j = 0; j < output_arcs_[i].size(); j++) {
  1058. OutputStateId nextstate = output_arcs_[i][j].nextstate;
  1059. // Always find an earlier-numbered predecessor; this
  1060. // is always possible because of the way the algorithm
  1061. // works.
  1062. if (nextstate <= max_state && nextstate > i) predecessor[nextstate] = i;
  1063. }
  1064. }
  1065. std::vector<std::pair<Label, StringId> > traceback;
  1066. // 'traceback' is a pair of (ilabel, olabel-seq).
  1067. OutputStateId cur_state = max_state; // A recently constructed state.
  1068. while (cur_state != 0 && cur_state != kNoStateId) {
  1069. OutputStateId last_state = predecessor[cur_state];
  1070. std::pair<Label, StringId> p;
  1071. size_t i;
  1072. for (i = 0; i < output_arcs_[last_state].size(); i++) {
  1073. if (output_arcs_[last_state][i].nextstate == cur_state) {
  1074. p.first = output_arcs_[last_state][i].ilabel;
  1075. p.second = output_arcs_[last_state][i].ostring;
  1076. traceback.push_back(p);
  1077. break;
  1078. }
  1079. }
  1080. KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
  1081. cur_state = last_state;
  1082. }
  1083. if (cur_state == kNoStateId)
  1084. KALDI_WARN << "Traceback did not reach start state "
  1085. << "(possibly debug-code error)";
  1086. std::stringstream ss;
  1087. ss << "Traceback follows in format "
  1088. << "ilabel (olabel olabel) ilabel (olabel) ... :";
  1089. for (ssize_t i = traceback.size() - 1; i >= 0; i--) {
  1090. ss << ' ' << traceback[i].first << " ( ";
  1091. std::vector<Label> seq;
  1092. repository_.SeqOfId(traceback[i].second, &seq);
  1093. for (size_t j = 0; j < seq.size(); j++) ss << seq[j] << ' ';
  1094. ss << ')';
  1095. }
  1096. KALDI_ERR << ss.str();
  1097. }
  1098. } // namespace fst
  1099. #endif // KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_