You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
8.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // This file defines the registration mechanism for new operations.
  19. // These operations are designed to enable scripts to work with FST classes
  20. // at a high level.
  21. //
  22. // If you have a new arc type and want these operations to work with FSTs
  23. // with that arc type, see below for the registration steps
  24. // you must take.
  25. //
  26. // These methods are only recommended for use in high-level scripting
  27. // applications. Most users should use the lower-level templated versions
  28. // corresponding to these.
  29. //
  30. // If you have a new arc type you'd like these operations to work with,
  31. // use the REGISTER_FST_OPERATIONS macro defined in fstscript.h.
  32. //
  33. // If you have a custom operation you'd like to define, you need four
  34. // components. In the following, assume you want to create a new operation
  35. // with the signature
  36. //
  37. // void Foo(const FstClass &ifst, MutableFstClass *ofst);
  38. //
  39. // You need:
  40. //
  41. // 1) A way to bundle the args that your new Foo operation will take, as
  42. // a single struct. The template structs in arg-packs.h provide a handy
  43. // way to do this. In Foo's case, that might look like this:
  44. //
  45. // using FooArgs = std::pair<const FstClass &, MutableFstClass *>;
  46. //
  47. // Note: this package of args is going to be passed by non-const pointer.
  48. //
  49. // 2) A function template that is able to perform Foo, given the args and
  50. // arc type. Yours might look like this:
  51. //
  52. // template<class Arc>
  53. // void Foo(FooArgs *args) {
  54. // // Pulls out the actual, arc-templated FSTs.
  55. // const Fst<Arc> &ifst = std::get<0>(*args).GetFst<Arc>();
  56. // MutableFst<Arc> *ofst = std::get<1>(*args)->GetMutableFst<Arc>();
  57. // // Actually perform Foo on ifst and ofst.
  58. // }
  59. //
  60. // 3) a client-facing function for your operation. This would look like
  61. // the following:
  62. //
  63. // void Foo(const FstClass &ifst, MutableFstClass *ofst) {
  64. // // Check that the arc types of the FSTs match
  65. // if (!ArcTypesMatch(ifst, *ofst, "Foo")) return;
  66. // // package the args
  67. // FooArgs args(ifst, ofst);
  68. // // Finally, call the operation
  69. // Apply<Operation<FooArgs>>("Foo", ifst->ArcType(), &args);
  70. // }
  71. //
  72. // The Apply<> function template takes care of the link between 2 and 3,
  73. // provided you also have:
  74. //
  75. // 4) A registration for your new operation, on the arc types you care about.
  76. // This can be provided easily by the REGISTER_FST_OPERATION macro:
  77. //
  78. // REGISTER_FST_OPERATION(Foo, StdArc, FooArgs);
  79. // REGISTER_FST_OPERATION(Foo, MyArc, FooArgs);
  80. // // .. etc
  81. //
  82. // You can also use REGISTER_FST_OPERATION_3ARCS macro to register an
  83. // operation for StdArc, LogArc, and Log64Arc:
  84. //
  85. // REGISTER_FST_OPERATION_3ARCS(Foo, FooArcs);
  86. //
  87. // That's it! Now when you call Foo(const FstClass &, MutableFstClass *),
  88. // it dispatches (in #3) via the Apply<> function to the correct
  89. // instantiation of the template function in #2.
  90. //
  91. #ifndef FST_SCRIPT_SCRIPT_IMPL_H_
  92. #define FST_SCRIPT_SCRIPT_IMPL_H_
  93. // This file contains general-purpose templates which are used in the
  94. // implementation of the operations.
  95. #include <cstdint>
  96. #include <memory>
  97. #include <string>
  98. #include <utility>
  99. #include <vector>
  100. #include <fst/log.h>
  101. #include <fst/arc.h>
  102. #include <fst/generic-register.h>
  103. #include <fst/util.h>
  104. #include <fst/script/fst-class.h>
  105. #include <fst/script/weight-class.h>
  106. #include <string_view>
  107. namespace fst {
  108. namespace script {
  109. enum class RandArcSelection : uint8_t { UNIFORM, LOG_PROB, FAST_LOG_PROB };
  110. // A generic register for operations with various kinds of signatures.
  111. // Needed since every function signature requires a new registration class.
  112. // The std::pair<std::string, std::string> is understood to be the operation
  113. // name and arc type; subclasses (or typedefs) need only provide the operation
  114. // signature.
  115. template <class OperationSignature>
  116. class GenericOperationRegister
  117. : public GenericRegister<std::pair<std::string_view, std::string_view>,
  118. OperationSignature,
  119. GenericOperationRegister<OperationSignature>> {
  120. public:
  121. OperationSignature GetOperation(std::string_view operation_name,
  122. std::string_view arc_type) {
  123. return this->GetEntry(std::make_pair(operation_name, arc_type));
  124. }
  125. protected:
  126. std::string ConvertKeyToSoFilename(
  127. const std::pair<std::string_view, std::string_view> &key) const final {
  128. // Uses the old-style FST for now.
  129. std::string legal_type(key.second); // The arc type.
  130. ConvertToLegalCSymbol(&legal_type);
  131. legal_type.append("-arc.so");
  132. return legal_type;
  133. }
  134. };
  135. // Operation package: everything you need to register a new type of operation.
  136. // The ArgPack should be the type that's passed into each wrapped function;
  137. // for instance, it might be a struct containing all the args. It's always
  138. // passed by pointer, so const members should be used to enforce constness where
  139. // it's needed. Return values should be implemented as a member of ArgPack as
  140. // well.
  141. template <class Args>
  142. struct Operation {
  143. using ArgPack = Args;
  144. using OpType = void (*)(ArgPack *args);
  145. // The register (hash) type.
  146. using Register = GenericOperationRegister<OpType>;
  147. // The register-er type.
  148. using Registerer = GenericRegisterer<Register>;
  149. };
  150. // Macro for registering new types of operations.
  151. #define REGISTER_FST_OPERATION(Op, Arc, ArgPack) \
  152. static fst::script::Operation<ArgPack>::Registerer \
  153. arc_dispatched_operation_##ArgPack##Op##Arc##_registerer \
  154. ({#Op, Arc::Type()}, Op<Arc>)
  155. // A macro that calls REGISTER_FST_OPERATION for widely-used arc types.
  156. #define REGISTER_FST_OPERATION_3ARCS(Op, ArgPack) \
  157. REGISTER_FST_OPERATION(Op, StdArc, ArgPack); \
  158. REGISTER_FST_OPERATION(Op, LogArc, ArgPack); \
  159. REGISTER_FST_OPERATION(Op, Log64Arc, ArgPack)
  160. // Template function to apply an operation by name.
  161. template <class OpReg>
  162. void Apply(const std::string &op_name, const std::string &arc_type,
  163. typename OpReg::ArgPack *args) {
  164. const auto op =
  165. OpReg::Register::GetRegister()->GetOperation(op_name, arc_type);
  166. if (!op) {
  167. FSTERROR() << op_name << ": No operation found on arc type " << arc_type;
  168. return;
  169. }
  170. op(args);
  171. }
  172. namespace internal {
  173. // Helper that logs to ERROR if the arc types of m and n don't match,
  174. // assuming that both m and n implement .ArcType(). The op_name argument is
  175. // used to construct the error message.
  176. template <class M, class N>
  177. bool ArcTypesMatch(const M &m, const N &n, const std::string &op_name) {
  178. if (m.ArcType() != n.ArcType()) {
  179. FSTERROR() << op_name << ": Arguments with non-matching arc types "
  180. << m.ArcType() << " and " << n.ArcType();
  181. return false;
  182. }
  183. return true;
  184. }
  185. // From untyped to typed weights.
  186. template <class Weight>
  187. void CopyWeights(const std::vector<WeightClass> &weights,
  188. std::vector<Weight> *typed_weights) {
  189. typed_weights->clear();
  190. typed_weights->reserve(weights.size());
  191. for (const auto &weight : weights) {
  192. typed_weights->emplace_back(*weight.GetWeight<Weight>());
  193. }
  194. }
  195. // From typed to untyped weights.
  196. template <class Weight>
  197. void CopyWeights(const std::vector<Weight> &typed_weights,
  198. std::vector<WeightClass> *weights) {
  199. weights->clear();
  200. weights->reserve(typed_weights.size());
  201. for (const auto &typed_weight : typed_weights) {
  202. weights->emplace_back(typed_weight);
  203. }
  204. }
  205. } // namespace internal
  206. // Used for Replace operations.
  207. inline std::vector<std::pair<int64_t, const FstClass *>> BorrowPairs(
  208. const std::vector<std::pair<int64_t, std::unique_ptr<const FstClass>>>
  209. &pairs) {
  210. std::vector<std::pair<int64_t, const FstClass *>> borrowed_pairs;
  211. borrowed_pairs.reserve(pairs.size());
  212. for (const auto &pair : pairs) {
  213. borrowed_pairs.emplace_back(pair.first, pair.second.get());
  214. }
  215. return borrowed_pairs;
  216. }
  217. } // namespace script
  218. } // namespace fst
  219. #endif // FST_SCRIPT_SCRIPT_IMPL_H_