|
// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Classes and functions to generate random paths through an FST.
|
|
|
|
#ifndef FST_RANDGEN_H_
|
|
#define FST_RANDGEN_H_
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <cstring>
|
|
#include <functional>
|
|
#include <limits>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <random>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <fst/log.h>
|
|
#include <fst/accumulator.h>
|
|
#include <fst/arc.h>
|
|
#include <fst/cache.h>
|
|
#include <fst/dfs-visit.h>
|
|
#include <fst/float-weight.h>
|
|
#include <fst/fst-decl.h>
|
|
#include <fst/fst.h>
|
|
#include <fst/impl-to-fst.h>
|
|
#include <fst/mutable-fst.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/util.h>
|
|
#include <fst/weight.h>
|
|
#include <vector>
|
|
|
|
namespace fst {
|
|
|
|
// The RandGenFst class is roughly similar to ArcMapFst in that it takes two
|
|
// template parameters denoting the input and output arc types. However, it also
|
|
// takes an additional template parameter which specifies a sampler object which
|
|
// samples (with replacement) arcs from an FST state. The sampler in turn takes
|
|
// a template parameter for a selector object which actually chooses the arc.
|
|
//
|
|
// Arc selector functors are used to select a random transition given an FST
|
|
// state s, returning a number N such that 0 <= N <= NumArcs(s). If N is
|
|
// NumArcs(s), then the final weight is selected; otherwise the N-th arc is
|
|
// selected. It is assumed these are not applied to any state which is neither
|
|
// final nor has any arcs leaving it.
|
|
|
|
// Randomly selects a transition using the uniform distribution. This class is
|
|
// not thread-safe.
|
|
template <class Arc>
|
|
class UniformArcSelector {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
explicit UniformArcSelector(uint64_t seed = std::random_device()())
|
|
: rand_(seed) {}
|
|
|
|
size_t operator()(const Fst<Arc> &fst, StateId s) const {
|
|
const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero());
|
|
return static_cast<size_t>(
|
|
std::uniform_int_distribution<>(0, n - 1)(rand_));
|
|
}
|
|
|
|
private:
|
|
mutable std::mt19937_64 rand_;
|
|
};
|
|
|
|
// Randomly selects a transition w.r.t. the weights treated as negative log
|
|
// probabilities after normalizing for the total weight leaving the state. Zero
|
|
// transitions are disregarded. It assumed that Arc::Weight::Value() accesses
|
|
// the floating point representation of the weight. This class is not
|
|
// thread-safe.
|
|
template <class Arc>
|
|
class LogProbArcSelector {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
// Constructs a selector with a non-deterministic seed.
|
|
LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {}
|
|
|
|
// Constructs a selector with a given seed.
|
|
explicit LogProbArcSelector(uint64_t seed) : seed_(seed), rand_(seed) {}
|
|
|
|
size_t operator()(const Fst<Arc> &fst, StateId s) const {
|
|
// Finds total weight leaving state.
|
|
auto sum = Log64Weight::Zero();
|
|
ArcIterator<Fst<Arc>> aiter(fst, s);
|
|
for (; !aiter.Done(); aiter.Next()) {
|
|
const auto &arc = aiter.Value();
|
|
sum = Plus(sum, to_log_weight_(arc.weight));
|
|
}
|
|
sum = Plus(sum, to_log_weight_(fst.Final(s)));
|
|
const double threshold =
|
|
std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
|
|
auto p = Log64Weight::Zero();
|
|
size_t n = 0;
|
|
for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) {
|
|
p = Plus(p, to_log_weight_(aiter.Value().weight));
|
|
if (exp(-p.Value()) > threshold) return n;
|
|
}
|
|
return n;
|
|
}
|
|
|
|
uint64_t Seed() const { return seed_; }
|
|
|
|
protected:
|
|
Log64Weight ToLogWeight(const Weight &weight) const {
|
|
return to_log_weight_(weight);
|
|
}
|
|
|
|
std::mt19937_64 &MutableRand() const { return rand_; }
|
|
|
|
private:
|
|
const uint64_t seed_;
|
|
mutable std::mt19937_64 rand_;
|
|
const WeightConvert<Weight, Log64Weight> to_log_weight_{};
|
|
};
|
|
|
|
// Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight
|
|
// accumulation computations. This class is not thread-safe.
|
|
template <class Arc>
|
|
class FastLogProbArcSelector : public LogProbArcSelector<Arc> {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
using LogProbArcSelector<Arc>::MutableRand;
|
|
using LogProbArcSelector<Arc>::ToLogWeight;
|
|
using LogProbArcSelector<Arc>::operator();
|
|
|
|
// Constructs a selector with a non-deterministic seed.
|
|
FastLogProbArcSelector() : LogProbArcSelector<Arc>() {}
|
|
// Constructs a selector with a given seed.
|
|
explicit FastLogProbArcSelector(uint64_t seed)
|
|
: LogProbArcSelector<Arc>(seed) {}
|
|
|
|
size_t operator()(const Fst<Arc> &fst, StateId s,
|
|
CacheLogAccumulator<Arc> *accumulator) const {
|
|
accumulator->SetState(s);
|
|
ArcIterator<Fst<Arc>> aiter(fst, s);
|
|
// Finds total weight leaving state.
|
|
const double sum =
|
|
ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s)))
|
|
.Value();
|
|
const double r =
|
|
-log(std::uniform_real_distribution<>(0, 1)(MutableRand()));
|
|
Weight w = from_log_weight_(r + sum);
|
|
aiter.Reset();
|
|
return accumulator->LowerBound(w, &aiter);
|
|
}
|
|
|
|
private:
|
|
const WeightConvert<Log64Weight, Weight> from_log_weight_{};
|
|
};
|
|
|
|
// Random path state info maintained by RandGenFst and passed to samplers.
|
|
template <typename Arc>
|
|
struct RandState {
|
|
using StateId = typename Arc::StateId;
|
|
|
|
StateId state_id; // Current input FST state.
|
|
size_t nsamples; // Number of samples to be sampled at this state.
|
|
size_t length; // Length of path to this random state.
|
|
size_t select; // Previous sample arc selection.
|
|
const RandState<Arc> *parent; // Previous random state on this path.
|
|
|
|
explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0,
|
|
size_t select = 0, const RandState<Arc> *parent = nullptr)
|
|
: state_id(state_id),
|
|
nsamples(nsamples),
|
|
length(length),
|
|
select(select),
|
|
parent(parent) {}
|
|
|
|
RandState() : RandState(kNoStateId) {}
|
|
};
|
|
|
|
// This class, given an arc selector, samples, with replacement, multiple random
|
|
// transitions from an FST's state. This is a generic version with a
|
|
// straightforward use of the arc selector. Specializations may be defined for
|
|
// arc selectors for greater efficiency or special behavior.
|
|
template <class Arc, class Selector>
|
|
class ArcSampler {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
// The max_length argument may be interpreted (or ignored) by a selector as
|
|
// it chooses. This generic version interprets this literally.
|
|
ArcSampler(const Fst<Arc> &fst, const Selector &selector,
|
|
int32_t max_length = std::numeric_limits<int32_t>::max())
|
|
: fst_(fst), selector_(selector), max_length_(max_length) {}
|
|
|
|
// Allow updating FST argument; pass only if changed.
|
|
ArcSampler(const ArcSampler<Arc, Selector> &sampler,
|
|
const Fst<Arc> *fst = nullptr)
|
|
: fst_(fst ? *fst : sampler.fst_),
|
|
selector_(sampler.selector_),
|
|
max_length_(sampler.max_length_) {
|
|
Reset();
|
|
}
|
|
|
|
// Samples a fixed number of samples from the given state. The length argument
|
|
// specifies the length of the path to the state. Returns true if the samples
|
|
// were collected. No samples may be collected if either there are no
|
|
// transitions leaving the state and the state is non-final, or if the path
|
|
// length has been exceeded. Iterator members are provided to read the samples
|
|
// in the order in which they were collected.
|
|
bool Sample(const RandState<Arc> &rstate) {
|
|
sample_map_.clear();
|
|
if ((fst_.NumArcs(rstate.state_id) == 0 &&
|
|
fst_.Final(rstate.state_id) == Weight::Zero()) ||
|
|
rstate.length == max_length_) {
|
|
Reset();
|
|
return false;
|
|
}
|
|
for (size_t i = 0; i < rstate.nsamples; ++i) {
|
|
++sample_map_[selector_(fst_, rstate.state_id)];
|
|
}
|
|
Reset();
|
|
return true;
|
|
}
|
|
|
|
// More samples?
|
|
bool Done() const { return sample_iter_ == sample_map_.end(); }
|
|
|
|
// Gets the next sample.
|
|
void Next() { ++sample_iter_; }
|
|
|
|
std::pair<size_t, size_t> Value() const { return *sample_iter_; }
|
|
|
|
void Reset() { sample_iter_ = sample_map_.begin(); }
|
|
|
|
bool Error() const { return false; }
|
|
|
|
private:
|
|
const Fst<Arc> &fst_;
|
|
const Selector &selector_;
|
|
const int32_t max_length_;
|
|
|
|
// Stores (N, K) as described for Value().
|
|
std::map<size_t, size_t> sample_map_;
|
|
std::map<size_t, size_t>::const_iterator sample_iter_;
|
|
|
|
ArcSampler<Arc, Selector> &operator=(const ArcSampler &) = delete;
|
|
};
|
|
|
|
// Samples one sample of num_to_sample dimensions from a multinomial
|
|
// distribution parameterized by a vector of probabilities. The result
|
|
// container should be pre-initialized (e.g., an empty map or a zeroed vector
|
|
// sized the same as the vector of probabilities.
|
|
// probs.size()).
|
|
template <class Result, class RNG>
|
|
void OneMultinomialSample(const std::vector<double> &probs,
|
|
size_t num_to_sample, Result *result, RNG *rng) {
|
|
using distribution = std::binomial_distribution<size_t>;
|
|
// Left-over probability mass. Keep an array of the partial sums because
|
|
// keeping a scalar and modifying norm -= probs[i] in the loop will result
|
|
// in round-off error and can have probs[i] > norm.
|
|
std::vector<double> norm(probs.size());
|
|
std::partial_sum(probs.rbegin(), probs.rend(), norm.rbegin());
|
|
// Left-over number of samples needed.
|
|
for (size_t i = 0; i < probs.size(); ++i) {
|
|
distribution::result_type num_sampled = 0;
|
|
if (probs[i] > 0) {
|
|
distribution d(num_to_sample, probs[i] / norm[i]);
|
|
num_sampled = d(*rng);
|
|
}
|
|
if (num_sampled != 0) (*result)[i] = num_sampled;
|
|
num_to_sample -= std::min(num_sampled, num_to_sample);
|
|
}
|
|
}
|
|
|
|
// Specialization for FastLogProbArcSelector.
|
|
template <class Arc>
|
|
class ArcSampler<Arc, FastLogProbArcSelector<Arc>> {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
using Accumulator = CacheLogAccumulator<Arc>;
|
|
using Selector = FastLogProbArcSelector<Arc>;
|
|
|
|
ArcSampler(const Fst<Arc> &fst, const Selector &selector,
|
|
int32_t max_length = std::numeric_limits<int32_t>::max())
|
|
: fst_(fst),
|
|
selector_(selector),
|
|
max_length_(max_length),
|
|
accumulator_(new Accumulator()) {
|
|
accumulator_->Init(fst);
|
|
rng_.seed(selector_.Seed());
|
|
}
|
|
|
|
ArcSampler(const ArcSampler<Arc, Selector> &sampler,
|
|
const Fst<Arc> *fst = nullptr)
|
|
: fst_(fst ? *fst : sampler.fst_),
|
|
selector_(sampler.selector_),
|
|
max_length_(sampler.max_length_) {
|
|
if (fst) {
|
|
accumulator_ = std::make_unique<Accumulator>();
|
|
accumulator_->Init(*fst);
|
|
} else { // Shallow copy.
|
|
accumulator_ = std::make_unique<Accumulator>(*sampler.accumulator_);
|
|
}
|
|
}
|
|
|
|
bool Sample(const RandState<Arc> &rstate) {
|
|
sample_map_.clear();
|
|
if ((fst_.NumArcs(rstate.state_id) == 0 &&
|
|
fst_.Final(rstate.state_id) == Weight::Zero()) ||
|
|
rstate.length == max_length_) {
|
|
Reset();
|
|
return false;
|
|
}
|
|
if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
|
|
MultinomialSample(rstate);
|
|
Reset();
|
|
return true;
|
|
}
|
|
for (size_t i = 0; i < rstate.nsamples; ++i) {
|
|
++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())];
|
|
}
|
|
Reset();
|
|
return true;
|
|
}
|
|
|
|
bool Done() const { return sample_iter_ == sample_map_.end(); }
|
|
|
|
void Next() { ++sample_iter_; }
|
|
|
|
std::pair<size_t, size_t> Value() const { return *sample_iter_; }
|
|
|
|
void Reset() { sample_iter_ = sample_map_.begin(); }
|
|
|
|
bool Error() const { return accumulator_->Error(); }
|
|
|
|
private:
|
|
using RNG = std::mt19937;
|
|
|
|
// Sample according to the multinomial distribution of rstate.nsamples draws
|
|
// from p_.
|
|
void MultinomialSample(const RandState<Arc> &rstate) {
|
|
p_.clear();
|
|
for (ArcIterator<Fst<Arc>> aiter(fst_, rstate.state_id); !aiter.Done();
|
|
aiter.Next()) {
|
|
p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
|
|
}
|
|
if (fst_.Final(rstate.state_id) != Weight::Zero()) {
|
|
p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
|
|
}
|
|
if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
|
|
OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_);
|
|
} else {
|
|
for (size_t i = 0; i < p_.size(); ++i) {
|
|
sample_map_[i] = ceil(p_[i] * rstate.nsamples);
|
|
}
|
|
}
|
|
}
|
|
|
|
const Fst<Arc> &fst_;
|
|
const Selector &selector_;
|
|
const int32_t max_length_;
|
|
|
|
// Stores (N, K) for Value().
|
|
std::map<size_t, size_t> sample_map_;
|
|
std::map<size_t, size_t>::const_iterator sample_iter_;
|
|
|
|
std::unique_ptr<Accumulator> accumulator_;
|
|
RNG rng_; // Random number generator.
|
|
std::vector<double> p_; // Multinomial parameters.
|
|
const WeightConvert<Weight, Log64Weight> to_log_weight_{};
|
|
};
|
|
|
|
// Options for random path generation with RandGenFst. The template argument is
|
|
// a sampler, typically the class ArcSampler. Ownership of the sampler is taken
|
|
// by RandGenFst.
|
|
template <class Sampler>
|
|
struct RandGenFstOptions : public CacheOptions {
|
|
Sampler *sampler; // How to sample transitions at a state.
|
|
int32_t npath; // Number of paths to generate.
|
|
bool weighted; // Is the output tree weighted by path count, or
|
|
// is it just an unweighted DAG?
|
|
bool remove_total_weight; // Remove total weight when output is weighted.
|
|
|
|
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler,
|
|
int32_t npath = 1, bool weighted = true,
|
|
bool remove_total_weight = false)
|
|
: CacheOptions(opts),
|
|
sampler(sampler),
|
|
npath(npath),
|
|
weighted(weighted),
|
|
remove_total_weight(remove_total_weight) {}
|
|
};
|
|
|
|
namespace internal {
|
|
|
|
// Implementation of RandGenFst.
|
|
template <class FromArc, class ToArc, class Sampler>
|
|
class RandGenFstImpl : public CacheImpl<ToArc> {
|
|
public:
|
|
using FstImpl<ToArc>::SetType;
|
|
using FstImpl<ToArc>::SetProperties;
|
|
using FstImpl<ToArc>::SetInputSymbols;
|
|
using FstImpl<ToArc>::SetOutputSymbols;
|
|
|
|
using CacheBaseImpl<CacheState<ToArc>>::EmplaceArc;
|
|
using CacheBaseImpl<CacheState<ToArc>>::HasArcs;
|
|
using CacheBaseImpl<CacheState<ToArc>>::HasFinal;
|
|
using CacheBaseImpl<CacheState<ToArc>>::HasStart;
|
|
using CacheBaseImpl<CacheState<ToArc>>::SetArcs;
|
|
using CacheBaseImpl<CacheState<ToArc>>::SetFinal;
|
|
using CacheBaseImpl<CacheState<ToArc>>::SetStart;
|
|
|
|
using Label = typename FromArc::Label;
|
|
using StateId = typename FromArc::StateId;
|
|
using FromWeight = typename FromArc::Weight;
|
|
|
|
using ToWeight = typename ToArc::Weight;
|
|
|
|
RandGenFstImpl(const Fst<FromArc> &fst,
|
|
const RandGenFstOptions<Sampler> &opts)
|
|
: CacheImpl<ToArc>(opts),
|
|
fst_(fst.Copy()),
|
|
sampler_(opts.sampler),
|
|
npath_(opts.npath),
|
|
weighted_(opts.weighted),
|
|
remove_total_weight_(opts.remove_total_weight),
|
|
superfinal_(kNoLabel) {
|
|
SetType("randgen");
|
|
SetProperties(
|
|
RandGenProperties(fst.Properties(kFstProperties, false), weighted_),
|
|
kCopyProperties);
|
|
SetInputSymbols(fst.InputSymbols());
|
|
SetOutputSymbols(fst.OutputSymbols());
|
|
}
|
|
|
|
RandGenFstImpl(const RandGenFstImpl &impl)
|
|
: CacheImpl<ToArc>(impl),
|
|
fst_(impl.fst_->Copy(true)),
|
|
sampler_(new Sampler(*impl.sampler_, fst_.get())),
|
|
npath_(impl.npath_),
|
|
weighted_(impl.weighted_),
|
|
superfinal_(kNoLabel) {
|
|
SetType("randgen");
|
|
SetProperties(impl.Properties(), kCopyProperties);
|
|
SetInputSymbols(impl.InputSymbols());
|
|
SetOutputSymbols(impl.OutputSymbols());
|
|
}
|
|
|
|
StateId Start() {
|
|
if (!HasStart()) {
|
|
const auto s = fst_->Start();
|
|
if (s == kNoStateId) return kNoStateId;
|
|
SetStart(state_table_.size());
|
|
state_table_.emplace_back(
|
|
new RandState<FromArc>(s, npath_, 0, 0, nullptr));
|
|
}
|
|
return CacheImpl<ToArc>::Start();
|
|
}
|
|
|
|
ToWeight Final(StateId s) {
|
|
if (!HasFinal(s)) Expand(s);
|
|
return CacheImpl<ToArc>::Final(s);
|
|
}
|
|
|
|
size_t NumArcs(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<ToArc>::NumArcs(s);
|
|
}
|
|
|
|
size_t NumInputEpsilons(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<ToArc>::NumInputEpsilons(s);
|
|
}
|
|
|
|
size_t NumOutputEpsilons(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<ToArc>::NumOutputEpsilons(s);
|
|
}
|
|
|
|
uint64_t Properties() const override { return Properties(kFstProperties); }
|
|
|
|
// Sets error if found, and returns other FST impl properties.
|
|
uint64_t Properties(uint64_t mask) const override {
|
|
if ((mask & kError) &&
|
|
(fst_->Properties(kError, false) || sampler_->Error())) {
|
|
SetProperties(kError, kError);
|
|
}
|
|
return FstImpl<ToArc>::Properties(mask);
|
|
}
|
|
|
|
void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
CacheImpl<ToArc>::InitArcIterator(s, data);
|
|
}
|
|
|
|
// Computes the outgoing transitions from a state, creating new destination
|
|
// states as needed.
|
|
void Expand(StateId s) {
|
|
if (s == superfinal_) {
|
|
SetFinal(s);
|
|
SetArcs(s);
|
|
return;
|
|
}
|
|
SetFinal(s, ToWeight::Zero());
|
|
const auto &rstate = *state_table_[s];
|
|
sampler_->Sample(rstate);
|
|
ArcIterator<Fst<FromArc>> aiter(*fst_, rstate.state_id);
|
|
const auto narcs = fst_->NumArcs(rstate.state_id);
|
|
for (; !sampler_->Done(); sampler_->Next()) {
|
|
const auto &sample_pair = sampler_->Value();
|
|
const auto pos = sample_pair.first;
|
|
const auto count = sample_pair.second;
|
|
double prob = static_cast<double>(count) / rstate.nsamples;
|
|
if (pos < narcs) { // Regular transition.
|
|
aiter.Seek(sample_pair.first);
|
|
const auto &aarc = aiter.Value();
|
|
auto weight =
|
|
weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One();
|
|
EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
|
|
state_table_.size());
|
|
auto nrstate = std::make_unique<RandState<FromArc>>(
|
|
aarc.nextstate, count, rstate.length + 1, pos, &rstate);
|
|
state_table_.push_back(std::move(nrstate));
|
|
} else { // Super-final transition.
|
|
if (weighted_) {
|
|
const auto weight =
|
|
remove_total_weight_
|
|
? to_weight_(Log64Weight(-log(prob)))
|
|
: to_weight_(Log64Weight(-log(prob * npath_)));
|
|
SetFinal(s, weight);
|
|
} else {
|
|
if (superfinal_ == kNoLabel) {
|
|
superfinal_ = state_table_.size();
|
|
state_table_.emplace_back(
|
|
new RandState<FromArc>(kNoStateId, 0, 0, 0, nullptr));
|
|
}
|
|
for (size_t n = 0; n < count; ++n) EmplaceArc(s, 0, 0, superfinal_);
|
|
}
|
|
}
|
|
}
|
|
SetArcs(s);
|
|
}
|
|
|
|
private:
|
|
const std::unique_ptr<Fst<FromArc>> fst_;
|
|
std::unique_ptr<Sampler> sampler_;
|
|
const int32_t npath_;
|
|
std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
|
|
const bool weighted_;
|
|
bool remove_total_weight_;
|
|
StateId superfinal_;
|
|
const WeightConvert<Log64Weight, ToWeight> to_weight_{};
|
|
};
|
|
|
|
} // namespace internal
|
|
|
|
// FST class to randomly generate paths through an FST, with details controlled
|
|
// by RandGenOptionsFst. Output format is a tree weighted by the path count.
|
|
template <class FromArc, class ToArc, class Sampler>
|
|
class RandGenFst
|
|
: public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
|
|
public:
|
|
using Label = typename FromArc::Label;
|
|
using StateId = typename FromArc::StateId;
|
|
using Weight = typename FromArc::Weight;
|
|
|
|
using Store = DefaultCacheStore<FromArc>;
|
|
using State = typename Store::State;
|
|
|
|
using Impl = internal::RandGenFstImpl<FromArc, ToArc, Sampler>;
|
|
|
|
friend class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>;
|
|
friend class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>;
|
|
|
|
RandGenFst(const Fst<FromArc> &fst, const RandGenFstOptions<Sampler> &opts)
|
|
: ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
|
|
|
|
// See Fst<>::Copy() for doc.
|
|
RandGenFst(const RandGenFst &fst, bool safe = false)
|
|
: ImplToFst<Impl>(fst, safe) {}
|
|
|
|
// Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
|
|
RandGenFst *Copy(bool safe = false) const override {
|
|
return new RandGenFst(*this, safe);
|
|
}
|
|
|
|
inline void InitStateIterator(StateIteratorData<ToArc> *data) const override;
|
|
|
|
void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) const override {
|
|
GetMutableImpl()->InitArcIterator(s, data);
|
|
}
|
|
|
|
private:
|
|
using ImplToFst<Impl>::GetImpl;
|
|
using ImplToFst<Impl>::GetMutableImpl;
|
|
|
|
RandGenFst &operator=(const RandGenFst &) = delete;
|
|
};
|
|
|
|
// Specialization for RandGenFst.
|
|
template <class FromArc, class ToArc, class Sampler>
|
|
class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>
|
|
: public CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>> {
|
|
public:
|
|
explicit StateIterator(const RandGenFst<FromArc, ToArc, Sampler> &fst)
|
|
: CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>>(
|
|
fst, fst.GetMutableImpl()) {}
|
|
};
|
|
|
|
// Specialization for RandGenFst.
|
|
template <class FromArc, class ToArc, class Sampler>
|
|
class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>
|
|
: public CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>> {
|
|
public:
|
|
using StateId = typename FromArc::StateId;
|
|
|
|
ArcIterator(const RandGenFst<FromArc, ToArc, Sampler> &fst, StateId s)
|
|
: CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>>(
|
|
fst.GetMutableImpl(), s) {
|
|
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
|
|
}
|
|
};
|
|
|
|
template <class FromArc, class ToArc, class Sampler>
|
|
inline void RandGenFst<FromArc, ToArc, Sampler>::InitStateIterator(
|
|
StateIteratorData<ToArc> *data) const {
|
|
data->base =
|
|
std::make_unique<StateIterator<RandGenFst<FromArc, ToArc, Sampler>>>(
|
|
*this);
|
|
}
|
|
|
|
// Options for random path generation.
|
|
template <class Selector>
|
|
struct RandGenOptions {
|
|
const Selector &selector; // How an arc is selected at a state.
|
|
int32_t max_length; // Maximum path length.
|
|
int32_t npath; // Number of paths to generate.
|
|
bool weighted; // Is the output tree weighted by path count, or
|
|
// is it just an unweighted DAG?
|
|
bool remove_total_weight; // Remove total weight when output is weighted?
|
|
|
|
explicit RandGenOptions(
|
|
const Selector &selector,
|
|
int32_t max_length = std::numeric_limits<int32_t>::max(),
|
|
int32_t npath = 1, bool weighted = false,
|
|
bool remove_total_weight = false)
|
|
: selector(selector),
|
|
max_length(max_length),
|
|
npath(npath),
|
|
weighted(weighted),
|
|
remove_total_weight(remove_total_weight) {}
|
|
};
|
|
|
|
namespace internal {
|
|
|
|
template <class FromArc, class ToArc>
|
|
class RandGenVisitor {
|
|
public:
|
|
using StateId = typename FromArc::StateId;
|
|
using Weight = typename FromArc::Weight;
|
|
|
|
explicit RandGenVisitor(MutableFst<ToArc> *ofst) : ofst_(ofst) {}
|
|
|
|
void InitVisit(const Fst<FromArc> &ifst) {
|
|
ifst_ = &ifst;
|
|
ofst_->DeleteStates();
|
|
ofst_->SetInputSymbols(ifst.InputSymbols());
|
|
ofst_->SetOutputSymbols(ifst.OutputSymbols());
|
|
if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError);
|
|
path_.clear();
|
|
}
|
|
|
|
constexpr bool InitState(StateId, StateId) const { return true; }
|
|
|
|
bool TreeArc(StateId, const ToArc &arc) {
|
|
if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
|
|
path_.push_back(arc);
|
|
} else {
|
|
OutputPath();
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool BackArc(StateId, const FromArc &) {
|
|
FSTERROR() << "RandGenVisitor: cyclic input";
|
|
ofst_->SetProperties(kError, kError);
|
|
return false;
|
|
}
|
|
|
|
bool ForwardOrCrossArc(StateId, const FromArc &) {
|
|
OutputPath();
|
|
return true;
|
|
}
|
|
|
|
void FinishState(StateId s, StateId p, const FromArc *) {
|
|
if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
|
|
}
|
|
|
|
void FinishVisit() {}
|
|
|
|
private:
|
|
void OutputPath() {
|
|
if (ofst_->Start() == kNoStateId) {
|
|
const auto start = ofst_->AddState();
|
|
ofst_->SetStart(start);
|
|
}
|
|
auto src = ofst_->Start();
|
|
for (size_t i = 0; i < path_.size(); ++i) {
|
|
const auto dest = ofst_->AddState();
|
|
const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
|
|
ofst_->AddArc(src, arc);
|
|
src = dest;
|
|
}
|
|
ofst_->SetFinal(src);
|
|
}
|
|
|
|
const Fst<FromArc> *ifst_;
|
|
MutableFst<ToArc> *ofst_;
|
|
std::vector<ToArc> path_;
|
|
|
|
RandGenVisitor(const RandGenVisitor &) = delete;
|
|
RandGenVisitor &operator=(const RandGenVisitor &) = delete;
|
|
};
|
|
|
|
} // namespace internal
|
|
|
|
// Randomly generate paths through an FST; details controlled by
|
|
// RandGenOptions.
|
|
template <class FromArc, class ToArc, class Selector>
|
|
void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
|
|
const RandGenOptions<Selector> &opts) {
|
|
using Sampler = ArcSampler<FromArc, Selector>;
|
|
auto sampler =
|
|
std::make_unique<Sampler>(ifst, opts.selector, opts.max_length);
|
|
RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), sampler.release(),
|
|
opts.npath, opts.weighted,
|
|
opts.remove_total_weight);
|
|
RandGenFst<FromArc, ToArc, Sampler> rfst(ifst, fopts);
|
|
if (opts.weighted) {
|
|
*ofst = rfst;
|
|
} else {
|
|
internal::RandGenVisitor<FromArc, ToArc> rand_visitor(ofst);
|
|
DfsVisit(rfst, &rand_visitor);
|
|
}
|
|
}
|
|
|
|
// Randomly generate a path through an FST with the uniform distribution
|
|
// over the transitions.
|
|
template <class FromArc, class ToArc>
|
|
void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
|
|
uint64_t seed = std::random_device()()) {
|
|
const UniformArcSelector<FromArc> uniform_selector(seed);
|
|
RandGenOptions<UniformArcSelector<ToArc>> opts(uniform_selector);
|
|
RandGen(ifst, ofst, opts);
|
|
}
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_RANDGEN_H_
|