You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

253 lines
8.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. #ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_
  18. #define FST_SCRIPT_SHORTEST_DISTANCE_H_
  19. #include <cstdint>
  20. #include <memory>
  21. #include <tuple>
  22. #include <type_traits>
  23. #include <vector>
  24. #include <fst/log.h>
  25. #include <fst/arcfilter.h>
  26. #include <fst/fst.h>
  27. #include <fst/queue.h>
  28. #include <fst/shortest-distance.h>
  29. #include <fst/util.h>
  30. #include <fst/weight.h>
  31. #include <fst/script/arcfilter-impl.h>
  32. #include <fst/script/arg-packs.h>
  33. #include <fst/script/fst-class.h>
  34. #include <fst/script/prune.h>
  35. #include <fst/script/script-impl.h>
  36. #include <fst/script/weight-class.h>
  37. namespace fst {
  38. namespace script {
  39. struct ShortestDistanceOptions {
  40. const QueueType queue_type;
  41. const ArcFilterType arc_filter_type;
  42. const int64_t source;
  43. const float delta;
  44. ShortestDistanceOptions(QueueType queue_type, ArcFilterType arc_filter_type,
  45. int64_t source, float delta)
  46. : queue_type(queue_type),
  47. arc_filter_type(arc_filter_type),
  48. source(source),
  49. delta(delta) {}
  50. };
  51. namespace internal {
  52. // Code to implement switching on queue and arc filter types.
  53. template <class Arc, class Queue, class ArcFilter>
  54. struct QueueConstructor {
  55. using Weight = typename Arc::Weight;
  56. static std::unique_ptr<Queue> Construct(const Fst<Arc> &,
  57. const std::vector<Weight> *) {
  58. return std::make_unique<Queue>();
  59. }
  60. };
  61. // Specializations to support queues with different constructors.
  62. template <class Arc, class ArcFilter>
  63. struct QueueConstructor<Arc, AutoQueue<typename Arc::StateId>, ArcFilter> {
  64. using StateId = typename Arc::StateId;
  65. using Weight = typename Arc::Weight;
  66. // template<class Arc, class ArcFilter>
  67. static std::unique_ptr<AutoQueue<StateId>> Construct(
  68. const Fst<Arc> &fst, const std::vector<Weight> *distance) {
  69. return std::make_unique<AutoQueue<StateId>>(fst, distance, ArcFilter());
  70. }
  71. };
  72. template <class Arc, class ArcFilter>
  73. struct QueueConstructor<
  74. Arc, NaturalShortestFirstQueue<typename Arc::StateId, typename Arc::Weight>,
  75. ArcFilter> {
  76. using StateId = typename Arc::StateId;
  77. using Weight = typename Arc::Weight;
  78. static std::unique_ptr<NaturalShortestFirstQueue<StateId, Weight>> Construct(
  79. const Fst<Arc> &, const std::vector<Weight> *distance) {
  80. return std::make_unique<NaturalShortestFirstQueue<StateId, Weight>>(
  81. *distance);
  82. }
  83. };
  84. template <class Arc, class ArcFilter>
  85. struct QueueConstructor<Arc, TopOrderQueue<typename Arc::StateId>, ArcFilter> {
  86. using StateId = typename Arc::StateId;
  87. using Weight = typename Arc::Weight;
  88. static std::unique_ptr<TopOrderQueue<StateId>> Construct(
  89. const Fst<Arc> &fst, const std::vector<Weight> *) {
  90. return std::make_unique<TopOrderQueue<StateId>>(fst, ArcFilter());
  91. }
  92. };
  93. template <class Arc, class Queue, class ArcFilter>
  94. void ShortestDistance(const Fst<Arc> &fst,
  95. std::vector<typename Arc::Weight> *distance,
  96. const ShortestDistanceOptions &opts) {
  97. std::unique_ptr<Queue> queue(
  98. QueueConstructor<Arc, Queue, ArcFilter>::Construct(fst, distance));
  99. const fst::ShortestDistanceOptions<Arc, Queue, ArcFilter> sopts(
  100. queue.get(), ArcFilter(), opts.source, opts.delta);
  101. ShortestDistance(fst, distance, sopts);
  102. }
  103. template <class Arc, class Queue>
  104. void ShortestDistance(const Fst<Arc> &fst,
  105. std::vector<typename Arc::Weight> *distance,
  106. const ShortestDistanceOptions &opts) {
  107. switch (opts.arc_filter_type) {
  108. case ArcFilterType::ANY: {
  109. ShortestDistance<Arc, Queue, AnyArcFilter<Arc>>(fst, distance, opts);
  110. return;
  111. }
  112. case ArcFilterType::EPSILON: {
  113. ShortestDistance<Arc, Queue, EpsilonArcFilter<Arc>>(fst, distance, opts);
  114. return;
  115. }
  116. case ArcFilterType::INPUT_EPSILON: {
  117. ShortestDistance<Arc, Queue, InputEpsilonArcFilter<Arc>>(fst, distance,
  118. opts);
  119. return;
  120. }
  121. case ArcFilterType::OUTPUT_EPSILON: {
  122. ShortestDistance<Arc, Queue, OutputEpsilonArcFilter<Arc>>(fst, distance,
  123. opts);
  124. return;
  125. }
  126. default: {
  127. FSTERROR() << "ShortestDistance: Unknown arc filter type: "
  128. << static_cast<std::underlying_type_t<ArcFilterType>>(
  129. opts.arc_filter_type);
  130. distance->clear();
  131. distance->resize(1, Arc::Weight::NoWeight());
  132. return;
  133. }
  134. }
  135. }
  136. } // namespace internal
  137. using FstShortestDistanceArgs1 =
  138. std::tuple<const FstClass &, std::vector<WeightClass> *,
  139. const ShortestDistanceOptions &>;
  140. template <class Arc>
  141. void ShortestDistance(FstShortestDistanceArgs1 *args) {
  142. using StateId = typename Arc::StateId;
  143. using Weight = typename Arc::Weight;
  144. const Fst<Arc> &fst = *std::get<0>(*args).GetFst<Arc>();
  145. const auto &opts = std::get<2>(*args);
  146. std::vector<Weight> typed_distance;
  147. switch (opts.queue_type) {
  148. case AUTO_QUEUE: {
  149. internal::ShortestDistance<Arc, AutoQueue<StateId>>(fst, &typed_distance,
  150. opts);
  151. break;
  152. }
  153. case FIFO_QUEUE: {
  154. internal::ShortestDistance<Arc, FifoQueue<StateId>>(fst, &typed_distance,
  155. opts);
  156. break;
  157. }
  158. case LIFO_QUEUE: {
  159. internal::ShortestDistance<Arc, LifoQueue<StateId>>(fst, &typed_distance,
  160. opts);
  161. break;
  162. }
  163. case SHORTEST_FIRST_QUEUE: {
  164. if constexpr (IsIdempotent<Weight>::value) {
  165. internal::ShortestDistance<Arc,
  166. NaturalShortestFirstQueue<StateId, Weight>>(
  167. fst, &typed_distance, opts);
  168. } else {
  169. FSTERROR() << "ShortestDistance: Bad queue type SHORTEST_FIRST_QUEUE"
  170. << " for non-idempotent Weight " << Weight::Type();
  171. }
  172. break;
  173. }
  174. case STATE_ORDER_QUEUE: {
  175. internal::ShortestDistance<Arc, StateOrderQueue<StateId>>(
  176. fst, &typed_distance, opts);
  177. break;
  178. }
  179. case TOP_ORDER_QUEUE: {
  180. internal::ShortestDistance<Arc, TopOrderQueue<StateId>>(
  181. fst, &typed_distance, opts);
  182. break;
  183. }
  184. default: {
  185. FSTERROR() << "ShortestDistance: Unknown queue type: " << opts.queue_type;
  186. typed_distance.clear();
  187. typed_distance.resize(1, Arc::Weight::NoWeight());
  188. break;
  189. }
  190. }
  191. internal::CopyWeights(typed_distance, std::get<1>(*args));
  192. }
  193. using FstShortestDistanceArgs2 =
  194. std::tuple<const FstClass &, std::vector<WeightClass> *, bool, double>;
  195. template <class Arc>
  196. void ShortestDistance(FstShortestDistanceArgs2 *args) {
  197. using Weight = typename Arc::Weight;
  198. const Fst<Arc> &fst = *std::get<0>(*args).GetFst<Arc>();
  199. std::vector<Weight> typed_distance;
  200. ShortestDistance(fst, &typed_distance, std::get<2>(*args),
  201. std::get<3>(*args));
  202. internal::CopyWeights(typed_distance, std::get<1>(*args));
  203. }
  204. using FstShortestDistanceInnerArgs3 = std::tuple<const FstClass &, double>;
  205. using FstShortestDistanceArgs3 =
  206. WithReturnValue<WeightClass, FstShortestDistanceInnerArgs3>;
  207. template <class Arc>
  208. void ShortestDistance(FstShortestDistanceArgs3 *args) {
  209. const Fst<Arc> &fst = *std::get<0>(args->args).GetFst<Arc>();
  210. args->retval = WeightClass(ShortestDistance(fst, std::get<1>(args->args)));
  211. }
  212. void ShortestDistance(const FstClass &fst, std::vector<WeightClass> *distance,
  213. const ShortestDistanceOptions &opts);
  214. void ShortestDistance(const FstClass &ifst, std::vector<WeightClass> *distance,
  215. bool reverse = false,
  216. double delta = fst::kShortestDelta);
  217. WeightClass ShortestDistance(const FstClass &ifst,
  218. double delta = fst::kShortestDelta);
  219. } // namespace script
  220. } // namespace fst
  221. #endif // FST_SCRIPT_SHORTEST_DISTANCE_H_