You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

780 lines
26 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Classes and functions to generate random paths through an FST.
#ifndef FST_RANDGEN_H_
#define FST_RANDGEN_H_
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <functional>
#include <limits>
#include <map>
#include <memory>
#include <numeric>
#include <random>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/accumulator.h>
#include <fst/arc.h>
#include <fst/cache.h>
#include <fst/dfs-visit.h>
#include <fst/float-weight.h>
#include <fst/fst-decl.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <vector>
namespace fst {
// The RandGenFst class is roughly similar to ArcMapFst in that it takes two
// template parameters denoting the input and output arc types. However, it also
// takes an additional template parameter which specifies a sampler object which
// samples (with replacement) arcs from an FST state. The sampler in turn takes
// a template parameter for a selector object which actually chooses the arc.
//
// Arc selector functors are used to select a random transition given an FST
// state s, returning a number N such that 0 <= N <= NumArcs(s). If N is
// NumArcs(s), then the final weight is selected; otherwise the N-th arc is
// selected. It is assumed these are not applied to any state which is neither
// final nor has any arcs leaving it.
// Randomly selects a transition using the uniform distribution. This class is
// not thread-safe.
template <class Arc>
class UniformArcSelector {
public:
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
explicit UniformArcSelector(uint64_t seed = std::random_device()())
: rand_(seed) {}
size_t operator()(const Fst<Arc> &fst, StateId s) const {
const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero());
return static_cast<size_t>(
std::uniform_int_distribution<>(0, n - 1)(rand_));
}
private:
mutable std::mt19937_64 rand_;
};
// Randomly selects a transition w.r.t. the weights treated as negative log
// probabilities after normalizing for the total weight leaving the state. Zero
// transitions are disregarded. It assumed that Arc::Weight::Value() accesses
// the floating point representation of the weight. This class is not
// thread-safe.
template <class Arc>
class LogProbArcSelector {
public:
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
// Constructs a selector with a non-deterministic seed.
LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {}
// Constructs a selector with a given seed.
explicit LogProbArcSelector(uint64_t seed) : seed_(seed), rand_(seed) {}
size_t operator()(const Fst<Arc> &fst, StateId s) const {
// Finds total weight leaving state.
auto sum = Log64Weight::Zero();
ArcIterator<Fst<Arc>> aiter(fst, s);
for (; !aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
sum = Plus(sum, to_log_weight_(arc.weight));
}
sum = Plus(sum, to_log_weight_(fst.Final(s)));
const double threshold =
std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
auto p = Log64Weight::Zero();
size_t n = 0;
for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) {
p = Plus(p, to_log_weight_(aiter.Value().weight));
if (exp(-p.Value()) > threshold) return n;
}
return n;
}
uint64_t Seed() const { return seed_; }
protected:
Log64Weight ToLogWeight(const Weight &weight) const {
return to_log_weight_(weight);
}
std::mt19937_64 &MutableRand() const { return rand_; }
private:
const uint64_t seed_;
mutable std::mt19937_64 rand_;
const WeightConvert<Weight, Log64Weight> to_log_weight_{};
};
// Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight
// accumulation computations. This class is not thread-safe.
template <class Arc>
class FastLogProbArcSelector : public LogProbArcSelector<Arc> {
public:
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using LogProbArcSelector<Arc>::MutableRand;
using LogProbArcSelector<Arc>::ToLogWeight;
using LogProbArcSelector<Arc>::operator();
// Constructs a selector with a non-deterministic seed.
FastLogProbArcSelector() : LogProbArcSelector<Arc>() {}
// Constructs a selector with a given seed.
explicit FastLogProbArcSelector(uint64_t seed)
: LogProbArcSelector<Arc>(seed) {}
size_t operator()(const Fst<Arc> &fst, StateId s,
CacheLogAccumulator<Arc> *accumulator) const {
accumulator->SetState(s);
ArcIterator<Fst<Arc>> aiter(fst, s);
// Finds total weight leaving state.
const double sum =
ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s)))
.Value();
const double r =
-log(std::uniform_real_distribution<>(0, 1)(MutableRand()));
Weight w = from_log_weight_(r + sum);
aiter.Reset();
return accumulator->LowerBound(w, &aiter);
}
private:
const WeightConvert<Log64Weight, Weight> from_log_weight_{};
};
// Random path state info maintained by RandGenFst and passed to samplers.
template <typename Arc>
struct RandState {
using StateId = typename Arc::StateId;
StateId state_id; // Current input FST state.
size_t nsamples; // Number of samples to be sampled at this state.
size_t length; // Length of path to this random state.
size_t select; // Previous sample arc selection.
const RandState<Arc> *parent; // Previous random state on this path.
explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0,
size_t select = 0, const RandState<Arc> *parent = nullptr)
: state_id(state_id),
nsamples(nsamples),
length(length),
select(select),
parent(parent) {}
RandState() : RandState(kNoStateId) {}
};
// This class, given an arc selector, samples, with replacement, multiple random
// transitions from an FST's state. This is a generic version with a
// straightforward use of the arc selector. Specializations may be defined for
// arc selectors for greater efficiency or special behavior.
template <class Arc, class Selector>
class ArcSampler {
public:
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
// The max_length argument may be interpreted (or ignored) by a selector as
// it chooses. This generic version interprets this literally.
ArcSampler(const Fst<Arc> &fst, const Selector &selector,
int32_t max_length = std::numeric_limits<int32_t>::max())
: fst_(fst), selector_(selector), max_length_(max_length) {}
// Allow updating FST argument; pass only if changed.
ArcSampler(const ArcSampler<Arc, Selector> &sampler,
const Fst<Arc> *fst = nullptr)
: fst_(fst ? *fst : sampler.fst_),
selector_(sampler.selector_),
max_length_(sampler.max_length_) {
Reset();
}
// Samples a fixed number of samples from the given state. The length argument
// specifies the length of the path to the state. Returns true if the samples
// were collected. No samples may be collected if either there are no
// transitions leaving the state and the state is non-final, or if the path
// length has been exceeded. Iterator members are provided to read the samples
// in the order in which they were collected.
bool Sample(const RandState<Arc> &rstate) {
sample_map_.clear();
if ((fst_.NumArcs(rstate.state_id) == 0 &&
fst_.Final(rstate.state_id) == Weight::Zero()) ||
rstate.length == max_length_) {
Reset();
return false;
}
for (size_t i = 0; i < rstate.nsamples; ++i) {
++sample_map_[selector_(fst_, rstate.state_id)];
}
Reset();
return true;
}
// More samples?
bool Done() const { return sample_iter_ == sample_map_.end(); }
// Gets the next sample.
void Next() { ++sample_iter_; }
std::pair<size_t, size_t> Value() const { return *sample_iter_; }
void Reset() { sample_iter_ = sample_map_.begin(); }
bool Error() const { return false; }
private:
const Fst<Arc> &fst_;
const Selector &selector_;
const int32_t max_length_;
// Stores (N, K) as described for Value().
std::map<size_t, size_t> sample_map_;
std::map<size_t, size_t>::const_iterator sample_iter_;
ArcSampler<Arc, Selector> &operator=(const ArcSampler &) = delete;
};
// Samples one sample of num_to_sample dimensions from a multinomial
// distribution parameterized by a vector of probabilities. The result
// container should be pre-initialized (e.g., an empty map or a zeroed vector
// sized the same as the vector of probabilities.
// probs.size()).
template <class Result, class RNG>
void OneMultinomialSample(const std::vector<double> &probs,
size_t num_to_sample, Result *result, RNG *rng) {
using distribution = std::binomial_distribution<size_t>;
// Left-over probability mass. Keep an array of the partial sums because
// keeping a scalar and modifying norm -= probs[i] in the loop will result
// in round-off error and can have probs[i] > norm.
std::vector<double> norm(probs.size());
std::partial_sum(probs.rbegin(), probs.rend(), norm.rbegin());
// Left-over number of samples needed.
for (size_t i = 0; i < probs.size(); ++i) {
distribution::result_type num_sampled = 0;
if (probs[i] > 0) {
distribution d(num_to_sample, probs[i] / norm[i]);
num_sampled = d(*rng);
}
if (num_sampled != 0) (*result)[i] = num_sampled;
num_to_sample -= std::min(num_sampled, num_to_sample);
}
}
// Specialization for FastLogProbArcSelector.
template <class Arc>
class ArcSampler<Arc, FastLogProbArcSelector<Arc>> {
public:
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Accumulator = CacheLogAccumulator<Arc>;
using Selector = FastLogProbArcSelector<Arc>;
ArcSampler(const Fst<Arc> &fst, const Selector &selector,
int32_t max_length = std::numeric_limits<int32_t>::max())
: fst_(fst),
selector_(selector),
max_length_(max_length),
accumulator_(new Accumulator()) {
accumulator_->Init(fst);
rng_.seed(selector_.Seed());
}
ArcSampler(const ArcSampler<Arc, Selector> &sampler,
const Fst<Arc> *fst = nullptr)
: fst_(fst ? *fst : sampler.fst_),
selector_(sampler.selector_),
max_length_(sampler.max_length_) {
if (fst) {
accumulator_ = std::make_unique<Accumulator>();
accumulator_->Init(*fst);
} else { // Shallow copy.
accumulator_ = std::make_unique<Accumulator>(*sampler.accumulator_);
}
}
bool Sample(const RandState<Arc> &rstate) {
sample_map_.clear();
if ((fst_.NumArcs(rstate.state_id) == 0 &&
fst_.Final(rstate.state_id) == Weight::Zero()) ||
rstate.length == max_length_) {
Reset();
return false;
}
if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
MultinomialSample(rstate);
Reset();
return true;
}
for (size_t i = 0; i < rstate.nsamples; ++i) {
++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())];
}
Reset();
return true;
}
bool Done() const { return sample_iter_ == sample_map_.end(); }
void Next() { ++sample_iter_; }
std::pair<size_t, size_t> Value() const { return *sample_iter_; }
void Reset() { sample_iter_ = sample_map_.begin(); }
bool Error() const { return accumulator_->Error(); }
private:
using RNG = std::mt19937;
// Sample according to the multinomial distribution of rstate.nsamples draws
// from p_.
void MultinomialSample(const RandState<Arc> &rstate) {
p_.clear();
for (ArcIterator<Fst<Arc>> aiter(fst_, rstate.state_id); !aiter.Done();
aiter.Next()) {
p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
}
if (fst_.Final(rstate.state_id) != Weight::Zero()) {
p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
}
if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_);
} else {
for (size_t i = 0; i < p_.size(); ++i) {
sample_map_[i] = ceil(p_[i] * rstate.nsamples);
}
}
}
const Fst<Arc> &fst_;
const Selector &selector_;
const int32_t max_length_;
// Stores (N, K) for Value().
std::map<size_t, size_t> sample_map_;
std::map<size_t, size_t>::const_iterator sample_iter_;
std::unique_ptr<Accumulator> accumulator_;
RNG rng_; // Random number generator.
std::vector<double> p_; // Multinomial parameters.
const WeightConvert<Weight, Log64Weight> to_log_weight_{};
};
// Options for random path generation with RandGenFst. The template argument is
// a sampler, typically the class ArcSampler. Ownership of the sampler is taken
// by RandGenFst.
template <class Sampler>
struct RandGenFstOptions : public CacheOptions {
Sampler *sampler; // How to sample transitions at a state.
int32_t npath; // Number of paths to generate.
bool weighted; // Is the output tree weighted by path count, or
// is it just an unweighted DAG?
bool remove_total_weight; // Remove total weight when output is weighted.
RandGenFstOptions(const CacheOptions &opts, Sampler *sampler,
int32_t npath = 1, bool weighted = true,
bool remove_total_weight = false)
: CacheOptions(opts),
sampler(sampler),
npath(npath),
weighted(weighted),
remove_total_weight(remove_total_weight) {}
};
namespace internal {
// Implementation of RandGenFst.
template <class FromArc, class ToArc, class Sampler>
class RandGenFstImpl : public CacheImpl<ToArc> {
public:
using FstImpl<ToArc>::SetType;
using FstImpl<ToArc>::SetProperties;
using FstImpl<ToArc>::SetInputSymbols;
using FstImpl<ToArc>::SetOutputSymbols;
using CacheBaseImpl<CacheState<ToArc>>::EmplaceArc;
using CacheBaseImpl<CacheState<ToArc>>::HasArcs;
using CacheBaseImpl<CacheState<ToArc>>::HasFinal;
using CacheBaseImpl<CacheState<ToArc>>::HasStart;
using CacheBaseImpl<CacheState<ToArc>>::SetArcs;
using CacheBaseImpl<CacheState<ToArc>>::SetFinal;
using CacheBaseImpl<CacheState<ToArc>>::SetStart;
using Label = typename FromArc::Label;
using StateId = typename FromArc::StateId;
using FromWeight = typename FromArc::Weight;
using ToWeight = typename ToArc::Weight;
RandGenFstImpl(const Fst<FromArc> &fst,
const RandGenFstOptions<Sampler> &opts)
: CacheImpl<ToArc>(opts),
fst_(fst.Copy()),
sampler_(opts.sampler),
npath_(opts.npath),
weighted_(opts.weighted),
remove_total_weight_(opts.remove_total_weight),
superfinal_(kNoLabel) {
SetType("randgen");
SetProperties(
RandGenProperties(fst.Properties(kFstProperties, false), weighted_),
kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
RandGenFstImpl(const RandGenFstImpl &impl)
: CacheImpl<ToArc>(impl),
fst_(impl.fst_->Copy(true)),
sampler_(new Sampler(*impl.sampler_, fst_.get())),
npath_(impl.npath_),
weighted_(impl.weighted_),
superfinal_(kNoLabel) {
SetType("randgen");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
StateId Start() {
if (!HasStart()) {
const auto s = fst_->Start();
if (s == kNoStateId) return kNoStateId;
SetStart(state_table_.size());
state_table_.emplace_back(
new RandState<FromArc>(s, npath_, 0, 0, nullptr));
}
return CacheImpl<ToArc>::Start();
}
ToWeight Final(StateId s) {
if (!HasFinal(s)) Expand(s);
return CacheImpl<ToArc>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<ToArc>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<ToArc>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<ToArc>::NumOutputEpsilons(s);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) &&
(fst_->Properties(kError, false) || sampler_->Error())) {
SetProperties(kError, kError);
}
return FstImpl<ToArc>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) {
if (!HasArcs(s)) Expand(s);
CacheImpl<ToArc>::InitArcIterator(s, data);
}
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) {
if (s == superfinal_) {
SetFinal(s);
SetArcs(s);
return;
}
SetFinal(s, ToWeight::Zero());
const auto &rstate = *state_table_[s];
sampler_->Sample(rstate);
ArcIterator<Fst<FromArc>> aiter(*fst_, rstate.state_id);
const auto narcs = fst_->NumArcs(rstate.state_id);
for (; !sampler_->Done(); sampler_->Next()) {
const auto &sample_pair = sampler_->Value();
const auto pos = sample_pair.first;
const auto count = sample_pair.second;
double prob = static_cast<double>(count) / rstate.nsamples;
if (pos < narcs) { // Regular transition.
aiter.Seek(sample_pair.first);
const auto &aarc = aiter.Value();
auto weight =
weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One();
EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
state_table_.size());
auto nrstate = std::make_unique<RandState<FromArc>>(
aarc.nextstate, count, rstate.length + 1, pos, &rstate);
state_table_.push_back(std::move(nrstate));
} else { // Super-final transition.
if (weighted_) {
const auto weight =
remove_total_weight_
? to_weight_(Log64Weight(-log(prob)))
: to_weight_(Log64Weight(-log(prob * npath_)));
SetFinal(s, weight);
} else {
if (superfinal_ == kNoLabel) {
superfinal_ = state_table_.size();
state_table_.emplace_back(
new RandState<FromArc>(kNoStateId, 0, 0, 0, nullptr));
}
for (size_t n = 0; n < count; ++n) EmplaceArc(s, 0, 0, superfinal_);
}
}
}
SetArcs(s);
}
private:
const std::unique_ptr<Fst<FromArc>> fst_;
std::unique_ptr<Sampler> sampler_;
const int32_t npath_;
std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
const bool weighted_;
bool remove_total_weight_;
StateId superfinal_;
const WeightConvert<Log64Weight, ToWeight> to_weight_{};
};
} // namespace internal
// FST class to randomly generate paths through an FST, with details controlled
// by RandGenOptionsFst. Output format is a tree weighted by the path count.
template <class FromArc, class ToArc, class Sampler>
class RandGenFst
: public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
public:
using Label = typename FromArc::Label;
using StateId = typename FromArc::StateId;
using Weight = typename FromArc::Weight;
using Store = DefaultCacheStore<FromArc>;
using State = typename Store::State;
using Impl = internal::RandGenFstImpl<FromArc, ToArc, Sampler>;
friend class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>;
friend class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>;
RandGenFst(const Fst<FromArc> &fst, const RandGenFstOptions<Sampler> &opts)
: ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
// See Fst<>::Copy() for doc.
RandGenFst(const RandGenFst &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
RandGenFst *Copy(bool safe = false) const override {
return new RandGenFst(*this, safe);
}
inline void InitStateIterator(StateIteratorData<ToArc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) const override {
GetMutableImpl()->InitArcIterator(s, data);
}
private:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
RandGenFst &operator=(const RandGenFst &) = delete;
};
// Specialization for RandGenFst.
template <class FromArc, class ToArc, class Sampler>
class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>
: public CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>> {
public:
explicit StateIterator(const RandGenFst<FromArc, ToArc, Sampler> &fst)
: CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>>(
fst, fst.GetMutableImpl()) {}
};
// Specialization for RandGenFst.
template <class FromArc, class ToArc, class Sampler>
class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>
: public CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>> {
public:
using StateId = typename FromArc::StateId;
ArcIterator(const RandGenFst<FromArc, ToArc, Sampler> &fst, StateId s)
: CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>>(
fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class FromArc, class ToArc, class Sampler>
inline void RandGenFst<FromArc, ToArc, Sampler>::InitStateIterator(
StateIteratorData<ToArc> *data) const {
data->base =
std::make_unique<StateIterator<RandGenFst<FromArc, ToArc, Sampler>>>(
*this);
}
// Options for random path generation.
template <class Selector>
struct RandGenOptions {
const Selector &selector; // How an arc is selected at a state.
int32_t max_length; // Maximum path length.
int32_t npath; // Number of paths to generate.
bool weighted; // Is the output tree weighted by path count, or
// is it just an unweighted DAG?
bool remove_total_weight; // Remove total weight when output is weighted?
explicit RandGenOptions(
const Selector &selector,
int32_t max_length = std::numeric_limits<int32_t>::max(),
int32_t npath = 1, bool weighted = false,
bool remove_total_weight = false)
: selector(selector),
max_length(max_length),
npath(npath),
weighted(weighted),
remove_total_weight(remove_total_weight) {}
};
namespace internal {
template <class FromArc, class ToArc>
class RandGenVisitor {
public:
using StateId = typename FromArc::StateId;
using Weight = typename FromArc::Weight;
explicit RandGenVisitor(MutableFst<ToArc> *ofst) : ofst_(ofst) {}
void InitVisit(const Fst<FromArc> &ifst) {
ifst_ = &ifst;
ofst_->DeleteStates();
ofst_->SetInputSymbols(ifst.InputSymbols());
ofst_->SetOutputSymbols(ifst.OutputSymbols());
if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError);
path_.clear();
}
constexpr bool InitState(StateId, StateId) const { return true; }
bool TreeArc(StateId, const ToArc &arc) {
if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
path_.push_back(arc);
} else {
OutputPath();
}
return true;
}
bool BackArc(StateId, const FromArc &) {
FSTERROR() << "RandGenVisitor: cyclic input";
ofst_->SetProperties(kError, kError);
return false;
}
bool ForwardOrCrossArc(StateId, const FromArc &) {
OutputPath();
return true;
}
void FinishState(StateId s, StateId p, const FromArc *) {
if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
}
void FinishVisit() {}
private:
void OutputPath() {
if (ofst_->Start() == kNoStateId) {
const auto start = ofst_->AddState();
ofst_->SetStart(start);
}
auto src = ofst_->Start();
for (size_t i = 0; i < path_.size(); ++i) {
const auto dest = ofst_->AddState();
const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
ofst_->AddArc(src, arc);
src = dest;
}
ofst_->SetFinal(src);
}
const Fst<FromArc> *ifst_;
MutableFst<ToArc> *ofst_;
std::vector<ToArc> path_;
RandGenVisitor(const RandGenVisitor &) = delete;
RandGenVisitor &operator=(const RandGenVisitor &) = delete;
};
} // namespace internal
// Randomly generate paths through an FST; details controlled by
// RandGenOptions.
template <class FromArc, class ToArc, class Selector>
void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
const RandGenOptions<Selector> &opts) {
using Sampler = ArcSampler<FromArc, Selector>;
auto sampler =
std::make_unique<Sampler>(ifst, opts.selector, opts.max_length);
RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), sampler.release(),
opts.npath, opts.weighted,
opts.remove_total_weight);
RandGenFst<FromArc, ToArc, Sampler> rfst(ifst, fopts);
if (opts.weighted) {
*ofst = rfst;
} else {
internal::RandGenVisitor<FromArc, ToArc> rand_visitor(ofst);
DfsVisit(rfst, &rand_visitor);
}
}
// Randomly generate a path through an FST with the uniform distribution
// over the transitions.
template <class FromArc, class ToArc>
void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
uint64_t seed = std::random_device()()) {
const UniformArcSelector<FromArc> uniform_selector(seed);
RandGenOptions<UniformArcSelector<ToArc>> opts(uniform_selector);
RandGen(ifst, ofst, opts);
}
} // namespace fst
#endif // FST_RANDGEN_H_