You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1337 lines
42 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // An FST implementation that caches FST elements of a delayed computation.
  19. #ifndef FST_CACHE_H_
  20. #define FST_CACHE_H_
  21. #include <algorithm>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <functional>
  25. #include <list>
  26. #include <memory>
  27. #include <new>
  28. #include <utility>
  29. #include <vector>
  30. #include <fst/flags.h>
  31. #include <fst/log.h>
  32. #include <fst/fst.h>
  33. #include <fst/memory.h>
  34. #include <fst/mutable-fst.h>
  35. #include <fst/properties.h>
  36. #include <fst/util.h>
  37. #include <fst/vector-fst.h>
  38. #include <unordered_map>
  39. #include <functional>
  40. DECLARE_bool(fst_default_cache_gc);
  41. DECLARE_int64(fst_default_cache_gc_limit);
  42. namespace fst {
  43. // Options for controlling caching behavior; higher level than CacheImplOptions.
  44. struct CacheOptions {
  45. bool gc; // Enables GC.
  46. size_t gc_limit; // Number of bytes allowed before GC.
  47. explicit CacheOptions(
  48. bool gc = FST_FLAGS_fst_default_cache_gc,
  49. size_t gc_limit = FST_FLAGS_fst_default_cache_gc_limit)
  50. : gc(gc), gc_limit(gc_limit) {}
  51. };
  52. // Options for controlling caching behavior, at a lower level than
  53. // CacheOptions; templated on the cache store and allows passing the store.
  54. template <class CacheStore>
  55. struct CacheImplOptions {
  56. bool gc; // Enables GC.
  57. size_t gc_limit; // Number of bytes allowed before GC.
  58. CacheStore *store; // Cache store.
  59. bool own_store; // Should CacheImpl takes ownership of the store?
  60. explicit CacheImplOptions(
  61. bool gc = FST_FLAGS_fst_default_cache_gc,
  62. size_t gc_limit = FST_FLAGS_fst_default_cache_gc_limit,
  63. CacheStore *store = nullptr)
  64. : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {}
  65. explicit CacheImplOptions(const CacheOptions &opts)
  66. : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {}
  67. };
  68. // Cache flags.
  69. inline constexpr uint8_t kCacheFinal = 0x01; // Final weight has been cached.
  70. inline constexpr uint8_t kCacheArcs = 0x02; // Arcs have been cached.
  71. inline constexpr uint8_t kCacheInit = 0x04; // Initialized by GC.
  72. inline constexpr uint8_t kCacheRecent = 0x08; // Visited since GC.
  73. inline constexpr uint8_t kCacheFlags =
  74. kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent;
  75. // Cache state, with arcs stored in a per-state std::vector.
  76. template <class A, class M = PoolAllocator<A>>
  77. class CacheState {
  78. public:
  79. using Arc = A;
  80. using Label = typename Arc::Label;
  81. using StateId = typename Arc::StateId;
  82. using Weight = typename Arc::Weight;
  83. using ArcAllocator = M;
  84. using StateAllocator = typename std::allocator_traits<
  85. ArcAllocator>::template rebind_alloc<CacheState<A, M>>;
  86. // Provides STL allocator for arcs.
  87. explicit CacheState(const ArcAllocator &alloc)
  88. : final_weight_(Weight::Zero()),
  89. niepsilons_(0),
  90. noepsilons_(0),
  91. arcs_(alloc),
  92. flags_(0),
  93. ref_count_(0) {}
  94. CacheState(const CacheState<A> &state, const ArcAllocator &alloc)
  95. : final_weight_(state.Final()),
  96. niepsilons_(state.NumInputEpsilons()),
  97. noepsilons_(state.NumOutputEpsilons()),
  98. arcs_(state.arcs_.begin(), state.arcs_.end(), alloc),
  99. flags_(state.Flags()),
  100. ref_count_(0) {}
  101. void Reset() {
  102. final_weight_ = Weight::Zero();
  103. niepsilons_ = 0;
  104. noepsilons_ = 0;
  105. ref_count_ = 0;
  106. flags_ = 0;
  107. arcs_.clear();
  108. }
  109. Weight Final() const { return final_weight_; }
  110. size_t NumInputEpsilons() const { return niepsilons_; }
  111. size_t NumOutputEpsilons() const { return noepsilons_; }
  112. size_t NumArcs() const { return arcs_.size(); }
  113. const Arc &GetArc(size_t n) const { return arcs_[n]; }
  114. // Used by the ArcIterator<Fst<Arc>> efficient implementation.
  115. const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
  116. // Accesses flags; used by the caller.
  117. uint8_t Flags() const { return flags_; }
  118. // Accesses ref count; used by the caller.
  119. int RefCount() const { return ref_count_; }
  120. void SetFinal(Weight weight = Weight::One()) {
  121. final_weight_ = std::move(weight);
  122. }
  123. void ReserveArcs(size_t n) { arcs_.reserve(n); }
  124. // Adds one arc at a time with all needed book-keeping; use PushArc and
  125. // SetArcs for a more efficient alternative.
  126. void AddArc(const Arc &arc) {
  127. IncrementNumEpsilons(arc);
  128. arcs_.push_back(arc);
  129. }
  130. void AddArc(Arc &&arc) {
  131. IncrementNumEpsilons(arc);
  132. arcs_.push_back(std::move(arc));
  133. }
  134. // Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
  135. void PushArc(const Arc &arc) { arcs_.push_back(arc); }
  136. void PushArc(Arc &&arc) { arcs_.push_back(std::move(arc)); }
  137. // Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
  138. template <class... T>
  139. void EmplaceArc(T &&...ctor_args) {
  140. arcs_.emplace_back(std::forward<T>(ctor_args)...);
  141. }
  142. // Finalizes arcs book-keeping; call only once.
  143. void SetArcs() {
  144. for (const auto &arc : arcs_) {
  145. IncrementNumEpsilons(arc);
  146. }
  147. }
  148. // Modifies nth arc.
  149. void SetArc(const Arc &arc, size_t n) {
  150. if (arcs_[n].ilabel == 0) --niepsilons_;
  151. if (arcs_[n].olabel == 0) --noepsilons_;
  152. IncrementNumEpsilons(arc);
  153. arcs_[n] = arc;
  154. }
  155. // Deletes all arcs.
  156. void DeleteArcs() {
  157. niepsilons_ = 0;
  158. noepsilons_ = 0;
  159. arcs_.clear();
  160. }
  161. void DeleteArcs(size_t n) {
  162. for (size_t i = 0; i < n; ++i) {
  163. if (arcs_.back().ilabel == 0) --niepsilons_;
  164. if (arcs_.back().olabel == 0) --noepsilons_;
  165. arcs_.pop_back();
  166. }
  167. }
  168. // Sets status flags; used by the caller.
  169. void SetFlags(uint8_t flags, uint8_t mask) const {
  170. flags_ &= ~mask;
  171. flags_ |= flags;
  172. }
  173. // Mutates reference counts; used by the caller.
  174. int IncrRefCount() const { return ++ref_count_; }
  175. int DecrRefCount() const { return --ref_count_; }
  176. // Used by the ArcIterator<Fst<Arc>> efficient implementation.
  177. int *MutableRefCount() const { return &ref_count_; }
  178. // Used for state class allocation.
  179. void *operator new(size_t size, StateAllocator *alloc) {
  180. return alloc->allocate(1);
  181. }
  182. // For state destruction and memory freeing.
  183. static void Destroy(CacheState<Arc> *state, StateAllocator *alloc) {
  184. if (state) {
  185. state->~CacheState<Arc>();
  186. alloc->deallocate(state, 1);
  187. }
  188. }
  189. private:
  190. // Update the number of epsilons as a result of having added an arc.
  191. void IncrementNumEpsilons(const Arc &arc) {
  192. if (arc.ilabel == 0) ++niepsilons_;
  193. if (arc.olabel == 0) ++noepsilons_;
  194. }
  195. Weight final_weight_; // Final weight.
  196. size_t niepsilons_; // # of input epsilons.
  197. size_t noepsilons_; // # of output epsilons.
  198. std::vector<Arc, ArcAllocator> arcs_; // Arcs representation.
  199. mutable uint8_t flags_;
  200. mutable int ref_count_; // If 0, available for GC.
  201. };
  202. // Cache store, allocating and storing states, providing a mapping from state
  203. // IDs to cached states, and an iterator over these states. The state template
  204. // argument must implement the CacheState interface. The state for a StateId s
  205. // is constructed when requested by GetMutableState(s) if it is not yet stored.
  206. // Initially, a state has a reference count of zero, but the user may increment
  207. // or decrement this to control the time of destruction. In particular, a state
  208. // is destroyed when:
  209. //
  210. // 1. This instance is destroyed, or
  211. // 2. Clear() or Delete() is called, or
  212. // 3. Possibly (implementation-dependently) when:
  213. // - Garbage collection is enabled (as defined by opts.gc),
  214. // - The cache store size exceeds the limits (as defined by opts.gc_limits),
  215. // - The state's reference count is zero, and
  216. // - The state is not the most recently requested state.
  217. //
  218. // template <class S>
  219. // class CacheStore {
  220. // public:
  221. // using State = S;
  222. // using Arc = typename State::Arc;
  223. // using StateId = typename Arc::StateId;
  224. //
  225. // // Required constructors/assignment operators.
  226. // explicit CacheStore(const CacheOptions &opts);
  227. //
  228. // // Returns nullptr if state is not stored.
  229. // const State *GetState(StateId s);
  230. //
  231. // // Creates state if state is not stored.
  232. // State *GetMutableState(StateId s);
  233. //
  234. // // Similar to State::AddArc() but updates cache store book-keeping.
  235. // void AddArc(State *state, const Arc &arc);
  236. //
  237. // // Similar to State::SetArcs() but updates cache store book-keeping; call
  238. // // only once.
  239. // void SetArcs(State *state);
  240. //
  241. // // Similar to State::DeleteArcs() but updates cache store book-keeping.
  242. //
  243. // void DeleteArcs(State *state);
  244. //
  245. // void DeleteArcs(State *state, size_t n);
  246. //
  247. // // Deletes all cached states.
  248. // void Clear();
  249. //
  250. // // Number of cached states.
  251. // StateId CountStates();
  252. //
  253. // // Iterates over cached states (in an arbitrary order); only needed if
  254. // // opts.gc is true.
  255. // bool Done() const; // End of iteration.
  256. // StateId Value() const; // Current state.
  257. // void Next(); // Advances to next state (when !Done).
  258. // void Reset(); // Returns to initial condition.
  259. // void Delete(); // Deletes current state and advances to next.
  260. // };
  261. // Container cache stores.
  262. // This class uses a vector of pointers to states to store cached states.
  263. template <class S>
  264. class VectorCacheStore {
  265. public:
  266. using State = S;
  267. using Arc = typename State::Arc;
  268. using StateId = typename Arc::StateId;
  269. using StateList = std::list<StateId, PoolAllocator<StateId>>;
  270. // Required constructors/assignment operators.
  271. explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) {
  272. Clear();
  273. Reset();
  274. }
  275. VectorCacheStore(const VectorCacheStore<S> &store)
  276. : cache_gc_(store.cache_gc_) {
  277. CopyStates(store);
  278. Reset();
  279. }
  280. ~VectorCacheStore() { Clear(); }
  281. VectorCacheStore &operator=(const VectorCacheStore &store) {
  282. if (this != &store) {
  283. CopyStates(store);
  284. Reset();
  285. }
  286. return *this;
  287. }
  288. bool InBounds(StateId s) const {
  289. return s < static_cast<StateId>(state_vec_.size());
  290. }
  291. // Returns nullptr if state is not stored.
  292. const State *GetState(StateId s) const {
  293. return InBounds(s) ? state_vec_[s] : nullptr;
  294. }
  295. // Creates state if state is not stored.
  296. State *GetMutableState(StateId s) {
  297. State *state = nullptr;
  298. if (InBounds(s)) {
  299. state = state_vec_[s];
  300. } else {
  301. state_vec_.resize(s + 1, nullptr);
  302. }
  303. if (!state) {
  304. state = new (&state_alloc_) State(arc_alloc_);
  305. state_vec_[s] = state;
  306. if (cache_gc_) state_list_.push_back(s);
  307. }
  308. return state;
  309. }
  310. // Similar to State::AddArc() but updates cache store book-keeping
  311. void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
  312. // Similar to State::SetArcs() but updates cache store book-keeping; call
  313. // only once.
  314. void SetArcs(State *state) { state->SetArcs(); }
  315. // Deletes all arcs.
  316. void DeleteArcs(State *state) { state->DeleteArcs(); }
  317. // Deletes some arcs.
  318. void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
  319. // Deletes all cached states.
  320. void Clear() {
  321. for (State *s : state_vec_) {
  322. State::Destroy(s, &state_alloc_);
  323. }
  324. state_vec_.clear();
  325. state_list_.clear();
  326. }
  327. StateId CountStates() const {
  328. return std::count_if(state_vec_.begin(), state_vec_.end(),
  329. [](const State *s) { return s != nullptr; });
  330. }
  331. // Iterates over cached states (in an arbitrary order); only works if GC is
  332. // enabled (o.w. avoiding state_list_ overhead).
  333. bool Done() const { return iter_ == state_list_.end(); }
  334. StateId Value() const { return *iter_; }
  335. void Next() { ++iter_; }
  336. void Reset() { iter_ = state_list_.begin(); }
  337. // Deletes current state and advances to next.
  338. void Delete() {
  339. State::Destroy(state_vec_[*iter_], &state_alloc_);
  340. state_vec_[*iter_] = nullptr;
  341. state_list_.erase(iter_++);
  342. }
  343. private:
  344. void CopyStates(const VectorCacheStore<State> &store) {
  345. Clear();
  346. state_vec_.reserve(store.state_vec_.size());
  347. for (size_t s = 0; s < store.state_vec_.size(); ++s) {
  348. State *state = nullptr;
  349. const auto *store_state = store.state_vec_[s];
  350. if (store_state) {
  351. state = new (&state_alloc_) State(*store_state, arc_alloc_);
  352. if (cache_gc_) state_list_.push_back(s);
  353. }
  354. state_vec_.push_back(state);
  355. }
  356. }
  357. bool cache_gc_; // Supports iteration when true.
  358. std::vector<State *> state_vec_; // Vector of states (or null).
  359. StateList state_list_; // List of states.
  360. typename StateList::iterator iter_; // State list iterator.
  361. typename State::StateAllocator state_alloc_; // For state allocation.
  362. typename State::ArcAllocator arc_alloc_; // For arc allocation.
  363. };
  364. // This class uses a hash map from state IDs to pointers to cached states.
  365. template <class S>
  366. class HashCacheStore {
  367. public:
  368. using State = S;
  369. using Arc = typename State::Arc;
  370. using StateId = typename Arc::StateId;
  371. using StateMap =
  372. std::unordered_map<StateId, State *, std::hash<StateId>,
  373. std::equal_to<StateId>,
  374. PoolAllocator<std::pair<const StateId, State *>>>;
  375. // Required constructors/assignment operators.
  376. explicit HashCacheStore(const CacheOptions &opts) {
  377. Clear();
  378. Reset();
  379. }
  380. HashCacheStore(const HashCacheStore<S> &store) {
  381. CopyStates(store);
  382. Reset();
  383. }
  384. ~HashCacheStore() { Clear(); }
  385. HashCacheStore &operator=(const HashCacheStore &store) {
  386. if (this != &store) {
  387. CopyStates(store);
  388. Reset();
  389. }
  390. return *this;
  391. }
  392. // Returns nullptr if state is not stored.
  393. const State *GetState(StateId s) const {
  394. const auto it = state_map_.find(s);
  395. return it != state_map_.end() ? it->second : nullptr;
  396. }
  397. // Creates state if state is not stored.
  398. State *GetMutableState(StateId s) {
  399. auto *&state = state_map_[s];
  400. if (!state) state = new (&state_alloc_) State(arc_alloc_);
  401. return state;
  402. }
  403. // Similar to State::AddArc() but updates cache store book-keeping.
  404. void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
  405. // Similar to State::SetArcs() but updates internal cache size; call only
  406. // once.
  407. void SetArcs(State *state) { state->SetArcs(); }
  408. // Deletes all arcs.
  409. void DeleteArcs(State *state) { state->DeleteArcs(); }
  410. // Deletes some arcs.
  411. void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
  412. // Deletes all cached states.
  413. void Clear() {
  414. for (auto &[unused_state_id, state_ptr] : state_map_) {
  415. State::Destroy(state_ptr, &state_alloc_);
  416. }
  417. state_map_.clear();
  418. }
  419. StateId CountStates() const { return state_map_.size(); }
  420. // Iterates over cached states (in an arbitrary order).
  421. bool Done() const { return iter_ == state_map_.end(); }
  422. StateId Value() const { return iter_->first; }
  423. void Next() { ++iter_; }
  424. void Reset() { iter_ = state_map_.begin(); }
  425. // Deletes current state and advances to next.
  426. void Delete() {
  427. State::Destroy(iter_->second, &state_alloc_);
  428. state_map_.erase(iter_++);
  429. }
  430. private:
  431. void CopyStates(const HashCacheStore<State> &store) {
  432. Clear();
  433. for (auto &[state_id, state_ptr] : store.state_map_) {
  434. state_map_[state_id] = new (&state_alloc_) State(*state_ptr, arc_alloc_);
  435. }
  436. }
  437. StateMap state_map_; // Map from state ID to state.
  438. typename StateMap::iterator iter_; // State map iterator.
  439. typename State::StateAllocator state_alloc_; // For state allocation.
  440. typename State::ArcAllocator arc_alloc_; // For arc allocation.
  441. };
  442. // Garbage-colllection cache stores.
  443. // This class implements a simple garbage collection scheme when
  444. // 'opts.gc_limit = 0'. In particular, the first cached state is reused for each
  445. // new state so long as the reference count is zero on the to-be-reused state.
  446. // Otherwise, the full underlying store is used. The caller can increment the
  447. // reference count to inhibit the GC of in-use states (e.g., in an ArcIterator).
  448. //
  449. // The typical use case for this optimization is when a single pass over a
  450. // cached
  451. // FST is performed with only one-state expanded at a time.
  452. template <class CacheStore>
  453. class FirstCacheStore {
  454. public:
  455. using State = typename CacheStore::State;
  456. using Arc = typename State::Arc;
  457. using StateId = typename Arc::StateId;
  458. // Required constructors/assignment operators.
  459. explicit FirstCacheStore(const CacheOptions &opts)
  460. : store_(opts),
  461. cache_gc_(opts.gc_limit == 0), // opts.gc ignored historically.
  462. cache_first_state_id_(kNoStateId),
  463. cache_first_state_(nullptr) {}
  464. FirstCacheStore(const FirstCacheStore<CacheStore> &store)
  465. : store_(store.store_),
  466. cache_gc_(store.cache_gc_),
  467. cache_first_state_id_(store.cache_first_state_id_),
  468. cache_first_state_(store.cache_first_state_id_ != kNoStateId
  469. ? store_.GetMutableState(0)
  470. : nullptr) {}
  471. FirstCacheStore<CacheStore> &operator=(
  472. const FirstCacheStore<CacheStore> &store) {
  473. if (this != &store) {
  474. store_ = store.store_;
  475. cache_gc_ = store.cache_gc_;
  476. cache_first_state_id_ = store.cache_first_state_id_;
  477. cache_first_state_ = store.cache_first_state_id_ != kNoStateId
  478. ? store_.GetMutableState(0)
  479. : nullptr;
  480. }
  481. return *this;
  482. }
  483. // Returns nullptr if state is not stored.
  484. const State *GetState(StateId s) const {
  485. // store_ state 0 may hold first cached state; the rest are shifted by 1.
  486. return s == cache_first_state_id_ ? cache_first_state_
  487. : store_.GetState(s + 1);
  488. }
  489. // Creates state if state is not stored.
  490. State *GetMutableState(StateId s) {
  491. // store_ state 0 used to hold first cached state; the rest are shifted by
  492. // 1.
  493. if (cache_first_state_id_ == s) {
  494. return cache_first_state_; // Request for first cached state.
  495. }
  496. if (cache_gc_) {
  497. if (cache_first_state_id_ == kNoStateId) {
  498. cache_first_state_id_ = s; // Sets first cached state.
  499. cache_first_state_ = store_.GetMutableState(0);
  500. cache_first_state_->SetFlags(kCacheInit, kCacheInit);
  501. cache_first_state_->ReserveArcs(2 * kAllocSize);
  502. return cache_first_state_;
  503. } else if (cache_first_state_->RefCount() == 0) {
  504. cache_first_state_id_ = s; // Updates first cached state.
  505. cache_first_state_->Reset();
  506. cache_first_state_->SetFlags(kCacheInit, kCacheInit);
  507. return cache_first_state_;
  508. } else { // Keeps first cached state.
  509. cache_first_state_->SetFlags(0, kCacheInit); // Clears initialized bit.
  510. cache_gc_ = false; // Disables GC.
  511. }
  512. }
  513. auto *state = store_.GetMutableState(s + 1);
  514. return state;
  515. }
  516. // Similar to State::AddArc() but updates cache store book-keeping.
  517. void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); }
  518. // Similar to State::SetArcs() but updates internal cache size; call only
  519. // once.
  520. void SetArcs(State *state) { store_.SetArcs(state); }
  521. // Deletes all arcs
  522. void DeleteArcs(State *state) { store_.DeleteArcs(state); }
  523. // Deletes some arcs
  524. void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); }
  525. // Deletes all cached states
  526. void Clear() {
  527. store_.Clear();
  528. cache_first_state_id_ = kNoStateId;
  529. cache_first_state_ = nullptr;
  530. }
  531. StateId CountStates() const { return store_.CountStates(); }
  532. // Iterates over cached states (in an arbitrary order). Only needed if GC is
  533. // enabled.
  534. bool Done() const { return store_.Done(); }
  535. StateId Value() const {
  536. // store_ state 0 may hold first cached state; rest shifted + 1.
  537. const auto s = store_.Value();
  538. return s ? s - 1 : cache_first_state_id_;
  539. }
  540. void Next() { store_.Next(); }
  541. void Reset() { store_.Reset(); }
  542. // Deletes current state and advances to next.
  543. void Delete() {
  544. if (Value() == cache_first_state_id_) {
  545. cache_first_state_id_ = kNoStateId;
  546. cache_first_state_ = nullptr;
  547. }
  548. store_.Delete();
  549. }
  550. private:
  551. CacheStore store_; // Underlying store.
  552. bool cache_gc_; // GC enabled.
  553. StateId cache_first_state_id_; // First cached state ID.
  554. State *cache_first_state_; // First cached state.
  555. };
  556. // This class implements mark-sweep garbage collection on an underlying cache
  557. // store. If GC is enabled, garbage collection of states is performed in a
  558. // rough approximation of LRU order once when 'gc_limit' bytes is reached. The
  559. // caller can increment the reference count to inhibit the GC of in-use state
  560. // (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows
  561. // the caller to trade-off time vs. space.
  562. template <class CacheStore>
  563. class GCCacheStore {
  564. public:
  565. using State = typename CacheStore::State;
  566. using Arc = typename State::Arc;
  567. using StateId = typename Arc::StateId;
  568. // Required constructors/assignment operators.
  569. explicit GCCacheStore(const CacheOptions &opts)
  570. : store_(opts),
  571. cache_gc_request_(opts.gc),
  572. cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit
  573. : kMinCacheLimit),
  574. cache_gc_(false),
  575. cache_size_(0) {}
  576. // Returns 0 if state is not stored.
  577. const State *GetState(StateId s) const { return store_.GetState(s); }
  578. // Creates state if state is not stored
  579. State *GetMutableState(StateId s) {
  580. auto *state = store_.GetMutableState(s);
  581. if (cache_gc_request_ && !(state->Flags() & kCacheInit)) {
  582. state->SetFlags(kCacheInit, kCacheInit);
  583. cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc);
  584. // GC is enabled once an uninited state (from underlying store) is seen.
  585. cache_gc_ = true;
  586. if (cache_size_ > cache_limit_) GC(state, false);
  587. }
  588. return state;
  589. }
  590. // Similar to State::AddArc() but updates cache store book-keeping.
  591. void AddArc(State *state, const Arc &arc) {
  592. store_.AddArc(state, arc);
  593. if (cache_gc_ && (state->Flags() & kCacheInit)) {
  594. cache_size_ += sizeof(Arc);
  595. if (cache_size_ > cache_limit_) GC(state, false);
  596. }
  597. }
  598. // Similar to State::SetArcs() but updates internal cache size; call only
  599. // once.
  600. void SetArcs(State *state) {
  601. store_.SetArcs(state);
  602. if (cache_gc_ && (state->Flags() & kCacheInit)) {
  603. cache_size_ += state->NumArcs() * sizeof(Arc);
  604. if (cache_size_ > cache_limit_) GC(state, false);
  605. }
  606. }
  607. // Deletes all arcs.
  608. void DeleteArcs(State *state) {
  609. if (cache_gc_ && (state->Flags() & kCacheInit)) {
  610. cache_size_ -= state->NumArcs() * sizeof(Arc);
  611. }
  612. store_.DeleteArcs(state);
  613. }
  614. // Deletes some arcs.
  615. void DeleteArcs(State *state, size_t n) {
  616. if (cache_gc_ && (state->Flags() & kCacheInit)) {
  617. cache_size_ -= n * sizeof(Arc);
  618. }
  619. store_.DeleteArcs(state, n);
  620. }
  621. // Deletes all cached states.
  622. void Clear() {
  623. store_.Clear();
  624. cache_size_ = 0;
  625. }
  626. StateId CountStates() const { return store_.CountStates(); }
  627. // Iterates over cached states (in an arbitrary order); only needed if GC is
  628. // enabled.
  629. bool Done() const { return store_.Done(); }
  630. StateId Value() const { return store_.Value(); }
  631. void Next() { store_.Next(); }
  632. void Reset() { store_.Reset(); }
  633. // Deletes current state and advances to next.
  634. void Delete() {
  635. if (cache_gc_) {
  636. const auto *state = store_.GetState(Value());
  637. if (state->Flags() & kCacheInit) {
  638. cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc);
  639. }
  640. }
  641. store_.Delete();
  642. }
  643. // Removes from the cache store (not referenced-counted and not the current)
  644. // states that have not been accessed since the last GC until at most
  645. // cache_fraction * cache_limit_ bytes are cached. If that fails to free
  646. // enough, attempts to uncaching recently visited states as well. If still
  647. // unable to free enough memory, then widens cache_limit_.
  648. void GC(const State *current, bool free_recent, float cache_fraction = 0.666);
  649. // Returns the current cache size in bytes or 0 if GC is disabled.
  650. size_t CacheSize() const { return cache_size_; }
  651. // Returns the cache limit in bytes.
  652. size_t CacheLimit() const { return cache_limit_; }
  653. private:
  654. static constexpr size_t kMinCacheLimit = 8096; // Minimum cache limit.
  655. CacheStore store_; // Underlying store.
  656. bool cache_gc_request_; // GC requested but possibly not yet enabled.
  657. size_t cache_limit_; // Number of bytes allowed before GC.
  658. bool cache_gc_; // GC enabled
  659. size_t cache_size_; // Number of bytes cached.
  660. };
  661. template <class CacheStore>
  662. void GCCacheStore<CacheStore>::GC(const State *current, bool free_recent,
  663. float cache_fraction) {
  664. if (!cache_gc_) return;
  665. VLOG(2) << "GCCacheStore: Enter GC: object = "
  666. << "(" << this << "), free recently cached = " << free_recent
  667. << ", cache size = " << cache_size_
  668. << ", cache frac = " << cache_fraction
  669. << ", cache limit = " << cache_limit_ << "\n";
  670. size_t cache_target = cache_fraction * cache_limit_;
  671. store_.Reset();
  672. while (!store_.Done()) {
  673. auto *state = store_.GetMutableState(store_.Value());
  674. if (cache_size_ > cache_target && state->RefCount() == 0 &&
  675. (free_recent || !(state->Flags() & kCacheRecent)) && state != current) {
  676. if (state->Flags() & kCacheInit) {
  677. size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc);
  678. if (size < cache_size_) {
  679. cache_size_ -= size;
  680. }
  681. }
  682. store_.Delete();
  683. } else {
  684. state->SetFlags(0, kCacheRecent);
  685. store_.Next();
  686. }
  687. }
  688. if (!free_recent && cache_size_ > cache_target) { // Recurses on recent.
  689. GC(current, true, cache_fraction);
  690. } else if (cache_target > 0) { // Widens cache limit.
  691. while (cache_size_ > cache_target) {
  692. cache_limit_ *= 2;
  693. cache_target *= 2;
  694. }
  695. } else if (cache_size_ > 0) {
  696. FSTERROR() << "GCCacheStore:GC: Unable to free all cached states";
  697. }
  698. VLOG(2) << "GCCacheStore: Exit GC: object = "
  699. << "(" << this << "), free recently cached = " << free_recent
  700. << ", cache size = " << cache_size_
  701. << ", cache frac = " << cache_fraction
  702. << ", cache limit = " << cache_limit_ << "\n";
  703. }
  704. // This class is the default cache state and store used by CacheBaseImpl.
  705. // It uses VectorCacheStore for storage decorated by FirstCacheStore
  706. // and GCCacheStore to do (optional) garbage collection.
  707. template <class Arc>
  708. class DefaultCacheStore
  709. : public GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>> {
  710. public:
  711. explicit DefaultCacheStore(const CacheOptions &opts)
  712. : GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>>(opts) {
  713. }
  714. };
  715. namespace internal {
  716. // This class is used to cache FST elements stored in states of type State
  717. // (see CacheState) with the flags used to indicate what has been cached. Use
  718. // HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(),
  719. // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you
  720. // must set the final weight even if the state is non-final to mark it as
  721. // cached. The state storage method and any garbage collection policy are
  722. // determined by the cache store. If the store is passed in with the options,
  723. // CacheBaseImpl takes ownership.
  724. template <class State,
  725. class CacheStore = DefaultCacheStore<typename State::Arc>>
  726. class CacheBaseImpl : public FstImpl<typename State::Arc> {
  727. public:
  728. using Arc = typename State::Arc;
  729. using StateId = typename Arc::StateId;
  730. using Weight = typename Arc::Weight;
  731. using Store = CacheStore;
  732. using FstImpl<Arc>::Type;
  733. using FstImpl<Arc>::Properties;
  734. explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions())
  735. : has_start_(false),
  736. cache_start_(kNoStateId),
  737. nknown_states_(0),
  738. min_unexpanded_state_id_(0),
  739. max_expanded_state_id_(-1),
  740. cache_gc_(opts.gc),
  741. cache_limit_(opts.gc_limit),
  742. cache_store_(new CacheStore(opts)),
  743. new_cache_store_(true),
  744. own_cache_store_(true) {}
  745. explicit CacheBaseImpl(const CacheImplOptions<CacheStore> &opts)
  746. : has_start_(false),
  747. cache_start_(kNoStateId),
  748. nknown_states_(0),
  749. min_unexpanded_state_id_(0),
  750. max_expanded_state_id_(-1),
  751. cache_gc_(opts.gc),
  752. cache_limit_(opts.gc_limit),
  753. cache_store_(
  754. opts.store ? opts.store
  755. : new CacheStore(CacheOptions(opts.gc, opts.gc_limit))),
  756. new_cache_store_(!opts.store),
  757. own_cache_store_(opts.store ? opts.own_store : true) {}
  758. // Preserve gc parameters. If preserve_cache is true, also preserves
  759. // cache data.
  760. CacheBaseImpl(const CacheBaseImpl<State, CacheStore> &impl,
  761. bool preserve_cache = false)
  762. : FstImpl<Arc>(),
  763. has_start_(false),
  764. cache_start_(kNoStateId),
  765. nknown_states_(0),
  766. min_unexpanded_state_id_(0),
  767. max_expanded_state_id_(-1),
  768. cache_gc_(impl.cache_gc_),
  769. cache_limit_(impl.cache_limit_),
  770. cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))),
  771. new_cache_store_(impl.new_cache_store_ || !preserve_cache),
  772. own_cache_store_(true) {
  773. if (preserve_cache) {
  774. *cache_store_ = *impl.cache_store_;
  775. has_start_ = impl.has_start_;
  776. cache_start_ = impl.cache_start_;
  777. nknown_states_ = impl.nknown_states_;
  778. expanded_states_ = impl.expanded_states_;
  779. min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
  780. max_expanded_state_id_ = impl.max_expanded_state_id_;
  781. }
  782. }
  783. ~CacheBaseImpl() override {
  784. if (own_cache_store_) delete cache_store_;
  785. }
  786. void SetStart(StateId s) {
  787. cache_start_ = s;
  788. has_start_ = true;
  789. if (s >= nknown_states_) nknown_states_ = s + 1;
  790. }
  791. void SetFinal(StateId s, Weight weight = Weight::One()) {
  792. auto *state = cache_store_->GetMutableState(s);
  793. state->SetFinal(std::move(weight));
  794. static constexpr auto flags = kCacheFinal | kCacheRecent;
  795. state->SetFlags(flags, flags);
  796. }
  797. // Adds a single arc to a state but delays cache book-keeping. SetArcs must
  798. // be called when all PushArc and EmplaceArc calls at a state are complete.
  799. // Do not mix with calls to AddArc.
  800. void PushArc(StateId s, const Arc &arc) {
  801. auto *state = cache_store_->GetMutableState(s);
  802. state->PushArc(arc);
  803. }
  804. void PushArc(StateId s, Arc &&arc) {
  805. auto *state = cache_store_->GetMutableState(s);
  806. state->PushArc(std::move(arc));
  807. }
  808. // Adds a single arc to a state but delays cache book-keeping. SetArcs must
  809. // be called when all PushArc and EmplaceArc calls at a state are complete.
  810. // Do not mix with calls to AddArc.
  811. template <class... T>
  812. void EmplaceArc(StateId s, T &&...ctor_args) {
  813. auto *state = cache_store_->GetMutableState(s);
  814. state->EmplaceArc(std::forward<T>(ctor_args)...);
  815. }
  816. // Marks arcs of a state as cached and does cache book-keeping after all
  817. // calls to PushArc have been completed. Do not mix with calls to AddArc.
  818. void SetArcs(StateId s) {
  819. auto *state = cache_store_->GetMutableState(s);
  820. cache_store_->SetArcs(state);
  821. const auto narcs = state->NumArcs();
  822. for (size_t a = 0; a < narcs; ++a) {
  823. const auto &arc = state->GetArc(a);
  824. if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1;
  825. }
  826. SetExpandedState(s);
  827. static constexpr auto flags = kCacheArcs | kCacheRecent;
  828. state->SetFlags(flags, flags);
  829. }
  830. void ReserveArcs(StateId s, size_t n) {
  831. auto *state = cache_store_->GetMutableState(s);
  832. state->ReserveArcs(n);
  833. }
  834. void DeleteArcs(StateId s) {
  835. auto *state = cache_store_->GetMutableState(s);
  836. cache_store_->DeleteArcs(state);
  837. }
  838. void DeleteArcs(StateId s, size_t n) {
  839. auto *state = cache_store_->GetMutableState(s);
  840. cache_store_->DeleteArcs(state, n);
  841. }
  842. void Clear() {
  843. nknown_states_ = 0;
  844. min_unexpanded_state_id_ = 0;
  845. max_expanded_state_id_ = -1;
  846. has_start_ = false;
  847. cache_start_ = kNoStateId;
  848. cache_store_->Clear();
  849. }
  850. // Is the start state cached?
  851. bool HasStart() const {
  852. if (!has_start_ && Properties(kError)) has_start_ = true;
  853. return has_start_;
  854. }
  855. // Is the final weight of the state cached?
  856. bool HasFinal(StateId s) const {
  857. const auto *state = cache_store_->GetState(s);
  858. if (state && state->Flags() & kCacheFinal) {
  859. state->SetFlags(kCacheRecent, kCacheRecent);
  860. return true;
  861. } else {
  862. return false;
  863. }
  864. }
  865. // Are arcs of the state cached?
  866. bool HasArcs(StateId s) const {
  867. const auto *state = cache_store_->GetState(s);
  868. if (state && state->Flags() & kCacheArcs) {
  869. state->SetFlags(kCacheRecent, kCacheRecent);
  870. return true;
  871. } else {
  872. return false;
  873. }
  874. }
  875. StateId Start() const { return cache_start_; }
  876. Weight Final(StateId s) const {
  877. const auto *state = cache_store_->GetState(s);
  878. return state->Final();
  879. }
  880. size_t NumArcs(StateId s) const {
  881. const auto *state = cache_store_->GetState(s);
  882. return state->NumArcs();
  883. }
  884. size_t NumInputEpsilons(StateId s) const {
  885. const auto *state = cache_store_->GetState(s);
  886. return state->NumInputEpsilons();
  887. }
  888. size_t NumOutputEpsilons(StateId s) const {
  889. const auto *state = cache_store_->GetState(s);
  890. return state->NumOutputEpsilons();
  891. }
  892. // Provides information needed for generic arc iterator.
  893. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
  894. const auto *state = cache_store_->GetState(s);
  895. data->base = nullptr;
  896. data->narcs = state->NumArcs();
  897. data->arcs = state->Arcs();
  898. data->ref_count = state->MutableRefCount();
  899. state->IncrRefCount();
  900. }
  901. // Number of known states.
  902. StateId NumKnownStates() const { return nknown_states_; }
  903. // Updates number of known states, taking into account the passed state ID.
  904. void UpdateNumKnownStates(StateId s) {
  905. if (s >= nknown_states_) nknown_states_ = s + 1;
  906. }
  907. // Finds the mininum never-expanded state ID.
  908. StateId MinUnexpandedState() const {
  909. while (min_unexpanded_state_id_ <= max_expanded_state_id_ &&
  910. ExpandedState(min_unexpanded_state_id_)) {
  911. ++min_unexpanded_state_id_;
  912. }
  913. return min_unexpanded_state_id_;
  914. }
  915. // Returns maximum ever-expanded state ID.
  916. StateId MaxExpandedState() const { return max_expanded_state_id_; }
  917. void SetExpandedState(StateId s) {
  918. if (s > max_expanded_state_id_) max_expanded_state_id_ = s;
  919. if (s < min_unexpanded_state_id_) return;
  920. if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_;
  921. if (cache_gc_ || cache_limit_ == 0) {
  922. if (expanded_states_.size() <= static_cast<size_t>(s))
  923. expanded_states_.resize(s + 1, false);
  924. expanded_states_[s] = true;
  925. }
  926. }
  927. bool ExpandedState(StateId s) const {
  928. if (cache_gc_ || cache_limit_ == 0) {
  929. return expanded_states_[s];
  930. } else if (new_cache_store_) {
  931. return cache_store_->GetState(s) != nullptr;
  932. } else {
  933. // If the cache was not created by this class, then the cached state needs
  934. // to be inspected to update nknown_states_.
  935. return false;
  936. }
  937. }
  938. const CacheStore *GetCacheStore() const { return cache_store_; }
  939. CacheStore *GetCacheStore() { return cache_store_; }
  940. // Caching on/off switch, limit and size accessors.
  941. bool GetCacheGc() const { return cache_gc_; }
  942. size_t GetCacheLimit() const { return cache_limit_; }
  943. private:
  944. mutable bool has_start_; // Is the start state cached?
  945. StateId cache_start_; // ID of start state.
  946. StateId nknown_states_; // Number of known states.
  947. std::vector<bool> expanded_states_; // States that have been expanded.
  948. mutable StateId min_unexpanded_state_id_; // Minimum never-expanded state ID
  949. mutable StateId max_expanded_state_id_; // Maximum ever-expanded state ID
  950. bool cache_gc_; // GC enabled.
  951. size_t cache_limit_; // Number of bytes allowed before GC.
  952. CacheStore *cache_store_; // The store of cached states.
  953. bool new_cache_store_; // Was the store was created by class?
  954. bool own_cache_store_; // Is the store owned by class?
  955. CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete;
  956. };
  957. // A CacheBaseImpl with the default cache state type.
  958. template <class Arc>
  959. class CacheImpl : public CacheBaseImpl<CacheState<Arc>> {
  960. public:
  961. using State = CacheState<Arc>;
  962. CacheImpl() = default;
  963. explicit CacheImpl(const CacheOptions &opts)
  964. : CacheBaseImpl<CacheState<Arc>>(opts) {}
  965. CacheImpl(const CacheImpl<Arc> &impl, bool preserve_cache = false)
  966. : CacheBaseImpl<State>(impl, preserve_cache) {}
  967. private:
  968. CacheImpl &operator=(const CacheImpl &impl) = delete;
  969. };
  970. } // namespace internal
  971. // Use this to make a state iterator for a CacheBaseImpl-derived FST, which must
  972. // have Arc and Store types defined. Note this iterator only returns those
  973. // states reachable from the initial state, so consider implementing a
  974. // class-specific one.
  975. //
  976. // This class may be derived from.
  977. template <class FST>
  978. class CacheStateIterator : public StateIteratorBase<typename FST::Arc> {
  979. public:
  980. using Arc = typename FST::Arc;
  981. using StateId = typename Arc::StateId;
  982. using Weight = typename Arc::Weight;
  983. using Store = typename FST::Store;
  984. using State = typename Store::State;
  985. using Impl = internal::CacheBaseImpl<State, Store>;
  986. CacheStateIterator(const FST &fst, Impl *impl)
  987. : fst_(fst), impl_(impl), s_(0) {
  988. fst_.Start(); // Forces start state.
  989. }
  990. bool Done() const final {
  991. if (s_ < impl_->NumKnownStates()) return false;
  992. for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates();
  993. u = impl_->MinUnexpandedState()) {
  994. // Forces state expansion.
  995. ArcIterator<FST> aiter(fst_, u);
  996. aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
  997. for (; !aiter.Done(); aiter.Next()) {
  998. impl_->UpdateNumKnownStates(aiter.Value().nextstate);
  999. }
  1000. impl_->SetExpandedState(u);
  1001. if (s_ < impl_->NumKnownStates()) return false;
  1002. }
  1003. return true;
  1004. }
  1005. StateId Value() const final { return s_; }
  1006. void Next() final { ++s_; }
  1007. void Reset() final { s_ = 0; }
  1008. private:
  1009. const FST &fst_;
  1010. Impl *impl_;
  1011. StateId s_;
  1012. };
  1013. // Used to make an arc iterator for a CacheBaseImpl-derived FST, which must
  1014. // have Arc and State types defined.
  1015. template <class FST>
  1016. class CacheArcIterator {
  1017. public:
  1018. using Arc = typename FST::Arc;
  1019. using StateId = typename Arc::StateId;
  1020. using Weight = typename Arc::Weight;
  1021. using Store = typename FST::Store;
  1022. using State = typename Store::State;
  1023. using Impl = internal::CacheBaseImpl<State, Store>;
  1024. CacheArcIterator(Impl *impl, StateId s) : i_(0) {
  1025. state_ = impl->GetCacheStore()->GetMutableState(s);
  1026. state_->IncrRefCount();
  1027. }
  1028. ~CacheArcIterator() { state_->DecrRefCount(); }
  1029. bool Done() const { return i_ >= state_->NumArcs(); }
  1030. const Arc &Value() const { return state_->GetArc(i_); }
  1031. void Next() { ++i_; }
  1032. size_t Position() const { return i_; }
  1033. void Reset() { i_ = 0; }
  1034. void Seek(size_t a) { i_ = a; }
  1035. constexpr uint8_t Flags() const { return kArcValueFlags; }
  1036. void SetFlags(uint8_t flags, uint8_t mask) {}
  1037. private:
  1038. const State *state_;
  1039. size_t i_;
  1040. CacheArcIterator(const CacheArcIterator &) = delete;
  1041. CacheArcIterator &operator=(const CacheArcIterator &) = delete;
  1042. };
  1043. // Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST,
  1044. // which must have types Arc and Store defined.
  1045. template <class FST>
  1046. class CacheMutableArcIterator
  1047. : public MutableArcIteratorBase<typename FST::Arc> {
  1048. public:
  1049. using Arc = typename FST::Arc;
  1050. using StateId = typename Arc::StateId;
  1051. using Weight = typename Arc::Weight;
  1052. using Store = typename FST::Store;
  1053. using State = typename Store::State;
  1054. using Impl = internal::CacheBaseImpl<State, Store>;
  1055. // User must call MutateCheck() in the constructor.
  1056. CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
  1057. state_ = impl_->GetCacheStore()->GetMutableState(s_);
  1058. state_->IncrRefCount();
  1059. }
  1060. ~CacheMutableArcIterator() override { state_->DecrRefCount(); }
  1061. bool Done() const final { return i_ >= state_->NumArcs(); }
  1062. const Arc &Value() const final { return state_->GetArc(i_); }
  1063. void Next() final { ++i_; }
  1064. size_t Position() const final { return i_; }
  1065. void Reset() final { i_ = 0; }
  1066. void Seek(size_t a) final { i_ = a; }
  1067. void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); }
  1068. uint8_t Flags() const final { return kArcValueFlags; }
  1069. void SetFlags(uint8_t, uint8_t) final {}
  1070. private:
  1071. size_t i_;
  1072. StateId s_;
  1073. Impl *impl_;
  1074. State *state_;
  1075. CacheMutableArcIterator(const CacheMutableArcIterator &) = delete;
  1076. CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete;
  1077. };
  1078. // Wrap existing CacheStore implementation to use with ExpanderFst.
  1079. template <class CacheStore>
  1080. class ExpanderCacheStore {
  1081. public:
  1082. using State = typename CacheStore::State;
  1083. using Arc = typename CacheStore::Arc;
  1084. using StateId = typename Arc::StateId;
  1085. using Weight = typename Arc::Weight;
  1086. explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions())
  1087. : store_(opts) {}
  1088. template <class Expander>
  1089. State *FindOrExpand(Expander &expander, StateId s) {
  1090. auto *state = store_.GetMutableState(s);
  1091. if (state->Flags()) {
  1092. state->SetFlags(kCacheRecent, kCacheRecent);
  1093. } else {
  1094. StateBuilder builder(state);
  1095. expander.Expand(s, &builder);
  1096. state->SetFlags(kCacheFlags, kCacheFlags);
  1097. store_.SetArcs(state);
  1098. }
  1099. return state;
  1100. }
  1101. private:
  1102. CacheStore store_;
  1103. struct StateBuilder {
  1104. State *state;
  1105. explicit StateBuilder(State *state_) : state(state_) {}
  1106. void AddArc(const Arc &arc) { state->PushArc(arc); }
  1107. void AddArc(Arc &&arc) { state->PushArc(std::move(arc)); }
  1108. void SetFinal(Weight weight = Weight::One()) {
  1109. state->SetFinal(std::move(weight));
  1110. }
  1111. };
  1112. };
  1113. } // namespace fst
  1114. #endif // FST_CACHE_H_