// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Cache implementations for ExpanderFst. // // Expander caches must expose a State type and a FindOrExpand template method: // // class ExpanderCache { // public: // class State; // // template // State* FindOrExpander(Expander& expander, StateId id) { // if (id is found in cache) return cached_state; // // // Use the provided expander to create a new cached state and cache it. // expander.Expand(id, &new_state); // insert new_state into cache; // return new_state; // } // }; // // Cache implementations must be copyable and assignable. It is up to the // implementation whether this means it will discard the contents of the cache, // copy all of the cache, share some of the cache etc. It is *REQUIRED* that the // copy be "safe", the copy and the original must be usable from concurrent // threads without accessing any internally shared state. #ifndef FST_EXPANDER_CACHE_H_ #define FST_EXPANDER_CACHE_H_ #include #include #include #include #include #include #include #include #include namespace fst { // Stateful allocators can't be used without careful handling in threaded // contexts, so arbitrary stl allocators aren't supported here. template class SimpleVectorCacheState { public: using Arc = A; using Weight = typename Arc::Weight; using StateId = typename Arc::StateId; void Reset() { final_weight_ = Weight::Zero(); niepsilons_ = 0; noepsilons_ = 0; arcs_.clear(); } Weight Final() const { return final_weight_; } size_t NumInputEpsilons() const { return niepsilons_; } size_t NumOutputEpsilons() const { return noepsilons_; } size_t NumArcs() const { return arcs_.size(); } const Arc &GetArc(size_t n) const { return arcs_[n]; } const Arc *Arcs() const { return arcs_.empty() ? nullptr : &arcs_[0]; } void SetFinal(Weight weight) { final_weight_ = weight; } void ReserveArcs(size_t n) { arcs_.reserve(n); } void AddArc(const Arc &arc) { if (arc.ilabel == 0) ++niepsilons_; if (arc.olabel == 0) ++noepsilons_; arcs_.push_back(arc); } void AddArc(Arc &&arc) { if (arc.ilabel == 0) ++niepsilons_; if (arc.olabel == 0) ++noepsilons_; arcs_.push_back(std::move(arc)); } int *MutableRefCount() const { return nullptr; } private: Weight final_weight_ = Weight::Zero(); size_t niepsilons_ = 0; // Number of input epsilons. size_t noepsilons_ = 0; // Number of output epsilons. std::vector arcs_; }; template class NoGcKeepOneExpanderCache { public: using Arc = A; using StateId = typename Arc::StateId; // Reference-counted state. class State : public SimpleVectorCacheState { public: int *MutableRefCount() { return &ref_count_; } void Reset() { SimpleVectorCacheState::Reset(); ref_count_ = 0; } private: int ref_count_ = 0; friend class NoGcKeepOneExpanderCache; }; NoGcKeepOneExpanderCache() : state_(new State) {} NoGcKeepOneExpanderCache(const NoGcKeepOneExpanderCache ©) : state_(new State(*copy.state_)) {} template State *FindOrExpand(Expander &expander, StateId state_id) { if (state_id == state_id_) return state_.get(); if (state_->ref_count_ > 0) cache_[state_id_] = std::move(state_); state_id_ = state_id; if (cache_.empty()) { state_->Reset(); expander.Expand(state_id_, state_.get()); return state_.get(); } if (auto i = cache_.find(state_id_); i != cache_.end()) { state_ = std::move(i->second); } if (state_ == nullptr) { state_ = std::make_unique(); expander.Expand(state_id_, state_.get()); } return state_.get(); } StateId state_id_ = kNoStateId; std::unique_ptr state_; std::unordered_map> cache_; }; template class HashExpanderCache { public: using Arc = A; using StateId = typename Arc::StateId; using State = SimpleVectorCacheState; HashExpanderCache(const HashExpanderCache ©) { *this = copy; } HashExpanderCache &operator=(const HashExpanderCache ©) { for (const auto &[id, state] : copy.cache_) { cache_[id] = std::make_unique(*state); } return *this; } ~HashExpanderCache() = default; template State *FindOrExpand(Expander &expander, StateId state_id) { auto [it, inserted] = cache_.emplace(state_id, nullptr); if (inserted) { it->second = std::make_unique(); expander.Expand(state_id, it->second.get()); } return it->second.get(); } private: std::unordered_map> cache_; }; template class VectorExpanderCache { public: using Arc = A; using StateId = typename Arc::StateId; using State = SimpleVectorCacheState; VectorExpanderCache() : vec_(0, nullptr) {} VectorExpanderCache(const VectorExpanderCache ©) { *this = copy; } VectorExpanderCache &operator=(const VectorExpanderCache ©) { vec_.resize(copy.vec_.size()); for (StateId i = 0; i < copy.vec_.size(); ++i) { const auto *state = copy.vec_[i]; if (state != nullptr) { states_.emplace_back(*state); vec_[i] = &states_.back(); } } return *this; } template State *FindOrExpand(Expander &expander, StateId state_id) { if (state_id >= vec_.size()) vec_.resize(state_id + 1); auto **slot = &vec_[state_id]; if (*slot == nullptr) { states_.emplace_back(); *slot = &states_.back(); expander.Expand(state_id, *slot); } return *slot; } private: std::deque states_; std::vector vec_; }; template using DefaultExpanderCache = VectorExpanderCache; } // namespace fst #endif // FST_EXPANDER_CACHE_H_