You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. #ifndef FST_TEST_RAND_FST_H_
  16. #define FST_TEST_RAND_FST_H_
  17. #include <cstddef>
  18. #include <cstdint>
  19. #include <random>
  20. #include <fst/log.h>
  21. #include <fst/mutable-fst.h>
  22. #include <fst/properties.h>
  23. #include <fst/verify.h>
  24. namespace fst {
  25. // Generates a random FST.
  26. template <class Arc, class Generate>
  27. void RandFst(const int num_random_states, const int num_random_arcs,
  28. const int num_random_labels, const float acyclic_prob,
  29. Generate generate, uint64_t seed, MutableFst<Arc> *fst) {
  30. using Label = typename Arc::Label;
  31. using StateId = typename Arc::StateId;
  32. using Weight = typename Arc::Weight;
  33. // Determines direction of the arcs wrt state numbering. This way we
  34. // can force acyclicity when desired.
  35. enum ArcDirection {
  36. ANY_DIRECTION = 0,
  37. FORWARD_DIRECTION = 1,
  38. REVERSE_DIRECTION = 2,
  39. NUM_DIRECTIONS = 3
  40. };
  41. std::mt19937_64 rand(seed);
  42. const StateId ns =
  43. std::uniform_int_distribution<>(0, num_random_states - 1)(rand);
  44. std::uniform_int_distribution<size_t> arc_dist(0, num_random_arcs - 1);
  45. std::uniform_int_distribution<Label> label_dist(0, num_random_labels - 1);
  46. std::uniform_int_distribution<StateId> ns_dist(0, ns - 1);
  47. ArcDirection arc_direction = ANY_DIRECTION;
  48. if (!std::bernoulli_distribution(acyclic_prob)(rand)) {
  49. arc_direction = std::bernoulli_distribution(.5)(rand) ? FORWARD_DIRECTION
  50. : REVERSE_DIRECTION;
  51. }
  52. fst->DeleteStates();
  53. if (ns == 0) return;
  54. fst->AddStates(ns);
  55. const StateId start = ns_dist(rand);
  56. fst->SetStart(start);
  57. const size_t na = arc_dist(rand);
  58. for (size_t n = 0; n < na; ++n) {
  59. StateId s = ns_dist(rand);
  60. Arc arc;
  61. arc.ilabel = label_dist(rand);
  62. arc.olabel = label_dist(rand);
  63. arc.weight = generate();
  64. arc.nextstate = ns_dist(rand);
  65. if ((arc_direction == FORWARD_DIRECTION ||
  66. arc_direction == REVERSE_DIRECTION) &&
  67. s == arc.nextstate) {
  68. continue; // Skips self-loops.
  69. }
  70. if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate) ||
  71. (arc_direction == REVERSE_DIRECTION && s < arc.nextstate)) {
  72. StateId t = s; // reverses arcs
  73. s = arc.nextstate;
  74. arc.nextstate = t;
  75. }
  76. fst->AddArc(s, arc);
  77. }
  78. const StateId nf = std::uniform_int_distribution<>(0, ns)(rand);
  79. for (StateId n = 0; n < nf; ++n) {
  80. const StateId s = ns_dist(rand);
  81. fst->SetFinal(s, generate());
  82. }
  83. VLOG(1) << "Check FST for sanity (including property bits).";
  84. CHECK(Verify(*fst));
  85. // Get/compute all properties.
  86. const uint64_t props = fst->Properties(kFstProperties, true);
  87. // Select random set of properties to be unknown.
  88. uint64_t mask = 0;
  89. for (int n = 0; n < 8; ++n) {
  90. mask |= std::uniform_int_distribution<>(0, 0xff)(rand);
  91. mask <<= 8;
  92. }
  93. mask &= ~kTrinaryProperties;
  94. fst->SetProperties(props & ~mask, mask);
  95. }
  96. } // namespace fst
  97. #endif // FST_TEST_RAND_FST_H_