You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

381 lines
14 KiB

  1. // lm/arpa-lm-compiler.cc
  2. // Copyright 2009-2011 Gilles Boulianne
  3. // Copyright 2016 Smart Action LLC (kkm)
  4. // Copyright 2017 Xiaohui Zhang
  5. // See ../../COPYING for clarification regarding multiple authors
  6. //
  7. // Licensed under the Apache License, Version 2.0 (the "License");
  8. // you may not use this file except in compliance with the License.
  9. // You may obtain a copy of the License at
  10. //
  11. // http://www.apache.org/licenses/LICENSE-2.0
  12. //
  13. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  15. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  16. // MERCHANTABLITY OR NON-INFRINGEMENT.
  17. // See the Apache 2 License for the specific language governing permissions and
  18. // limitations under the License.
  19. #include <algorithm>
  20. #include <functional>
  21. #include <limits>
  22. #include <sstream>
  23. #include <unordered_map>
  24. #include <utility>
  25. #include <vector>
  26. #include "base/kaldi-math.h"
  27. #include "fstext/remove-eps-local.h"
  28. #include "lm/arpa-lm-compiler.h"
  29. #include "util/stl-utils.h"
  30. #include "util/text-utils.h"
  31. namespace kaldi {
  32. class ArpaLmCompilerImplInterface {
  33. public:
  34. virtual ~ArpaLmCompilerImplInterface() {}
  35. virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0;
  36. };
  37. namespace {
  38. typedef int32 StateId;
  39. typedef int32 Symbol;
  40. // GeneralHistKey can represent state history in an arbitrarily large n
  41. // n-gram model with symbol ids fitting int32.
  42. class GeneralHistKey {
  43. public:
  44. // Construct key from being and end iterators.
  45. template <class InputIt>
  46. GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) {}
  47. // Construct empty history key.
  48. GeneralHistKey() : vector_() {}
  49. // Return tails of the key as a GeneralHistKey. The tails of an n-gram
  50. // w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
  51. // key class does not need this operartion).
  52. GeneralHistKey Tails() const {
  53. return GeneralHistKey(vector_.begin() + 1, vector_.end());
  54. }
  55. // Keys are equal if represent same state.
  56. friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) {
  57. return a.vector_ == b.vector_;
  58. }
  59. // Public typename HashType for hashing.
  60. struct HashType : public std::unary_function<GeneralHistKey, size_t> {
  61. size_t operator()(const GeneralHistKey& key) const {
  62. return VectorHasher<Symbol>().operator()(key.vector_);
  63. }
  64. };
  65. private:
  66. std::vector<Symbol> vector_;
  67. };
  68. // OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit
  69. // machine word. allowing significant memory reduction and some runtime
  70. // benefit over GeneralHistKey. Since 3 symbols are enough to track history
  71. // in a 4-gram model, this optimized key is used for smaller models with up
  72. // to 4-gram and symbol values up to 2^21-1.
  73. //
  74. // See GeneralHistKey for interface requirements of a key class.
  75. class OptimizedHistKey {
  76. public:
  77. enum {
  78. kShift = 21, // 21 * 3 = 63 bits for data.
  79. kMaxData = (1 << kShift) - 1
  80. };
  81. template <class InputIt>
  82. OptimizedHistKey(InputIt begin, InputIt end) : data_(0) {
  83. for (uint32 shift = 0; begin != end; ++begin, shift += kShift) {
  84. data_ |= static_cast<uint64>(*begin) << shift;
  85. }
  86. }
  87. OptimizedHistKey() : data_(0) {}
  88. OptimizedHistKey Tails() const { return OptimizedHistKey(data_ >> kShift); }
  89. friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) {
  90. return a.data_ == b.data_;
  91. }
  92. struct HashType : public std::unary_function<OptimizedHistKey, size_t> {
  93. size_t operator()(const OptimizedHistKey& key) const { return key.data_; }
  94. };
  95. private:
  96. explicit OptimizedHistKey(uint64 data) : data_(data) {}
  97. uint64 data_;
  98. };
  99. } // namespace
  100. template <class HistKey>
  101. class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
  102. public:
  103. ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
  104. Symbol sub_eps);
  105. virtual void ConsumeNGram(const NGram& ngram, bool is_highest);
  106. private:
  107. StateId AddStateWithBackoff(HistKey key, float backoff);
  108. void CreateBackoff(HistKey key, StateId state, float weight);
  109. ArpaLmCompiler* parent_; // Not owned.
  110. fst::StdVectorFst* fst_; // Not owned.
  111. Symbol bos_symbol_;
  112. Symbol eos_symbol_;
  113. Symbol sub_eps_;
  114. StateId eos_state_;
  115. typedef unordered_map<HistKey, StateId, typename HistKey::HashType>
  116. HistoryMap;
  117. HistoryMap history_;
  118. };
  119. template <class HistKey>
  120. ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(ArpaLmCompiler* parent,
  121. fst::StdVectorFst* fst,
  122. Symbol sub_eps)
  123. : parent_(parent),
  124. fst_(fst),
  125. bos_symbol_(parent->Options().bos_symbol),
  126. eos_symbol_(parent->Options().eos_symbol),
  127. sub_eps_(sub_eps) {
  128. // The algorithm maintains state per history. The 0-gram is a special state
  129. // for empty history. All unigrams (including BOS) backoff into this state.
  130. StateId zerogram = fst_->AddState();
  131. history_[HistKey()] = zerogram;
  132. // Also, if </s> is not treated as epsilon, create a common end state for
  133. // all transitions accepting the </s>, since they do not back off. This small
  134. // optimization saves about 2% states in an average grammar.
  135. if (sub_eps_ == 0) {
  136. eos_state_ = fst_->AddState();
  137. fst_->SetFinal(eos_state_, 0);
  138. }
  139. }
  140. template <class HistKey>
  141. void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram& ngram,
  142. bool is_highest) {
  143. // Generally, we do the following. Suppose we are adding an n-gram "A B
  144. // C". Then find the node for "A B", add a new node for "A B C", and connect
  145. // them with the arc accepting "C" with the specified weight. Also, add a
  146. // backoff arc from the new "A B C" node to its backoff state "B C".
  147. //
  148. // Two notable exceptions are the highest order n-grams, and final n-grams.
  149. //
  150. // When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM),
  151. // the following optimization is performed. There is no point adding a node
  152. // for "A B C" with a "C" arc from "A B", since there will be no other
  153. // arcs ingoing to this node, and an epsilon backoff arc into the backoff
  154. // model "B C", with the weight of \bar{1}. To save a node, create an arc
  155. // accepting "C" directly from "A B" to "B C". This saves as many nodes
  156. // as there are the highest order n-grams, which is typically about half
  157. // the size of a large 3-gram model.
  158. //
  159. // Indeed, this does not apply to n-grams ending in EOS, since they do not
  160. // back off. These are special, as they do not have a back-off state, and
  161. // the node for "(..anything..) </s>" is always final. These are handled
  162. // in one of the two possible ways, If symbols <s> and </s> are being
  163. // replaced by epsilons, neither node nor arc is created, and the logprob
  164. // of the n-gram is applied to its source node as final weight. If <s> and
  165. // </s> are preserved, then a special final node for </s> is allocated and
  166. // used as the destination of the "</s>" acceptor arc.
  167. HistKey heads(ngram.words.begin(), ngram.words.end() - 1);
  168. typename HistoryMap::iterator source_it = history_.find(heads);
  169. if (source_it == history_.end()) {
  170. // There was no "A B", therefore the probability of "A B C" is zero.
  171. // Print a warning and discard current n-gram.
  172. if (parent_->ShouldWarn())
  173. KALDI_WARN << parent_->LineReference()
  174. << " skipped: no parent (n-1)-gram exists";
  175. return;
  176. }
  177. StateId source = source_it->second;
  178. StateId dest;
  179. Symbol sym = ngram.words.back();
  180. float weight = -ngram.logprob;
  181. if (sym == sub_eps_ || sym == 0) {
  182. KALDI_ERR << " <eps> or disambiguation symbol " << sym
  183. << "found in the ARPA file. ";
  184. }
  185. if (sym == eos_symbol_) {
  186. if (sub_eps_ == 0) {
  187. // Keep </s> as a real symbol when not substituting.
  188. dest = eos_state_;
  189. } else {
  190. // Treat </s> as if it was epsilon: mark source final, with the weight
  191. // of the n-gram.
  192. fst_->SetFinal(source, weight);
  193. return;
  194. }
  195. } else {
  196. // For the highest order n-gram, this may find an existing state, for
  197. // non-highest, will create one (unless there are duplicate n-grams
  198. // in the grammar, which cannot be reliably detected if highest order,
  199. // so we better do not do that at all).
  200. dest = AddStateWithBackoff(
  201. HistKey(ngram.words.begin() + (is_highest ? 1 : 0), ngram.words.end()),
  202. -ngram.backoff);
  203. }
  204. if (sym == bos_symbol_) {
  205. weight = 0; // Accepting <s> is always free.
  206. if (sub_eps_ == 0) {
  207. // <s> is as a real symbol, only accepted in the start state.
  208. source = fst_->AddState();
  209. fst_->SetStart(source);
  210. } else {
  211. // The new state for <s> unigram history *is* the start state.
  212. fst_->SetStart(dest);
  213. return;
  214. }
  215. }
  216. // Add arc from source to dest, whichever way it was found.
  217. fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
  218. return;
  219. }
  220. // Find or create a new state for n-gram defined by key, and ensure it has a
  221. // backoff transition. The key is either the current n-gram for all but
  222. // highest orders, or the tails of the n-gram for the highest order. The
  223. // latter arises from the chain-collapsing optimization described above.
  224. template <class HistKey>
  225. StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key,
  226. float backoff) {
  227. typename HistoryMap::iterator dest_it = history_.find(key);
  228. if (dest_it != history_.end()) {
  229. // Found an existing state in the history map. Invariant: if the state in
  230. // the map, then its backoff arc is in the FST. We are done.
  231. return dest_it->second;
  232. }
  233. // Otherwise create a new state and its backoff arc, and register in the map.
  234. StateId dest = fst_->AddState();
  235. history_[key] = dest;
  236. CreateBackoff(key.Tails(), dest, backoff);
  237. return dest;
  238. }
  239. // Create a backoff arc for a state. Key is a backoff destination that may or
  240. // may not exist. When the destination is not found, naturally fall back to
  241. // the lower order model, and all the way down until one is found (since the
  242. // 0-gram model is always present, the search is guaranteed to terminate).
  243. template <class HistKey>
  244. inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(HistKey key,
  245. StateId state,
  246. float weight) {
  247. typename HistoryMap::iterator dest_it = history_.find(key);
  248. while (dest_it == history_.end()) {
  249. key = key.Tails();
  250. dest_it = history_.find(key);
  251. }
  252. // The arc should transduce either <eos> or #0 to <eps>, depending on the
  253. // epsilon substitution mode. This is the only case when input and output
  254. // label may differ.
  255. fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
  256. }
  257. ArpaLmCompiler::~ArpaLmCompiler() {
  258. if (impl_ != NULL) delete impl_;
  259. }
  260. void ArpaLmCompiler::HeaderAvailable() {
  261. KALDI_ASSERT(impl_ == NULL);
  262. // Use optimized implementation if the grammar is 4-gram or less, and the
  263. // maximum attained symbol id will fit into the optimized range.
  264. int64 max_symbol = 0;
  265. if (Symbols() != NULL) max_symbol = Symbols()->AvailableKey() - 1;
  266. // If augmenting the symbol table, assume the worst case when all words in
  267. // the model being read are novel.
  268. if (Options().oov_handling == ArpaParseOptions::kAddToSymbols)
  269. max_symbol += NgramCounts()[0];
  270. if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
  271. impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
  272. } else {
  273. impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
  274. KALDI_LOG << "Reverting to slower state tracking because model is large: "
  275. << NgramCounts().size() << "-gram with symbols up to "
  276. << max_symbol;
  277. }
  278. }
  279. void ArpaLmCompiler::ConsumeNGram(const NGram& ngram) {
  280. // <s> is invalid in tails, </s> in heads of an n-gram.
  281. for (int i = 0; i < ngram.words.size(); ++i) {
  282. if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
  283. (i + 1 < ngram.words.size() &&
  284. ngram.words[i] == Options().eos_symbol)) {
  285. if (ShouldWarn())
  286. KALDI_WARN << LineReference()
  287. << " skipped: n-gram has invalid BOS/EOS placement";
  288. return;
  289. }
  290. }
  291. bool is_highest = ngram.words.size() == NgramCounts().size();
  292. impl_->ConsumeNGram(ngram, is_highest);
  293. }
  294. void ArpaLmCompiler::RemoveRedundantStates() {
  295. fst::StdArc::Label backoff_symbol = sub_eps_;
  296. if (backoff_symbol == 0) {
  297. // The method of removing redundant states implemented in this function
  298. // leads to slow determinization of L o G when people use the older style of
  299. // usage of arpa2fst where the --disambig-symbol option was not specified.
  300. // The issue seems to be that it creates a non-deterministic FST, while G is
  301. // supposed to be deterministic. By 'return'ing below, we just disable this
  302. // method if people were using an older script. This method isn't really
  303. // that consequential anyway, and people will move to the newer-style
  304. // scripts (see current utils/format_lm.sh), so this isn't much of a
  305. // problem.
  306. return;
  307. }
  308. fst::StdArc::StateId num_states = fst_.NumStates();
  309. // replace the #0 symbols on the input of arcs out of redundant states (states
  310. // that are not final and have only a backoff arc leaving them), with <eps>.
  311. for (fst::StdArc::StateId state = 0; state < num_states; state++) {
  312. if (fst_.NumArcs(state) == 1 &&
  313. fst_.Final(state) == fst::TropicalWeight::Zero()) {
  314. fst::MutableArcIterator<fst::StdVectorFst> iter(&fst_, state);
  315. fst::StdArc arc = iter.Value();
  316. if (arc.ilabel == backoff_symbol) {
  317. arc.ilabel = 0;
  318. iter.SetValue(arc);
  319. }
  320. }
  321. }
  322. // we could call fst::RemoveEps, and it would have the same effect in normal
  323. // cases, where backoff_symbol != 0 and there are no epsilons in unexpected
  324. // places, but RemoveEpsLocal is a bit safer in case something weird is going
  325. // on; it guarantees not to blow up the FST.
  326. fst::RemoveEpsLocal(&fst_);
  327. KALDI_LOG << "Reduced num-states from " << num_states << " to "
  328. << fst_.NumStates();
  329. }
  330. void ArpaLmCompiler::Check() const {
  331. if (fst_.Start() == fst::kNoStateId) {
  332. KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol "
  333. << Symbols()->Find(Options().bos_symbol) << ".";
  334. }
  335. }
  336. void ArpaLmCompiler::ReadComplete() {
  337. fst_.SetInputSymbols(Symbols());
  338. fst_.SetOutputSymbols(Symbols());
  339. RemoveRedundantStates();
  340. Check();
  341. }
  342. } // namespace kaldi