|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// An FST implementation that caches FST elements of a delayed computation.
#ifndef FST_CACHE_H_
#define FST_CACHE_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <list>
#include <memory>
#include <new>
#include <utility>
#include <vector>
#include <fst/flags.h>
#include <fst/log.h>
#include <fst/fst.h>
#include <fst/memory.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/util.h>
#include <fst/vector-fst.h>
#include <unordered_map>
#include <functional>
DECLARE_bool(fst_default_cache_gc); DECLARE_int64(fst_default_cache_gc_limit);
namespace fst {
// Options for controlling caching behavior; higher level than CacheImplOptions.
struct CacheOptions { bool gc; // Enables GC.
size_t gc_limit; // Number of bytes allowed before GC.
explicit CacheOptions( bool gc = FST_FLAGS_fst_default_cache_gc, size_t gc_limit = FST_FLAGS_fst_default_cache_gc_limit) : gc(gc), gc_limit(gc_limit) {} };
// Options for controlling caching behavior, at a lower level than
// CacheOptions; templated on the cache store and allows passing the store.
template <class CacheStore> struct CacheImplOptions { bool gc; // Enables GC.
size_t gc_limit; // Number of bytes allowed before GC.
CacheStore *store; // Cache store.
bool own_store; // Should CacheImpl takes ownership of the store?
explicit CacheImplOptions( bool gc = FST_FLAGS_fst_default_cache_gc, size_t gc_limit = FST_FLAGS_fst_default_cache_gc_limit, CacheStore *store = nullptr) : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {}
explicit CacheImplOptions(const CacheOptions &opts) : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {} };
// Cache flags.
inline constexpr uint8_t kCacheFinal = 0x01; // Final weight has been cached.
inline constexpr uint8_t kCacheArcs = 0x02; // Arcs have been cached.
inline constexpr uint8_t kCacheInit = 0x04; // Initialized by GC.
inline constexpr uint8_t kCacheRecent = 0x08; // Visited since GC.
inline constexpr uint8_t kCacheFlags = kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent;
// Cache state, with arcs stored in a per-state std::vector.
template <class A, class M = PoolAllocator<A>> class CacheState { public: using Arc = A; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using ArcAllocator = M; using StateAllocator = typename std::allocator_traits< ArcAllocator>::template rebind_alloc<CacheState<A, M>>;
// Provides STL allocator for arcs.
explicit CacheState(const ArcAllocator &alloc) : final_weight_(Weight::Zero()), niepsilons_(0), noepsilons_(0), arcs_(alloc), flags_(0), ref_count_(0) {}
CacheState(const CacheState<A> &state, const ArcAllocator &alloc) : final_weight_(state.Final()), niepsilons_(state.NumInputEpsilons()), noepsilons_(state.NumOutputEpsilons()), arcs_(state.arcs_.begin(), state.arcs_.end(), alloc), flags_(state.Flags()), ref_count_(0) {}
void Reset() { final_weight_ = Weight::Zero(); niepsilons_ = 0; noepsilons_ = 0; ref_count_ = 0; flags_ = 0; arcs_.clear(); }
Weight Final() const { return final_weight_; }
size_t NumInputEpsilons() const { return niepsilons_; }
size_t NumOutputEpsilons() const { return noepsilons_; }
size_t NumArcs() const { return arcs_.size(); }
const Arc &GetArc(size_t n) const { return arcs_[n]; }
// Used by the ArcIterator<Fst<Arc>> efficient implementation.
const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
// Accesses flags; used by the caller.
uint8_t Flags() const { return flags_; }
// Accesses ref count; used by the caller.
int RefCount() const { return ref_count_; }
void SetFinal(Weight weight = Weight::One()) { final_weight_ = std::move(weight); }
void ReserveArcs(size_t n) { arcs_.reserve(n); }
// Adds one arc at a time with all needed book-keeping; use PushArc and
// SetArcs for a more efficient alternative.
void AddArc(const Arc &arc) { IncrementNumEpsilons(arc); arcs_.push_back(arc); }
void AddArc(Arc &&arc) { IncrementNumEpsilons(arc); arcs_.push_back(std::move(arc)); }
// Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
void PushArc(const Arc &arc) { arcs_.push_back(arc); }
void PushArc(Arc &&arc) { arcs_.push_back(std::move(arc)); }
// Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
template <class... T> void EmplaceArc(T &&...ctor_args) { arcs_.emplace_back(std::forward<T>(ctor_args)...); }
// Finalizes arcs book-keeping; call only once.
void SetArcs() { for (const auto &arc : arcs_) { IncrementNumEpsilons(arc); } }
// Modifies nth arc.
void SetArc(const Arc &arc, size_t n) { if (arcs_[n].ilabel == 0) --niepsilons_; if (arcs_[n].olabel == 0) --noepsilons_; IncrementNumEpsilons(arc); arcs_[n] = arc; }
// Deletes all arcs.
void DeleteArcs() { niepsilons_ = 0; noepsilons_ = 0; arcs_.clear(); }
void DeleteArcs(size_t n) { for (size_t i = 0; i < n; ++i) { if (arcs_.back().ilabel == 0) --niepsilons_; if (arcs_.back().olabel == 0) --noepsilons_; arcs_.pop_back(); } }
// Sets status flags; used by the caller.
void SetFlags(uint8_t flags, uint8_t mask) const { flags_ &= ~mask; flags_ |= flags; }
// Mutates reference counts; used by the caller.
int IncrRefCount() const { return ++ref_count_; }
int DecrRefCount() const { return --ref_count_; }
// Used by the ArcIterator<Fst<Arc>> efficient implementation.
int *MutableRefCount() const { return &ref_count_; }
// Used for state class allocation.
void *operator new(size_t size, StateAllocator *alloc) { return alloc->allocate(1); }
// For state destruction and memory freeing.
static void Destroy(CacheState<Arc> *state, StateAllocator *alloc) { if (state) { state->~CacheState<Arc>(); alloc->deallocate(state, 1); } }
private: // Update the number of epsilons as a result of having added an arc.
void IncrementNumEpsilons(const Arc &arc) { if (arc.ilabel == 0) ++niepsilons_; if (arc.olabel == 0) ++noepsilons_; }
Weight final_weight_; // Final weight.
size_t niepsilons_; // # of input epsilons.
size_t noepsilons_; // # of output epsilons.
std::vector<Arc, ArcAllocator> arcs_; // Arcs representation.
mutable uint8_t flags_; mutable int ref_count_; // If 0, available for GC.
};
// Cache store, allocating and storing states, providing a mapping from state
// IDs to cached states, and an iterator over these states. The state template
// argument must implement the CacheState interface. The state for a StateId s
// is constructed when requested by GetMutableState(s) if it is not yet stored.
// Initially, a state has a reference count of zero, but the user may increment
// or decrement this to control the time of destruction. In particular, a state
// is destroyed when:
//
// 1. This instance is destroyed, or
// 2. Clear() or Delete() is called, or
// 3. Possibly (implementation-dependently) when:
// - Garbage collection is enabled (as defined by opts.gc),
// - The cache store size exceeds the limits (as defined by opts.gc_limits),
// - The state's reference count is zero, and
// - The state is not the most recently requested state.
//
// template <class S>
// class CacheStore {
// public:
// using State = S;
// using Arc = typename State::Arc;
// using StateId = typename Arc::StateId;
//
// // Required constructors/assignment operators.
// explicit CacheStore(const CacheOptions &opts);
//
// // Returns nullptr if state is not stored.
// const State *GetState(StateId s);
//
// // Creates state if state is not stored.
// State *GetMutableState(StateId s);
//
// // Similar to State::AddArc() but updates cache store book-keeping.
// void AddArc(State *state, const Arc &arc);
//
// // Similar to State::SetArcs() but updates cache store book-keeping; call
// // only once.
// void SetArcs(State *state);
//
// // Similar to State::DeleteArcs() but updates cache store book-keeping.
//
// void DeleteArcs(State *state);
//
// void DeleteArcs(State *state, size_t n);
//
// // Deletes all cached states.
// void Clear();
//
// // Number of cached states.
// StateId CountStates();
//
// // Iterates over cached states (in an arbitrary order); only needed if
// // opts.gc is true.
// bool Done() const; // End of iteration.
// StateId Value() const; // Current state.
// void Next(); // Advances to next state (when !Done).
// void Reset(); // Returns to initial condition.
// void Delete(); // Deletes current state and advances to next.
// };
// Container cache stores.
// This class uses a vector of pointers to states to store cached states.
template <class S> class VectorCacheStore { public: using State = S; using Arc = typename State::Arc; using StateId = typename Arc::StateId; using StateList = std::list<StateId, PoolAllocator<StateId>>;
// Required constructors/assignment operators.
explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) { Clear(); Reset(); }
VectorCacheStore(const VectorCacheStore<S> &store) : cache_gc_(store.cache_gc_) { CopyStates(store); Reset(); }
~VectorCacheStore() { Clear(); }
VectorCacheStore &operator=(const VectorCacheStore &store) { if (this != &store) { CopyStates(store); Reset(); } return *this; }
bool InBounds(StateId s) const { return s < static_cast<StateId>(state_vec_.size()); }
// Returns nullptr if state is not stored.
const State *GetState(StateId s) const { return InBounds(s) ? state_vec_[s] : nullptr; }
// Creates state if state is not stored.
State *GetMutableState(StateId s) { State *state = nullptr; if (InBounds(s)) { state = state_vec_[s]; } else { state_vec_.resize(s + 1, nullptr); } if (!state) { state = new (&state_alloc_) State(arc_alloc_); state_vec_[s] = state; if (cache_gc_) state_list_.push_back(s); } return state; }
// Similar to State::AddArc() but updates cache store book-keeping
void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
// Similar to State::SetArcs() but updates cache store book-keeping; call
// only once.
void SetArcs(State *state) { state->SetArcs(); }
// Deletes all arcs.
void DeleteArcs(State *state) { state->DeleteArcs(); }
// Deletes some arcs.
void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
// Deletes all cached states.
void Clear() { for (State *s : state_vec_) { State::Destroy(s, &state_alloc_); } state_vec_.clear(); state_list_.clear(); }
StateId CountStates() const { return std::count_if(state_vec_.begin(), state_vec_.end(), [](const State *s) { return s != nullptr; }); }
// Iterates over cached states (in an arbitrary order); only works if GC is
// enabled (o.w. avoiding state_list_ overhead).
bool Done() const { return iter_ == state_list_.end(); }
StateId Value() const { return *iter_; }
void Next() { ++iter_; }
void Reset() { iter_ = state_list_.begin(); }
// Deletes current state and advances to next.
void Delete() { State::Destroy(state_vec_[*iter_], &state_alloc_); state_vec_[*iter_] = nullptr; state_list_.erase(iter_++); }
private: void CopyStates(const VectorCacheStore<State> &store) { Clear(); state_vec_.reserve(store.state_vec_.size()); for (size_t s = 0; s < store.state_vec_.size(); ++s) { State *state = nullptr; const auto *store_state = store.state_vec_[s]; if (store_state) { state = new (&state_alloc_) State(*store_state, arc_alloc_); if (cache_gc_) state_list_.push_back(s); } state_vec_.push_back(state); } }
bool cache_gc_; // Supports iteration when true.
std::vector<State *> state_vec_; // Vector of states (or null).
StateList state_list_; // List of states.
typename StateList::iterator iter_; // State list iterator.
typename State::StateAllocator state_alloc_; // For state allocation.
typename State::ArcAllocator arc_alloc_; // For arc allocation.
};
// This class uses a hash map from state IDs to pointers to cached states.
template <class S> class HashCacheStore { public: using State = S; using Arc = typename State::Arc; using StateId = typename Arc::StateId;
using StateMap = std::unordered_map<StateId, State *, std::hash<StateId>, std::equal_to<StateId>, PoolAllocator<std::pair<const StateId, State *>>>;
// Required constructors/assignment operators.
explicit HashCacheStore(const CacheOptions &opts) { Clear(); Reset(); }
HashCacheStore(const HashCacheStore<S> &store) { CopyStates(store); Reset(); }
~HashCacheStore() { Clear(); }
HashCacheStore &operator=(const HashCacheStore &store) { if (this != &store) { CopyStates(store); Reset(); } return *this; }
// Returns nullptr if state is not stored.
const State *GetState(StateId s) const { const auto it = state_map_.find(s); return it != state_map_.end() ? it->second : nullptr; }
// Creates state if state is not stored.
State *GetMutableState(StateId s) { auto *&state = state_map_[s]; if (!state) state = new (&state_alloc_) State(arc_alloc_); return state; }
// Similar to State::AddArc() but updates cache store book-keeping.
void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
// Similar to State::SetArcs() but updates internal cache size; call only
// once.
void SetArcs(State *state) { state->SetArcs(); }
// Deletes all arcs.
void DeleteArcs(State *state) { state->DeleteArcs(); }
// Deletes some arcs.
void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
// Deletes all cached states.
void Clear() { for (auto &[unused_state_id, state_ptr] : state_map_) { State::Destroy(state_ptr, &state_alloc_); } state_map_.clear(); }
StateId CountStates() const { return state_map_.size(); }
// Iterates over cached states (in an arbitrary order).
bool Done() const { return iter_ == state_map_.end(); }
StateId Value() const { return iter_->first; }
void Next() { ++iter_; }
void Reset() { iter_ = state_map_.begin(); }
// Deletes current state and advances to next.
void Delete() { State::Destroy(iter_->second, &state_alloc_); state_map_.erase(iter_++); }
private: void CopyStates(const HashCacheStore<State> &store) { Clear(); for (auto &[state_id, state_ptr] : store.state_map_) { state_map_[state_id] = new (&state_alloc_) State(*state_ptr, arc_alloc_); } }
StateMap state_map_; // Map from state ID to state.
typename StateMap::iterator iter_; // State map iterator.
typename State::StateAllocator state_alloc_; // For state allocation.
typename State::ArcAllocator arc_alloc_; // For arc allocation.
};
// Garbage-colllection cache stores.
// This class implements a simple garbage collection scheme when
// 'opts.gc_limit = 0'. In particular, the first cached state is reused for each
// new state so long as the reference count is zero on the to-be-reused state.
// Otherwise, the full underlying store is used. The caller can increment the
// reference count to inhibit the GC of in-use states (e.g., in an ArcIterator).
//
// The typical use case for this optimization is when a single pass over a
// cached
// FST is performed with only one-state expanded at a time.
template <class CacheStore> class FirstCacheStore { public: using State = typename CacheStore::State; using Arc = typename State::Arc; using StateId = typename Arc::StateId;
// Required constructors/assignment operators.
explicit FirstCacheStore(const CacheOptions &opts) : store_(opts), cache_gc_(opts.gc_limit == 0), // opts.gc ignored historically.
cache_first_state_id_(kNoStateId), cache_first_state_(nullptr) {}
FirstCacheStore(const FirstCacheStore<CacheStore> &store) : store_(store.store_), cache_gc_(store.cache_gc_), cache_first_state_id_(store.cache_first_state_id_), cache_first_state_(store.cache_first_state_id_ != kNoStateId ? store_.GetMutableState(0) : nullptr) {}
FirstCacheStore<CacheStore> &operator=( const FirstCacheStore<CacheStore> &store) { if (this != &store) { store_ = store.store_; cache_gc_ = store.cache_gc_; cache_first_state_id_ = store.cache_first_state_id_; cache_first_state_ = store.cache_first_state_id_ != kNoStateId ? store_.GetMutableState(0) : nullptr; } return *this; }
// Returns nullptr if state is not stored.
const State *GetState(StateId s) const { // store_ state 0 may hold first cached state; the rest are shifted by 1.
return s == cache_first_state_id_ ? cache_first_state_ : store_.GetState(s + 1); }
// Creates state if state is not stored.
State *GetMutableState(StateId s) { // store_ state 0 used to hold first cached state; the rest are shifted by
// 1.
if (cache_first_state_id_ == s) { return cache_first_state_; // Request for first cached state.
} if (cache_gc_) { if (cache_first_state_id_ == kNoStateId) { cache_first_state_id_ = s; // Sets first cached state.
cache_first_state_ = store_.GetMutableState(0); cache_first_state_->SetFlags(kCacheInit, kCacheInit); cache_first_state_->ReserveArcs(2 * kAllocSize); return cache_first_state_; } else if (cache_first_state_->RefCount() == 0) { cache_first_state_id_ = s; // Updates first cached state.
cache_first_state_->Reset(); cache_first_state_->SetFlags(kCacheInit, kCacheInit); return cache_first_state_; } else { // Keeps first cached state.
cache_first_state_->SetFlags(0, kCacheInit); // Clears initialized bit.
cache_gc_ = false; // Disables GC.
} } auto *state = store_.GetMutableState(s + 1); return state; }
// Similar to State::AddArc() but updates cache store book-keeping.
void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); }
// Similar to State::SetArcs() but updates internal cache size; call only
// once.
void SetArcs(State *state) { store_.SetArcs(state); }
// Deletes all arcs
void DeleteArcs(State *state) { store_.DeleteArcs(state); }
// Deletes some arcs
void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); }
// Deletes all cached states
void Clear() { store_.Clear(); cache_first_state_id_ = kNoStateId; cache_first_state_ = nullptr; }
StateId CountStates() const { return store_.CountStates(); }
// Iterates over cached states (in an arbitrary order). Only needed if GC is
// enabled.
bool Done() const { return store_.Done(); }
StateId Value() const { // store_ state 0 may hold first cached state; rest shifted + 1.
const auto s = store_.Value(); return s ? s - 1 : cache_first_state_id_; }
void Next() { store_.Next(); }
void Reset() { store_.Reset(); }
// Deletes current state and advances to next.
void Delete() { if (Value() == cache_first_state_id_) { cache_first_state_id_ = kNoStateId; cache_first_state_ = nullptr; } store_.Delete(); }
private: CacheStore store_; // Underlying store.
bool cache_gc_; // GC enabled.
StateId cache_first_state_id_; // First cached state ID.
State *cache_first_state_; // First cached state.
};
// This class implements mark-sweep garbage collection on an underlying cache
// store. If GC is enabled, garbage collection of states is performed in a
// rough approximation of LRU order once when 'gc_limit' bytes is reached. The
// caller can increment the reference count to inhibit the GC of in-use state
// (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows
// the caller to trade-off time vs. space.
template <class CacheStore> class GCCacheStore { public: using State = typename CacheStore::State; using Arc = typename State::Arc; using StateId = typename Arc::StateId;
// Required constructors/assignment operators.
explicit GCCacheStore(const CacheOptions &opts) : store_(opts), cache_gc_request_(opts.gc), cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit : kMinCacheLimit), cache_gc_(false), cache_size_(0) {}
// Returns 0 if state is not stored.
const State *GetState(StateId s) const { return store_.GetState(s); }
// Creates state if state is not stored
State *GetMutableState(StateId s) { auto *state = store_.GetMutableState(s); if (cache_gc_request_ && !(state->Flags() & kCacheInit)) { state->SetFlags(kCacheInit, kCacheInit); cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc); // GC is enabled once an uninited state (from underlying store) is seen.
cache_gc_ = true; if (cache_size_ > cache_limit_) GC(state, false); } return state; }
// Similar to State::AddArc() but updates cache store book-keeping.
void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); if (cache_gc_ && (state->Flags() & kCacheInit)) { cache_size_ += sizeof(Arc); if (cache_size_ > cache_limit_) GC(state, false); } }
// Similar to State::SetArcs() but updates internal cache size; call only
// once.
void SetArcs(State *state) { store_.SetArcs(state); if (cache_gc_ && (state->Flags() & kCacheInit)) { cache_size_ += state->NumArcs() * sizeof(Arc); if (cache_size_ > cache_limit_) GC(state, false); } }
// Deletes all arcs.
void DeleteArcs(State *state) { if (cache_gc_ && (state->Flags() & kCacheInit)) { cache_size_ -= state->NumArcs() * sizeof(Arc); } store_.DeleteArcs(state); }
// Deletes some arcs.
void DeleteArcs(State *state, size_t n) { if (cache_gc_ && (state->Flags() & kCacheInit)) { cache_size_ -= n * sizeof(Arc); } store_.DeleteArcs(state, n); }
// Deletes all cached states.
void Clear() { store_.Clear(); cache_size_ = 0; }
StateId CountStates() const { return store_.CountStates(); }
// Iterates over cached states (in an arbitrary order); only needed if GC is
// enabled.
bool Done() const { return store_.Done(); }
StateId Value() const { return store_.Value(); }
void Next() { store_.Next(); }
void Reset() { store_.Reset(); }
// Deletes current state and advances to next.
void Delete() { if (cache_gc_) { const auto *state = store_.GetState(Value()); if (state->Flags() & kCacheInit) { cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc); } } store_.Delete(); }
// Removes from the cache store (not referenced-counted and not the current)
// states that have not been accessed since the last GC until at most
// cache_fraction * cache_limit_ bytes are cached. If that fails to free
// enough, attempts to uncaching recently visited states as well. If still
// unable to free enough memory, then widens cache_limit_.
void GC(const State *current, bool free_recent, float cache_fraction = 0.666);
// Returns the current cache size in bytes or 0 if GC is disabled.
size_t CacheSize() const { return cache_size_; }
// Returns the cache limit in bytes.
size_t CacheLimit() const { return cache_limit_; }
private: static constexpr size_t kMinCacheLimit = 8096; // Minimum cache limit.
CacheStore store_; // Underlying store.
bool cache_gc_request_; // GC requested but possibly not yet enabled.
size_t cache_limit_; // Number of bytes allowed before GC.
bool cache_gc_; // GC enabled
size_t cache_size_; // Number of bytes cached.
};
template <class CacheStore> void GCCacheStore<CacheStore>::GC(const State *current, bool free_recent, float cache_fraction) { if (!cache_gc_) return; VLOG(2) << "GCCacheStore: Enter GC: object = " << "(" << this << "), free recently cached = " << free_recent << ", cache size = " << cache_size_ << ", cache frac = " << cache_fraction << ", cache limit = " << cache_limit_ << "\n"; size_t cache_target = cache_fraction * cache_limit_; store_.Reset(); while (!store_.Done()) { auto *state = store_.GetMutableState(store_.Value()); if (cache_size_ > cache_target && state->RefCount() == 0 && (free_recent || !(state->Flags() & kCacheRecent)) && state != current) { if (state->Flags() & kCacheInit) { size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc); if (size < cache_size_) { cache_size_ -= size; } } store_.Delete(); } else { state->SetFlags(0, kCacheRecent); store_.Next(); } } if (!free_recent && cache_size_ > cache_target) { // Recurses on recent.
GC(current, true, cache_fraction); } else if (cache_target > 0) { // Widens cache limit.
while (cache_size_ > cache_target) { cache_limit_ *= 2; cache_target *= 2; } } else if (cache_size_ > 0) { FSTERROR() << "GCCacheStore:GC: Unable to free all cached states"; } VLOG(2) << "GCCacheStore: Exit GC: object = " << "(" << this << "), free recently cached = " << free_recent << ", cache size = " << cache_size_ << ", cache frac = " << cache_fraction << ", cache limit = " << cache_limit_ << "\n"; }
// This class is the default cache state and store used by CacheBaseImpl.
// It uses VectorCacheStore for storage decorated by FirstCacheStore
// and GCCacheStore to do (optional) garbage collection.
template <class Arc> class DefaultCacheStore : public GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>> { public: explicit DefaultCacheStore(const CacheOptions &opts) : GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>>(opts) { } };
namespace internal {
// This class is used to cache FST elements stored in states of type State
// (see CacheState) with the flags used to indicate what has been cached. Use
// HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(),
// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you
// must set the final weight even if the state is non-final to mark it as
// cached. The state storage method and any garbage collection policy are
// determined by the cache store. If the store is passed in with the options,
// CacheBaseImpl takes ownership.
template <class State, class CacheStore = DefaultCacheStore<typename State::Arc>> class CacheBaseImpl : public FstImpl<typename State::Arc> { public: using Arc = typename State::Arc; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using Store = CacheStore;
using FstImpl<Arc>::Type; using FstImpl<Arc>::Properties;
explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions()) : has_start_(false), cache_start_(kNoStateId), nknown_states_(0), min_unexpanded_state_id_(0), max_expanded_state_id_(-1), cache_gc_(opts.gc), cache_limit_(opts.gc_limit), cache_store_(new CacheStore(opts)), new_cache_store_(true), own_cache_store_(true) {}
explicit CacheBaseImpl(const CacheImplOptions<CacheStore> &opts) : has_start_(false), cache_start_(kNoStateId), nknown_states_(0), min_unexpanded_state_id_(0), max_expanded_state_id_(-1), cache_gc_(opts.gc), cache_limit_(opts.gc_limit), cache_store_( opts.store ? opts.store : new CacheStore(CacheOptions(opts.gc, opts.gc_limit))), new_cache_store_(!opts.store), own_cache_store_(opts.store ? opts.own_store : true) {}
// Preserve gc parameters. If preserve_cache is true, also preserves
// cache data.
CacheBaseImpl(const CacheBaseImpl<State, CacheStore> &impl, bool preserve_cache = false) : FstImpl<Arc>(), has_start_(false), cache_start_(kNoStateId), nknown_states_(0), min_unexpanded_state_id_(0), max_expanded_state_id_(-1), cache_gc_(impl.cache_gc_), cache_limit_(impl.cache_limit_), cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))), new_cache_store_(impl.new_cache_store_ || !preserve_cache), own_cache_store_(true) { if (preserve_cache) { *cache_store_ = *impl.cache_store_; has_start_ = impl.has_start_; cache_start_ = impl.cache_start_; nknown_states_ = impl.nknown_states_; expanded_states_ = impl.expanded_states_; min_unexpanded_state_id_ = impl.min_unexpanded_state_id_; max_expanded_state_id_ = impl.max_expanded_state_id_; } }
~CacheBaseImpl() override { if (own_cache_store_) delete cache_store_; }
void SetStart(StateId s) { cache_start_ = s; has_start_ = true; if (s >= nknown_states_) nknown_states_ = s + 1; }
void SetFinal(StateId s, Weight weight = Weight::One()) { auto *state = cache_store_->GetMutableState(s); state->SetFinal(std::move(weight)); static constexpr auto flags = kCacheFinal | kCacheRecent; state->SetFlags(flags, flags); }
// Adds a single arc to a state but delays cache book-keeping. SetArcs must
// be called when all PushArc and EmplaceArc calls at a state are complete.
// Do not mix with calls to AddArc.
void PushArc(StateId s, const Arc &arc) { auto *state = cache_store_->GetMutableState(s); state->PushArc(arc); }
void PushArc(StateId s, Arc &&arc) { auto *state = cache_store_->GetMutableState(s); state->PushArc(std::move(arc)); }
// Adds a single arc to a state but delays cache book-keeping. SetArcs must
// be called when all PushArc and EmplaceArc calls at a state are complete.
// Do not mix with calls to AddArc.
template <class... T> void EmplaceArc(StateId s, T &&...ctor_args) { auto *state = cache_store_->GetMutableState(s); state->EmplaceArc(std::forward<T>(ctor_args)...); }
// Marks arcs of a state as cached and does cache book-keeping after all
// calls to PushArc have been completed. Do not mix with calls to AddArc.
void SetArcs(StateId s) { auto *state = cache_store_->GetMutableState(s); cache_store_->SetArcs(state); const auto narcs = state->NumArcs(); for (size_t a = 0; a < narcs; ++a) { const auto &arc = state->GetArc(a); if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1; } SetExpandedState(s); static constexpr auto flags = kCacheArcs | kCacheRecent; state->SetFlags(flags, flags); }
void ReserveArcs(StateId s, size_t n) { auto *state = cache_store_->GetMutableState(s); state->ReserveArcs(n); }
void DeleteArcs(StateId s) { auto *state = cache_store_->GetMutableState(s); cache_store_->DeleteArcs(state); }
void DeleteArcs(StateId s, size_t n) { auto *state = cache_store_->GetMutableState(s); cache_store_->DeleteArcs(state, n); }
void Clear() { nknown_states_ = 0; min_unexpanded_state_id_ = 0; max_expanded_state_id_ = -1; has_start_ = false; cache_start_ = kNoStateId; cache_store_->Clear(); }
// Is the start state cached?
bool HasStart() const { if (!has_start_ && Properties(kError)) has_start_ = true; return has_start_; }
// Is the final weight of the state cached?
bool HasFinal(StateId s) const { const auto *state = cache_store_->GetState(s); if (state && state->Flags() & kCacheFinal) { state->SetFlags(kCacheRecent, kCacheRecent); return true; } else { return false; } }
// Are arcs of the state cached?
bool HasArcs(StateId s) const { const auto *state = cache_store_->GetState(s); if (state && state->Flags() & kCacheArcs) { state->SetFlags(kCacheRecent, kCacheRecent); return true; } else { return false; } }
StateId Start() const { return cache_start_; }
Weight Final(StateId s) const { const auto *state = cache_store_->GetState(s); return state->Final(); }
size_t NumArcs(StateId s) const { const auto *state = cache_store_->GetState(s); return state->NumArcs(); }
size_t NumInputEpsilons(StateId s) const { const auto *state = cache_store_->GetState(s); return state->NumInputEpsilons(); }
size_t NumOutputEpsilons(StateId s) const { const auto *state = cache_store_->GetState(s); return state->NumOutputEpsilons(); }
// Provides information needed for generic arc iterator.
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { const auto *state = cache_store_->GetState(s); data->base = nullptr; data->narcs = state->NumArcs(); data->arcs = state->Arcs(); data->ref_count = state->MutableRefCount(); state->IncrRefCount(); }
// Number of known states.
StateId NumKnownStates() const { return nknown_states_; }
// Updates number of known states, taking into account the passed state ID.
void UpdateNumKnownStates(StateId s) { if (s >= nknown_states_) nknown_states_ = s + 1; }
// Finds the mininum never-expanded state ID.
StateId MinUnexpandedState() const { while (min_unexpanded_state_id_ <= max_expanded_state_id_ && ExpandedState(min_unexpanded_state_id_)) { ++min_unexpanded_state_id_; } return min_unexpanded_state_id_; }
// Returns maximum ever-expanded state ID.
StateId MaxExpandedState() const { return max_expanded_state_id_; }
void SetExpandedState(StateId s) { if (s > max_expanded_state_id_) max_expanded_state_id_ = s; if (s < min_unexpanded_state_id_) return; if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_; if (cache_gc_ || cache_limit_ == 0) { if (expanded_states_.size() <= static_cast<size_t>(s)) expanded_states_.resize(s + 1, false); expanded_states_[s] = true; } }
bool ExpandedState(StateId s) const { if (cache_gc_ || cache_limit_ == 0) { return expanded_states_[s]; } else if (new_cache_store_) { return cache_store_->GetState(s) != nullptr; } else { // If the cache was not created by this class, then the cached state needs
// to be inspected to update nknown_states_.
return false; } }
const CacheStore *GetCacheStore() const { return cache_store_; }
CacheStore *GetCacheStore() { return cache_store_; }
// Caching on/off switch, limit and size accessors.
bool GetCacheGc() const { return cache_gc_; }
size_t GetCacheLimit() const { return cache_limit_; }
private: mutable bool has_start_; // Is the start state cached?
StateId cache_start_; // ID of start state.
StateId nknown_states_; // Number of known states.
std::vector<bool> expanded_states_; // States that have been expanded.
mutable StateId min_unexpanded_state_id_; // Minimum never-expanded state ID
mutable StateId max_expanded_state_id_; // Maximum ever-expanded state ID
bool cache_gc_; // GC enabled.
size_t cache_limit_; // Number of bytes allowed before GC.
CacheStore *cache_store_; // The store of cached states.
bool new_cache_store_; // Was the store was created by class?
bool own_cache_store_; // Is the store owned by class?
CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete; };
// A CacheBaseImpl with the default cache state type.
template <class Arc> class CacheImpl : public CacheBaseImpl<CacheState<Arc>> { public: using State = CacheState<Arc>;
CacheImpl() = default;
explicit CacheImpl(const CacheOptions &opts) : CacheBaseImpl<CacheState<Arc>>(opts) {}
CacheImpl(const CacheImpl<Arc> &impl, bool preserve_cache = false) : CacheBaseImpl<State>(impl, preserve_cache) {}
private: CacheImpl &operator=(const CacheImpl &impl) = delete; };
} // namespace internal
// Use this to make a state iterator for a CacheBaseImpl-derived FST, which must
// have Arc and Store types defined. Note this iterator only returns those
// states reachable from the initial state, so consider implementing a
// class-specific one.
//
// This class may be derived from.
template <class FST> class CacheStateIterator : public StateIteratorBase<typename FST::Arc> { public: using Arc = typename FST::Arc; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using Store = typename FST::Store; using State = typename Store::State; using Impl = internal::CacheBaseImpl<State, Store>;
CacheStateIterator(const FST &fst, Impl *impl) : fst_(fst), impl_(impl), s_(0) { fst_.Start(); // Forces start state.
}
bool Done() const final { if (s_ < impl_->NumKnownStates()) return false; for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates(); u = impl_->MinUnexpandedState()) { // Forces state expansion.
ArcIterator<FST> aiter(fst_, u); aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache); for (; !aiter.Done(); aiter.Next()) { impl_->UpdateNumKnownStates(aiter.Value().nextstate); } impl_->SetExpandedState(u); if (s_ < impl_->NumKnownStates()) return false; } return true; }
StateId Value() const final { return s_; }
void Next() final { ++s_; }
void Reset() final { s_ = 0; }
private: const FST &fst_; Impl *impl_; StateId s_; };
// Used to make an arc iterator for a CacheBaseImpl-derived FST, which must
// have Arc and State types defined.
template <class FST> class CacheArcIterator { public: using Arc = typename FST::Arc; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using Store = typename FST::Store; using State = typename Store::State; using Impl = internal::CacheBaseImpl<State, Store>;
CacheArcIterator(Impl *impl, StateId s) : i_(0) { state_ = impl->GetCacheStore()->GetMutableState(s); state_->IncrRefCount(); }
~CacheArcIterator() { state_->DecrRefCount(); }
bool Done() const { return i_ >= state_->NumArcs(); }
const Arc &Value() const { return state_->GetArc(i_); }
void Next() { ++i_; }
size_t Position() const { return i_; }
void Reset() { i_ = 0; }
void Seek(size_t a) { i_ = a; }
constexpr uint8_t Flags() const { return kArcValueFlags; }
void SetFlags(uint8_t flags, uint8_t mask) {}
private: const State *state_; size_t i_;
CacheArcIterator(const CacheArcIterator &) = delete; CacheArcIterator &operator=(const CacheArcIterator &) = delete; };
// Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST,
// which must have types Arc and Store defined.
template <class FST> class CacheMutableArcIterator : public MutableArcIteratorBase<typename FST::Arc> { public: using Arc = typename FST::Arc; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using Store = typename FST::Store; using State = typename Store::State; using Impl = internal::CacheBaseImpl<State, Store>;
// User must call MutateCheck() in the constructor.
CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) { state_ = impl_->GetCacheStore()->GetMutableState(s_); state_->IncrRefCount(); }
~CacheMutableArcIterator() override { state_->DecrRefCount(); }
bool Done() const final { return i_ >= state_->NumArcs(); }
const Arc &Value() const final { return state_->GetArc(i_); }
void Next() final { ++i_; }
size_t Position() const final { return i_; }
void Reset() final { i_ = 0; }
void Seek(size_t a) final { i_ = a; }
void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); }
uint8_t Flags() const final { return kArcValueFlags; }
void SetFlags(uint8_t, uint8_t) final {}
private: size_t i_; StateId s_; Impl *impl_; State *state_;
CacheMutableArcIterator(const CacheMutableArcIterator &) = delete; CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete; };
// Wrap existing CacheStore implementation to use with ExpanderFst.
template <class CacheStore> class ExpanderCacheStore { public: using State = typename CacheStore::State; using Arc = typename CacheStore::Arc; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions()) : store_(opts) {}
template <class Expander> State *FindOrExpand(Expander &expander, StateId s) { auto *state = store_.GetMutableState(s); if (state->Flags()) { state->SetFlags(kCacheRecent, kCacheRecent); } else { StateBuilder builder(state); expander.Expand(s, &builder); state->SetFlags(kCacheFlags, kCacheFlags); store_.SetArcs(state); } return state; }
private: CacheStore store_;
struct StateBuilder { State *state;
explicit StateBuilder(State *state_) : state(state_) {}
void AddArc(const Arc &arc) { state->PushArc(arc); }
void AddArc(Arc &&arc) { state->PushArc(std::move(arc)); }
void SetFinal(Weight weight = Weight::One()) { state->SetFinal(std::move(weight)); } }; };
} // namespace fst
#endif // FST_CACHE_H_
|