|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Classes to accumulate arc weights. Useful for weight lookahead.
#ifndef FST_ACCUMULATOR_H_
#define FST_ACCUMULATOR_H_
#include <sys/types.h>
#include <algorithm>
#include <cstddef>
#include <functional>
#include <memory>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arcfilter.h>
#include <fst/arcsort.h>
#include <fst/dfs-visit.h>
#include <fst/expanded-fst.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/replace.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <unordered_map>
namespace fst {
// This class accumulates arc weights using the semiring Plus().
// Sum(w, aiter, begin, end) has time complexity O(begin - end).
template <class A> class DefaultAccumulator { public: using Arc = A; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
DefaultAccumulator() = default;
DefaultAccumulator(const DefaultAccumulator &acc, bool safe = false) {}
void Init(const Fst<Arc> &fst, bool copy = false) {}
void SetState(StateId state) {}
Weight Sum(Weight w, Weight v) { return Plus(w, v); }
template <class ArcIter> Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { Adder<Weight> adder(w); // maintains cumulative sum accurately
aiter->Seek(begin); for (auto pos = begin; pos < end; aiter->Next(), ++pos) adder.Add(aiter->Value().weight); return adder.Sum(); }
constexpr bool Error() const { return false; }
private: DefaultAccumulator &operator=(const DefaultAccumulator &) = delete; };
// This class accumulates arc weights using the log semiring Plus() assuming an
// arc weight has a WeightConvert specialization to and from log64 weights.
// Sum(w, aiter, begin, end) has time complexity O(begin - end).
template <class A> class LogAccumulator { public: using Arc = A; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
LogAccumulator() = default;
LogAccumulator(const LogAccumulator &acc, bool safe = false) {}
void Init(const Fst<Arc> &fst, bool copy = false) {}
void SetState(StateId s) {}
Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
template <class ArcIter> Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { auto sum = w; aiter->Seek(begin); for (auto pos = begin; pos < end; aiter->Next(), ++pos) { sum = LogPlus(sum, aiter->Value().weight); } return sum; }
constexpr bool Error() const { return false; }
private: Weight LogPlus(Weight w, Weight v) { if (w == Weight::Zero()) { return v; } const auto f1 = to_log_weight_(w).Value(); const auto f2 = to_log_weight_(v).Value(); if (f1 > f2) { return to_weight_(Log64Weight(f2 - internal::LogPosExp(f1 - f2))); } else { return to_weight_(Log64Weight(f1 - internal::LogPosExp(f2 - f1))); } }
const WeightConvert<Weight, Log64Weight> to_log_weight_{}; const WeightConvert<Log64Weight, Weight> to_weight_{};
LogAccumulator &operator=(const LogAccumulator &) = delete; };
// Interface for shareable data for fast log accumulator copies. Holds pointers
// to data only, storage is provided by derived classes.
class FastLogAccumulatorData { public: FastLogAccumulatorData(int arc_limit, int arc_period) : arc_limit_(arc_limit), arc_period_(arc_period), weights_ptr_(nullptr), num_weights_(0), weight_positions_ptr_(nullptr), num_positions_(0) {}
virtual ~FastLogAccumulatorData() = default;
// Cummulative weight per state for all states s.t. # of arcs > arc_limit_
// with arcs in order. The first element per state is Log64Weight::Zero().
const double *Weights() const { return weights_ptr_; }
int NumWeights() const { return num_weights_; }
// Maps from state to corresponding beginning weight position in weights_.
// osition -1 means no pre-computed weights for that state.
const int *WeightPositions() const { return weight_positions_ptr_; }
int NumPositions() const { return num_positions_; }
int ArcLimit() const { return arc_limit_; }
int ArcPeriod() const { return arc_period_; }
// Returns true if the data object is mutable and supports SetData().
virtual bool IsMutable() const = 0;
// Does not take ownership but may invalidate the contents of weights and
// weight_positions.
virtual void SetData(std::vector<double> *weights, std::vector<int> *weight_positions) = 0;
protected: void Init(int num_weights, const double *weights, int num_positions, const int *weight_positions) { weights_ptr_ = weights; num_weights_ = num_weights; weight_positions_ptr_ = weight_positions; num_positions_ = num_positions; }
private: const int arc_limit_; const int arc_period_; const double *weights_ptr_; int num_weights_; const int *weight_positions_ptr_; int num_positions_;
FastLogAccumulatorData(const FastLogAccumulatorData &) = delete; FastLogAccumulatorData &operator=(const FastLogAccumulatorData &) = delete; };
// FastLogAccumulatorData with mutable storage; filled by
// FastLogAccumulator::Init.
class MutableFastLogAccumulatorData : public FastLogAccumulatorData { public: MutableFastLogAccumulatorData(int arc_limit, int arc_period) : FastLogAccumulatorData(arc_limit, arc_period) {}
bool IsMutable() const override { return true; }
void SetData(std::vector<double> *weights, std::vector<int> *weight_positions) override { weights_.swap(*weights); weight_positions_.swap(*weight_positions); Init(weights_.size(), weights_.data(), weight_positions_.size(), weight_positions_.data()); }
private: std::vector<double> weights_; std::vector<int> weight_positions_;
MutableFastLogAccumulatorData(const MutableFastLogAccumulatorData &) = delete; MutableFastLogAccumulatorData &operator=( const MutableFastLogAccumulatorData &) = delete; };
// This class accumulates arc weights using the log semiring Plus() assuming an
// arc weight has a WeightConvert specialization to and from log64 weights. The
// member function Init(fst) has to be called to setup pre-computed weight
// information.
// Sum(w, aiter, begin, end) has time complexity O(arc_limit_) or O(arc_period_)
// depending on whether the state has more than arc_limit_ arcs
// Space complexity is O(CountStates(fst) + CountArcs(fst) / arc_period_).
template <class A> class FastLogAccumulator { public: using Arc = A; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) : to_log_weight_(), to_weight_(), arc_limit_(arc_limit), arc_period_(arc_period), data_(std::make_shared<MutableFastLogAccumulatorData>(arc_limit, arc_period)), state_weights_(nullptr), error_(false) {}
explicit FastLogAccumulator(std::shared_ptr<FastLogAccumulatorData> data) : to_log_weight_(), to_weight_(), arc_limit_(data->ArcLimit()), arc_period_(data->ArcPeriod()), data_(data), state_weights_(nullptr), error_(false) {}
FastLogAccumulator(const FastLogAccumulator &acc, bool safe = false) : to_log_weight_(), to_weight_(), arc_limit_(acc.arc_limit_), arc_period_(acc.arc_period_), data_(acc.data_), state_weights_(nullptr), error_(acc.error_) {}
void SetState(StateId s) { const auto *weights = data_->Weights(); const auto *weight_positions = data_->WeightPositions(); state_weights_ = nullptr; if (s < data_->NumPositions()) { const auto pos = weight_positions[s]; if (pos >= 0) state_weights_ = &(weights[pos]); } }
Weight Sum(Weight w, Weight v) const { return LogPlus(w, v); }
template <class ArcIter> Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const { if (error_) return Weight::NoWeight(); auto sum = w; // Finds begin and end of pre-stored weights.
ssize_t index_begin = -1; ssize_t index_end = -1; ssize_t stored_begin = end; ssize_t stored_end = end; if (state_weights_) { index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0; index_end = end / arc_period_; stored_begin = index_begin * arc_period_; stored_end = index_end * arc_period_; } // Computes sum before pre-stored weights.
if (begin < stored_begin) { const auto pos_end = std::min(stored_begin, end); aiter->Seek(begin); for (auto pos = begin; pos < pos_end; aiter->Next(), ++pos) { sum = LogPlus(sum, aiter->Value().weight); } } // Computes sum between pre-stored weights.
if (stored_begin < stored_end) { const auto f1 = state_weights_[index_end]; const auto f2 = state_weights_[index_begin]; if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2)); // Commented out for efficiency; adds Zero().
/*
else { // explicitly computes if cumulative sum lacks precision
aiter->Seek(stored_begin); for (auto pos = stored_begin; pos < stored_end; aiter->Next(), ++pos) sum = LogPlus(sum, aiter->Value().weight); } */ } // Computes sum after pre-stored weights.
if (stored_end < end) { const auto pos_start = std::max(stored_begin, stored_end); aiter->Seek(pos_start); for (auto pos = pos_start; pos < end; aiter->Next(), ++pos) { sum = LogPlus(sum, aiter->Value().weight); } } return sum; }
template <class FST> void Init(const FST &fst, bool copy = false) { if (copy || !data_->IsMutable()) return; if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) { FSTERROR() << "FastLogAccumulator: Initialization error"; error_ = true; return; } std::vector<double> weights; std::vector<int> weight_positions; weight_positions.reserve(CountStates(fst)); for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) { const auto s = siter.Value(); if (fst.NumArcs(s) >= arc_limit_) { auto sum = FloatLimits<double>::PosInfinity(); if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1); weight_positions[s] = weights.size(); weights.push_back(sum); size_t narcs = 0; ArcIterator<FST> aiter(fst, s); aiter.SetFlags(kArcWeightValue | kArcNoCache, kArcFlags); for (; !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); sum = LogPlus(sum, arc.weight); // Stores cumulative weight distribution per arc_period_.
if (++narcs % arc_period_ == 0) weights.push_back(sum); } } } data_->SetData(&weights, &weight_positions); }
bool Error() const { return error_; }
std::shared_ptr<FastLogAccumulatorData> GetData() const { return data_; }
private: static double LogPosExp(double x) { return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F + exp(-x)); }
static double LogMinusExp(double x) { return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F - exp(-x)); }
Weight LogPlus(Weight w, Weight v) const { if (w == Weight::Zero()) { return v; } const auto f1 = to_log_weight_(w).Value(); const auto f2 = to_log_weight_(v).Value(); if (f1 > f2) { return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2))); } else { return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1))); } }
double LogPlus(double f1, Weight v) const { const auto f2 = to_log_weight_(v).Value(); if (f1 == FloatLimits<double>::PosInfinity()) { return f2; } else if (f1 > f2) { return f2 - LogPosExp(f1 - f2); } else { return f1 - LogPosExp(f2 - f1); } }
// Assumes f1 < f2.
Weight LogMinus(double f1, double f2) const { if (f2 == FloatLimits<double>::PosInfinity()) { return to_weight_(Log64Weight(f1)); } else { return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1))); } }
const WeightConvert<Weight, Log64Weight> to_log_weight_{}; const WeightConvert<Log64Weight, Weight> to_weight_{}; const ssize_t arc_limit_; // Minimum number of arcs to pre-compute state.
const ssize_t arc_period_; // Saves cumulative weights per arc_period_.
std::shared_ptr<FastLogAccumulatorData> data_; const double *state_weights_; bool error_;
FastLogAccumulator &operator=(const FastLogAccumulator &) = delete; };
// Stores shareable data for cache log accumulator copies. All copies share the
// same cache.
template <class Arc> class CacheLogAccumulatorData { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
CacheLogAccumulatorData(bool gc, size_t gc_limit) : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
CacheLogAccumulatorData(const CacheLogAccumulatorData<Arc> &data) : cache_gc_(data.cache_gc_), cache_limit_(data.cache_limit_), cache_size_(0) {}
bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
std::vector<double> *GetWeights(StateId s) { if (auto it = cache_.find(s); it != cache_.end()) { it->second.recent = true; return it->second.weights.get(); } else { return nullptr; } }
void AddWeights(StateId s, std::unique_ptr<std::vector<double>> weights) { if (cache_gc_ && cache_size_ >= cache_limit_) GC(false); if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double); cache_.emplace(s, CacheState(std::move(weights), true)); }
private: // Cached information for a given state.
struct CacheState { std::unique_ptr<std::vector<double>> weights; // Accumulated weights.
bool recent; // Has this state been accessed since last GC?
CacheState(std::unique_ptr<std::vector<double>> weights, bool recent) : weights(std::move(weights)), recent(recent) {} };
// Garbage collect: Deletes from cache states that have not been accessed
// since the last GC ('free_recent = false') until 'cache_size_' is 2/3 of
// 'cache_limit_'. If it does not free enough memory, start deleting
// recently accessed states.
void GC(bool free_recent) { auto cache_target = (2 * cache_limit_) / 3 + 1; auto it = cache_.begin(); while (it != cache_.end() && cache_size_ > cache_target) { auto &cs = it->second; if (free_recent || !cs.recent) { cache_size_ -= cs.weights->capacity() * sizeof(double); cache_.erase(it++); } else { cs.recent = false; ++it; } } if (!free_recent && cache_size_ > cache_target) GC(true); }
std::unordered_map<StateId, CacheState> cache_; // Cache.
bool cache_gc_; // Enables garbage collection.
size_t cache_limit_; // # of bytes cached.
size_t cache_size_; // # of bytes allowed before GC.
CacheLogAccumulatorData &operator=(const CacheLogAccumulatorData &) = delete; };
// This class accumulates arc weights using the log semiring Plus() has a
// WeightConvert specialization to and from log64 weights. It is similar to the
// FastLogAccumator. However here, the accumulated weights are pre-computed and
// stored only for the states that are visited. The member function Init(fst)
// has to be called to setup this accumulator. Space complexity is O(gc_limit).
template <class Arc> class CacheLogAccumulator { public: using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, size_t gc_limit = 10 * 1024 * 1024) : arc_limit_(arc_limit), data_(std::make_shared<CacheLogAccumulatorData<Arc>>(gc, gc_limit)), s_(kNoStateId), error_(false) {}
CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe = false) : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : nullptr), data_(safe ? std::make_shared<CacheLogAccumulatorData<Arc>>(*acc.data_) : acc.data_), s_(kNoStateId), error_(acc.error_) {}
// Argument arc_limit specifies the minimum number of arcs to pre-compute.
void Init(const Fst<Arc> &fst, bool copy = false) { if (!copy && fst_) { FSTERROR() << "CacheLogAccumulator: Initialization error"; error_ = true; return; } fst_.reset(fst.Copy()); }
void SetState(StateId s, int depth = 0) { if (s == s_) return; s_ = s; if (data_->CacheDisabled() || error_) { weights_ = nullptr; return; } if (!fst_) { FSTERROR() << "CacheLogAccumulator::SetState: Incorrectly initialized"; error_ = true; weights_ = nullptr; return; } weights_ = data_->GetWeights(s); if ((weights_ == nullptr) && (fst_->NumArcs(s) >= arc_limit_)) { auto weights = std::make_unique<std::vector<double>>(); weights->reserve(fst_->NumArcs(s) + 1); weights->push_back(FloatLimits<double>::PosInfinity()); // `weights` holds a reference to the weight vector, whose ownership is
// transferred to `data_`.
weights_ = weights.get(); data_->AddWeights(s, std::move(weights)); } }
Weight Sum(Weight w, Weight v) { return LogPlus(w, v); }
template <class ArcIter> Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { if (weights_ == nullptr) { auto sum = w; aiter->Seek(begin); for (auto pos = begin; pos < end; aiter->Next(), ++pos) { sum = LogPlus(sum, aiter->Value().weight); } return sum; } else { Extend(end, aiter); const auto &f1 = (*weights_)[end]; const auto &f2 = (*weights_)[begin]; if (f1 < f2) { return LogPlus(w, LogMinus(f1, f2)); } else { // Commented out for efficiency; adds Zero().
/*
auto sum = w; // Explicitly computes if cumulative sum lacks precision.
aiter->Seek(begin); for (auto pos = begin; pos < end; aiter->Next(), ++pos) { sum = LogPlus(sum, aiter->Value().weight); } return sum; */ return w; } } }
// Returns first position from aiter->Position() whose accumulated
// value is greater or equal to w (w.r.t. Zero() < One()). The
// iterator may be repositioned.
template <class ArcIter> size_t LowerBound(Weight w, ArcIter *aiter) { const auto f = to_log_weight_(w).Value(); auto pos = aiter->Position(); if (weights_) { Extend(fst_->NumArcs(s_), aiter); return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), f, std::greater<double>()) - weights_->begin() - 1; } else { size_t n = 0; auto x = FloatLimits<double>::PosInfinity(); for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { x = LogPlus(x, aiter->Value().weight); if (n >= pos && x <= f) break; } return n; } }
bool Error() const { return error_; }
private: double LogPosExp(double x) { return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F + exp(-x)); }
double LogMinusExp(double x) { return x == FloatLimits<double>::PosInfinity() ? 0.0 : log(1.0F - exp(-x)); }
Weight LogPlus(Weight w, Weight v) { if (w == Weight::Zero()) { return v; } const auto f1 = to_log_weight_(w).Value(); const auto f2 = to_log_weight_(v).Value(); if (f1 > f2) { return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2))); } else { return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1))); } }
double LogPlus(double f1, Weight v) { const auto f2 = to_log_weight_(v).Value(); if (f1 == FloatLimits<double>::PosInfinity()) { return f2; } else if (f1 > f2) { return f2 - LogPosExp(f1 - f2); } else { return f1 - LogPosExp(f2 - f1); } }
// Assumes f1 < f2.
Weight LogMinus(double f1, double f2) { if (f2 == FloatLimits<double>::PosInfinity()) { return to_weight_(Log64Weight(f1)); } else { return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1))); } }
// Extends weights up to index 'end'.
template <class ArcIter> void Extend(ssize_t end, ArcIter *aiter) { if (weights_->size() <= end) { for (aiter->Seek(weights_->size() - 1); weights_->size() <= end; aiter->Next()) { weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight)); } } }
const WeightConvert<Weight, Log64Weight> to_log_weight_{}; const WeightConvert<Log64Weight, Weight> to_weight_{}; ssize_t arc_limit_; // Minimum # of arcs to cache a state.
std::vector<double> *weights_; // Accumulated weights for cur. state.
// Pointee owned by `data_`.
std::unique_ptr<const Fst<Arc>> fst_; // Input FST.
std::shared_ptr<CacheLogAccumulatorData<Arc>> data_; // Cache data.
StateId s_; // Current state.
bool error_; };
// Stores shareable data for replace accumulator copies.
template <class Accumulator, class T> class ReplaceAccumulatorData { public: using Arc = typename Accumulator::Arc; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using StateTable = T; using StateTuple = typename StateTable::StateTuple;
ReplaceAccumulatorData() : state_table_(nullptr) {}
explicit ReplaceAccumulatorData( std::vector<std::unique_ptr<Accumulator>> &&accumulators) : state_table_(nullptr), accumulators_(std::move(accumulators)) {}
void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples, const StateTable *state_table) { state_table_ = state_table; accumulators_.resize(fst_tuples.size()); for (Label i = 0; i < accumulators_.size(); ++i) { if (!accumulators_[i]) { accumulators_[i] = std::make_unique<Accumulator>(); accumulators_[i]->Init(*(fst_tuples[i].second)); } fst_array_.emplace_back(fst_tuples[i].second->Copy()); } }
const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); }
Accumulator *GetAccumulator(size_t i) { return accumulators_[i].get(); }
const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
private: const StateTable *state_table_; std::vector<std::unique_ptr<Accumulator>> accumulators_; std::vector<std::unique_ptr<const Fst<Arc>>> fst_array_; };
// This class accumulates weights in a ReplaceFst. The 'Init' method takes as
// input the argument used to build the ReplaceFst and the ReplaceFst state
// table. It uses accumulators of type 'Accumulator' in the underlying FSTs.
template <class Accumulator, class T = DefaultReplaceStateTable<typename Accumulator::Arc>> class ReplaceAccumulator { public: using Arc = typename Accumulator::Arc; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using StateTable = T; using StateTuple = typename StateTable::StateTuple; using Weight = typename Arc::Weight;
ReplaceAccumulator() : init_(false), data_(std::make_shared< ReplaceAccumulatorData<Accumulator, StateTable>>()), error_(false) {}
explicit ReplaceAccumulator( std::vector<std::unique_ptr<Accumulator>> &&accumulators) : init_(false), data_(std::make_shared<ReplaceAccumulatorData<Accumulator, StateTable>>( std::move(accumulators))), error_(false) {}
ReplaceAccumulator(const ReplaceAccumulator<Accumulator, StateTable> &acc, bool safe = false) : init_(acc.init_), data_(acc.data_), error_(acc.error_) { if (!init_) { FSTERROR() << "ReplaceAccumulator: Can't copy unintialized accumulator"; } if (safe) FSTERROR() << "ReplaceAccumulator: Safe copy not supported"; }
// Does not take ownership of the state table, the state table is owned by
// the ReplaceFst.
void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples, const StateTable *state_table) { init_ = true; data_->Init(fst_tuples, state_table); }
// Method required by LookAheadMatcher. However, ReplaceAccumulator needs to
// be initialized by calling the Init method above before being passed to
// LookAheadMatcher.
//
// TODO(allauzen): Revisit this. Consider creating a method
// Init(const ReplaceFst<A, T, C>&, bool) and using friendship to get access
// to the innards of ReplaceFst.
void Init(const Fst<Arc> &fst, bool copy = false) { if (!init_) { FSTERROR() << "ReplaceAccumulator::Init: Accumulator needs to be" << " initialized before being passed to LookAheadMatcher"; error_ = true; } }
void SetState(StateId s) { if (!init_) { FSTERROR() << "ReplaceAccumulator::SetState: Incorrectly initialized"; error_ = true; return; } auto tuple = data_->GetTuple(s); fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based.
data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); if ((tuple.prefix_id != 0) && (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { offset_ = 1; offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); } else { offset_ = 0; offset_weight_ = Weight::Zero(); } aiter_ = std::make_unique<ArcIterator<Fst<Arc>>>(*data_->GetFst(fst_id_), tuple.fst_state); }
Weight Sum(Weight w, Weight v) { if (error_) return Weight::NoWeight(); return data_->GetAccumulator(fst_id_)->Sum(w, v); }
template <class ArcIter> Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { if (error_) return Weight::NoWeight(); auto sum = begin == end ? Weight::Zero() : data_->GetAccumulator(fst_id_)->Sum( w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_); if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum); return sum; }
bool Error() const { return error_; }
private: bool init_; std::shared_ptr<ReplaceAccumulatorData<Accumulator, StateTable>> data_; Label fst_id_; size_t offset_; Weight offset_weight_; std::unique_ptr<ArcIterator<Fst<Arc>>> aiter_; bool error_; };
// SafeReplaceAccumulator accumulates weights in a ReplaceFst and copies of it
// are always thread-safe copies.
template <class Accumulator, class T> class SafeReplaceAccumulator { public: using Arc = typename Accumulator::Arc; using StateId = typename Arc::StateId; using Label = typename Arc::Label; using Weight = typename Arc::Weight; using StateTable = T; using StateTuple = typename StateTable::StateTuple;
SafeReplaceAccumulator() = default;
SafeReplaceAccumulator(const SafeReplaceAccumulator ©, bool safe) : SafeReplaceAccumulator(copy) {}
explicit SafeReplaceAccumulator( const std::vector<Accumulator> &accumulators) { for (const auto &accumulator : accumulators) { accumulators_.emplace_back(accumulator, true); } }
void Init(const std::vector<std::pair<Label, const Fst<Arc> *>> &fst_tuples, const StateTable *state_table) { state_table_ = state_table; for (Label i = 0; i < fst_tuples.size(); ++i) { if (i == accumulators_.size()) { accumulators_.resize(accumulators_.size() + 1); accumulators_[i].Init(*(fst_tuples[i].second)); } fst_array_.emplace_back(fst_tuples[i].second->Copy(true)); } init_ = true; }
void Init(const Fst<Arc> &fst, bool copy = false) { if (!init_) { FSTERROR() << "SafeReplaceAccumulator::Init: Accumulator needs to be" << " initialized before being passed to LookAheadMatcher"; error_ = true; } }
void SetState(StateId s) { auto tuple = state_table_->Tuple(s); fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based
GetAccumulator(fst_id_)->SetState(tuple.fst_state); offset_ = 0; offset_weight_ = Weight::Zero(); const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state); if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) { offset_ = 1; offset_weight_ = final_weight; } aiter_.Set(*GetFst(fst_id_), tuple.fst_state); }
Weight Sum(Weight w, Weight v) { if (error_) return Weight::NoWeight(); return GetAccumulator(fst_id_)->Sum(w, v); }
template <class ArcIter> Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { if (error_) return Weight::NoWeight(); if (begin == end) return Weight::Zero(); auto sum = GetAccumulator(fst_id_)->Sum( w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_); if (begin == 0 && end != 0 && offset_ > 0) { sum = Sum(offset_weight_, sum); } return sum; }
bool Error() const { return error_; }
private: class ArcIteratorPtr { public: ArcIteratorPtr() = default;
ArcIteratorPtr(const ArcIteratorPtr ©) {}
void Set(const Fst<Arc> &fst, StateId state_id) { ptr_ = std::make_unique<ArcIterator<Fst<Arc>>>(fst, state_id); }
ArcIterator<Fst<Arc>> *get() { return ptr_.get(); }
private: std::unique_ptr<ArcIterator<Fst<Arc>>> ptr_; };
Accumulator *GetAccumulator(size_t i) { return &accumulators_[i]; }
const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i].get(); }
const StateTable *state_table_; std::vector<Accumulator> accumulators_; std::vector<std::shared_ptr<Fst<Arc>>> fst_array_; ArcIteratorPtr aiter_; bool init_ = false; bool error_ = false; Label fst_id_; size_t offset_; Weight offset_weight_; };
} // namespace fst
#endif // FST_ACCUMULATOR_H_
|