You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.7 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. #ifndef FST_SCRIPT_PRUNE_H_
  18. #define FST_SCRIPT_PRUNE_H_
  19. #include <cstdint>
  20. #include <tuple>
  21. #include <utility>
  22. #include <fst/log.h>
  23. #include <fst/fst.h>
  24. #include <fst/mutable-fst.h>
  25. #include <fst/properties.h>
  26. #include <fst/prune.h>
  27. #include <fst/util.h>
  28. #include <fst/weight.h>
  29. #include <fst/script/fst-class.h>
  30. #include <fst/script/weight-class.h>
  31. namespace fst {
  32. namespace script {
  33. using FstPruneArgs1 = std::tuple<const FstClass &, MutableFstClass *,
  34. const WeightClass &, int64_t, float>;
  35. template <class Arc>
  36. void Prune(FstPruneArgs1 *args) {
  37. using Weight = typename Arc::Weight;
  38. const Fst<Arc> &ifst = *std::get<0>(*args).GetFst<Arc>();
  39. MutableFst<Arc> *ofst = std::get<1>(*args)->GetMutableFst<Arc>();
  40. if constexpr (IsPath<Weight>::value) {
  41. const auto weight_threshold = *std::get<2>(*args).GetWeight<Weight>();
  42. Prune(ifst, ofst, weight_threshold, std::get<3>(*args), std::get<4>(*args));
  43. } else {
  44. FSTERROR() << "Prune: Weight must have path property: " << Weight::Type();
  45. ofst->SetProperties(kError, kError);
  46. }
  47. }
  48. using FstPruneArgs2 =
  49. std::tuple<MutableFstClass *, const WeightClass &, int64_t, float>;
  50. template <class Arc>
  51. void Prune(FstPruneArgs2 *args) {
  52. using Weight = typename Arc::Weight;
  53. MutableFst<Arc> *fst = std::get<0>(*args)->GetMutableFst<Arc>();
  54. if constexpr (IsPath<Weight>::value) {
  55. const auto weight_threshold = *std::get<1>(*args).GetWeight<Weight>();
  56. Prune(fst, weight_threshold, std::get<2>(*args), std::get<3>(*args));
  57. } else {
  58. FSTERROR() << "Prune: Weight must have path property: " << Weight::Type();
  59. fst->SetProperties(kError, kError);
  60. }
  61. }
  62. void Prune(const FstClass &ifst, MutableFstClass *ofst,
  63. const WeightClass &weight_threshold,
  64. int64_t state_threshold = kNoStateId, float delta = kDelta);
  65. void Prune(MutableFstClass *fst, const WeightClass &weight_threshold,
  66. int64_t state_threshold = kNoStateId, float delta = kDelta);
  67. } // namespace script
  68. } // namespace fst
  69. #endif // FST_SCRIPT_PRUNE_H_