You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
5.3 KiB

  1. // lm/arpa-file-parser.h
  2. // Copyright 2014 Guoguo Chen
  3. // Copyright 2016 Smart Action Company LLC (kkm)
  4. // See ../../COPYING for clarification regarding multiple authors
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  13. // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  14. // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  15. // MERCHANTABLITY OR NON-INFRINGEMENT.
  16. // See the Apache 2 License for the specific language governing permissions and
  17. // limitations under the License.
  18. #ifndef KALDI_LM_ARPA_FILE_PARSER_H_
  19. #define KALDI_LM_ARPA_FILE_PARSER_H_
  20. #include <fst/fst-decl.h>
  21. #include <string>
  22. #include <vector>
  23. #include "base/kaldi-types.h"
  24. #include "itf/options-itf.h"
  25. namespace kaldi {
  26. /**
  27. Options that control ArpaFileParser
  28. */
  29. struct ArpaParseOptions {
  30. enum OovHandling {
  31. kRaiseError, ///< Abort on OOV words
  32. kAddToSymbols, ///< Add novel words to the symbol table.
  33. kReplaceWithUnk, ///< Replace OOV words with <unk>.
  34. kSkipNGram ///< Skip n-gram with OOV word and continue.
  35. };
  36. ArpaParseOptions()
  37. : bos_symbol(-1),
  38. eos_symbol(-1),
  39. unk_symbol(-1),
  40. oov_handling(kRaiseError),
  41. max_warnings(30) {}
  42. void Register(OptionsItf* opts) {
  43. // Registering only the max_warnings count, since other options are
  44. // treated differently by client programs: some want integer symbols,
  45. // while other are passed words in their command line.
  46. opts->Register("max-arpa-warnings", &max_warnings,
  47. "Maximum warnings to report on ARPA parsing, "
  48. "0 to disable, -1 to show all");
  49. }
  50. int32 bos_symbol; ///< Symbol for <s>, Required non-epsilon.
  51. int32 eos_symbol; ///< Symbol for </s>, Required non-epsilon.
  52. int32 unk_symbol; ///< Symbol for <unk>, Required for kReplaceWithUnk.
  53. OovHandling oov_handling; ///< How to handle OOV words in the file.
  54. int32 max_warnings; ///< Maximum warnings to report, <0 unlimited.
  55. };
  56. /**
  57. A parsed n-gram from ARPA LM file.
  58. */
  59. struct NGram {
  60. NGram() : logprob(0.0), backoff(0.0) {}
  61. std::vector<int32> words; ///< Symbols in left to right order.
  62. float logprob; ///< Log-prob of the n-gram.
  63. float backoff; ///< log-backoff weight of the n-gram.
  64. ///< Defaults to zero if not specified.
  65. };
  66. /**
  67. ArpaFileParser is an abstract base class for ARPA LM file conversion.
  68. See ConstArpaLmBuilder and ArpaLmCompiler for usage examples.
  69. */
  70. class ArpaFileParser {
  71. public:
  72. /// Constructs the parser with the given options and optional symbol table.
  73. /// If symbol table is provided, then the file should contain text n-grams,
  74. /// and the words are mapped to symbols through it. bos_symbol and
  75. /// eos_symbol in the options structure must be valid symbols in the table,
  76. /// and so must be unk_symbol if provided. The table is not owned by the
  77. /// parser, but may be augmented, if oov_handling is set to kAddToSymbols.
  78. /// If symbol table is a null pointer, the file should contain integer
  79. /// symbol values, and oov_handling has no effect. bos_symbol and eos_symbol
  80. /// must be valid symbols still.
  81. ArpaFileParser(const ArpaParseOptions& options, fst::SymbolTable* symbols);
  82. virtual ~ArpaFileParser();
  83. /// Read ARPA LM file from a stream.
  84. void Read(std::istream& is);
  85. /// Parser options.
  86. const ArpaParseOptions& Options() const { return options_; }
  87. protected:
  88. /// Override called before reading starts. This is the point to prepare
  89. /// any state in the derived class.
  90. virtual void ReadStarted() {}
  91. /// Override function called to signal that ARPA header with the expected
  92. /// number of n-grams has been read, and ngram_counts() is now valid.
  93. virtual void HeaderAvailable() {}
  94. /// Pure override that must be implemented to process current n-gram. The
  95. /// n-grams are sent in the file order, which guarantees that all
  96. /// (k-1)-grams are processed before the first k-gram is.
  97. virtual void ConsumeNGram(const NGram&) = 0;
  98. /// Override function called after the last n-gram has been consumed.
  99. virtual void ReadComplete() {}
  100. /// Read-only access to symbol table. Not owned, do not make public.
  101. const fst::SymbolTable* Symbols() const { return symbols_; }
  102. /// Inside ConsumeNGram(), provides the current line number.
  103. int32 LineNumber() const { return line_number_; }
  104. /// Inside ConsumeNGram(), returns a formatted reference to the line being
  105. /// compiled, to print out as part of diagnostics.
  106. std::string LineReference() const;
  107. /// Increments warning count, and returns true if a warning should be
  108. /// printed or false if the count has exceeded the set maximum.
  109. bool ShouldWarn();
  110. /// N-gram counts. Valid from the point when HeaderAvailable() is called.
  111. const std::vector<int32>& NgramCounts() const { return ngram_counts_; }
  112. private:
  113. ArpaParseOptions options_;
  114. fst::SymbolTable* symbols_; // the pointer is not owned here.
  115. int32 line_number_;
  116. uint32 warning_count_;
  117. std::string current_line_;
  118. std::vector<int32> ngram_counts_;
  119. };
  120. } // namespace kaldi
  121. #endif // KALDI_LM_ARPA_FILE_PARSER_H_