You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

300 lines
9.1 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. #ifndef FST_SCRIPT_ENCODEMAPPER_CLASS_H_
  18. #define FST_SCRIPT_ENCODEMAPPER_CLASS_H_
  19. #include <cstdint>
  20. #include <iostream>
  21. #include <istream>
  22. #include <memory>
  23. #include <ostream>
  24. #include <string>
  25. #include <utility>
  26. #include <fst/encode.h>
  27. #include <fst/generic-register.h>
  28. #include <fst/symbol-table.h>
  29. #include <fst/util.h>
  30. #include <fst/script/arc-class.h>
  31. #include <fst/script/fst-class.h>
  32. #include <string_view>
  33. // Scripting API support for EncodeMapper.
  34. namespace fst {
  35. namespace script {
  36. // Virtual interface implemented by each concrete EncodeMapperClassImpl<Arc>.
  37. class EncodeMapperImplBase {
  38. public:
  39. // Returns an encoded ArcClass.
  40. virtual ArcClass operator()(const ArcClass &) = 0;
  41. virtual const std::string &ArcType() const = 0;
  42. virtual const std::string &WeightType() const = 0;
  43. virtual EncodeMapperImplBase *Copy() const = 0;
  44. virtual uint8_t Flags() const = 0;
  45. virtual uint64_t Properties(uint64_t) = 0;
  46. virtual EncodeType Type() const = 0;
  47. virtual bool Write(const std::string &) const = 0;
  48. virtual bool Write(std::ostream &, const std::string &) const = 0;
  49. virtual const SymbolTable *InputSymbols() const = 0;
  50. virtual const SymbolTable *OutputSymbols() const = 0;
  51. virtual void SetInputSymbols(const SymbolTable *) = 0;
  52. virtual void SetOutputSymbols(const SymbolTable *) = 0;
  53. virtual ~EncodeMapperImplBase() = default;
  54. };
  55. // Templated implementation.
  56. template <class Arc>
  57. class EncodeMapperClassImpl : public EncodeMapperImplBase {
  58. public:
  59. explicit EncodeMapperClassImpl(const EncodeMapper<Arc> &mapper)
  60. : mapper_(mapper) {}
  61. ArcClass operator()(const ArcClass &a) final;
  62. const std::string &ArcType() const final { return Arc::Type(); }
  63. const std::string &WeightType() const final { return Arc::Weight::Type(); }
  64. EncodeMapperClassImpl<Arc> *Copy() const final {
  65. return new EncodeMapperClassImpl<Arc>(mapper_);
  66. }
  67. uint8_t Flags() const final { return mapper_.Flags(); }
  68. uint64_t Properties(uint64_t inprops) final {
  69. return mapper_.Properties(inprops);
  70. }
  71. EncodeType Type() const final { return mapper_.Type(); }
  72. bool Write(const std::string &source) const final {
  73. return mapper_.Write(source);
  74. }
  75. bool Write(std::ostream &strm, const std::string &source) const final {
  76. return mapper_.Write(strm, source);
  77. }
  78. const SymbolTable *InputSymbols() const final {
  79. return mapper_.InputSymbols();
  80. }
  81. const SymbolTable *OutputSymbols() const final {
  82. return mapper_.OutputSymbols();
  83. }
  84. void SetInputSymbols(const SymbolTable *syms) final {
  85. mapper_.SetInputSymbols(syms);
  86. }
  87. void SetOutputSymbols(const SymbolTable *syms) final {
  88. mapper_.SetOutputSymbols(syms);
  89. }
  90. ~EncodeMapperClassImpl() override = default;
  91. const EncodeMapper<Arc> *GetImpl() const { return &mapper_; }
  92. EncodeMapper<Arc> *GetImpl() { return &mapper_; }
  93. private:
  94. EncodeMapper<Arc> mapper_;
  95. };
  96. template <class Arc>
  97. inline ArcClass EncodeMapperClassImpl<Arc>::operator()(const ArcClass &a) {
  98. const Arc arc(a.ilabel, a.olabel,
  99. *(a.weight.GetWeight<typename Arc::Weight>()), a.nextstate);
  100. return ArcClass(mapper_(arc));
  101. }
  102. class EncodeMapperClass {
  103. public:
  104. EncodeMapperClass() : impl_(nullptr) {}
  105. EncodeMapperClass(std::string_view arc_type, uint8_t flags,
  106. EncodeType type = ENCODE);
  107. template <class Arc>
  108. explicit EncodeMapperClass(const EncodeMapper<Arc> &mapper)
  109. : impl_(std::make_unique<EncodeMapperClassImpl<Arc>>(mapper)) {}
  110. EncodeMapperClass(const EncodeMapperClass &other)
  111. : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {}
  112. EncodeMapperClass &operator=(const EncodeMapperClass &other) {
  113. impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy());
  114. return *this;
  115. }
  116. ArcClass operator()(const ArcClass &arc) { return (*impl_)(arc); }
  117. const std::string &ArcType() const { return impl_->ArcType(); }
  118. const std::string &WeightType() const { return impl_->WeightType(); }
  119. uint8_t Flags() const { return impl_->Flags(); }
  120. uint64_t Properties(uint64_t inprops) { return impl_->Properties(inprops); }
  121. EncodeType Type() const { return impl_->Type(); }
  122. static std::unique_ptr<EncodeMapperClass> Read(
  123. const std::string &source);
  124. static std::unique_ptr<EncodeMapperClass> Read(
  125. std::istream &strm, const std::string &source);
  126. bool Write(const std::string &source) const { return impl_->Write(source); }
  127. bool Write(std::ostream &strm, const std::string &source) const {
  128. return impl_->Write(strm, source);
  129. }
  130. const SymbolTable *InputSymbols() const { return impl_->InputSymbols(); }
  131. const SymbolTable *OutputSymbols() const { return impl_->OutputSymbols(); }
  132. void SetInputSymbols(const SymbolTable *syms) {
  133. impl_->SetInputSymbols(syms);
  134. }
  135. void SetOutputSymbols(const SymbolTable *syms) {
  136. impl_->SetOutputSymbols(syms);
  137. }
  138. // Implementation stuff.
  139. template <class Arc>
  140. EncodeMapper<Arc> *GetEncodeMapper() {
  141. if (Arc::Type() != ArcType()) {
  142. return nullptr;
  143. } else {
  144. auto *typed_impl = down_cast<EncodeMapperClassImpl<Arc> *>(impl_.get());
  145. return typed_impl->GetImpl();
  146. }
  147. }
  148. template <class Arc>
  149. const EncodeMapper<Arc> *GetEncodeMapper() const {
  150. if (Arc::Type() != ArcType()) {
  151. return nullptr;
  152. } else {
  153. auto *typed_impl = down_cast<EncodeMapperClassImpl<Arc> *>(impl_.get());
  154. return typed_impl->GetImpl();
  155. }
  156. }
  157. // Required for registration.
  158. template <class Arc>
  159. static std::unique_ptr<EncodeMapperClass> Read(std::istream &strm,
  160. std::string_view source) {
  161. std::unique_ptr<EncodeMapper<Arc>> mapper(
  162. EncodeMapper<Arc>::Read(strm, source));
  163. return mapper ? std::make_unique<EncodeMapperClass>(*mapper) : nullptr;
  164. }
  165. template <class Arc>
  166. static std::unique_ptr<EncodeMapperImplBase> Create(
  167. uint8_t flags, EncodeType type = ENCODE) {
  168. return std::make_unique<EncodeMapperClassImpl<Arc>>(
  169. EncodeMapper<Arc>(flags, type));
  170. }
  171. private:
  172. explicit EncodeMapperClass(std::unique_ptr<EncodeMapperImplBase> impl)
  173. : impl_(std::move(impl)) {}
  174. const EncodeMapperImplBase *GetImpl() const { return impl_.get(); }
  175. EncodeMapperImplBase *GetImpl() { return impl_.get(); }
  176. std::unique_ptr<EncodeMapperImplBase> impl_;
  177. };
  178. // Registration for EncodeMapper types.
  179. // This class definition is to avoid a nested class definition inside the
  180. // EncodeMapperIORegistration struct.
  181. template <class Reader, class Creator>
  182. struct EncodeMapperClassRegEntry {
  183. Reader reader;
  184. Creator creator;
  185. EncodeMapperClassRegEntry(Reader reader, Creator creator)
  186. : reader(reader), creator(creator) {}
  187. EncodeMapperClassRegEntry() : reader(nullptr), creator(nullptr) {}
  188. };
  189. template <class Reader, class Creator>
  190. class EncodeMapperClassIORegister
  191. : public GenericRegister<std::string,
  192. EncodeMapperClassRegEntry<Reader, Creator>,
  193. EncodeMapperClassIORegister<Reader, Creator>> {
  194. public:
  195. Reader GetReader(std::string_view arc_type) const {
  196. return this->GetEntry(arc_type).reader;
  197. }
  198. Creator GetCreator(std::string_view arc_type) const {
  199. return this->GetEntry(arc_type).creator;
  200. }
  201. protected:
  202. std::string ConvertKeyToSoFilename(std::string_view key) const final {
  203. std::string legal_type(key);
  204. ConvertToLegalCSymbol(&legal_type);
  205. legal_type.append("-arc.so");
  206. return legal_type;
  207. }
  208. };
  209. // Struct containing everything needed to register a particular type
  210. struct EncodeMapperClassIORegistration {
  211. using Reader = std::unique_ptr<EncodeMapperClass> (*)(
  212. std::istream &stream, std::string_view source);
  213. using Creator = std::unique_ptr<EncodeMapperImplBase> (*)(uint8_t flags,
  214. EncodeType type);
  215. using Entry = EncodeMapperClassRegEntry<Reader, Creator>;
  216. // EncodeMapper register.
  217. using Register = EncodeMapperClassIORegister<Reader, Creator>;
  218. // EncodeMapper register-er.
  219. using Registerer =
  220. GenericRegisterer<EncodeMapperClassIORegister<Reader, Creator>>;
  221. };
  222. #define REGISTER_ENCODEMAPPER_CLASS(Arc) \
  223. static EncodeMapperClassIORegistration::Registerer \
  224. EncodeMapperClass_##Arc##_registerer( \
  225. Arc::Type(), \
  226. EncodeMapperClassIORegistration::Entry( \
  227. EncodeMapperClass::Read<Arc>, EncodeMapperClass::Create<Arc>));
  228. } // namespace script
  229. } // namespace fst
  230. #endif // FST_SCRIPT_ENCODEMAPPER_CLASS_H_