You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
3.0 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. #ifndef FST_SCRIPT_RANDEQUIVALENT_H_
  18. #define FST_SCRIPT_RANDEQUIVALENT_H_
  19. #include <cstdint>
  20. #include <random>
  21. #include <tuple>
  22. #include <fst/fst.h>
  23. #include <fst/randequivalent.h>
  24. #include <fst/randgen.h>
  25. #include <fst/weight.h>
  26. #include <fst/script/arg-packs.h>
  27. #include <fst/script/fst-class.h>
  28. #include <fst/script/script-impl.h>
  29. namespace fst {
  30. namespace script {
  31. using FstRandEquivalentInnerArgs =
  32. std::tuple<const FstClass &, const FstClass &, int32_t,
  33. const RandGenOptions<RandArcSelection> &, float, uint64_t>;
  34. using FstRandEquivalentArgs = WithReturnValue<bool, FstRandEquivalentInnerArgs>;
  35. template <class Arc>
  36. void RandEquivalent(FstRandEquivalentArgs *args) {
  37. const Fst<Arc> &fst1 = *std::get<0>(args->args).GetFst<Arc>();
  38. const Fst<Arc> &fst2 = *std::get<1>(args->args).GetFst<Arc>();
  39. const int32_t npath = std::get<2>(args->args);
  40. const auto &opts = std::get<3>(args->args);
  41. const float delta = std::get<4>(args->args);
  42. const uint64_t seed = std::get<5>(args->args);
  43. switch (opts.selector) {
  44. case RandArcSelection::UNIFORM: {
  45. const UniformArcSelector<Arc> selector(seed);
  46. const RandGenOptions<UniformArcSelector<Arc>> ropts(selector,
  47. opts.max_length);
  48. args->retval = RandEquivalent(fst1, fst2, npath, ropts, delta, seed);
  49. return;
  50. }
  51. case RandArcSelection::FAST_LOG_PROB: {
  52. const FastLogProbArcSelector<Arc> selector(seed);
  53. const RandGenOptions<FastLogProbArcSelector<Arc>> ropts(selector,
  54. opts.max_length);
  55. args->retval = RandEquivalent(fst1, fst2, npath, ropts, delta, seed);
  56. return;
  57. }
  58. case RandArcSelection::LOG_PROB: {
  59. const LogProbArcSelector<Arc> selector(seed);
  60. const RandGenOptions<LogProbArcSelector<Arc>> ropts(selector,
  61. opts.max_length);
  62. args->retval = RandEquivalent(fst1, fst2, npath, ropts, delta, seed);
  63. return;
  64. }
  65. }
  66. }
  67. bool RandEquivalent(
  68. const FstClass &fst1, const FstClass &fst2, int32_t npath = 1,
  69. const RandGenOptions<RandArcSelection> &opts =
  70. RandGenOptions<RandArcSelection>(RandArcSelection::UNIFORM),
  71. float delta = kDelta, uint64_t seed = std::random_device()());
  72. } // namespace script
  73. } // namespace fst
  74. #endif // FST_SCRIPT_RANDEQUIVALENT_H_