You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
10 KiB

  1. // lm/arpa-file-parser.cc
  2. // Copyright 2014 Guoguo Chen
  3. // Copyright 2016 Smart Action Company LLC (kkm)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #include <fst/fstlib.h>
  19. #include <sstream>
  20. #include "base/kaldi-error.h"
  21. #include "base/kaldi-math.h"
  22. #include "lm/arpa-file-parser.h"
  23. #include "util/text-utils.h"
  24. namespace kaldi {
  25. ArpaFileParser::ArpaFileParser(const ArpaParseOptions& options,
  26. fst::SymbolTable* symbols)
  27. : options_(options),
  28. symbols_(symbols),
  29. line_number_(0),
  30. warning_count_(0) {}
  31. ArpaFileParser::~ArpaFileParser() {}
  32. void TrimTrailingWhitespace(std::string* str) {
  33. str->erase(str->find_last_not_of(" \n\r\t") + 1);
  34. }
  35. void ArpaFileParser::Read(std::istream& is) {
  36. // Argument sanity checks.
  37. if (options_.bos_symbol <= 0 || options_.eos_symbol <= 0 ||
  38. options_.bos_symbol == options_.eos_symbol)
  39. KALDI_ERR << "BOS and EOS symbols are required, must not be epsilons, and "
  40. << "differ from each other. Given:"
  41. << " BOS=" << options_.bos_symbol
  42. << " EOS=" << options_.eos_symbol;
  43. if (symbols_ != NULL &&
  44. options_.oov_handling == ArpaParseOptions::kReplaceWithUnk &&
  45. (options_.unk_symbol <= 0 || options_.unk_symbol == options_.bos_symbol ||
  46. options_.unk_symbol == options_.eos_symbol))
  47. KALDI_ERR << "When symbol table is given and OOV mode is kReplaceWithUnk, "
  48. << "UNK symbol is required, must not be epsilon, and "
  49. << "differ from both BOS and EOS symbols. Given:"
  50. << " UNK=" << options_.unk_symbol
  51. << " BOS=" << options_.bos_symbol
  52. << " EOS=" << options_.eos_symbol;
  53. if (symbols_ != NULL && symbols_->Find(options_.bos_symbol).empty())
  54. KALDI_ERR << "BOS symbol must exist in symbol table";
  55. if (symbols_ != NULL && symbols_->Find(options_.eos_symbol).empty())
  56. KALDI_ERR << "EOS symbol must exist in symbol table";
  57. if (symbols_ != NULL && options_.unk_symbol > 0 &&
  58. symbols_->Find(options_.unk_symbol).empty())
  59. KALDI_ERR << "UNK symbol must exist in symbol table";
  60. ngram_counts_.clear();
  61. line_number_ = 0;
  62. warning_count_ = 0;
  63. current_line_.clear();
  64. #define PARSE_ERR KALDI_ERR << LineReference() << ": "
  65. // Give derived class an opportunity to prepare its state.
  66. ReadStarted();
  67. // Processes "\data\" section.
  68. bool keyword_found = false;
  69. while (++line_number_, getline(is, current_line_) && !is.eof()) {
  70. if (current_line_.find_first_not_of(" \t\n\r") == std::string::npos) {
  71. continue;
  72. }
  73. TrimTrailingWhitespace(&current_line_);
  74. // Continue skipping lines until the \data\ marker alone on a line is found.
  75. if (!keyword_found) {
  76. if (current_line_ == "\\data\\") {
  77. KALDI_LOG << "Reading \\data\\ section.";
  78. keyword_found = true;
  79. }
  80. continue;
  81. }
  82. if (current_line_[0] == '\\') break;
  83. // Enters "\data\" section, and looks for patterns like "ngram 1=1000",
  84. // which means there are 1000 unigrams.
  85. std::size_t equal_symbol_pos = current_line_.find("=");
  86. if (equal_symbol_pos != std::string::npos)
  87. // Guaranteed spaces around the "=".
  88. current_line_.replace(equal_symbol_pos, 1, " = ");
  89. std::vector<std::string> col;
  90. SplitStringToVector(current_line_, " \t", true, &col);
  91. if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") {
  92. int32 order, ngram_count = 0;
  93. if (!ConvertStringToInteger(col[1], &order) ||
  94. !ConvertStringToInteger(col[3], &ngram_count)) {
  95. PARSE_ERR << "cannot parse ngram count";
  96. }
  97. if (ngram_counts_.size() <= order) {
  98. ngram_counts_.resize(order);
  99. }
  100. ngram_counts_[order - 1] = ngram_count;
  101. } else {
  102. KALDI_WARN << LineReference()
  103. << ": uninterpretable line in \\data\\ section";
  104. }
  105. }
  106. if (ngram_counts_.size() == 0)
  107. PARSE_ERR << "\\data\\ section missing or empty.";
  108. // Signal that grammar order and n-gram counts are known.
  109. HeaderAvailable();
  110. NGram ngram;
  111. ngram.words.reserve(ngram_counts_.size());
  112. // Processes "\N-grams:" section.
  113. for (int32 cur_order = 1; cur_order <= ngram_counts_.size(); ++cur_order) {
  114. // Skips n-grams with zero count.
  115. if (ngram_counts_[cur_order - 1] == 0)
  116. KALDI_WARN << "Zero ngram count in ngram order " << cur_order
  117. << "(look for 'ngram " << cur_order << "=0' in the \\data\\ "
  118. << " section). There is possibly a problem with the file.";
  119. // Must be looking at a \k-grams: directive at this point.
  120. std::ostringstream keyword;
  121. keyword << "\\" << cur_order << "-grams:";
  122. if (current_line_ != keyword.str()) {
  123. PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'";
  124. }
  125. KALDI_LOG << "Reading " << current_line_ << " section.";
  126. int32 ngram_count = 0;
  127. while (++line_number_, getline(is, current_line_) && !is.eof()) {
  128. if (current_line_.find_first_not_of(" \n\t\r") == std::string::npos) {
  129. continue;
  130. }
  131. if (current_line_[0] == '\\') {
  132. TrimTrailingWhitespace(&current_line_);
  133. std::ostringstream next_keyword;
  134. next_keyword << "\\" << cur_order + 1 << "-grams:";
  135. if ((current_line_ != next_keyword.str()) &&
  136. (current_line_ != "\\end\\")) {
  137. if (ShouldWarn()) {
  138. KALDI_WARN << "ignoring possible directive '" << current_line_
  139. << "' expecting '" << next_keyword.str() << "'";
  140. if (warning_count_ > 0 &&
  141. warning_count_ > static_cast<uint32>(options_.max_warnings)) {
  142. KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
  143. << options_.max_warnings << " were reported. "
  144. << "Run program with --max-arpa-warnings=-1 "
  145. << "to see all warnings";
  146. }
  147. }
  148. } else {
  149. break;
  150. }
  151. }
  152. std::vector<std::string> col;
  153. SplitStringToVector(current_line_, " \t", true, &col);
  154. if (col.size() < 1 + cur_order || col.size() > 2 + cur_order ||
  155. (cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) {
  156. PARSE_ERR << "Invalid n-gram data line";
  157. }
  158. ++ngram_count;
  159. // Parse out n-gram logprob and, if present, backoff weight.
  160. if (!ConvertStringToReal(col[0], &ngram.logprob)) {
  161. PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'";
  162. }
  163. ngram.backoff = 0.0;
  164. if (col.size() > cur_order + 1) {
  165. if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff))
  166. PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'";
  167. }
  168. // Convert to natural log.
  169. ngram.logprob *= M_LN10;
  170. ngram.backoff *= M_LN10;
  171. ngram.words.resize(cur_order);
  172. bool skip_ngram = false;
  173. for (int32 index = 0; !skip_ngram && index < cur_order; ++index) {
  174. int32 word;
  175. if (symbols_) {
  176. // Symbol table provided, so symbol labels are expected.
  177. if (options_.oov_handling == ArpaParseOptions::kAddToSymbols) {
  178. word = symbols_->AddSymbol(col[1 + index]);
  179. } else {
  180. word = symbols_->Find(col[1 + index]);
  181. if (word == -1) { // fst::kNoSymbol
  182. switch (options_.oov_handling) {
  183. case ArpaParseOptions::kReplaceWithUnk:
  184. word = options_.unk_symbol;
  185. break;
  186. case ArpaParseOptions::kSkipNGram:
  187. if (ShouldWarn())
  188. KALDI_WARN << LineReference() << " skipped: word '"
  189. << col[1 + index] << "' not in symbol table";
  190. skip_ngram = true;
  191. break;
  192. default:
  193. PARSE_ERR << "word '" << col[1 + index]
  194. << "' not in symbol table";
  195. }
  196. }
  197. }
  198. } else {
  199. // Symbols not provided, LM file should contain integers.
  200. if (!ConvertStringToInteger(col[1 + index], &word) || word < 0) {
  201. PARSE_ERR << "invalid symbol '" << col[1 + index] << "'";
  202. }
  203. }
  204. // Whichever way we got it, an epsilon is invalid.
  205. if (word == 0) {
  206. PARSE_ERR << "epsilon symbol '" << col[1 + index]
  207. << "' is illegal in ARPA LM";
  208. }
  209. ngram.words[index] = word;
  210. }
  211. if (!skip_ngram) {
  212. ConsumeNGram(ngram);
  213. }
  214. }
  215. if (ngram_count > ngram_counts_[cur_order - 1]) {
  216. PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1]
  217. << " n-grams of order " << cur_order
  218. << ", but we saw more already.";
  219. }
  220. }
  221. if (current_line_ != "\\end\\") {
  222. PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\";
  223. }
  224. if (warning_count_ > 0 &&
  225. warning_count_ > static_cast<uint32>(options_.max_warnings)) {
  226. KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
  227. << options_.max_warnings << " were reported. Run program with "
  228. << "--max_warnings=-1 to see all warnings";
  229. }
  230. current_line_.clear();
  231. ReadComplete();
  232. #undef PARSE_ERR
  233. }
  234. std::string ArpaFileParser::LineReference() const {
  235. std::ostringstream ss;
  236. ss << "line " << line_number_ << " [" << current_line_ << "]";
  237. return ss.str();
  238. }
  239. bool ArpaFileParser::ShouldWarn() {
  240. return (warning_count_ != -1) &&
  241. (++warning_count_ <= static_cast<uint32>(options_.max_warnings));
  242. }
  243. } // namespace kaldi