You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

780 lines
26 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Classes and functions to generate random paths through an FST.
  19. #ifndef FST_RANDGEN_H_
  20. #define FST_RANDGEN_H_
  21. #include <algorithm>
  22. #include <cmath>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <cstring>
  26. #include <functional>
  27. #include <limits>
  28. #include <map>
  29. #include <memory>
  30. #include <numeric>
  31. #include <random>
  32. #include <utility>
  33. #include <vector>
  34. #include <fst/log.h>
  35. #include <fst/accumulator.h>
  36. #include <fst/arc.h>
  37. #include <fst/cache.h>
  38. #include <fst/dfs-visit.h>
  39. #include <fst/float-weight.h>
  40. #include <fst/fst-decl.h>
  41. #include <fst/fst.h>
  42. #include <fst/impl-to-fst.h>
  43. #include <fst/mutable-fst.h>
  44. #include <fst/properties.h>
  45. #include <fst/util.h>
  46. #include <fst/weight.h>
  47. #include <vector>
  48. namespace fst {
  49. // The RandGenFst class is roughly similar to ArcMapFst in that it takes two
  50. // template parameters denoting the input and output arc types. However, it also
  51. // takes an additional template parameter which specifies a sampler object which
  52. // samples (with replacement) arcs from an FST state. The sampler in turn takes
  53. // a template parameter for a selector object which actually chooses the arc.
  54. //
  55. // Arc selector functors are used to select a random transition given an FST
  56. // state s, returning a number N such that 0 <= N <= NumArcs(s). If N is
  57. // NumArcs(s), then the final weight is selected; otherwise the N-th arc is
  58. // selected. It is assumed these are not applied to any state which is neither
  59. // final nor has any arcs leaving it.
  60. // Randomly selects a transition using the uniform distribution. This class is
  61. // not thread-safe.
  62. template <class Arc>
  63. class UniformArcSelector {
  64. public:
  65. using StateId = typename Arc::StateId;
  66. using Weight = typename Arc::Weight;
  67. explicit UniformArcSelector(uint64_t seed = std::random_device()())
  68. : rand_(seed) {}
  69. size_t operator()(const Fst<Arc> &fst, StateId s) const {
  70. const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero());
  71. return static_cast<size_t>(
  72. std::uniform_int_distribution<>(0, n - 1)(rand_));
  73. }
  74. private:
  75. mutable std::mt19937_64 rand_;
  76. };
  77. // Randomly selects a transition w.r.t. the weights treated as negative log
  78. // probabilities after normalizing for the total weight leaving the state. Zero
  79. // transitions are disregarded. It assumed that Arc::Weight::Value() accesses
  80. // the floating point representation of the weight. This class is not
  81. // thread-safe.
  82. template <class Arc>
  83. class LogProbArcSelector {
  84. public:
  85. using StateId = typename Arc::StateId;
  86. using Weight = typename Arc::Weight;
  87. // Constructs a selector with a non-deterministic seed.
  88. LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {}
  89. // Constructs a selector with a given seed.
  90. explicit LogProbArcSelector(uint64_t seed) : seed_(seed), rand_(seed) {}
  91. size_t operator()(const Fst<Arc> &fst, StateId s) const {
  92. // Finds total weight leaving state.
  93. auto sum = Log64Weight::Zero();
  94. ArcIterator<Fst<Arc>> aiter(fst, s);
  95. for (; !aiter.Done(); aiter.Next()) {
  96. const auto &arc = aiter.Value();
  97. sum = Plus(sum, to_log_weight_(arc.weight));
  98. }
  99. sum = Plus(sum, to_log_weight_(fst.Final(s)));
  100. const double threshold =
  101. std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_);
  102. auto p = Log64Weight::Zero();
  103. size_t n = 0;
  104. for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) {
  105. p = Plus(p, to_log_weight_(aiter.Value().weight));
  106. if (exp(-p.Value()) > threshold) return n;
  107. }
  108. return n;
  109. }
  110. uint64_t Seed() const { return seed_; }
  111. protected:
  112. Log64Weight ToLogWeight(const Weight &weight) const {
  113. return to_log_weight_(weight);
  114. }
  115. std::mt19937_64 &MutableRand() const { return rand_; }
  116. private:
  117. const uint64_t seed_;
  118. mutable std::mt19937_64 rand_;
  119. const WeightConvert<Weight, Log64Weight> to_log_weight_{};
  120. };
  121. // Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight
  122. // accumulation computations. This class is not thread-safe.
  123. template <class Arc>
  124. class FastLogProbArcSelector : public LogProbArcSelector<Arc> {
  125. public:
  126. using StateId = typename Arc::StateId;
  127. using Weight = typename Arc::Weight;
  128. using LogProbArcSelector<Arc>::MutableRand;
  129. using LogProbArcSelector<Arc>::ToLogWeight;
  130. using LogProbArcSelector<Arc>::operator();
  131. // Constructs a selector with a non-deterministic seed.
  132. FastLogProbArcSelector() : LogProbArcSelector<Arc>() {}
  133. // Constructs a selector with a given seed.
  134. explicit FastLogProbArcSelector(uint64_t seed)
  135. : LogProbArcSelector<Arc>(seed) {}
  136. size_t operator()(const Fst<Arc> &fst, StateId s,
  137. CacheLogAccumulator<Arc> *accumulator) const {
  138. accumulator->SetState(s);
  139. ArcIterator<Fst<Arc>> aiter(fst, s);
  140. // Finds total weight leaving state.
  141. const double sum =
  142. ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s)))
  143. .Value();
  144. const double r =
  145. -log(std::uniform_real_distribution<>(0, 1)(MutableRand()));
  146. Weight w = from_log_weight_(r + sum);
  147. aiter.Reset();
  148. return accumulator->LowerBound(w, &aiter);
  149. }
  150. private:
  151. const WeightConvert<Log64Weight, Weight> from_log_weight_{};
  152. };
  153. // Random path state info maintained by RandGenFst and passed to samplers.
  154. template <typename Arc>
  155. struct RandState {
  156. using StateId = typename Arc::StateId;
  157. StateId state_id; // Current input FST state.
  158. size_t nsamples; // Number of samples to be sampled at this state.
  159. size_t length; // Length of path to this random state.
  160. size_t select; // Previous sample arc selection.
  161. const RandState<Arc> *parent; // Previous random state on this path.
  162. explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0,
  163. size_t select = 0, const RandState<Arc> *parent = nullptr)
  164. : state_id(state_id),
  165. nsamples(nsamples),
  166. length(length),
  167. select(select),
  168. parent(parent) {}
  169. RandState() : RandState(kNoStateId) {}
  170. };
  171. // This class, given an arc selector, samples, with replacement, multiple random
  172. // transitions from an FST's state. This is a generic version with a
  173. // straightforward use of the arc selector. Specializations may be defined for
  174. // arc selectors for greater efficiency or special behavior.
  175. template <class Arc, class Selector>
  176. class ArcSampler {
  177. public:
  178. using StateId = typename Arc::StateId;
  179. using Weight = typename Arc::Weight;
  180. // The max_length argument may be interpreted (or ignored) by a selector as
  181. // it chooses. This generic version interprets this literally.
  182. ArcSampler(const Fst<Arc> &fst, const Selector &selector,
  183. int32_t max_length = std::numeric_limits<int32_t>::max())
  184. : fst_(fst), selector_(selector), max_length_(max_length) {}
  185. // Allow updating FST argument; pass only if changed.
  186. ArcSampler(const ArcSampler<Arc, Selector> &sampler,
  187. const Fst<Arc> *fst = nullptr)
  188. : fst_(fst ? *fst : sampler.fst_),
  189. selector_(sampler.selector_),
  190. max_length_(sampler.max_length_) {
  191. Reset();
  192. }
  193. // Samples a fixed number of samples from the given state. The length argument
  194. // specifies the length of the path to the state. Returns true if the samples
  195. // were collected. No samples may be collected if either there are no
  196. // transitions leaving the state and the state is non-final, or if the path
  197. // length has been exceeded. Iterator members are provided to read the samples
  198. // in the order in which they were collected.
  199. bool Sample(const RandState<Arc> &rstate) {
  200. sample_map_.clear();
  201. if ((fst_.NumArcs(rstate.state_id) == 0 &&
  202. fst_.Final(rstate.state_id) == Weight::Zero()) ||
  203. rstate.length == max_length_) {
  204. Reset();
  205. return false;
  206. }
  207. for (size_t i = 0; i < rstate.nsamples; ++i) {
  208. ++sample_map_[selector_(fst_, rstate.state_id)];
  209. }
  210. Reset();
  211. return true;
  212. }
  213. // More samples?
  214. bool Done() const { return sample_iter_ == sample_map_.end(); }
  215. // Gets the next sample.
  216. void Next() { ++sample_iter_; }
  217. std::pair<size_t, size_t> Value() const { return *sample_iter_; }
  218. void Reset() { sample_iter_ = sample_map_.begin(); }
  219. bool Error() const { return false; }
  220. private:
  221. const Fst<Arc> &fst_;
  222. const Selector &selector_;
  223. const int32_t max_length_;
  224. // Stores (N, K) as described for Value().
  225. std::map<size_t, size_t> sample_map_;
  226. std::map<size_t, size_t>::const_iterator sample_iter_;
  227. ArcSampler<Arc, Selector> &operator=(const ArcSampler &) = delete;
  228. };
  229. // Samples one sample of num_to_sample dimensions from a multinomial
  230. // distribution parameterized by a vector of probabilities. The result
  231. // container should be pre-initialized (e.g., an empty map or a zeroed vector
  232. // sized the same as the vector of probabilities.
  233. // probs.size()).
  234. template <class Result, class RNG>
  235. void OneMultinomialSample(const std::vector<double> &probs,
  236. size_t num_to_sample, Result *result, RNG *rng) {
  237. using distribution = std::binomial_distribution<size_t>;
  238. // Left-over probability mass. Keep an array of the partial sums because
  239. // keeping a scalar and modifying norm -= probs[i] in the loop will result
  240. // in round-off error and can have probs[i] > norm.
  241. std::vector<double> norm(probs.size());
  242. std::partial_sum(probs.rbegin(), probs.rend(), norm.rbegin());
  243. // Left-over number of samples needed.
  244. for (size_t i = 0; i < probs.size(); ++i) {
  245. distribution::result_type num_sampled = 0;
  246. if (probs[i] > 0) {
  247. distribution d(num_to_sample, probs[i] / norm[i]);
  248. num_sampled = d(*rng);
  249. }
  250. if (num_sampled != 0) (*result)[i] = num_sampled;
  251. num_to_sample -= std::min(num_sampled, num_to_sample);
  252. }
  253. }
  254. // Specialization for FastLogProbArcSelector.
  255. template <class Arc>
  256. class ArcSampler<Arc, FastLogProbArcSelector<Arc>> {
  257. public:
  258. using StateId = typename Arc::StateId;
  259. using Weight = typename Arc::Weight;
  260. using Accumulator = CacheLogAccumulator<Arc>;
  261. using Selector = FastLogProbArcSelector<Arc>;
  262. ArcSampler(const Fst<Arc> &fst, const Selector &selector,
  263. int32_t max_length = std::numeric_limits<int32_t>::max())
  264. : fst_(fst),
  265. selector_(selector),
  266. max_length_(max_length),
  267. accumulator_(new Accumulator()) {
  268. accumulator_->Init(fst);
  269. rng_.seed(selector_.Seed());
  270. }
  271. ArcSampler(const ArcSampler<Arc, Selector> &sampler,
  272. const Fst<Arc> *fst = nullptr)
  273. : fst_(fst ? *fst : sampler.fst_),
  274. selector_(sampler.selector_),
  275. max_length_(sampler.max_length_) {
  276. if (fst) {
  277. accumulator_ = std::make_unique<Accumulator>();
  278. accumulator_->Init(*fst);
  279. } else { // Shallow copy.
  280. accumulator_ = std::make_unique<Accumulator>(*sampler.accumulator_);
  281. }
  282. }
  283. bool Sample(const RandState<Arc> &rstate) {
  284. sample_map_.clear();
  285. if ((fst_.NumArcs(rstate.state_id) == 0 &&
  286. fst_.Final(rstate.state_id) == Weight::Zero()) ||
  287. rstate.length == max_length_) {
  288. Reset();
  289. return false;
  290. }
  291. if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
  292. MultinomialSample(rstate);
  293. Reset();
  294. return true;
  295. }
  296. for (size_t i = 0; i < rstate.nsamples; ++i) {
  297. ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())];
  298. }
  299. Reset();
  300. return true;
  301. }
  302. bool Done() const { return sample_iter_ == sample_map_.end(); }
  303. void Next() { ++sample_iter_; }
  304. std::pair<size_t, size_t> Value() const { return *sample_iter_; }
  305. void Reset() { sample_iter_ = sample_map_.begin(); }
  306. bool Error() const { return accumulator_->Error(); }
  307. private:
  308. using RNG = std::mt19937;
  309. // Sample according to the multinomial distribution of rstate.nsamples draws
  310. // from p_.
  311. void MultinomialSample(const RandState<Arc> &rstate) {
  312. p_.clear();
  313. for (ArcIterator<Fst<Arc>> aiter(fst_, rstate.state_id); !aiter.Done();
  314. aiter.Next()) {
  315. p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value()));
  316. }
  317. if (fst_.Final(rstate.state_id) != Weight::Zero()) {
  318. p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
  319. }
  320. if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
  321. OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_);
  322. } else {
  323. for (size_t i = 0; i < p_.size(); ++i) {
  324. sample_map_[i] = ceil(p_[i] * rstate.nsamples);
  325. }
  326. }
  327. }
  328. const Fst<Arc> &fst_;
  329. const Selector &selector_;
  330. const int32_t max_length_;
  331. // Stores (N, K) for Value().
  332. std::map<size_t, size_t> sample_map_;
  333. std::map<size_t, size_t>::const_iterator sample_iter_;
  334. std::unique_ptr<Accumulator> accumulator_;
  335. RNG rng_; // Random number generator.
  336. std::vector<double> p_; // Multinomial parameters.
  337. const WeightConvert<Weight, Log64Weight> to_log_weight_{};
  338. };
  339. // Options for random path generation with RandGenFst. The template argument is
  340. // a sampler, typically the class ArcSampler. Ownership of the sampler is taken
  341. // by RandGenFst.
  342. template <class Sampler>
  343. struct RandGenFstOptions : public CacheOptions {
  344. Sampler *sampler; // How to sample transitions at a state.
  345. int32_t npath; // Number of paths to generate.
  346. bool weighted; // Is the output tree weighted by path count, or
  347. // is it just an unweighted DAG?
  348. bool remove_total_weight; // Remove total weight when output is weighted.
  349. RandGenFstOptions(const CacheOptions &opts, Sampler *sampler,
  350. int32_t npath = 1, bool weighted = true,
  351. bool remove_total_weight = false)
  352. : CacheOptions(opts),
  353. sampler(sampler),
  354. npath(npath),
  355. weighted(weighted),
  356. remove_total_weight(remove_total_weight) {}
  357. };
  358. namespace internal {
  359. // Implementation of RandGenFst.
  360. template <class FromArc, class ToArc, class Sampler>
  361. class RandGenFstImpl : public CacheImpl<ToArc> {
  362. public:
  363. using FstImpl<ToArc>::SetType;
  364. using FstImpl<ToArc>::SetProperties;
  365. using FstImpl<ToArc>::SetInputSymbols;
  366. using FstImpl<ToArc>::SetOutputSymbols;
  367. using CacheBaseImpl<CacheState<ToArc>>::EmplaceArc;
  368. using CacheBaseImpl<CacheState<ToArc>>::HasArcs;
  369. using CacheBaseImpl<CacheState<ToArc>>::HasFinal;
  370. using CacheBaseImpl<CacheState<ToArc>>::HasStart;
  371. using CacheBaseImpl<CacheState<ToArc>>::SetArcs;
  372. using CacheBaseImpl<CacheState<ToArc>>::SetFinal;
  373. using CacheBaseImpl<CacheState<ToArc>>::SetStart;
  374. using Label = typename FromArc::Label;
  375. using StateId = typename FromArc::StateId;
  376. using FromWeight = typename FromArc::Weight;
  377. using ToWeight = typename ToArc::Weight;
  378. RandGenFstImpl(const Fst<FromArc> &fst,
  379. const RandGenFstOptions<Sampler> &opts)
  380. : CacheImpl<ToArc>(opts),
  381. fst_(fst.Copy()),
  382. sampler_(opts.sampler),
  383. npath_(opts.npath),
  384. weighted_(opts.weighted),
  385. remove_total_weight_(opts.remove_total_weight),
  386. superfinal_(kNoLabel) {
  387. SetType("randgen");
  388. SetProperties(
  389. RandGenProperties(fst.Properties(kFstProperties, false), weighted_),
  390. kCopyProperties);
  391. SetInputSymbols(fst.InputSymbols());
  392. SetOutputSymbols(fst.OutputSymbols());
  393. }
  394. RandGenFstImpl(const RandGenFstImpl &impl)
  395. : CacheImpl<ToArc>(impl),
  396. fst_(impl.fst_->Copy(true)),
  397. sampler_(new Sampler(*impl.sampler_, fst_.get())),
  398. npath_(impl.npath_),
  399. weighted_(impl.weighted_),
  400. superfinal_(kNoLabel) {
  401. SetType("randgen");
  402. SetProperties(impl.Properties(), kCopyProperties);
  403. SetInputSymbols(impl.InputSymbols());
  404. SetOutputSymbols(impl.OutputSymbols());
  405. }
  406. StateId Start() {
  407. if (!HasStart()) {
  408. const auto s = fst_->Start();
  409. if (s == kNoStateId) return kNoStateId;
  410. SetStart(state_table_.size());
  411. state_table_.emplace_back(
  412. new RandState<FromArc>(s, npath_, 0, 0, nullptr));
  413. }
  414. return CacheImpl<ToArc>::Start();
  415. }
  416. ToWeight Final(StateId s) {
  417. if (!HasFinal(s)) Expand(s);
  418. return CacheImpl<ToArc>::Final(s);
  419. }
  420. size_t NumArcs(StateId s) {
  421. if (!HasArcs(s)) Expand(s);
  422. return CacheImpl<ToArc>::NumArcs(s);
  423. }
  424. size_t NumInputEpsilons(StateId s) {
  425. if (!HasArcs(s)) Expand(s);
  426. return CacheImpl<ToArc>::NumInputEpsilons(s);
  427. }
  428. size_t NumOutputEpsilons(StateId s) {
  429. if (!HasArcs(s)) Expand(s);
  430. return CacheImpl<ToArc>::NumOutputEpsilons(s);
  431. }
  432. uint64_t Properties() const override { return Properties(kFstProperties); }
  433. // Sets error if found, and returns other FST impl properties.
  434. uint64_t Properties(uint64_t mask) const override {
  435. if ((mask & kError) &&
  436. (fst_->Properties(kError, false) || sampler_->Error())) {
  437. SetProperties(kError, kError);
  438. }
  439. return FstImpl<ToArc>::Properties(mask);
  440. }
  441. void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) {
  442. if (!HasArcs(s)) Expand(s);
  443. CacheImpl<ToArc>::InitArcIterator(s, data);
  444. }
  445. // Computes the outgoing transitions from a state, creating new destination
  446. // states as needed.
  447. void Expand(StateId s) {
  448. if (s == superfinal_) {
  449. SetFinal(s);
  450. SetArcs(s);
  451. return;
  452. }
  453. SetFinal(s, ToWeight::Zero());
  454. const auto &rstate = *state_table_[s];
  455. sampler_->Sample(rstate);
  456. ArcIterator<Fst<FromArc>> aiter(*fst_, rstate.state_id);
  457. const auto narcs = fst_->NumArcs(rstate.state_id);
  458. for (; !sampler_->Done(); sampler_->Next()) {
  459. const auto &sample_pair = sampler_->Value();
  460. const auto pos = sample_pair.first;
  461. const auto count = sample_pair.second;
  462. double prob = static_cast<double>(count) / rstate.nsamples;
  463. if (pos < narcs) { // Regular transition.
  464. aiter.Seek(sample_pair.first);
  465. const auto &aarc = aiter.Value();
  466. auto weight =
  467. weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One();
  468. EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight),
  469. state_table_.size());
  470. auto nrstate = std::make_unique<RandState<FromArc>>(
  471. aarc.nextstate, count, rstate.length + 1, pos, &rstate);
  472. state_table_.push_back(std::move(nrstate));
  473. } else { // Super-final transition.
  474. if (weighted_) {
  475. const auto weight =
  476. remove_total_weight_
  477. ? to_weight_(Log64Weight(-log(prob)))
  478. : to_weight_(Log64Weight(-log(prob * npath_)));
  479. SetFinal(s, weight);
  480. } else {
  481. if (superfinal_ == kNoLabel) {
  482. superfinal_ = state_table_.size();
  483. state_table_.emplace_back(
  484. new RandState<FromArc>(kNoStateId, 0, 0, 0, nullptr));
  485. }
  486. for (size_t n = 0; n < count; ++n) EmplaceArc(s, 0, 0, superfinal_);
  487. }
  488. }
  489. }
  490. SetArcs(s);
  491. }
  492. private:
  493. const std::unique_ptr<Fst<FromArc>> fst_;
  494. std::unique_ptr<Sampler> sampler_;
  495. const int32_t npath_;
  496. std::vector<std::unique_ptr<RandState<FromArc>>> state_table_;
  497. const bool weighted_;
  498. bool remove_total_weight_;
  499. StateId superfinal_;
  500. const WeightConvert<Log64Weight, ToWeight> to_weight_{};
  501. };
  502. } // namespace internal
  503. // FST class to randomly generate paths through an FST, with details controlled
  504. // by RandGenOptionsFst. Output format is a tree weighted by the path count.
  505. template <class FromArc, class ToArc, class Sampler>
  506. class RandGenFst
  507. : public ImplToFst<internal::RandGenFstImpl<FromArc, ToArc, Sampler>> {
  508. public:
  509. using Label = typename FromArc::Label;
  510. using StateId = typename FromArc::StateId;
  511. using Weight = typename FromArc::Weight;
  512. using Store = DefaultCacheStore<FromArc>;
  513. using State = typename Store::State;
  514. using Impl = internal::RandGenFstImpl<FromArc, ToArc, Sampler>;
  515. friend class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>;
  516. friend class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>;
  517. RandGenFst(const Fst<FromArc> &fst, const RandGenFstOptions<Sampler> &opts)
  518. : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
  519. // See Fst<>::Copy() for doc.
  520. RandGenFst(const RandGenFst &fst, bool safe = false)
  521. : ImplToFst<Impl>(fst, safe) {}
  522. // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
  523. RandGenFst *Copy(bool safe = false) const override {
  524. return new RandGenFst(*this, safe);
  525. }
  526. inline void InitStateIterator(StateIteratorData<ToArc> *data) const override;
  527. void InitArcIterator(StateId s, ArcIteratorData<ToArc> *data) const override {
  528. GetMutableImpl()->InitArcIterator(s, data);
  529. }
  530. private:
  531. using ImplToFst<Impl>::GetImpl;
  532. using ImplToFst<Impl>::GetMutableImpl;
  533. RandGenFst &operator=(const RandGenFst &) = delete;
  534. };
  535. // Specialization for RandGenFst.
  536. template <class FromArc, class ToArc, class Sampler>
  537. class StateIterator<RandGenFst<FromArc, ToArc, Sampler>>
  538. : public CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>> {
  539. public:
  540. explicit StateIterator(const RandGenFst<FromArc, ToArc, Sampler> &fst)
  541. : CacheStateIterator<RandGenFst<FromArc, ToArc, Sampler>>(
  542. fst, fst.GetMutableImpl()) {}
  543. };
  544. // Specialization for RandGenFst.
  545. template <class FromArc, class ToArc, class Sampler>
  546. class ArcIterator<RandGenFst<FromArc, ToArc, Sampler>>
  547. : public CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>> {
  548. public:
  549. using StateId = typename FromArc::StateId;
  550. ArcIterator(const RandGenFst<FromArc, ToArc, Sampler> &fst, StateId s)
  551. : CacheArcIterator<RandGenFst<FromArc, ToArc, Sampler>>(
  552. fst.GetMutableImpl(), s) {
  553. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  554. }
  555. };
  556. template <class FromArc, class ToArc, class Sampler>
  557. inline void RandGenFst<FromArc, ToArc, Sampler>::InitStateIterator(
  558. StateIteratorData<ToArc> *data) const {
  559. data->base =
  560. std::make_unique<StateIterator<RandGenFst<FromArc, ToArc, Sampler>>>(
  561. *this);
  562. }
  563. // Options for random path generation.
  564. template <class Selector>
  565. struct RandGenOptions {
  566. const Selector &selector; // How an arc is selected at a state.
  567. int32_t max_length; // Maximum path length.
  568. int32_t npath; // Number of paths to generate.
  569. bool weighted; // Is the output tree weighted by path count, or
  570. // is it just an unweighted DAG?
  571. bool remove_total_weight; // Remove total weight when output is weighted?
  572. explicit RandGenOptions(
  573. const Selector &selector,
  574. int32_t max_length = std::numeric_limits<int32_t>::max(),
  575. int32_t npath = 1, bool weighted = false,
  576. bool remove_total_weight = false)
  577. : selector(selector),
  578. max_length(max_length),
  579. npath(npath),
  580. weighted(weighted),
  581. remove_total_weight(remove_total_weight) {}
  582. };
  583. namespace internal {
  584. template <class FromArc, class ToArc>
  585. class RandGenVisitor {
  586. public:
  587. using StateId = typename FromArc::StateId;
  588. using Weight = typename FromArc::Weight;
  589. explicit RandGenVisitor(MutableFst<ToArc> *ofst) : ofst_(ofst) {}
  590. void InitVisit(const Fst<FromArc> &ifst) {
  591. ifst_ = &ifst;
  592. ofst_->DeleteStates();
  593. ofst_->SetInputSymbols(ifst.InputSymbols());
  594. ofst_->SetOutputSymbols(ifst.OutputSymbols());
  595. if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError);
  596. path_.clear();
  597. }
  598. constexpr bool InitState(StateId, StateId) const { return true; }
  599. bool TreeArc(StateId, const ToArc &arc) {
  600. if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
  601. path_.push_back(arc);
  602. } else {
  603. OutputPath();
  604. }
  605. return true;
  606. }
  607. bool BackArc(StateId, const FromArc &) {
  608. FSTERROR() << "RandGenVisitor: cyclic input";
  609. ofst_->SetProperties(kError, kError);
  610. return false;
  611. }
  612. bool ForwardOrCrossArc(StateId, const FromArc &) {
  613. OutputPath();
  614. return true;
  615. }
  616. void FinishState(StateId s, StateId p, const FromArc *) {
  617. if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back();
  618. }
  619. void FinishVisit() {}
  620. private:
  621. void OutputPath() {
  622. if (ofst_->Start() == kNoStateId) {
  623. const auto start = ofst_->AddState();
  624. ofst_->SetStart(start);
  625. }
  626. auto src = ofst_->Start();
  627. for (size_t i = 0; i < path_.size(); ++i) {
  628. const auto dest = ofst_->AddState();
  629. const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
  630. ofst_->AddArc(src, arc);
  631. src = dest;
  632. }
  633. ofst_->SetFinal(src);
  634. }
  635. const Fst<FromArc> *ifst_;
  636. MutableFst<ToArc> *ofst_;
  637. std::vector<ToArc> path_;
  638. RandGenVisitor(const RandGenVisitor &) = delete;
  639. RandGenVisitor &operator=(const RandGenVisitor &) = delete;
  640. };
  641. } // namespace internal
  642. // Randomly generate paths through an FST; details controlled by
  643. // RandGenOptions.
  644. template <class FromArc, class ToArc, class Selector>
  645. void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
  646. const RandGenOptions<Selector> &opts) {
  647. using Sampler = ArcSampler<FromArc, Selector>;
  648. auto sampler =
  649. std::make_unique<Sampler>(ifst, opts.selector, opts.max_length);
  650. RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), sampler.release(),
  651. opts.npath, opts.weighted,
  652. opts.remove_total_weight);
  653. RandGenFst<FromArc, ToArc, Sampler> rfst(ifst, fopts);
  654. if (opts.weighted) {
  655. *ofst = rfst;
  656. } else {
  657. internal::RandGenVisitor<FromArc, ToArc> rand_visitor(ofst);
  658. DfsVisit(rfst, &rand_visitor);
  659. }
  660. }
  661. // Randomly generate a path through an FST with the uniform distribution
  662. // over the transitions.
  663. template <class FromArc, class ToArc>
  664. void RandGen(const Fst<FromArc> &ifst, MutableFst<ToArc> *ofst,
  665. uint64_t seed = std::random_device()()) {
  666. const UniformArcSelector<FromArc> uniform_selector(seed);
  667. RandGenOptions<UniformArcSelector<ToArc>> opts(uniform_selector);
  668. RandGen(ifst, ofst, opts);
  669. }
  670. } // namespace fst
  671. #endif // FST_RANDGEN_H_