You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

387 lines
14 KiB

  1. // fstext/table-matcher.h
  2. // Copyright 2009-2011 Microsoft Corporation
  3. // See ../../COPYING for clarification regarding multiple authors
  4. //
  5. // Licensed under the Apache License, Version 2.0 (the "License");
  6. // you may not use this file except in compliance with the License.
  7. // You may obtain a copy of the License at
  8. //
  9. // http://www.apache.org/licenses/LICENSE-2.0
  10. //
  11. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  12. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  13. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  14. // MERCHANTABLITY OR NON-INFRINGEMENT.
  15. // See the Apache 2 License for the specific language governing permissions and
  16. // limitations under the License.
  17. #ifndef KALDI_FSTEXT_TABLE_MATCHER_H_
  18. #define KALDI_FSTEXT_TABLE_MATCHER_H_
  19. #include <fst/fst-decl.h>
  20. #include <fst/fstlib.h>
  21. #include <memory>
  22. #include <vector>
  23. namespace fst {
  24. /// TableMatcher is a matcher specialized for the case where the output
  25. /// side of the left FST always has either all-epsilons coming out of
  26. /// a state, or a majority of the symbol table. Therefore we can
  27. /// either store nothing (for the all-epsilon case) or store a lookup
  28. /// table from Labels to arc offsets. Since the TableMatcher has to
  29. /// iterate over all arcs in each left-hand state the first time it sees
  30. /// it, this matcher type is not efficient if you compose with
  31. /// something very small on the right-- unless you do it multiple
  32. /// times and keep the matcher around. To do this requires using the
  33. /// most advanced form of ComposeFst in Compose.h, that initializes
  34. /// with ComposeFstImplOptions.
  35. struct TableMatcherOptions {
  36. float
  37. table_ratio; // we construct the table if it would be at least this full.
  38. int min_table_size;
  39. TableMatcherOptions() : table_ratio(0.25), min_table_size(4) {}
  40. };
  41. // Introducing an "impl" class for TableMatcher because
  42. // we need to do a shallow copy of the Matcher for when
  43. // we want to cache tables for multiple compositions.
  44. template <class F, class BackoffMatcher = SortedMatcher<F> >
  45. class TableMatcherImpl : public MatcherBase<typename F::Arc> {
  46. public:
  47. typedef F FST;
  48. typedef typename F::Arc Arc;
  49. typedef typename Arc::Label Label;
  50. typedef typename Arc::StateId StateId;
  51. typedef StateId
  52. ArcId; // Use this type to store arc offsets [it's actually size_t
  53. // in the Seek function of ArcIterator, but StateId should be big enough].
  54. typedef typename Arc::Weight Weight;
  55. public:
  56. TableMatcherImpl(const FST& fst, MatchType match_type,
  57. const TableMatcherOptions& opts = TableMatcherOptions())
  58. : match_type_(match_type),
  59. fst_(fst.Copy()),
  60. loop_(match_type == MATCH_INPUT
  61. ? Arc(kNoLabel, 0, Weight::One(), kNoStateId)
  62. : Arc(0, kNoLabel, Weight::One(), kNoStateId)),
  63. aiter_(NULL),
  64. s_(kNoStateId),
  65. opts_(opts),
  66. backoff_matcher_(fst, match_type) {
  67. assert(opts_.min_table_size > 0);
  68. if (match_type == MATCH_INPUT)
  69. assert(fst_->Properties(kILabelSorted, true) == kILabelSorted);
  70. else if (match_type == MATCH_OUTPUT)
  71. assert(fst_->Properties(kOLabelSorted, true) == kOLabelSorted);
  72. else
  73. assert(0 && "Invalid FST properties");
  74. }
  75. virtual const FST& GetFst() const { return *fst_; }
  76. virtual ~TableMatcherImpl() {
  77. std::vector<ArcId>* const empty =
  78. ((std::vector<ArcId>*)(NULL)) + 1; // special marker.
  79. for (size_t i = 0; i < tables_.size(); i++) {
  80. if (tables_[i] != NULL && tables_[i] != empty) delete tables_[i];
  81. }
  82. delete aiter_;
  83. delete fst_;
  84. }
  85. virtual MatchType Type(bool test) const { return match_type_; }
  86. void SetState(StateId s) {
  87. if (aiter_) {
  88. delete aiter_;
  89. aiter_ = NULL;
  90. }
  91. if (match_type_ == MATCH_NONE) LOG(FATAL) << "TableMatcher: bad match type";
  92. s_ = s;
  93. std::vector<ArcId>* const empty =
  94. ((std::vector<ArcId>*)(NULL)) + 1; // special marker.
  95. if (static_cast<size_t>(s) >= tables_.size()) {
  96. assert(s >= 0);
  97. tables_.resize(s + 1, NULL);
  98. }
  99. std::vector<ArcId>*& this_table_ = tables_[s]; // note: ref to ptr.
  100. if (this_table_ == empty) {
  101. backoff_matcher_.SetState(s);
  102. return;
  103. } else if (this_table_ == NULL) { // NULL means has not been set.
  104. ArcId num_arcs = fst_->NumArcs(s);
  105. if (num_arcs == 0 || num_arcs < opts_.min_table_size) {
  106. this_table_ = empty;
  107. backoff_matcher_.SetState(s);
  108. return;
  109. }
  110. ArcIterator<FST> aiter(*fst_, s);
  111. aiter.SetFlags(
  112. kArcNoCache |
  113. (match_type_ == MATCH_OUTPUT ? kArcOLabelValue : kArcILabelValue),
  114. kArcNoCache | kArcValueFlags);
  115. // the statement above, says: "Don't cache stuff; and I only need the
  116. // ilabel/olabel to be computed.
  117. aiter.Seek(num_arcs - 1);
  118. Label highest_label =
  119. (match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
  120. : aiter.Value().ilabel);
  121. if ((highest_label + 1) * opts_.table_ratio > num_arcs) {
  122. this_table_ = empty;
  123. backoff_matcher_.SetState(s);
  124. return; // table would be too sparse.
  125. }
  126. // OK, now we are creating the table.
  127. this_table_ = new std::vector<ArcId>(highest_label + 1, kNoStateId);
  128. ArcId pos = 0;
  129. for (aiter.Seek(0); !aiter.Done(); aiter.Next(), pos++) {
  130. Label label = (match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
  131. : aiter.Value().ilabel);
  132. assert(static_cast<size_t>(label) <=
  133. static_cast<size_t>(highest_label)); // also checks >= 0.
  134. if ((*this_table_)[label] == kNoStateId) (*this_table_)[label] = pos;
  135. // set this_table_[label] to first position where arc has this
  136. // label.
  137. }
  138. }
  139. // At this point in the code, this_table_ != NULL and != empty.
  140. aiter_ = new ArcIterator<FST>(*fst_, s);
  141. aiter_->SetFlags(kArcNoCache,
  142. kArcNoCache); // don't need to cache arcs as may only
  143. // need a small subset.
  144. loop_.nextstate = s;
  145. // aiter_ = NULL;
  146. // backoff_matcher_.SetState(s);
  147. }
  148. bool Find(Label match_label) {
  149. if (!aiter_) {
  150. return backoff_matcher_.Find(match_label);
  151. } else {
  152. match_label_ = match_label;
  153. current_loop_ = (match_label == 0);
  154. // kNoLabel means the implicit loop on the other FST --
  155. // matches real epsilons but not the self-loop.
  156. match_label_ = (match_label_ == kNoLabel ? 0 : match_label_);
  157. if (static_cast<size_t>(match_label_) < tables_[s_]->size() &&
  158. (*(tables_[s_]))[match_label_] != kNoStateId) {
  159. aiter_->Seek((*(tables_[s_]))[match_label_]); // label exists.
  160. return true;
  161. }
  162. return current_loop_;
  163. }
  164. }
  165. const Arc& Value() const {
  166. if (aiter_)
  167. return current_loop_ ? loop_ : aiter_->Value();
  168. else
  169. return backoff_matcher_.Value();
  170. }
  171. void Next() {
  172. if (aiter_) {
  173. if (current_loop_)
  174. current_loop_ = false;
  175. else
  176. aiter_->Next();
  177. } else {
  178. backoff_matcher_.Next();
  179. }
  180. }
  181. bool Done() const {
  182. if (aiter_ != NULL) {
  183. if (current_loop_) return false;
  184. if (aiter_->Done()) return true;
  185. Label label = (match_type_ == MATCH_OUTPUT ? aiter_->Value().olabel
  186. : aiter_->Value().ilabel);
  187. return (label != match_label_);
  188. } else {
  189. return backoff_matcher_.Done();
  190. }
  191. }
  192. const Arc& Value() {
  193. if (aiter_ != NULL) {
  194. return (current_loop_ ? loop_ : aiter_->Value());
  195. } else {
  196. return backoff_matcher_.Value();
  197. }
  198. }
  199. virtual TableMatcherImpl<FST>* Copy(bool safe = false) const {
  200. assert(0); // shouldn't be called. This is not a "real" matcher,
  201. // although we derive from MatcherBase for convenience.
  202. return NULL;
  203. }
  204. virtual uint64 Properties(uint64 props) const {
  205. return props;
  206. } // simple matcher that does
  207. // not change its FST, so properties are properties of FST it is applied to
  208. private:
  209. virtual void SetState_(StateId s) { SetState(s); }
  210. virtual bool Find_(Label label) { return Find(label); }
  211. virtual bool Done_() const { return Done(); }
  212. virtual const Arc& Value_() const { return Value(); }
  213. virtual void Next_() { Next(); }
  214. MatchType match_type_;
  215. FST* fst_;
  216. bool current_loop_;
  217. Label match_label_;
  218. Arc loop_;
  219. ArcIterator<FST>* aiter_;
  220. StateId s_;
  221. std::vector<std::vector<ArcId>*> tables_;
  222. TableMatcherOptions opts_;
  223. BackoffMatcher backoff_matcher_;
  224. };
  225. template <class F, class BackoffMatcher = SortedMatcher<F> >
  226. class TableMatcher : public MatcherBase<typename F::Arc> {
  227. public:
  228. typedef F FST;
  229. typedef typename F::Arc Arc;
  230. typedef typename Arc::Label Label;
  231. typedef typename Arc::StateId StateId;
  232. typedef StateId
  233. ArcId; // Use this type to store arc offsets [it's actually size_t
  234. // in the Seek function of ArcIterator, but StateId should be big enough].
  235. typedef typename Arc::Weight Weight;
  236. typedef TableMatcherImpl<F, BackoffMatcher> Impl;
  237. TableMatcher(const FST& fst, MatchType match_type,
  238. const TableMatcherOptions& opts = TableMatcherOptions())
  239. : impl_(std::make_shared<Impl>(fst, match_type, opts)) {}
  240. TableMatcher(const TableMatcher<FST, BackoffMatcher>& matcher,
  241. bool safe = false)
  242. : impl_(matcher.impl_) {
  243. if (safe == true) {
  244. LOG(FATAL) << "TableMatcher: Safe copy not supported";
  245. }
  246. }
  247. virtual const FST& GetFst() const { return impl_->GetFst(); }
  248. virtual MatchType Type(bool test) const { return impl_->Type(test); }
  249. void SetState(StateId s) { return impl_->SetState(s); }
  250. bool Find(Label match_label) { return impl_->Find(match_label); }
  251. const Arc& Value() const { return impl_->Value(); }
  252. void Next() { return impl_->Next(); }
  253. bool Done() const { return impl_->Done(); }
  254. const Arc& Value() { return impl_->Value(); }
  255. virtual TableMatcher<FST, BackoffMatcher>* Copy(bool safe = false) const {
  256. return new TableMatcher<FST, BackoffMatcher>(*this, safe);
  257. }
  258. virtual uint64 Properties(uint64 props) const {
  259. return impl_->Properties(props);
  260. } // simple matcher that does
  261. // not change its FST, so properties are properties of FST it is applied to
  262. private:
  263. std::shared_ptr<Impl> impl_;
  264. virtual void SetState_(StateId s) { impl_->SetState(s); }
  265. virtual bool Find_(Label label) { return impl_->Find(label); }
  266. virtual bool Done_() const { return impl_->Done(); }
  267. virtual const Arc& Value_() const { return impl_->Value(); }
  268. virtual void Next_() { impl_->Next(); }
  269. TableMatcher& operator=(const TableMatcher&) = delete;
  270. };
  271. struct TableComposeOptions : public TableMatcherOptions {
  272. bool connect; // Connect output
  273. ComposeFilter filter_type; // Which pre-defined filter to use
  274. MatchType table_match_type;
  275. explicit TableComposeOptions(const TableMatcherOptions& mo, bool c = true,
  276. ComposeFilter ft = SEQUENCE_FILTER,
  277. MatchType tms = MATCH_OUTPUT)
  278. : TableMatcherOptions(mo),
  279. connect(c),
  280. filter_type(ft),
  281. table_match_type(tms) {}
  282. TableComposeOptions()
  283. : connect(true),
  284. filter_type(SEQUENCE_FILTER),
  285. table_match_type(MATCH_OUTPUT) {}
  286. };
  287. template <class Arc>
  288. void TableCompose(const Fst<Arc>& ifst1, const Fst<Arc>& ifst2,
  289. MutableFst<Arc>* ofst,
  290. const TableComposeOptions& opts = TableComposeOptions()) {
  291. typedef Fst<Arc> F;
  292. CacheOptions nopts;
  293. nopts.gc_limit = 0; // Cache only the last state for fastest copy.
  294. if (opts.table_match_type == MATCH_OUTPUT) {
  295. // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
  296. ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
  297. impl_opts.matcher1 = new TableMatcher<F>(ifst1, MATCH_OUTPUT, opts);
  298. *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
  299. } else {
  300. assert(opts.table_match_type == MATCH_INPUT);
  301. // ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
  302. ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
  303. impl_opts.matcher2 = new TableMatcher<F>(ifst2, MATCH_INPUT, opts);
  304. *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
  305. }
  306. if (opts.connect) Connect(ofst);
  307. }
  308. /// TableComposeCache lets us do multiple compositions while caching the same
  309. /// matcher.
  310. template <class F>
  311. struct TableComposeCache {
  312. TableMatcher<F>* matcher;
  313. TableComposeOptions opts;
  314. explicit TableComposeCache(
  315. const TableComposeOptions& opts = TableComposeOptions())
  316. : matcher(NULL), opts(opts) {}
  317. ~TableComposeCache() { delete (matcher); }
  318. };
  319. template <class Arc>
  320. void TableCompose(const Fst<Arc>& ifst1, const Fst<Arc>& ifst2,
  321. MutableFst<Arc>* ofst, TableComposeCache<Fst<Arc> >* cache) {
  322. typedef Fst<Arc> F;
  323. assert(cache != NULL);
  324. CacheOptions nopts;
  325. nopts.gc_limit = 0; // Cache only the last state for fastest copy.
  326. if (cache->opts.table_match_type == MATCH_OUTPUT) {
  327. ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
  328. if (cache->matcher == NULL)
  329. cache->matcher = new TableMatcher<F>(ifst1, MATCH_OUTPUT, cache->opts);
  330. impl_opts.matcher1 = cache->matcher->Copy(); // not passing "safe": may not
  331. // be thread-safe-- anway I don't understand this part.
  332. *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
  333. } else {
  334. assert(cache->opts.table_match_type == MATCH_INPUT);
  335. ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
  336. if (cache->matcher == NULL)
  337. cache->matcher = new TableMatcher<F>(ifst2, MATCH_INPUT, cache->opts);
  338. impl_opts.matcher2 = cache->matcher->Copy();
  339. *ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
  340. }
  341. if (cache->opts.connect) Connect(ofst);
  342. }
  343. } // namespace fst
  344. #endif // KALDI_FSTEXT_TABLE_MATCHER_H_