// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Cache implementations for ExpanderFst.
|
|
//
|
|
// Expander caches must expose a State type and a FindOrExpand template method:
|
|
//
|
|
// class ExpanderCache {
|
|
// public:
|
|
// class State;
|
|
//
|
|
// template <Expander>
|
|
// State* FindOrExpander(Expander& expander, StateId id) {
|
|
// if (id is found in cache) return cached_state;
|
|
//
|
|
// // Use the provided expander to create a new cached state and cache it.
|
|
// expander.Expand(id, &new_state);
|
|
// insert new_state into cache;
|
|
// return new_state;
|
|
// }
|
|
// };
|
|
//
|
|
// Cache implementations must be copyable and assignable. It is up to the
|
|
// implementation whether this means it will discard the contents of the cache,
|
|
// copy all of the cache, share some of the cache etc. It is *REQUIRED* that the
|
|
// copy be "safe", the copy and the original must be usable from concurrent
|
|
// threads without accessing any internally shared state.
|
|
|
|
#ifndef FST_EXPANDER_CACHE_H_
|
|
#define FST_EXPANDER_CACHE_H_
|
|
|
|
#include <cstddef>
|
|
#include <deque>
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <fst/cache.h>
|
|
#include <fst/fst.h>
|
|
#include <unordered_map>
|
|
#include <unordered_map>
|
|
|
|
namespace fst {
|
|
|
|
// Stateful allocators can't be used without careful handling in threaded
|
|
// contexts, so arbitrary stl allocators aren't supported here.
|
|
template <class A>
|
|
class SimpleVectorCacheState {
|
|
public:
|
|
using Arc = A;
|
|
using Weight = typename Arc::Weight;
|
|
using StateId = typename Arc::StateId;
|
|
|
|
void Reset() {
|
|
final_weight_ = Weight::Zero();
|
|
niepsilons_ = 0;
|
|
noepsilons_ = 0;
|
|
arcs_.clear();
|
|
}
|
|
|
|
Weight Final() const { return final_weight_; }
|
|
|
|
size_t NumInputEpsilons() const { return niepsilons_; }
|
|
|
|
size_t NumOutputEpsilons() const { return noepsilons_; }
|
|
|
|
size_t NumArcs() const { return arcs_.size(); }
|
|
|
|
const Arc &GetArc(size_t n) const { return arcs_[n]; }
|
|
|
|
const Arc *Arcs() const { return arcs_.empty() ? nullptr : &arcs_[0]; }
|
|
|
|
void SetFinal(Weight weight) { final_weight_ = weight; }
|
|
|
|
void ReserveArcs(size_t n) { arcs_.reserve(n); }
|
|
|
|
void AddArc(const Arc &arc) {
|
|
if (arc.ilabel == 0) ++niepsilons_;
|
|
if (arc.olabel == 0) ++noepsilons_;
|
|
arcs_.push_back(arc);
|
|
}
|
|
|
|
void AddArc(Arc &&arc) {
|
|
if (arc.ilabel == 0) ++niepsilons_;
|
|
if (arc.olabel == 0) ++noepsilons_;
|
|
arcs_.push_back(std::move(arc));
|
|
}
|
|
|
|
int *MutableRefCount() const { return nullptr; }
|
|
|
|
private:
|
|
Weight final_weight_ = Weight::Zero();
|
|
size_t niepsilons_ = 0; // Number of input epsilons.
|
|
size_t noepsilons_ = 0; // Number of output epsilons.
|
|
std::vector<Arc> arcs_;
|
|
};
|
|
|
|
template <class A>
|
|
class NoGcKeepOneExpanderCache {
|
|
public:
|
|
using Arc = A;
|
|
using StateId = typename Arc::StateId;
|
|
|
|
// Reference-counted state.
|
|
class State : public SimpleVectorCacheState<Arc> {
|
|
public:
|
|
int *MutableRefCount() { return &ref_count_; }
|
|
|
|
void Reset() {
|
|
SimpleVectorCacheState<Arc>::Reset();
|
|
ref_count_ = 0;
|
|
}
|
|
|
|
private:
|
|
int ref_count_ = 0;
|
|
|
|
friend class NoGcKeepOneExpanderCache;
|
|
};
|
|
|
|
NoGcKeepOneExpanderCache() : state_(new State) {}
|
|
|
|
NoGcKeepOneExpanderCache(const NoGcKeepOneExpanderCache ©)
|
|
: state_(new State(*copy.state_)) {}
|
|
|
|
template <class Expander>
|
|
State *FindOrExpand(Expander &expander, StateId state_id) {
|
|
if (state_id == state_id_) return state_.get();
|
|
if (state_->ref_count_ > 0) cache_[state_id_] = std::move(state_);
|
|
state_id_ = state_id;
|
|
if (cache_.empty()) {
|
|
state_->Reset();
|
|
expander.Expand(state_id_, state_.get());
|
|
return state_.get();
|
|
}
|
|
if (auto i = cache_.find(state_id_); i != cache_.end()) {
|
|
state_ = std::move(i->second);
|
|
}
|
|
if (state_ == nullptr) {
|
|
state_ = std::make_unique<State>();
|
|
expander.Expand(state_id_, state_.get());
|
|
}
|
|
return state_.get();
|
|
}
|
|
|
|
StateId state_id_ = kNoStateId;
|
|
std::unique_ptr<State> state_;
|
|
std::unordered_map<StateId, std::unique_ptr<State>> cache_;
|
|
};
|
|
|
|
template <class A>
|
|
class HashExpanderCache {
|
|
public:
|
|
using Arc = A;
|
|
using StateId = typename Arc::StateId;
|
|
|
|
using State = SimpleVectorCacheState<Arc>;
|
|
|
|
HashExpanderCache(const HashExpanderCache ©) { *this = copy; }
|
|
|
|
HashExpanderCache &operator=(const HashExpanderCache ©) {
|
|
for (const auto &[id, state] : copy.cache_) {
|
|
cache_[id] = std::make_unique<State>(*state);
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
~HashExpanderCache() = default;
|
|
|
|
template <class Expander>
|
|
State *FindOrExpand(Expander &expander, StateId state_id) {
|
|
auto [it, inserted] = cache_.emplace(state_id, nullptr);
|
|
if (inserted) {
|
|
it->second = std::make_unique<State>();
|
|
expander.Expand(state_id, it->second.get());
|
|
}
|
|
return it->second.get();
|
|
}
|
|
|
|
private:
|
|
std::unordered_map<StateId, std::unique_ptr<State>> cache_;
|
|
};
|
|
|
|
template <class A>
|
|
class VectorExpanderCache {
|
|
public:
|
|
using Arc = A;
|
|
using StateId = typename Arc::StateId;
|
|
|
|
using State = SimpleVectorCacheState<Arc>;
|
|
|
|
VectorExpanderCache() : vec_(0, nullptr) {}
|
|
|
|
VectorExpanderCache(const VectorExpanderCache ©) { *this = copy; }
|
|
|
|
VectorExpanderCache &operator=(const VectorExpanderCache ©) {
|
|
vec_.resize(copy.vec_.size());
|
|
for (StateId i = 0; i < copy.vec_.size(); ++i) {
|
|
const auto *state = copy.vec_[i];
|
|
if (state != nullptr) {
|
|
states_.emplace_back(*state);
|
|
vec_[i] = &states_.back();
|
|
}
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
template <class Expander>
|
|
State *FindOrExpand(Expander &expander, StateId state_id) {
|
|
if (state_id >= vec_.size()) vec_.resize(state_id + 1);
|
|
auto **slot = &vec_[state_id];
|
|
if (*slot == nullptr) {
|
|
states_.emplace_back();
|
|
*slot = &states_.back();
|
|
expander.Expand(state_id, *slot);
|
|
}
|
|
return *slot;
|
|
}
|
|
|
|
private:
|
|
std::deque<State> states_;
|
|
std::vector<State *> vec_;
|
|
};
|
|
|
|
template <class Expander>
|
|
using DefaultExpanderCache = VectorExpanderCache<typename Expander::Arc>;
|
|
|
|
} // namespace fst
|
|
#endif // FST_EXPANDER_CACHE_H_
|