You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

253 lines
7.0 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Allocators for contiguous arrays of arcs.
  19. #ifndef FST_ARC_ARENA_H_
  20. #define FST_ARC_ARENA_H_
  21. #include <algorithm>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <deque>
  25. #include <list>
  26. #include <memory>
  27. #include <utility>
  28. #include <fst/fst.h>
  29. #include <fst/memory.h>
  30. #include <unordered_map>
  31. namespace fst {
  32. // ArcArena is used for fast allocation of contiguous arrays of arcs.
  33. //
  34. // To create an arc array:
  35. // for each state:
  36. // for each arc:
  37. // arena.PushArc();
  38. // // Commits these arcs and returns pointer to them.
  39. // Arc *arcs = arena.GetArcs();
  40. //
  41. // OR
  42. //
  43. // arena.DropArcs(); // Throws away current arcs, reuse the space.
  44. //
  45. // The arcs returned are guaranteed to be contiguous and the pointer returned
  46. // will never be invalidated until the arena is cleared for reuse.
  47. //
  48. // The contents of the arena can be released with a call to arena.Clear() after
  49. // which the arena will restart with an initial allocation capable of holding at
  50. // least all of the arcs requested in the last usage before Clear() making
  51. // subsequent uses of the Arena more efficient.
  52. //
  53. // The max_retained_size option can limit the amount of arc space requested on
  54. // Clear() to avoid excess growth from intermittent high usage.
  55. template <typename Arc>
  56. class ArcArena {
  57. public:
  58. explicit ArcArena(size_t block_size = 256, size_t max_retained_size = 1e6)
  59. : block_size_(block_size), max_retained_size_(max_retained_size) {
  60. blocks_.emplace_back(MakeSharedBlock(block_size_));
  61. first_block_size_ = block_size_;
  62. total_size_ = block_size_;
  63. arcs_ = blocks_.back().get();
  64. end_ = arcs_ + block_size_;
  65. next_ = arcs_;
  66. }
  67. ArcArena(const ArcArena &copy)
  68. : arcs_(copy.arcs_),
  69. next_(copy.next_),
  70. end_(copy.end_),
  71. block_size_(copy.block_size_),
  72. first_block_size_(copy.first_block_size_),
  73. total_size_(copy.total_size_),
  74. max_retained_size_(copy.max_retained_size_),
  75. blocks_(copy.blocks_) {
  76. NewBlock(block_size_);
  77. }
  78. void ReserveArcs(size_t n) {
  79. if (next_ + n < end_) return;
  80. NewBlock(n);
  81. }
  82. void PushArc(const Arc &arc) {
  83. if (next_ == end_) {
  84. size_t length = next_ - arcs_;
  85. NewBlock(length * 2);
  86. }
  87. *next_ = arc;
  88. ++next_;
  89. }
  90. const Arc *GetArcs() {
  91. const auto *arcs = arcs_;
  92. arcs_ = next_;
  93. return arcs;
  94. }
  95. void DropArcs() { next_ = arcs_; }
  96. size_t Size() { return total_size_; }
  97. void Clear() {
  98. blocks_.resize(1);
  99. if (total_size_ > first_block_size_) {
  100. first_block_size_ = std::min(max_retained_size_, total_size_);
  101. blocks_.back() = MakeSharedBlock(first_block_size_);
  102. }
  103. total_size_ = first_block_size_;
  104. arcs_ = blocks_.back().get();
  105. end_ = arcs_ + first_block_size_;
  106. next_ = arcs_;
  107. }
  108. private:
  109. // Allocates a new block with capacity of at least n or block_size,
  110. // copying incomplete arc sequence from old block to new block.
  111. void NewBlock(size_t n) {
  112. const auto length = next_ - arcs_;
  113. const auto new_block_size = std::max(n, block_size_);
  114. total_size_ += new_block_size;
  115. blocks_.emplace_back(MakeSharedBlock(new_block_size));
  116. std::copy(arcs_, next_, blocks_.back().get());
  117. arcs_ = blocks_.back().get();
  118. next_ = arcs_ + length;
  119. end_ = arcs_ + new_block_size;
  120. }
  121. std::shared_ptr<Arc[]> MakeSharedBlock(size_t size) {
  122. return std::shared_ptr<Arc[]>(new Arc[size]);
  123. }
  124. Arc *arcs_;
  125. Arc *next_;
  126. const Arc *end_;
  127. size_t block_size_;
  128. size_t first_block_size_;
  129. size_t total_size_;
  130. size_t max_retained_size_;
  131. std::list<std::shared_ptr<Arc[]>> blocks_;
  132. };
  133. // ArcArenaStateStore uses a resusable ArcArena to store arc arrays and does not
  134. // require that the Expander call ReserveArcs first.
  135. //
  136. // TODO(tombagby): Make cache type configurable.
  137. // TODO(tombagby): Provide ThreadLocal/Concurrent configuration.
  138. template <class A>
  139. class ArcArenaStateStore {
  140. public:
  141. using Arc = A;
  142. using Weight = typename Arc::Weight;
  143. using StateId = typename Arc::StateId;
  144. class State {
  145. public:
  146. Weight Final() const { return final_weight_; }
  147. size_t NumInputEpsilons() const { return niepsilons_; }
  148. size_t NumOutputEpsilons() const { return noepsilons_; }
  149. size_t NumArcs() const { return narcs_; }
  150. const Arc &GetArc(size_t n) const { return arcs_[n]; }
  151. const Arc *Arcs() const { return arcs_; }
  152. int *MutableRefCount() const { return nullptr; }
  153. private:
  154. State(Weight final_weight, int32_t niepsilons, int32_t noepsilons,
  155. int32_t narcs, const Arc *arcs)
  156. : final_weight_(std::move(final_weight)),
  157. niepsilons_(niepsilons),
  158. noepsilons_(noepsilons),
  159. narcs_(narcs),
  160. arcs_(arcs) {}
  161. Weight final_weight_;
  162. size_t niepsilons_;
  163. size_t noepsilons_;
  164. size_t narcs_;
  165. const Arc *arcs_;
  166. friend class ArcArenaStateStore<Arc>;
  167. };
  168. template <class Expander>
  169. State *FindOrExpand(Expander &expander, StateId state_id) {
  170. const auto &[it, success] = cache_.emplace(state_id, nullptr);
  171. if (!success) return it->second;
  172. // Needs a new state.
  173. StateBuilder builder(&arena_);
  174. expander.Expand(state_id, &builder);
  175. const auto arcs = arena_.GetArcs();
  176. size_t narcs = builder.narcs_;
  177. size_t niepsilons = 0;
  178. size_t noepsilons = 0;
  179. for (size_t i = 0; i < narcs; ++i) {
  180. if (arcs[i].ilabel == 0) ++niepsilons;
  181. if (arcs[i].olabel == 0) ++noepsilons;
  182. }
  183. states_.emplace_back(
  184. State(builder.final_weight_, niepsilons, noepsilons, narcs, arcs));
  185. // Places it in the cache.
  186. auto state = &states_.back();
  187. it->second = state;
  188. return state;
  189. }
  190. State *Find(StateId state_id) const {
  191. auto it = cache_.find(state_id);
  192. return (it == cache_.end()) ? nullptr : it->second;
  193. }
  194. private:
  195. class StateBuilder {
  196. public:
  197. explicit StateBuilder(ArcArena<Arc> *arena)
  198. : arena_(arena), final_weight_(Weight::Zero()), narcs_(0) {}
  199. void SetFinal(Weight weight) { final_weight_ = std::move(weight); }
  200. void ReserveArcs(size_t n) { arena_->ReserveArcs(n); }
  201. void AddArc(const Arc &arc) {
  202. ++narcs_;
  203. arena_->PushArc(arc);
  204. }
  205. private:
  206. friend class ArcArenaStateStore<Arc>;
  207. ArcArena<Arc> *arena_;
  208. Weight final_weight_;
  209. size_t narcs_;
  210. };
  211. std::unordered_map<StateId, State *> cache_;
  212. std::deque<State> states_;
  213. ArcArena<Arc> arena_;
  214. };
  215. } // namespace fst
  216. #endif // FST_ARC_ARENA_H_