You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

597 lines
19 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes to provide symbol-to-integer and integer-to-symbol mappings.
  19. #ifndef FST_SYMBOL_TABLE_H_
  20. #define FST_SYMBOL_TABLE_H_
  21. #include <sys/types.h>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <functional>
  25. #include <ios>
  26. #include <iostream>
  27. #include <istream>
  28. #include <iterator>
  29. #include <memory>
  30. #include <ostream>
  31. #include <sstream>
  32. #include <string>
  33. #include <type_traits>
  34. #include <utility>
  35. #include <vector>
  36. #include <fst/compat.h>
  37. #include <fst/flags.h>
  38. #include <fst/log.h>
  39. #include <fstream>
  40. #include <fst/windows_defs.inc>
  41. #include <map>
  42. #include <functional>
  43. #include <string_view>
  44. #include <fst/lock.h>
  45. DECLARE_bool(fst_compat_symbols);
  46. DECLARE_string(fst_field_separator);
  47. namespace fst {
  48. inline constexpr int64_t kNoSymbol = -1;
  49. class SymbolTable;
  50. namespace internal {
  51. // Maximum line length in textual symbols file.
  52. inline constexpr int kLineLen = 8096;
  53. // List of symbols with a dense hash for looking up symbol index, rehashing at
  54. // 75% occupancy.
  55. class DenseSymbolMap {
  56. public:
  57. DenseSymbolMap();
  58. std::pair<int64_t, bool> InsertOrFind(std::string_view key);
  59. int64_t Find(std::string_view key) const;
  60. size_t Size() const { return symbols_.size(); }
  61. const std::string &GetSymbol(size_t idx) const { return symbols_[idx]; }
  62. void RemoveSymbol(size_t idx);
  63. void ShrinkToFit();
  64. private:
  65. static constexpr int64_t kEmptyBucket = -1;
  66. // num_buckets must be power of 2.
  67. void Rehash(size_t num_buckets);
  68. size_t GetHash(std::string_view key) const {
  69. return str_hash_(key) & hash_mask_;
  70. }
  71. const std::hash<std::string_view> str_hash_;
  72. std::vector<std::string> symbols_;
  73. std::vector<int64_t> buckets_;
  74. uint64_t hash_mask_;
  75. };
  76. // Base class for SymbolTable implementations.
  77. // Use either MutableSymbolTableImpl or ConstSymbolTableImpl to derive
  78. // implementation classes.
  79. class SymbolTableImplBase {
  80. public:
  81. SymbolTableImplBase() = default;
  82. virtual ~SymbolTableImplBase() = default;
  83. // Enforce copying through Copy().
  84. SymbolTableImplBase(const SymbolTableImplBase &) = delete;
  85. SymbolTableImplBase &operator=(const SymbolTableImplBase &) = delete;
  86. virtual std::unique_ptr<SymbolTableImplBase> Copy() const = 0;
  87. virtual bool Write(std::ostream &strm) const = 0;
  88. virtual int64_t AddSymbol(std::string_view symbol, int64_t key) = 0;
  89. virtual int64_t AddSymbol(std::string_view symbol) = 0;
  90. // Removes the symbol with the specified key. Subsequent Find() calls
  91. // for this key will return the empty string. Does not affect the keys
  92. // of other symbols.
  93. virtual void RemoveSymbol(int64_t key) = 0;
  94. // Returns the symbol for the specified key, or the empty string if not found.
  95. virtual std::string Find(int64_t key) const = 0;
  96. // Returns the key for the specified symbol, or kNoSymbol if not found.
  97. virtual int64_t Find(std::string_view symbol) const = 0;
  98. virtual bool Member(int64_t key) const { return !Find(key).empty(); }
  99. virtual bool Member(std::string_view symbol) const {
  100. return Find(symbol) != kNoSymbol;
  101. }
  102. virtual void AddTable(const SymbolTable &table) = 0;
  103. virtual int64_t GetNthKey(ssize_t pos) const = 0;
  104. virtual const std::string &Name() const = 0;
  105. virtual void SetName(std::string_view new_name) = 0;
  106. virtual const std::string &CheckSum() const = 0;
  107. virtual const std::string &LabeledCheckSum() const = 0;
  108. virtual int64_t AvailableKey() const = 0;
  109. virtual size_t NumSymbols() const = 0;
  110. virtual bool IsMutable() const = 0;
  111. };
  112. // Base class for SymbolTable implementations supporting Add/Remove.
  113. class MutableSymbolTableImpl : public SymbolTableImplBase {
  114. public:
  115. void AddTable(const SymbolTable &table) override;
  116. bool IsMutable() const final { return true; }
  117. };
  118. // Base class for immutable SymbolTable implementations.
  119. class ConstSymbolTableImpl : public SymbolTableImplBase {
  120. public:
  121. std::unique_ptr<SymbolTableImplBase> Copy() const final;
  122. int64_t AddSymbol(std::string_view symbol, int64_t key) final;
  123. int64_t AddSymbol(std::string_view symbol) final;
  124. void RemoveSymbol(int64_t key) final;
  125. void SetName(std::string_view new_name) final;
  126. void AddTable(const SymbolTable &table) final;
  127. bool IsMutable() const final { return false; }
  128. };
  129. // Default SymbolTable implementation using DenseSymbolMap and std::map.
  130. // Provides the common text and binary format serialization.
  131. class SymbolTableImpl final : public MutableSymbolTableImpl {
  132. public:
  133. explicit SymbolTableImpl(std::string_view name)
  134. : name_(name),
  135. available_key_(0),
  136. dense_key_limit_(0),
  137. check_sum_finalized_(false) {}
  138. SymbolTableImpl(const SymbolTableImpl &impl)
  139. : name_(impl.name_),
  140. available_key_(impl.available_key_),
  141. dense_key_limit_(impl.dense_key_limit_),
  142. symbols_(impl.symbols_),
  143. idx_key_(impl.idx_key_),
  144. key_map_(impl.key_map_),
  145. check_sum_finalized_(false) {}
  146. std::unique_ptr<SymbolTableImplBase> Copy() const override {
  147. return std::make_unique<SymbolTableImpl>(*this);
  148. }
  149. int64_t AddSymbol(std::string_view symbol, int64_t key) override;
  150. int64_t AddSymbol(std::string_view symbol) override {
  151. return AddSymbol(symbol, available_key_);
  152. }
  153. // Removes the symbol with the given key. The removal is costly
  154. // (O(NumSymbols)) and may reduce the efficiency of Find() because of a
  155. // potentially reduced size of the dense key interval.
  156. void RemoveSymbol(int64_t key) override;
  157. static SymbolTableImpl * ReadText(
  158. std::istream &strm, std::string_view name,
  159. // Characters to be used as a separator between fields in a textual
  160. // `SymbolTable` file, encoded as a string. Each byte in the string is
  161. // considered a valid separator. Multi-byte separators are not permitted.
  162. // The default value, "\t ", accepts space and tab.
  163. const std::string &sep = FST_FLAGS_fst_field_separator);
  164. // Reads a binary SymbolTable from stream, using source in error messages.
  165. static SymbolTableImpl * Read(std::istream &strm,
  166. std::string_view source);
  167. bool Write(std::ostream &strm) const override;
  168. // Returns the string associated with the key. If the key is out of
  169. // range (<0, >max), return an empty string.
  170. std::string Find(int64_t key) const override;
  171. // Returns the key associated with the symbol; if the symbol
  172. // does not exists, returns kNoSymbol.
  173. int64_t Find(std::string_view symbol) const override {
  174. int64_t idx = symbols_.Find(symbol);
  175. if (idx == kNoSymbol || idx < dense_key_limit_) return idx;
  176. return idx_key_[idx - dense_key_limit_];
  177. }
  178. int64_t GetNthKey(ssize_t pos) const override {
  179. if (pos < 0 || static_cast<size_t>(pos) >= symbols_.Size()) {
  180. return kNoSymbol;
  181. } else if (pos < dense_key_limit_) {
  182. return pos;
  183. }
  184. return Find(symbols_.GetSymbol(pos));
  185. }
  186. const std::string &Name() const override { return name_; }
  187. void SetName(std::string_view new_name) override {
  188. name_ = std::string(new_name);
  189. }
  190. const std::string &CheckSum() const override {
  191. MaybeRecomputeCheckSum();
  192. return check_sum_string_;
  193. }
  194. const std::string &LabeledCheckSum() const override {
  195. MaybeRecomputeCheckSum();
  196. return labeled_check_sum_string_;
  197. }
  198. int64_t AvailableKey() const override { return available_key_; }
  199. size_t NumSymbols() const override { return symbols_.Size(); }
  200. void ShrinkToFit();
  201. private:
  202. // Recomputes the checksums (both of them) if we've had changes since the last
  203. // computation (i.e., if check_sum_finalized_ is false).
  204. // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
  205. // if the checksum is up-to-date (requiring no recomputation).
  206. void MaybeRecomputeCheckSum() const;
  207. std::string name_;
  208. int64_t available_key_;
  209. int64_t dense_key_limit_;
  210. DenseSymbolMap symbols_;
  211. // Maps index to key for index >= dense_key_limit:
  212. // key = idx_key_[index - dense_key_limit]
  213. std::vector<int64_t> idx_key_;
  214. // Maps key to index for key >= dense_key_limit_.
  215. // index = key_map_[key]
  216. std::map<int64_t, int64_t> key_map_;
  217. mutable bool check_sum_finalized_;
  218. mutable std::string check_sum_string_;
  219. mutable std::string labeled_check_sum_string_;
  220. mutable Mutex check_sum_mutex_;
  221. };
  222. } // namespace internal
  223. // Symbol (string) to integer (and reverse) mapping.
  224. //
  225. // The SymbolTable implements the mappings of labels to strings and reverse.
  226. // SymbolTables are used to describe the alphabet of the input and output
  227. // labels for arcs in a Finite State Transducer.
  228. //
  229. // SymbolTables are reference-counted and can therefore be shared across
  230. // multiple machines. For example a language model grammar G, with a
  231. // SymbolTable for the words in the language model can share this symbol
  232. // table with the lexical representation L o G.
  233. class SymbolTable {
  234. public:
  235. class iterator {
  236. public:
  237. // TODO(wolfsonkin): Expand `SymbolTable::iterator` to be a random access
  238. // iterator.
  239. using iterator_category = std::input_iterator_tag;
  240. class value_type {
  241. public:
  242. // Return the label of the current symbol.
  243. int64_t Label() const { return key_; }
  244. // Return the string of the current symbol.
  245. // TODO(wolfsonkin): Consider adding caching.
  246. std::string Symbol() const { return table_->Find(key_); }
  247. private:
  248. explicit value_type(const SymbolTable &table, ssize_t pos)
  249. : table_(&table), key_(table.GetNthKey(pos)) {}
  250. // Sets this item to the pos'th element in the symbol table
  251. void SetPosition(ssize_t pos) { key_ = table_->GetNthKey(pos); }
  252. friend class SymbolTable::iterator;
  253. const SymbolTable *table_; // Does not own the underlying SymbolTable.
  254. int64_t key_;
  255. };
  256. using difference_type = std::ptrdiff_t;
  257. using pointer = const value_type *const;
  258. using reference = const value_type &;
  259. iterator &operator++() {
  260. ++pos_;
  261. if (static_cast<size_t>(pos_) < nsymbols_) iter_item_.SetPosition(pos_);
  262. return *this;
  263. }
  264. iterator operator++(int) {
  265. iterator retval = *this;
  266. ++(*this);
  267. return retval;
  268. }
  269. bool operator==(const iterator &that) const { return (pos_ == that.pos_); }
  270. bool operator!=(const iterator &that) const { return !(*this == that); }
  271. reference operator*() { return iter_item_; }
  272. pointer operator->() const { return &iter_item_; }
  273. private:
  274. explicit iterator(const SymbolTable &table, ssize_t pos = 0)
  275. : pos_(pos), nsymbols_(table.NumSymbols()), iter_item_(table, pos) {}
  276. friend class SymbolTable;
  277. ssize_t pos_;
  278. size_t nsymbols_;
  279. value_type iter_item_;
  280. };
  281. using const_iterator = iterator;
  282. // Constructs symbol table with an optional name.
  283. explicit SymbolTable(std::string_view name = "<unspecified>")
  284. : impl_(std::make_shared<internal::SymbolTableImpl>(name)) {}
  285. virtual ~SymbolTable() = default;
  286. // Reads a text representation of the symbol table from an istream. Pass a
  287. // name to give the resulting SymbolTable.
  288. static SymbolTable *ReadText(
  289. std::istream &strm, std::string_view name,
  290. const std::string &sep = FST_FLAGS_fst_field_separator) {
  291. auto impl =
  292. fst::WrapUnique(internal::SymbolTableImpl::ReadText(strm, name, sep));
  293. return impl ? new SymbolTable(std::move(impl)) : nullptr;
  294. }
  295. // Reads a text representation of the symbol table.
  296. static SymbolTable * ReadText(
  297. const std::string &source,
  298. const std::string &sep = FST_FLAGS_fst_field_separator);
  299. // Reads a binary dump of the symbol table from a stream.
  300. static SymbolTable *Read(std::istream &strm, std::string_view source) {
  301. auto impl = fst::WrapUnique(internal::SymbolTableImpl::Read(strm, source));
  302. return impl ? new SymbolTable(std::move(impl)) : nullptr;
  303. }
  304. // Reads a binary dump of the symbol table.
  305. static SymbolTable *Read(const std::string &source) {
  306. std::ifstream strm(source, std::ios_base::in | std::ios_base::binary);
  307. if (!strm.good()) {
  308. LOG(ERROR) << "SymbolTable::Read: Can't open file: " << source;
  309. return nullptr;
  310. }
  311. return Read(strm, source);
  312. }
  313. // Creates a reference counted copy.
  314. virtual SymbolTable *Copy() const { return new SymbolTable(*this); }
  315. // Adds another symbol table to this table. All keys will be offset by the
  316. // current available key (highest key in the symbol table). Note string
  317. // symbols with the same key will still have the same key after the symbol
  318. // table has been merged, but a different value. Adding symbol tables do not
  319. // result in changes in the base table.
  320. void AddTable(const SymbolTable &table) {
  321. MutateCheck();
  322. impl_->AddTable(table);
  323. }
  324. // Adds a symbol with given key to table. A symbol table also keeps track of
  325. // the last available key (highest key value in the symbol table).
  326. int64_t AddSymbol(std::string_view symbol, int64_t key) {
  327. MutateCheck();
  328. return impl_->AddSymbol(symbol, key);
  329. }
  330. // Adds a symbol to the table. The associated value key is automatically
  331. // assigned by the symbol table.
  332. int64_t AddSymbol(std::string_view symbol) {
  333. MutateCheck();
  334. return impl_->AddSymbol(symbol);
  335. }
  336. // Returns the current available key (i.e., highest key + 1) in the symbol
  337. // table.
  338. int64_t AvailableKey() const { return impl_->AvailableKey(); }
  339. // Return the label-agnostic MD5 check-sum for this table. All new symbols
  340. // added to the table will result in an updated checksum.
  341. OPENFST_DEPRECATED("Use `LabeledCheckSum()` instead.")
  342. const std::string &CheckSum() const { return impl_->CheckSum(); }
  343. int64_t GetNthKey(ssize_t pos) const { return impl_->GetNthKey(pos); }
  344. // Returns the string associated with the key; if the key is out of
  345. // range (<0, >max), returns an empty string.
  346. std::string Find(int64_t key) const { return impl_->Find(key); }
  347. // Returns the key associated with the symbol; if the symbol does not exist,
  348. // kNoSymbol is returned.
  349. int64_t Find(std::string_view symbol) const { return impl_->Find(symbol); }
  350. // Same as CheckSum(), but returns an label-dependent version.
  351. const std::string &LabeledCheckSum() const {
  352. return impl_->LabeledCheckSum();
  353. }
  354. bool Member(int64_t key) const { return impl_->Member(key); }
  355. bool Member(std::string_view symbol) const { return impl_->Member(symbol); }
  356. // Returns the name of the symbol table.
  357. const std::string &Name() const { return impl_->Name(); }
  358. // Returns the current number of symbols in table (not necessarily equal to
  359. // AvailableKey()).
  360. size_t NumSymbols() const { return impl_->NumSymbols(); }
  361. void RemoveSymbol(int64_t key) {
  362. MutateCheck();
  363. return impl_->RemoveSymbol(key);
  364. }
  365. // Sets the name of the symbol table.
  366. void SetName(std::string_view new_name) {
  367. MutateCheck();
  368. impl_->SetName(new_name);
  369. }
  370. bool Write(std::ostream &strm) const { return impl_->Write(strm); }
  371. bool Write(const std::string &source) const;
  372. // Dumps a text representation of the symbol table via a stream.
  373. bool WriteText(
  374. std::ostream &strm,
  375. // Characters to be used as a separator between fields in a textual
  376. // `SymbolTable` file, encoded as a string. Each byte in the string is
  377. // considered a valid separator. Multi-byte separators are not permitted.
  378. // The default value, "\t ", outputs tab.
  379. const std::string &sep = FST_FLAGS_fst_field_separator) const;
  380. // Dumps a text representation of the symbol table.
  381. bool WriteText(const std::string &sink,
  382. const std::string &sep = FST_FLAGS_fst_field_separator) const;
  383. const_iterator begin() const { return const_iterator(*this, 0); }
  384. const_iterator end() const { return const_iterator(*this, NumSymbols()); }
  385. const_iterator cbegin() const { return begin(); }
  386. const_iterator cend() const { return end(); }
  387. protected:
  388. explicit SymbolTable(std::shared_ptr<internal::SymbolTableImplBase> impl)
  389. : impl_(std::move(impl)) {}
  390. template <class T = internal::SymbolTableImplBase>
  391. const T *Impl() const {
  392. return down_cast<const T *>(impl_.get());
  393. }
  394. template <class T = internal::SymbolTableImplBase>
  395. T *MutableImpl() {
  396. MutateCheck();
  397. return down_cast<T *>(impl_.get());
  398. }
  399. private:
  400. void MutateCheck() {
  401. if (impl_.unique() || !impl_->IsMutable()) return;
  402. std::unique_ptr<internal::SymbolTableImplBase> copy = impl_->Copy();
  403. CHECK(copy != nullptr);
  404. impl_ = std::move(copy);
  405. }
  406. std::shared_ptr<internal::SymbolTableImplBase> impl_;
  407. };
  408. // Iterator class for symbols in a symbol table.
  409. class OPENFST_DEPRECATED(
  410. "Use SymbolTable::iterator, a C++ compliant iterator, instead")
  411. SymbolTableIterator {
  412. public:
  413. explicit SymbolTableIterator(const SymbolTable &table)
  414. : table_(table), iter_(table.begin()), end_(table.end()) {}
  415. ~SymbolTableIterator() = default;
  416. // Returns whether iterator is done.
  417. bool Done() const { return (iter_ == end_); }
  418. // Return the key of the current symbol.
  419. int64_t Value() const { return iter_->Label(); }
  420. // Return the string of the current symbol.
  421. std::string Symbol() const { return iter_->Symbol(); }
  422. // Advances iterator.
  423. void Next() { ++iter_; }
  424. // Resets iterator.
  425. void Reset() { iter_ = table_.begin(); }
  426. private:
  427. const SymbolTable &table_;
  428. SymbolTable::iterator iter_;
  429. const SymbolTable::iterator end_;
  430. };
  431. // Relabels a symbol table as specified by the input vector of pairs
  432. // (old label, new label). The new symbol table only retains symbols
  433. // for which a relabeling is explicitly specified.
  434. //
  435. // TODO(allauzen): consider adding options to allow for some form of implicit
  436. // identity relabeling.
  437. template <class Label>
  438. SymbolTable *RelabelSymbolTable(
  439. const SymbolTable *table,
  440. const std::vector<std::pair<Label, Label>> &pairs) {
  441. auto new_table = std::make_unique<SymbolTable>(
  442. table->Name().empty() ? std::string()
  443. : (std::string("relabeled_") + table->Name()));
  444. for (const auto &[old_label, new_label] : pairs) {
  445. new_table->AddSymbol(table->Find(old_label), new_label);
  446. }
  447. return new_table.release();
  448. }
  449. // Returns true if the two symbol tables have equal checksums. Passing in
  450. // nullptr for either table always returns true.
  451. bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
  452. bool warning = true);
  453. // Symbol table serialization.
  454. void SymbolTableToString(const SymbolTable *table, std::string *result);
  455. SymbolTable *StringToSymbolTable(const std::string &str);
  456. } // namespace fst
  457. #endif // FST_SYMBOL_TABLE_H_