You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

489 lines
14 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes for representing a bijective mapping between an arbitrary entry
  19. // of type T and a signed integral ID.
  20. #ifndef FST_BI_TABLE_H_
  21. #define FST_BI_TABLE_H_
  22. #include <sys/types.h>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <deque>
  26. #include <functional>
  27. #include <memory>
  28. #include <type_traits>
  29. #include <unordered_set>
  30. #include <vector>
  31. #include <fst/log.h>
  32. #include <fst/memory.h>
  33. #include <fst/windows_defs.inc>
  34. #include <unordered_map>
  35. #include <unordered_set>
  36. #include <functional>
  37. namespace fst {
  38. // Bitables model bijective mappings between entries of an arbitrary type T and
  39. // an signed integral ID of type I. The IDs are allocated starting from 0 in
  40. // order.
  41. //
  42. // template <class I, class T>
  43. // class BiTable {
  44. // public:
  45. //
  46. // // Required constructors.
  47. // BiTable();
  48. //
  49. // // Looks up integer ID from entry. If it doesn't exist and insert
  50. // / is true, adds it; otherwise, returns -1.
  51. // I FindId(const T &entry, bool insert = true);
  52. //
  53. // // Looks up entry from integer ID.
  54. // const T &FindEntry(I) const;
  55. //
  56. // // Returns number of stored entries.
  57. // I Size() const;
  58. // };
  59. // An implementation using a hash map for the entry to ID mapping. H is the
  60. // hash function and E is the equality function.
  61. template <class I, class T, class H = std::hash<T>, class E = std::equal_to<T>>
  62. class HashBiTable {
  63. public:
  64. // Reserves space for table_size elements.
  65. explicit HashBiTable(size_t table_size = 0, const H &h = H(),
  66. const E &e = E())
  67. : hash_func_(h),
  68. hash_equal_(e),
  69. entry2id_(table_size, hash_func_, hash_equal_) {
  70. if (table_size) id2entry_.reserve(table_size);
  71. }
  72. HashBiTable(const HashBiTable<I, T, H, E> &table)
  73. : hash_func_(table.hash_func_),
  74. hash_equal_(table.hash_equal_),
  75. entry2id_(table.entry2id_.begin(), table.entry2id_.end(),
  76. table.entry2id_.size(), hash_func_, hash_equal_),
  77. id2entry_(table.id2entry_) {}
  78. I FindId(const T &entry, bool insert = true) {
  79. if (!insert) {
  80. const auto it = entry2id_.find(entry);
  81. return it == entry2id_.end() ? -1 : it->second - 1;
  82. }
  83. I &id_ref = entry2id_[entry];
  84. if (id_ref == 0) { // T not found; stores and assigns a new ID.
  85. id2entry_.push_back(entry);
  86. id_ref = id2entry_.size();
  87. }
  88. return id_ref - 1; // NB: id_ref = ID + 1.
  89. }
  90. const T &FindEntry(I s) const { return id2entry_[s]; }
  91. I Size() const { return id2entry_.size(); }
  92. // TODO(riley): Add fancy clear-to-size, as in CompactHashBiTable.
  93. void Clear() {
  94. entry2id_.clear();
  95. id2entry_.clear();
  96. }
  97. private:
  98. H hash_func_;
  99. E hash_equal_;
  100. std::unordered_map<T, I, H, E> entry2id_;
  101. std::vector<T> id2entry_;
  102. };
  103. // Enables alternative hash set representations below.
  104. enum HSType { HS_STL, HS_FLAT };
  105. // Default hash set is STL hash_set.
  106. template <class K, class H, class E, HSType HS>
  107. struct HashSet : public std::unordered_set<K, H, E, PoolAllocator<K>> {
  108. private:
  109. using Base = std::unordered_set<K, H, E, PoolAllocator<K>>;
  110. public:
  111. using Base::Base;
  112. void rehash(size_t n) {}
  113. };
  114. // An implementation using a hash set for the entry to ID mapping. The hash set
  115. // holds keys which are either the ID or kCurrentKey. These keys can be mapped
  116. // to entries either by looking up in the entry vector or, if kCurrentKey, in
  117. // current_entry_. The hash and key equality functions map to entries first. H
  118. // is the hash function and E is the equality function.
  119. template <class I, class T, class H = std::hash<T>, class E = std::equal_to<T>,
  120. HSType HS = HS_FLAT>
  121. class CompactHashBiTable {
  122. static_assert(HS == HS_STL || HS == HS_FLAT, "Unsupported hash set type");
  123. public:
  124. friend class HashFunc;
  125. friend class HashEqual;
  126. // Reserves space for table_size elements.
  127. explicit CompactHashBiTable(size_t table_size = 0, const H &h = H(),
  128. const E &e = E())
  129. : hash_func_(h),
  130. hash_equal_(e),
  131. compact_hash_func_(*this),
  132. compact_hash_equal_(*this),
  133. keys_(table_size, compact_hash_func_, compact_hash_equal_) {
  134. if (table_size) id2entry_.reserve(table_size);
  135. }
  136. CompactHashBiTable(const CompactHashBiTable &table)
  137. : hash_func_(table.hash_func_),
  138. hash_equal_(table.hash_equal_),
  139. compact_hash_func_(*this),
  140. compact_hash_equal_(*this),
  141. id2entry_(table.id2entry_),
  142. keys_(table.keys_.begin(), table.keys_.end(), table.keys_.size(),
  143. compact_hash_func_, compact_hash_equal_) {}
  144. I FindId(const T &entry, bool insert = true) {
  145. current_entry_ = &entry;
  146. if (insert) {
  147. auto [iter, was_inserted] = keys_.insert(kCurrentKey);
  148. if (!was_inserted) return *iter; // Already exists.
  149. // Overwrites kCurrentKey with a new key value; this is safe because it
  150. // doesn't affect hashing or equality testing.
  151. I key = id2entry_.size();
  152. const_cast<I &>(*iter) = key;
  153. id2entry_.push_back(entry);
  154. return key;
  155. }
  156. const auto it = keys_.find(kCurrentKey);
  157. return it == keys_.end() ? -1 : *it;
  158. }
  159. const T &FindEntry(I s) const { return id2entry_[s]; }
  160. I Size() const { return id2entry_.size(); }
  161. // Clears content; with argument, erases last n IDs.
  162. void Clear(ssize_t n = -1) {
  163. if (n < 0 || n >= id2entry_.size()) { // Clears completely.
  164. keys_.clear();
  165. id2entry_.clear();
  166. } else if (n == id2entry_.size() - 1) { // Leaves only key 0.
  167. const T entry = FindEntry(0);
  168. keys_.clear();
  169. id2entry_.clear();
  170. FindId(entry, true);
  171. } else {
  172. while (n-- > 0) {
  173. I key = id2entry_.size() - 1;
  174. keys_.erase(key);
  175. id2entry_.pop_back();
  176. }
  177. keys_.rehash(0);
  178. }
  179. }
  180. private:
  181. static_assert(std::is_signed_v<I>, "I must be a signed type");
  182. // ... otherwise >= kCurrentKey comparisons as used below don't work.
  183. // TODO(rybach): (1) don't use >= for key comparison, (2) allow unsigned key
  184. // types.
  185. static constexpr I kCurrentKey = -1;
  186. class HashFunc {
  187. public:
  188. explicit HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {}
  189. size_t operator()(I k) const {
  190. if (k >= kCurrentKey) {
  191. return (ht_->hash_func_)(ht_->Key2Entry(k));
  192. } else {
  193. return 0;
  194. }
  195. }
  196. private:
  197. const CompactHashBiTable *ht_;
  198. };
  199. class HashEqual {
  200. public:
  201. explicit HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {}
  202. bool operator()(I k1, I k2) const {
  203. if (k1 == k2) {
  204. return true;
  205. } else if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
  206. return (ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2));
  207. } else {
  208. return false;
  209. }
  210. }
  211. private:
  212. const CompactHashBiTable *ht_;
  213. };
  214. using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>;
  215. const T &Key2Entry(I k) const {
  216. if (k == kCurrentKey) {
  217. return *current_entry_;
  218. } else {
  219. return id2entry_[k];
  220. }
  221. }
  222. H hash_func_;
  223. E hash_equal_;
  224. HashFunc compact_hash_func_;
  225. HashEqual compact_hash_equal_;
  226. std::vector<T> id2entry_;
  227. KeyHashSet keys_;
  228. const T *current_entry_;
  229. };
  230. // An implementation using a vector for the entry to ID mapping. It is passed a
  231. // function object FP that should fingerprint entries uniquely to an integer
  232. // that can used as a vector index. Normally, VectorBiTable constructs the FP
  233. // object. The user can instead pass in this object.
  234. template <class I, class T, class FP>
  235. class VectorBiTable {
  236. public:
  237. // Reserves table_size cells of space.
  238. explicit VectorBiTable(const FP &fp = FP(), size_t table_size = 0) : fp_(fp) {
  239. if (table_size) id2entry_.reserve(table_size);
  240. }
  241. VectorBiTable(const VectorBiTable<I, T, FP> &table)
  242. : fp_(table.fp_), fp2id_(table.fp2id_), id2entry_(table.id2entry_) {}
  243. I FindId(const T &entry, bool insert = true) {
  244. ssize_t fp = (fp_)(entry);
  245. if (fp >= fp2id_.size()) fp2id_.resize(fp + 1);
  246. I &id_ref = fp2id_[fp];
  247. if (id_ref == 0) { // T not found.
  248. if (insert) { // Stores and assigns a new ID.
  249. id2entry_.push_back(entry);
  250. id_ref = id2entry_.size();
  251. } else {
  252. return -1;
  253. }
  254. }
  255. return id_ref - 1; // NB: id_ref = ID + 1.
  256. }
  257. const T &FindEntry(I s) const { return id2entry_[s]; }
  258. I Size() const { return id2entry_.size(); }
  259. const FP &Fingerprint() const { return fp_; }
  260. private:
  261. FP fp_;
  262. std::vector<I> fp2id_;
  263. std::vector<T> id2entry_;
  264. };
  265. // An implementation using a vector and a compact hash table. The selecting
  266. // functor S returns true for entries to be hashed in the vector. The
  267. // fingerprinting functor FP returns a unique fingerprint for each entry to be
  268. // hashed in the vector (these need to be suitable for indexing in a vector).
  269. // The hash functor H is used when hashing entry into the compact hash table.
  270. template <class I, class T, class S, class FP, class H = std::hash<T>,
  271. HSType HS = HS_FLAT>
  272. class VectorHashBiTable {
  273. public:
  274. friend class HashFunc;
  275. friend class HashEqual;
  276. explicit VectorHashBiTable(const S &s = S(), const FP &fp = FP(),
  277. const H &h = H(), size_t vector_size = 0,
  278. size_t entry_size = 0)
  279. : selector_(s),
  280. fp_(fp),
  281. h_(h),
  282. hash_func_(*this),
  283. hash_equal_(*this),
  284. keys_(0, hash_func_, hash_equal_) {
  285. if (vector_size) fp2id_.reserve(vector_size);
  286. if (entry_size) id2entry_.reserve(entry_size);
  287. }
  288. VectorHashBiTable(const VectorHashBiTable<I, T, S, FP, H, HS> &table)
  289. : selector_(table.s_),
  290. fp_(table.fp_),
  291. h_(table.h_),
  292. id2entry_(table.id2entry_),
  293. fp2id_(table.fp2id_),
  294. hash_func_(*this),
  295. hash_equal_(*this),
  296. keys_(table.keys_.size(), hash_func_, hash_equal_) {
  297. keys_.insert(table.keys_.begin(), table.keys_.end());
  298. }
  299. I FindId(const T &entry, bool insert = true) {
  300. if ((selector_)(entry)) { // Uses the vector if selector_(entry) == true.
  301. uint64_t fp = (fp_)(entry);
  302. if (fp2id_.size() <= fp) fp2id_.resize(fp + 1, 0);
  303. if (fp2id_[fp] == 0) { // T not found.
  304. if (insert) { // Stores and assigns a new ID.
  305. id2entry_.push_back(entry);
  306. fp2id_[fp] = id2entry_.size();
  307. } else {
  308. return -1;
  309. }
  310. }
  311. return fp2id_[fp] - 1; // NB: assoc_value = ID + 1.
  312. } else { // Uses the hash table otherwise.
  313. current_entry_ = &entry;
  314. if (const auto it = keys_.find(kCurrentKey); it != keys_.end()) {
  315. return *it;
  316. } else {
  317. if (insert) {
  318. I key = id2entry_.size();
  319. id2entry_.push_back(entry);
  320. keys_.insert(key);
  321. return key;
  322. } else {
  323. return -1;
  324. }
  325. }
  326. }
  327. }
  328. const T &FindEntry(I s) const { return id2entry_[s]; }
  329. I Size() const { return id2entry_.size(); }
  330. const S &Selector() const { return selector_; }
  331. const FP &Fingerprint() const { return fp_; }
  332. const H &HashFunction() const { return h_; }
  333. private:
  334. static constexpr I kCurrentKey = -1;
  335. class HashFunc {
  336. public:
  337. explicit HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {}
  338. size_t operator()(I k) const {
  339. if (k >= kCurrentKey) {
  340. return (ht_->h_)(ht_->Key2Entry(k));
  341. } else {
  342. return 0;
  343. }
  344. }
  345. private:
  346. const VectorHashBiTable *ht_;
  347. };
  348. class HashEqual {
  349. public:
  350. explicit HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {}
  351. bool operator()(I k1, I k2) const {
  352. if (k1 >= kCurrentKey && k2 >= kCurrentKey) {
  353. return ht_->Key2Entry(k1) == ht_->Key2Entry(k2);
  354. } else {
  355. return k1 == k2;
  356. }
  357. }
  358. private:
  359. const VectorHashBiTable *ht_;
  360. };
  361. using KeyHashSet = HashSet<I, HashFunc, HashEqual, HS>;
  362. const T &Key2Entry(I k) const {
  363. if (k == kCurrentKey) {
  364. return *current_entry_;
  365. } else {
  366. return id2entry_[k];
  367. }
  368. }
  369. S selector_; // True if entry hashed into vector.
  370. FP fp_; // Fingerprint used for hashing into vector.
  371. H h_; // Hash funcion used for hashing into hash_set.
  372. std::vector<T> id2entry_; // Maps state IDs to entry.
  373. std::vector<I> fp2id_; // Maps entry fingerprints to IDs.
  374. // Compact implementation of the hash table mapping entries to state IDs
  375. // using the hash function h_.
  376. HashFunc hash_func_;
  377. HashEqual hash_equal_;
  378. KeyHashSet keys_;
  379. const T *current_entry_;
  380. };
  381. // An implementation using a hash map for the entry to ID mapping. This version
  382. // permits erasing of arbitrary states. The entry T must have == defined and
  383. // its default constructor must produce a entry that will never be seen. F is
  384. // the hash function.
  385. template <class I, class T, class F>
  386. class ErasableBiTable {
  387. public:
  388. ErasableBiTable() : first_(0) {}
  389. I FindId(const T &entry, bool insert = true) {
  390. I &id_ref = entry2id_[entry];
  391. if (id_ref == 0) { // T not found.
  392. if (insert) { // Stores and assigns a new ID.
  393. id2entry_.push_back(entry);
  394. id_ref = id2entry_.size() + first_;
  395. } else {
  396. return -1;
  397. }
  398. }
  399. return id_ref - 1; // NB: id_ref = ID + 1.
  400. }
  401. const T &FindEntry(I s) const { return id2entry_[s - first_]; }
  402. I Size() const { return id2entry_.size(); }
  403. void Erase(I s) {
  404. auto &ref = id2entry_[s - first_];
  405. entry2id_.erase(ref);
  406. ref = empty_entry_;
  407. while (!id2entry_.empty() && id2entry_.front() == empty_entry_) {
  408. id2entry_.pop_front();
  409. ++first_;
  410. }
  411. }
  412. private:
  413. std::unordered_map<T, I, F> entry2id_;
  414. std::deque<T> id2entry_;
  415. const T empty_entry_;
  416. I first_; // I of first element in the deque.
  417. };
  418. } // namespace fst
  419. #endif // FST_BI_TABLE_H_