You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

115 lines
3.5 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#ifndef FST_TEST_RAND_FST_H_
#define FST_TEST_RAND_FST_H_
#include <cstddef>
#include <cstdint>
#include <random>
#include <fst/log.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/verify.h>
namespace fst {
// Generates a random FST.
template <class Arc, class Generate>
void RandFst(const int num_random_states, const int num_random_arcs,
const int num_random_labels, const float acyclic_prob,
Generate generate, uint64_t seed, MutableFst<Arc> *fst) {
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
// Determines direction of the arcs wrt state numbering. This way we
// can force acyclicity when desired.
enum ArcDirection {
ANY_DIRECTION = 0,
FORWARD_DIRECTION = 1,
REVERSE_DIRECTION = 2,
NUM_DIRECTIONS = 3
};
std::mt19937_64 rand(seed);
const StateId ns =
std::uniform_int_distribution<>(0, num_random_states - 1)(rand);
std::uniform_int_distribution<size_t> arc_dist(0, num_random_arcs - 1);
std::uniform_int_distribution<Label> label_dist(0, num_random_labels - 1);
std::uniform_int_distribution<StateId> ns_dist(0, ns - 1);
ArcDirection arc_direction = ANY_DIRECTION;
if (!std::bernoulli_distribution(acyclic_prob)(rand)) {
arc_direction = std::bernoulli_distribution(.5)(rand) ? FORWARD_DIRECTION
: REVERSE_DIRECTION;
}
fst->DeleteStates();
if (ns == 0) return;
fst->AddStates(ns);
const StateId start = ns_dist(rand);
fst->SetStart(start);
const size_t na = arc_dist(rand);
for (size_t n = 0; n < na; ++n) {
StateId s = ns_dist(rand);
Arc arc;
arc.ilabel = label_dist(rand);
arc.olabel = label_dist(rand);
arc.weight = generate();
arc.nextstate = ns_dist(rand);
if ((arc_direction == FORWARD_DIRECTION ||
arc_direction == REVERSE_DIRECTION) &&
s == arc.nextstate) {
continue; // Skips self-loops.
}
if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate) ||
(arc_direction == REVERSE_DIRECTION && s < arc.nextstate)) {
StateId t = s; // reverses arcs
s = arc.nextstate;
arc.nextstate = t;
}
fst->AddArc(s, arc);
}
const StateId nf = std::uniform_int_distribution<>(0, ns)(rand);
for (StateId n = 0; n < nf; ++n) {
const StateId s = ns_dist(rand);
fst->SetFinal(s, generate());
}
VLOG(1) << "Check FST for sanity (including property bits).";
CHECK(Verify(*fst));
// Get/compute all properties.
const uint64_t props = fst->Properties(kFstProperties, true);
// Select random set of properties to be unknown.
uint64_t mask = 0;
for (int n = 0; n < 8; ++n) {
mask |= std::uniform_int_distribution<>(0, 0xff)(rand);
mask <<= 8;
}
mask &= ~kTrinaryProperties;
fst->SetProperties(props & ~mask, mask);
}
} // namespace fst
#endif // FST_TEST_RAND_FST_H_