You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1337 lines
42 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// An FST implementation that caches FST elements of a delayed computation.
#ifndef FST_CACHE_H_
#define FST_CACHE_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <list>
#include <memory>
#include <new>
#include <utility>
#include <vector>
#include <fst/flags.h>
#include <fst/log.h>
#include <fst/fst.h>
#include <fst/memory.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/util.h>
#include <fst/vector-fst.h>
#include <unordered_map>
#include <functional>
DECLARE_bool(fst_default_cache_gc);
DECLARE_int64(fst_default_cache_gc_limit);
namespace fst {
// Options for controlling caching behavior; higher level than CacheImplOptions.
struct CacheOptions {
bool gc; // Enables GC.
size_t gc_limit; // Number of bytes allowed before GC.
explicit CacheOptions(
bool gc = FST_FLAGS_fst_default_cache_gc,
size_t gc_limit = FST_FLAGS_fst_default_cache_gc_limit)
: gc(gc), gc_limit(gc_limit) {}
};
// Options for controlling caching behavior, at a lower level than
// CacheOptions; templated on the cache store and allows passing the store.
template <class CacheStore>
struct CacheImplOptions {
bool gc; // Enables GC.
size_t gc_limit; // Number of bytes allowed before GC.
CacheStore *store; // Cache store.
bool own_store; // Should CacheImpl takes ownership of the store?
explicit CacheImplOptions(
bool gc = FST_FLAGS_fst_default_cache_gc,
size_t gc_limit = FST_FLAGS_fst_default_cache_gc_limit,
CacheStore *store = nullptr)
: gc(gc), gc_limit(gc_limit), store(store), own_store(true) {}
explicit CacheImplOptions(const CacheOptions &opts)
: gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {}
};
// Cache flags.
inline constexpr uint8_t kCacheFinal = 0x01; // Final weight has been cached.
inline constexpr uint8_t kCacheArcs = 0x02; // Arcs have been cached.
inline constexpr uint8_t kCacheInit = 0x04; // Initialized by GC.
inline constexpr uint8_t kCacheRecent = 0x08; // Visited since GC.
inline constexpr uint8_t kCacheFlags =
kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent;
// Cache state, with arcs stored in a per-state std::vector.
template <class A, class M = PoolAllocator<A>>
class CacheState {
public:
using Arc = A;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using ArcAllocator = M;
using StateAllocator = typename std::allocator_traits<
ArcAllocator>::template rebind_alloc<CacheState<A, M>>;
// Provides STL allocator for arcs.
explicit CacheState(const ArcAllocator &alloc)
: final_weight_(Weight::Zero()),
niepsilons_(0),
noepsilons_(0),
arcs_(alloc),
flags_(0),
ref_count_(0) {}
CacheState(const CacheState<A> &state, const ArcAllocator &alloc)
: final_weight_(state.Final()),
niepsilons_(state.NumInputEpsilons()),
noepsilons_(state.NumOutputEpsilons()),
arcs_(state.arcs_.begin(), state.arcs_.end(), alloc),
flags_(state.Flags()),
ref_count_(0) {}
void Reset() {
final_weight_ = Weight::Zero();
niepsilons_ = 0;
noepsilons_ = 0;
ref_count_ = 0;
flags_ = 0;
arcs_.clear();
}
Weight Final() const { return final_weight_; }
size_t NumInputEpsilons() const { return niepsilons_; }
size_t NumOutputEpsilons() const { return noepsilons_; }
size_t NumArcs() const { return arcs_.size(); }
const Arc &GetArc(size_t n) const { return arcs_[n]; }
// Used by the ArcIterator<Fst<Arc>> efficient implementation.
const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
// Accesses flags; used by the caller.
uint8_t Flags() const { return flags_; }
// Accesses ref count; used by the caller.
int RefCount() const { return ref_count_; }
void SetFinal(Weight weight = Weight::One()) {
final_weight_ = std::move(weight);
}
void ReserveArcs(size_t n) { arcs_.reserve(n); }
// Adds one arc at a time with all needed book-keeping; use PushArc and
// SetArcs for a more efficient alternative.
void AddArc(const Arc &arc) {
IncrementNumEpsilons(arc);
arcs_.push_back(arc);
}
void AddArc(Arc &&arc) {
IncrementNumEpsilons(arc);
arcs_.push_back(std::move(arc));
}
// Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
void PushArc(const Arc &arc) { arcs_.push_back(arc); }
void PushArc(Arc &&arc) { arcs_.push_back(std::move(arc)); }
// Adds one arc at a time with delayed book-keeping; finalize with SetArcs().
template <class... T>
void EmplaceArc(T &&...ctor_args) {
arcs_.emplace_back(std::forward<T>(ctor_args)...);
}
// Finalizes arcs book-keeping; call only once.
void SetArcs() {
for (const auto &arc : arcs_) {
IncrementNumEpsilons(arc);
}
}
// Modifies nth arc.
void SetArc(const Arc &arc, size_t n) {
if (arcs_[n].ilabel == 0) --niepsilons_;
if (arcs_[n].olabel == 0) --noepsilons_;
IncrementNumEpsilons(arc);
arcs_[n] = arc;
}
// Deletes all arcs.
void DeleteArcs() {
niepsilons_ = 0;
noepsilons_ = 0;
arcs_.clear();
}
void DeleteArcs(size_t n) {
for (size_t i = 0; i < n; ++i) {
if (arcs_.back().ilabel == 0) --niepsilons_;
if (arcs_.back().olabel == 0) --noepsilons_;
arcs_.pop_back();
}
}
// Sets status flags; used by the caller.
void SetFlags(uint8_t flags, uint8_t mask) const {
flags_ &= ~mask;
flags_ |= flags;
}
// Mutates reference counts; used by the caller.
int IncrRefCount() const { return ++ref_count_; }
int DecrRefCount() const { return --ref_count_; }
// Used by the ArcIterator<Fst<Arc>> efficient implementation.
int *MutableRefCount() const { return &ref_count_; }
// Used for state class allocation.
void *operator new(size_t size, StateAllocator *alloc) {
return alloc->allocate(1);
}
// For state destruction and memory freeing.
static void Destroy(CacheState<Arc> *state, StateAllocator *alloc) {
if (state) {
state->~CacheState<Arc>();
alloc->deallocate(state, 1);
}
}
private:
// Update the number of epsilons as a result of having added an arc.
void IncrementNumEpsilons(const Arc &arc) {
if (arc.ilabel == 0) ++niepsilons_;
if (arc.olabel == 0) ++noepsilons_;
}
Weight final_weight_; // Final weight.
size_t niepsilons_; // # of input epsilons.
size_t noepsilons_; // # of output epsilons.
std::vector<Arc, ArcAllocator> arcs_; // Arcs representation.
mutable uint8_t flags_;
mutable int ref_count_; // If 0, available for GC.
};
// Cache store, allocating and storing states, providing a mapping from state
// IDs to cached states, and an iterator over these states. The state template
// argument must implement the CacheState interface. The state for a StateId s
// is constructed when requested by GetMutableState(s) if it is not yet stored.
// Initially, a state has a reference count of zero, but the user may increment
// or decrement this to control the time of destruction. In particular, a state
// is destroyed when:
//
// 1. This instance is destroyed, or
// 2. Clear() or Delete() is called, or
// 3. Possibly (implementation-dependently) when:
// - Garbage collection is enabled (as defined by opts.gc),
// - The cache store size exceeds the limits (as defined by opts.gc_limits),
// - The state's reference count is zero, and
// - The state is not the most recently requested state.
//
// template <class S>
// class CacheStore {
// public:
// using State = S;
// using Arc = typename State::Arc;
// using StateId = typename Arc::StateId;
//
// // Required constructors/assignment operators.
// explicit CacheStore(const CacheOptions &opts);
//
// // Returns nullptr if state is not stored.
// const State *GetState(StateId s);
//
// // Creates state if state is not stored.
// State *GetMutableState(StateId s);
//
// // Similar to State::AddArc() but updates cache store book-keeping.
// void AddArc(State *state, const Arc &arc);
//
// // Similar to State::SetArcs() but updates cache store book-keeping; call
// // only once.
// void SetArcs(State *state);
//
// // Similar to State::DeleteArcs() but updates cache store book-keeping.
//
// void DeleteArcs(State *state);
//
// void DeleteArcs(State *state, size_t n);
//
// // Deletes all cached states.
// void Clear();
//
// // Number of cached states.
// StateId CountStates();
//
// // Iterates over cached states (in an arbitrary order); only needed if
// // opts.gc is true.
// bool Done() const; // End of iteration.
// StateId Value() const; // Current state.
// void Next(); // Advances to next state (when !Done).
// void Reset(); // Returns to initial condition.
// void Delete(); // Deletes current state and advances to next.
// };
// Container cache stores.
// This class uses a vector of pointers to states to store cached states.
template <class S>
class VectorCacheStore {
public:
using State = S;
using Arc = typename State::Arc;
using StateId = typename Arc::StateId;
using StateList = std::list<StateId, PoolAllocator<StateId>>;
// Required constructors/assignment operators.
explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) {
Clear();
Reset();
}
VectorCacheStore(const VectorCacheStore<S> &store)
: cache_gc_(store.cache_gc_) {
CopyStates(store);
Reset();
}
~VectorCacheStore() { Clear(); }
VectorCacheStore &operator=(const VectorCacheStore &store) {
if (this != &store) {
CopyStates(store);
Reset();
}
return *this;
}
bool InBounds(StateId s) const {
return s < static_cast<StateId>(state_vec_.size());
}
// Returns nullptr if state is not stored.
const State *GetState(StateId s) const {
return InBounds(s) ? state_vec_[s] : nullptr;
}
// Creates state if state is not stored.
State *GetMutableState(StateId s) {
State *state = nullptr;
if (InBounds(s)) {
state = state_vec_[s];
} else {
state_vec_.resize(s + 1, nullptr);
}
if (!state) {
state = new (&state_alloc_) State(arc_alloc_);
state_vec_[s] = state;
if (cache_gc_) state_list_.push_back(s);
}
return state;
}
// Similar to State::AddArc() but updates cache store book-keeping
void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
// Similar to State::SetArcs() but updates cache store book-keeping; call
// only once.
void SetArcs(State *state) { state->SetArcs(); }
// Deletes all arcs.
void DeleteArcs(State *state) { state->DeleteArcs(); }
// Deletes some arcs.
void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
// Deletes all cached states.
void Clear() {
for (State *s : state_vec_) {
State::Destroy(s, &state_alloc_);
}
state_vec_.clear();
state_list_.clear();
}
StateId CountStates() const {
return std::count_if(state_vec_.begin(), state_vec_.end(),
[](const State *s) { return s != nullptr; });
}
// Iterates over cached states (in an arbitrary order); only works if GC is
// enabled (o.w. avoiding state_list_ overhead).
bool Done() const { return iter_ == state_list_.end(); }
StateId Value() const { return *iter_; }
void Next() { ++iter_; }
void Reset() { iter_ = state_list_.begin(); }
// Deletes current state and advances to next.
void Delete() {
State::Destroy(state_vec_[*iter_], &state_alloc_);
state_vec_[*iter_] = nullptr;
state_list_.erase(iter_++);
}
private:
void CopyStates(const VectorCacheStore<State> &store) {
Clear();
state_vec_.reserve(store.state_vec_.size());
for (size_t s = 0; s < store.state_vec_.size(); ++s) {
State *state = nullptr;
const auto *store_state = store.state_vec_[s];
if (store_state) {
state = new (&state_alloc_) State(*store_state, arc_alloc_);
if (cache_gc_) state_list_.push_back(s);
}
state_vec_.push_back(state);
}
}
bool cache_gc_; // Supports iteration when true.
std::vector<State *> state_vec_; // Vector of states (or null).
StateList state_list_; // List of states.
typename StateList::iterator iter_; // State list iterator.
typename State::StateAllocator state_alloc_; // For state allocation.
typename State::ArcAllocator arc_alloc_; // For arc allocation.
};
// This class uses a hash map from state IDs to pointers to cached states.
template <class S>
class HashCacheStore {
public:
using State = S;
using Arc = typename State::Arc;
using StateId = typename Arc::StateId;
using StateMap =
std::unordered_map<StateId, State *, std::hash<StateId>,
std::equal_to<StateId>,
PoolAllocator<std::pair<const StateId, State *>>>;
// Required constructors/assignment operators.
explicit HashCacheStore(const CacheOptions &opts) {
Clear();
Reset();
}
HashCacheStore(const HashCacheStore<S> &store) {
CopyStates(store);
Reset();
}
~HashCacheStore() { Clear(); }
HashCacheStore &operator=(const HashCacheStore &store) {
if (this != &store) {
CopyStates(store);
Reset();
}
return *this;
}
// Returns nullptr if state is not stored.
const State *GetState(StateId s) const {
const auto it = state_map_.find(s);
return it != state_map_.end() ? it->second : nullptr;
}
// Creates state if state is not stored.
State *GetMutableState(StateId s) {
auto *&state = state_map_[s];
if (!state) state = new (&state_alloc_) State(arc_alloc_);
return state;
}
// Similar to State::AddArc() but updates cache store book-keeping.
void AddArc(State *state, const Arc &arc) { state->AddArc(arc); }
// Similar to State::SetArcs() but updates internal cache size; call only
// once.
void SetArcs(State *state) { state->SetArcs(); }
// Deletes all arcs.
void DeleteArcs(State *state) { state->DeleteArcs(); }
// Deletes some arcs.
void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); }
// Deletes all cached states.
void Clear() {
for (auto &[unused_state_id, state_ptr] : state_map_) {
State::Destroy(state_ptr, &state_alloc_);
}
state_map_.clear();
}
StateId CountStates() const { return state_map_.size(); }
// Iterates over cached states (in an arbitrary order).
bool Done() const { return iter_ == state_map_.end(); }
StateId Value() const { return iter_->first; }
void Next() { ++iter_; }
void Reset() { iter_ = state_map_.begin(); }
// Deletes current state and advances to next.
void Delete() {
State::Destroy(iter_->second, &state_alloc_);
state_map_.erase(iter_++);
}
private:
void CopyStates(const HashCacheStore<State> &store) {
Clear();
for (auto &[state_id, state_ptr] : store.state_map_) {
state_map_[state_id] = new (&state_alloc_) State(*state_ptr, arc_alloc_);
}
}
StateMap state_map_; // Map from state ID to state.
typename StateMap::iterator iter_; // State map iterator.
typename State::StateAllocator state_alloc_; // For state allocation.
typename State::ArcAllocator arc_alloc_; // For arc allocation.
};
// Garbage-colllection cache stores.
// This class implements a simple garbage collection scheme when
// 'opts.gc_limit = 0'. In particular, the first cached state is reused for each
// new state so long as the reference count is zero on the to-be-reused state.
// Otherwise, the full underlying store is used. The caller can increment the
// reference count to inhibit the GC of in-use states (e.g., in an ArcIterator).
//
// The typical use case for this optimization is when a single pass over a
// cached
// FST is performed with only one-state expanded at a time.
template <class CacheStore>
class FirstCacheStore {
public:
using State = typename CacheStore::State;
using Arc = typename State::Arc;
using StateId = typename Arc::StateId;
// Required constructors/assignment operators.
explicit FirstCacheStore(const CacheOptions &opts)
: store_(opts),
cache_gc_(opts.gc_limit == 0), // opts.gc ignored historically.
cache_first_state_id_(kNoStateId),
cache_first_state_(nullptr) {}
FirstCacheStore(const FirstCacheStore<CacheStore> &store)
: store_(store.store_),
cache_gc_(store.cache_gc_),
cache_first_state_id_(store.cache_first_state_id_),
cache_first_state_(store.cache_first_state_id_ != kNoStateId
? store_.GetMutableState(0)
: nullptr) {}
FirstCacheStore<CacheStore> &operator=(
const FirstCacheStore<CacheStore> &store) {
if (this != &store) {
store_ = store.store_;
cache_gc_ = store.cache_gc_;
cache_first_state_id_ = store.cache_first_state_id_;
cache_first_state_ = store.cache_first_state_id_ != kNoStateId
? store_.GetMutableState(0)
: nullptr;
}
return *this;
}
// Returns nullptr if state is not stored.
const State *GetState(StateId s) const {
// store_ state 0 may hold first cached state; the rest are shifted by 1.
return s == cache_first_state_id_ ? cache_first_state_
: store_.GetState(s + 1);
}
// Creates state if state is not stored.
State *GetMutableState(StateId s) {
// store_ state 0 used to hold first cached state; the rest are shifted by
// 1.
if (cache_first_state_id_ == s) {
return cache_first_state_; // Request for first cached state.
}
if (cache_gc_) {
if (cache_first_state_id_ == kNoStateId) {
cache_first_state_id_ = s; // Sets first cached state.
cache_first_state_ = store_.GetMutableState(0);
cache_first_state_->SetFlags(kCacheInit, kCacheInit);
cache_first_state_->ReserveArcs(2 * kAllocSize);
return cache_first_state_;
} else if (cache_first_state_->RefCount() == 0) {
cache_first_state_id_ = s; // Updates first cached state.
cache_first_state_->Reset();
cache_first_state_->SetFlags(kCacheInit, kCacheInit);
return cache_first_state_;
} else { // Keeps first cached state.
cache_first_state_->SetFlags(0, kCacheInit); // Clears initialized bit.
cache_gc_ = false; // Disables GC.
}
}
auto *state = store_.GetMutableState(s + 1);
return state;
}
// Similar to State::AddArc() but updates cache store book-keeping.
void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); }
// Similar to State::SetArcs() but updates internal cache size; call only
// once.
void SetArcs(State *state) { store_.SetArcs(state); }
// Deletes all arcs
void DeleteArcs(State *state) { store_.DeleteArcs(state); }
// Deletes some arcs
void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); }
// Deletes all cached states
void Clear() {
store_.Clear();
cache_first_state_id_ = kNoStateId;
cache_first_state_ = nullptr;
}
StateId CountStates() const { return store_.CountStates(); }
// Iterates over cached states (in an arbitrary order). Only needed if GC is
// enabled.
bool Done() const { return store_.Done(); }
StateId Value() const {
// store_ state 0 may hold first cached state; rest shifted + 1.
const auto s = store_.Value();
return s ? s - 1 : cache_first_state_id_;
}
void Next() { store_.Next(); }
void Reset() { store_.Reset(); }
// Deletes current state and advances to next.
void Delete() {
if (Value() == cache_first_state_id_) {
cache_first_state_id_ = kNoStateId;
cache_first_state_ = nullptr;
}
store_.Delete();
}
private:
CacheStore store_; // Underlying store.
bool cache_gc_; // GC enabled.
StateId cache_first_state_id_; // First cached state ID.
State *cache_first_state_; // First cached state.
};
// This class implements mark-sweep garbage collection on an underlying cache
// store. If GC is enabled, garbage collection of states is performed in a
// rough approximation of LRU order once when 'gc_limit' bytes is reached. The
// caller can increment the reference count to inhibit the GC of in-use state
// (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows
// the caller to trade-off time vs. space.
template <class CacheStore>
class GCCacheStore {
public:
using State = typename CacheStore::State;
using Arc = typename State::Arc;
using StateId = typename Arc::StateId;
// Required constructors/assignment operators.
explicit GCCacheStore(const CacheOptions &opts)
: store_(opts),
cache_gc_request_(opts.gc),
cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit
: kMinCacheLimit),
cache_gc_(false),
cache_size_(0) {}
// Returns 0 if state is not stored.
const State *GetState(StateId s) const { return store_.GetState(s); }
// Creates state if state is not stored
State *GetMutableState(StateId s) {
auto *state = store_.GetMutableState(s);
if (cache_gc_request_ && !(state->Flags() & kCacheInit)) {
state->SetFlags(kCacheInit, kCacheInit);
cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc);
// GC is enabled once an uninited state (from underlying store) is seen.
cache_gc_ = true;
if (cache_size_ > cache_limit_) GC(state, false);
}
return state;
}
// Similar to State::AddArc() but updates cache store book-keeping.
void AddArc(State *state, const Arc &arc) {
store_.AddArc(state, arc);
if (cache_gc_ && (state->Flags() & kCacheInit)) {
cache_size_ += sizeof(Arc);
if (cache_size_ > cache_limit_) GC(state, false);
}
}
// Similar to State::SetArcs() but updates internal cache size; call only
// once.
void SetArcs(State *state) {
store_.SetArcs(state);
if (cache_gc_ && (state->Flags() & kCacheInit)) {
cache_size_ += state->NumArcs() * sizeof(Arc);
if (cache_size_ > cache_limit_) GC(state, false);
}
}
// Deletes all arcs.
void DeleteArcs(State *state) {
if (cache_gc_ && (state->Flags() & kCacheInit)) {
cache_size_ -= state->NumArcs() * sizeof(Arc);
}
store_.DeleteArcs(state);
}
// Deletes some arcs.
void DeleteArcs(State *state, size_t n) {
if (cache_gc_ && (state->Flags() & kCacheInit)) {
cache_size_ -= n * sizeof(Arc);
}
store_.DeleteArcs(state, n);
}
// Deletes all cached states.
void Clear() {
store_.Clear();
cache_size_ = 0;
}
StateId CountStates() const { return store_.CountStates(); }
// Iterates over cached states (in an arbitrary order); only needed if GC is
// enabled.
bool Done() const { return store_.Done(); }
StateId Value() const { return store_.Value(); }
void Next() { store_.Next(); }
void Reset() { store_.Reset(); }
// Deletes current state and advances to next.
void Delete() {
if (cache_gc_) {
const auto *state = store_.GetState(Value());
if (state->Flags() & kCacheInit) {
cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc);
}
}
store_.Delete();
}
// Removes from the cache store (not referenced-counted and not the current)
// states that have not been accessed since the last GC until at most
// cache_fraction * cache_limit_ bytes are cached. If that fails to free
// enough, attempts to uncaching recently visited states as well. If still
// unable to free enough memory, then widens cache_limit_.
void GC(const State *current, bool free_recent, float cache_fraction = 0.666);
// Returns the current cache size in bytes or 0 if GC is disabled.
size_t CacheSize() const { return cache_size_; }
// Returns the cache limit in bytes.
size_t CacheLimit() const { return cache_limit_; }
private:
static constexpr size_t kMinCacheLimit = 8096; // Minimum cache limit.
CacheStore store_; // Underlying store.
bool cache_gc_request_; // GC requested but possibly not yet enabled.
size_t cache_limit_; // Number of bytes allowed before GC.
bool cache_gc_; // GC enabled
size_t cache_size_; // Number of bytes cached.
};
template <class CacheStore>
void GCCacheStore<CacheStore>::GC(const State *current, bool free_recent,
float cache_fraction) {
if (!cache_gc_) return;
VLOG(2) << "GCCacheStore: Enter GC: object = "
<< "(" << this << "), free recently cached = " << free_recent
<< ", cache size = " << cache_size_
<< ", cache frac = " << cache_fraction
<< ", cache limit = " << cache_limit_ << "\n";
size_t cache_target = cache_fraction * cache_limit_;
store_.Reset();
while (!store_.Done()) {
auto *state = store_.GetMutableState(store_.Value());
if (cache_size_ > cache_target && state->RefCount() == 0 &&
(free_recent || !(state->Flags() & kCacheRecent)) && state != current) {
if (state->Flags() & kCacheInit) {
size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc);
if (size < cache_size_) {
cache_size_ -= size;
}
}
store_.Delete();
} else {
state->SetFlags(0, kCacheRecent);
store_.Next();
}
}
if (!free_recent && cache_size_ > cache_target) { // Recurses on recent.
GC(current, true, cache_fraction);
} else if (cache_target > 0) { // Widens cache limit.
while (cache_size_ > cache_target) {
cache_limit_ *= 2;
cache_target *= 2;
}
} else if (cache_size_ > 0) {
FSTERROR() << "GCCacheStore:GC: Unable to free all cached states";
}
VLOG(2) << "GCCacheStore: Exit GC: object = "
<< "(" << this << "), free recently cached = " << free_recent
<< ", cache size = " << cache_size_
<< ", cache frac = " << cache_fraction
<< ", cache limit = " << cache_limit_ << "\n";
}
// This class is the default cache state and store used by CacheBaseImpl.
// It uses VectorCacheStore for storage decorated by FirstCacheStore
// and GCCacheStore to do (optional) garbage collection.
template <class Arc>
class DefaultCacheStore
: public GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>> {
public:
explicit DefaultCacheStore(const CacheOptions &opts)
: GCCacheStore<FirstCacheStore<VectorCacheStore<CacheState<Arc>>>>(opts) {
}
};
namespace internal {
// This class is used to cache FST elements stored in states of type State
// (see CacheState) with the flags used to indicate what has been cached. Use
// HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(),
// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you
// must set the final weight even if the state is non-final to mark it as
// cached. The state storage method and any garbage collection policy are
// determined by the cache store. If the store is passed in with the options,
// CacheBaseImpl takes ownership.
template <class State,
class CacheStore = DefaultCacheStore<typename State::Arc>>
class CacheBaseImpl : public FstImpl<typename State::Arc> {
public:
using Arc = typename State::Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = CacheStore;
using FstImpl<Arc>::Type;
using FstImpl<Arc>::Properties;
explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions())
: has_start_(false),
cache_start_(kNoStateId),
nknown_states_(0),
min_unexpanded_state_id_(0),
max_expanded_state_id_(-1),
cache_gc_(opts.gc),
cache_limit_(opts.gc_limit),
cache_store_(new CacheStore(opts)),
new_cache_store_(true),
own_cache_store_(true) {}
explicit CacheBaseImpl(const CacheImplOptions<CacheStore> &opts)
: has_start_(false),
cache_start_(kNoStateId),
nknown_states_(0),
min_unexpanded_state_id_(0),
max_expanded_state_id_(-1),
cache_gc_(opts.gc),
cache_limit_(opts.gc_limit),
cache_store_(
opts.store ? opts.store
: new CacheStore(CacheOptions(opts.gc, opts.gc_limit))),
new_cache_store_(!opts.store),
own_cache_store_(opts.store ? opts.own_store : true) {}
// Preserve gc parameters. If preserve_cache is true, also preserves
// cache data.
CacheBaseImpl(const CacheBaseImpl<State, CacheStore> &impl,
bool preserve_cache = false)
: FstImpl<Arc>(),
has_start_(false),
cache_start_(kNoStateId),
nknown_states_(0),
min_unexpanded_state_id_(0),
max_expanded_state_id_(-1),
cache_gc_(impl.cache_gc_),
cache_limit_(impl.cache_limit_),
cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))),
new_cache_store_(impl.new_cache_store_ || !preserve_cache),
own_cache_store_(true) {
if (preserve_cache) {
*cache_store_ = *impl.cache_store_;
has_start_ = impl.has_start_;
cache_start_ = impl.cache_start_;
nknown_states_ = impl.nknown_states_;
expanded_states_ = impl.expanded_states_;
min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
max_expanded_state_id_ = impl.max_expanded_state_id_;
}
}
~CacheBaseImpl() override {
if (own_cache_store_) delete cache_store_;
}
void SetStart(StateId s) {
cache_start_ = s;
has_start_ = true;
if (s >= nknown_states_) nknown_states_ = s + 1;
}
void SetFinal(StateId s, Weight weight = Weight::One()) {
auto *state = cache_store_->GetMutableState(s);
state->SetFinal(std::move(weight));
static constexpr auto flags = kCacheFinal | kCacheRecent;
state->SetFlags(flags, flags);
}
// Adds a single arc to a state but delays cache book-keeping. SetArcs must
// be called when all PushArc and EmplaceArc calls at a state are complete.
// Do not mix with calls to AddArc.
void PushArc(StateId s, const Arc &arc) {
auto *state = cache_store_->GetMutableState(s);
state->PushArc(arc);
}
void PushArc(StateId s, Arc &&arc) {
auto *state = cache_store_->GetMutableState(s);
state->PushArc(std::move(arc));
}
// Adds a single arc to a state but delays cache book-keeping. SetArcs must
// be called when all PushArc and EmplaceArc calls at a state are complete.
// Do not mix with calls to AddArc.
template <class... T>
void EmplaceArc(StateId s, T &&...ctor_args) {
auto *state = cache_store_->GetMutableState(s);
state->EmplaceArc(std::forward<T>(ctor_args)...);
}
// Marks arcs of a state as cached and does cache book-keeping after all
// calls to PushArc have been completed. Do not mix with calls to AddArc.
void SetArcs(StateId s) {
auto *state = cache_store_->GetMutableState(s);
cache_store_->SetArcs(state);
const auto narcs = state->NumArcs();
for (size_t a = 0; a < narcs; ++a) {
const auto &arc = state->GetArc(a);
if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1;
}
SetExpandedState(s);
static constexpr auto flags = kCacheArcs | kCacheRecent;
state->SetFlags(flags, flags);
}
void ReserveArcs(StateId s, size_t n) {
auto *state = cache_store_->GetMutableState(s);
state->ReserveArcs(n);
}
void DeleteArcs(StateId s) {
auto *state = cache_store_->GetMutableState(s);
cache_store_->DeleteArcs(state);
}
void DeleteArcs(StateId s, size_t n) {
auto *state = cache_store_->GetMutableState(s);
cache_store_->DeleteArcs(state, n);
}
void Clear() {
nknown_states_ = 0;
min_unexpanded_state_id_ = 0;
max_expanded_state_id_ = -1;
has_start_ = false;
cache_start_ = kNoStateId;
cache_store_->Clear();
}
// Is the start state cached?
bool HasStart() const {
if (!has_start_ && Properties(kError)) has_start_ = true;
return has_start_;
}
// Is the final weight of the state cached?
bool HasFinal(StateId s) const {
const auto *state = cache_store_->GetState(s);
if (state && state->Flags() & kCacheFinal) {
state->SetFlags(kCacheRecent, kCacheRecent);
return true;
} else {
return false;
}
}
// Are arcs of the state cached?
bool HasArcs(StateId s) const {
const auto *state = cache_store_->GetState(s);
if (state && state->Flags() & kCacheArcs) {
state->SetFlags(kCacheRecent, kCacheRecent);
return true;
} else {
return false;
}
}
StateId Start() const { return cache_start_; }
Weight Final(StateId s) const {
const auto *state = cache_store_->GetState(s);
return state->Final();
}
size_t NumArcs(StateId s) const {
const auto *state = cache_store_->GetState(s);
return state->NumArcs();
}
size_t NumInputEpsilons(StateId s) const {
const auto *state = cache_store_->GetState(s);
return state->NumInputEpsilons();
}
size_t NumOutputEpsilons(StateId s) const {
const auto *state = cache_store_->GetState(s);
return state->NumOutputEpsilons();
}
// Provides information needed for generic arc iterator.
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
const auto *state = cache_store_->GetState(s);
data->base = nullptr;
data->narcs = state->NumArcs();
data->arcs = state->Arcs();
data->ref_count = state->MutableRefCount();
state->IncrRefCount();
}
// Number of known states.
StateId NumKnownStates() const { return nknown_states_; }
// Updates number of known states, taking into account the passed state ID.
void UpdateNumKnownStates(StateId s) {
if (s >= nknown_states_) nknown_states_ = s + 1;
}
// Finds the mininum never-expanded state ID.
StateId MinUnexpandedState() const {
while (min_unexpanded_state_id_ <= max_expanded_state_id_ &&
ExpandedState(min_unexpanded_state_id_)) {
++min_unexpanded_state_id_;
}
return min_unexpanded_state_id_;
}
// Returns maximum ever-expanded state ID.
StateId MaxExpandedState() const { return max_expanded_state_id_; }
void SetExpandedState(StateId s) {
if (s > max_expanded_state_id_) max_expanded_state_id_ = s;
if (s < min_unexpanded_state_id_) return;
if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_;
if (cache_gc_ || cache_limit_ == 0) {
if (expanded_states_.size() <= static_cast<size_t>(s))
expanded_states_.resize(s + 1, false);
expanded_states_[s] = true;
}
}
bool ExpandedState(StateId s) const {
if (cache_gc_ || cache_limit_ == 0) {
return expanded_states_[s];
} else if (new_cache_store_) {
return cache_store_->GetState(s) != nullptr;
} else {
// If the cache was not created by this class, then the cached state needs
// to be inspected to update nknown_states_.
return false;
}
}
const CacheStore *GetCacheStore() const { return cache_store_; }
CacheStore *GetCacheStore() { return cache_store_; }
// Caching on/off switch, limit and size accessors.
bool GetCacheGc() const { return cache_gc_; }
size_t GetCacheLimit() const { return cache_limit_; }
private:
mutable bool has_start_; // Is the start state cached?
StateId cache_start_; // ID of start state.
StateId nknown_states_; // Number of known states.
std::vector<bool> expanded_states_; // States that have been expanded.
mutable StateId min_unexpanded_state_id_; // Minimum never-expanded state ID
mutable StateId max_expanded_state_id_; // Maximum ever-expanded state ID
bool cache_gc_; // GC enabled.
size_t cache_limit_; // Number of bytes allowed before GC.
CacheStore *cache_store_; // The store of cached states.
bool new_cache_store_; // Was the store was created by class?
bool own_cache_store_; // Is the store owned by class?
CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete;
};
// A CacheBaseImpl with the default cache state type.
template <class Arc>
class CacheImpl : public CacheBaseImpl<CacheState<Arc>> {
public:
using State = CacheState<Arc>;
CacheImpl() = default;
explicit CacheImpl(const CacheOptions &opts)
: CacheBaseImpl<CacheState<Arc>>(opts) {}
CacheImpl(const CacheImpl<Arc> &impl, bool preserve_cache = false)
: CacheBaseImpl<State>(impl, preserve_cache) {}
private:
CacheImpl &operator=(const CacheImpl &impl) = delete;
};
} // namespace internal
// Use this to make a state iterator for a CacheBaseImpl-derived FST, which must
// have Arc and Store types defined. Note this iterator only returns those
// states reachable from the initial state, so consider implementing a
// class-specific one.
//
// This class may be derived from.
template <class FST>
class CacheStateIterator : public StateIteratorBase<typename FST::Arc> {
public:
using Arc = typename FST::Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = typename FST::Store;
using State = typename Store::State;
using Impl = internal::CacheBaseImpl<State, Store>;
CacheStateIterator(const FST &fst, Impl *impl)
: fst_(fst), impl_(impl), s_(0) {
fst_.Start(); // Forces start state.
}
bool Done() const final {
if (s_ < impl_->NumKnownStates()) return false;
for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates();
u = impl_->MinUnexpandedState()) {
// Forces state expansion.
ArcIterator<FST> aiter(fst_, u);
aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
for (; !aiter.Done(); aiter.Next()) {
impl_->UpdateNumKnownStates(aiter.Value().nextstate);
}
impl_->SetExpandedState(u);
if (s_ < impl_->NumKnownStates()) return false;
}
return true;
}
StateId Value() const final { return s_; }
void Next() final { ++s_; }
void Reset() final { s_ = 0; }
private:
const FST &fst_;
Impl *impl_;
StateId s_;
};
// Used to make an arc iterator for a CacheBaseImpl-derived FST, which must
// have Arc and State types defined.
template <class FST>
class CacheArcIterator {
public:
using Arc = typename FST::Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = typename FST::Store;
using State = typename Store::State;
using Impl = internal::CacheBaseImpl<State, Store>;
CacheArcIterator(Impl *impl, StateId s) : i_(0) {
state_ = impl->GetCacheStore()->GetMutableState(s);
state_->IncrRefCount();
}
~CacheArcIterator() { state_->DecrRefCount(); }
bool Done() const { return i_ >= state_->NumArcs(); }
const Arc &Value() const { return state_->GetArc(i_); }
void Next() { ++i_; }
size_t Position() const { return i_; }
void Reset() { i_ = 0; }
void Seek(size_t a) { i_ = a; }
constexpr uint8_t Flags() const { return kArcValueFlags; }
void SetFlags(uint8_t flags, uint8_t mask) {}
private:
const State *state_;
size_t i_;
CacheArcIterator(const CacheArcIterator &) = delete;
CacheArcIterator &operator=(const CacheArcIterator &) = delete;
};
// Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST,
// which must have types Arc and Store defined.
template <class FST>
class CacheMutableArcIterator
: public MutableArcIteratorBase<typename FST::Arc> {
public:
using Arc = typename FST::Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = typename FST::Store;
using State = typename Store::State;
using Impl = internal::CacheBaseImpl<State, Store>;
// User must call MutateCheck() in the constructor.
CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
state_ = impl_->GetCacheStore()->GetMutableState(s_);
state_->IncrRefCount();
}
~CacheMutableArcIterator() override { state_->DecrRefCount(); }
bool Done() const final { return i_ >= state_->NumArcs(); }
const Arc &Value() const final { return state_->GetArc(i_); }
void Next() final { ++i_; }
size_t Position() const final { return i_; }
void Reset() final { i_ = 0; }
void Seek(size_t a) final { i_ = a; }
void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); }
uint8_t Flags() const final { return kArcValueFlags; }
void SetFlags(uint8_t, uint8_t) final {}
private:
size_t i_;
StateId s_;
Impl *impl_;
State *state_;
CacheMutableArcIterator(const CacheMutableArcIterator &) = delete;
CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete;
};
// Wrap existing CacheStore implementation to use with ExpanderFst.
template <class CacheStore>
class ExpanderCacheStore {
public:
using State = typename CacheStore::State;
using Arc = typename CacheStore::Arc;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions())
: store_(opts) {}
template <class Expander>
State *FindOrExpand(Expander &expander, StateId s) {
auto *state = store_.GetMutableState(s);
if (state->Flags()) {
state->SetFlags(kCacheRecent, kCacheRecent);
} else {
StateBuilder builder(state);
expander.Expand(s, &builder);
state->SetFlags(kCacheFlags, kCacheFlags);
store_.SetArcs(state);
}
return state;
}
private:
CacheStore store_;
struct StateBuilder {
State *state;
explicit StateBuilder(State *state_) : state(state_) {}
void AddArc(const Arc &arc) { state->PushArc(arc); }
void AddArc(Arc &&arc) { state->PushArc(std::move(arc)); }
void SetFinal(Weight weight = Weight::One()) {
state->SetFinal(std::move(weight));
}
};
};
} // namespace fst
#endif // FST_CACHE_H_