You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

242 lines
6.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Cache implementations for ExpanderFst.
  19. //
  20. // Expander caches must expose a State type and a FindOrExpand template method:
  21. //
  22. // class ExpanderCache {
  23. // public:
  24. // class State;
  25. //
  26. // template <Expander>
  27. // State* FindOrExpander(Expander& expander, StateId id) {
  28. // if (id is found in cache) return cached_state;
  29. //
  30. // // Use the provided expander to create a new cached state and cache it.
  31. // expander.Expand(id, &new_state);
  32. // insert new_state into cache;
  33. // return new_state;
  34. // }
  35. // };
  36. //
  37. // Cache implementations must be copyable and assignable. It is up to the
  38. // implementation whether this means it will discard the contents of the cache,
  39. // copy all of the cache, share some of the cache etc. It is *REQUIRED* that the
  40. // copy be "safe", the copy and the original must be usable from concurrent
  41. // threads without accessing any internally shared state.
  42. #ifndef FST_EXPANDER_CACHE_H_
  43. #define FST_EXPANDER_CACHE_H_
  44. #include <cstddef>
  45. #include <deque>
  46. #include <memory>
  47. #include <utility>
  48. #include <vector>
  49. #include <fst/cache.h>
  50. #include <fst/fst.h>
  51. #include <unordered_map>
  52. #include <unordered_map>
  53. namespace fst {
  54. // Stateful allocators can't be used without careful handling in threaded
  55. // contexts, so arbitrary stl allocators aren't supported here.
  56. template <class A>
  57. class SimpleVectorCacheState {
  58. public:
  59. using Arc = A;
  60. using Weight = typename Arc::Weight;
  61. using StateId = typename Arc::StateId;
  62. void Reset() {
  63. final_weight_ = Weight::Zero();
  64. niepsilons_ = 0;
  65. noepsilons_ = 0;
  66. arcs_.clear();
  67. }
  68. Weight Final() const { return final_weight_; }
  69. size_t NumInputEpsilons() const { return niepsilons_; }
  70. size_t NumOutputEpsilons() const { return noepsilons_; }
  71. size_t NumArcs() const { return arcs_.size(); }
  72. const Arc &GetArc(size_t n) const { return arcs_[n]; }
  73. const Arc *Arcs() const { return arcs_.empty() ? nullptr : &arcs_[0]; }
  74. void SetFinal(Weight weight) { final_weight_ = weight; }
  75. void ReserveArcs(size_t n) { arcs_.reserve(n); }
  76. void AddArc(const Arc &arc) {
  77. if (arc.ilabel == 0) ++niepsilons_;
  78. if (arc.olabel == 0) ++noepsilons_;
  79. arcs_.push_back(arc);
  80. }
  81. void AddArc(Arc &&arc) {
  82. if (arc.ilabel == 0) ++niepsilons_;
  83. if (arc.olabel == 0) ++noepsilons_;
  84. arcs_.push_back(std::move(arc));
  85. }
  86. int *MutableRefCount() const { return nullptr; }
  87. private:
  88. Weight final_weight_ = Weight::Zero();
  89. size_t niepsilons_ = 0; // Number of input epsilons.
  90. size_t noepsilons_ = 0; // Number of output epsilons.
  91. std::vector<Arc> arcs_;
  92. };
  93. template <class A>
  94. class NoGcKeepOneExpanderCache {
  95. public:
  96. using Arc = A;
  97. using StateId = typename Arc::StateId;
  98. // Reference-counted state.
  99. class State : public SimpleVectorCacheState<Arc> {
  100. public:
  101. int *MutableRefCount() { return &ref_count_; }
  102. void Reset() {
  103. SimpleVectorCacheState<Arc>::Reset();
  104. ref_count_ = 0;
  105. }
  106. private:
  107. int ref_count_ = 0;
  108. friend class NoGcKeepOneExpanderCache;
  109. };
  110. NoGcKeepOneExpanderCache() : state_(new State) {}
  111. NoGcKeepOneExpanderCache(const NoGcKeepOneExpanderCache &copy)
  112. : state_(new State(*copy.state_)) {}
  113. template <class Expander>
  114. State *FindOrExpand(Expander &expander, StateId state_id) {
  115. if (state_id == state_id_) return state_.get();
  116. if (state_->ref_count_ > 0) cache_[state_id_] = std::move(state_);
  117. state_id_ = state_id;
  118. if (cache_.empty()) {
  119. state_->Reset();
  120. expander.Expand(state_id_, state_.get());
  121. return state_.get();
  122. }
  123. if (auto i = cache_.find(state_id_); i != cache_.end()) {
  124. state_ = std::move(i->second);
  125. }
  126. if (state_ == nullptr) {
  127. state_ = std::make_unique<State>();
  128. expander.Expand(state_id_, state_.get());
  129. }
  130. return state_.get();
  131. }
  132. StateId state_id_ = kNoStateId;
  133. std::unique_ptr<State> state_;
  134. std::unordered_map<StateId, std::unique_ptr<State>> cache_;
  135. };
  136. template <class A>
  137. class HashExpanderCache {
  138. public:
  139. using Arc = A;
  140. using StateId = typename Arc::StateId;
  141. using State = SimpleVectorCacheState<Arc>;
  142. HashExpanderCache(const HashExpanderCache &copy) { *this = copy; }
  143. HashExpanderCache &operator=(const HashExpanderCache &copy) {
  144. for (const auto &[id, state] : copy.cache_) {
  145. cache_[id] = std::make_unique<State>(*state);
  146. }
  147. return *this;
  148. }
  149. ~HashExpanderCache() = default;
  150. template <class Expander>
  151. State *FindOrExpand(Expander &expander, StateId state_id) {
  152. auto [it, inserted] = cache_.emplace(state_id, nullptr);
  153. if (inserted) {
  154. it->second = std::make_unique<State>();
  155. expander.Expand(state_id, it->second.get());
  156. }
  157. return it->second.get();
  158. }
  159. private:
  160. std::unordered_map<StateId, std::unique_ptr<State>> cache_;
  161. };
  162. template <class A>
  163. class VectorExpanderCache {
  164. public:
  165. using Arc = A;
  166. using StateId = typename Arc::StateId;
  167. using State = SimpleVectorCacheState<Arc>;
  168. VectorExpanderCache() : vec_(0, nullptr) {}
  169. VectorExpanderCache(const VectorExpanderCache &copy) { *this = copy; }
  170. VectorExpanderCache &operator=(const VectorExpanderCache &copy) {
  171. vec_.resize(copy.vec_.size());
  172. for (StateId i = 0; i < copy.vec_.size(); ++i) {
  173. const auto *state = copy.vec_[i];
  174. if (state != nullptr) {
  175. states_.emplace_back(*state);
  176. vec_[i] = &states_.back();
  177. }
  178. }
  179. return *this;
  180. }
  181. template <class Expander>
  182. State *FindOrExpand(Expander &expander, StateId state_id) {
  183. if (state_id >= vec_.size()) vec_.resize(state_id + 1);
  184. auto **slot = &vec_[state_id];
  185. if (*slot == nullptr) {
  186. states_.emplace_back();
  187. *slot = &states_.back();
  188. expander.Expand(state_id, *slot);
  189. }
  190. return *slot;
  191. }
  192. private:
  193. std::deque<State> states_;
  194. std::vector<State *> vec_;
  195. };
  196. template <class Expander>
  197. using DefaultExpanderCache = VectorExpanderCache<typename Expander::Arc>;
  198. } // namespace fst
  199. #endif // FST_EXPANDER_CACHE_H_