You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

676 lines
22 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. #ifndef FST_SCRIPT_FST_CLASS_H_
  18. #define FST_SCRIPT_FST_CLASS_H_
  19. #include <algorithm>
  20. #include <cstddef>
  21. #include <cstdint>
  22. #include <istream>
  23. #include <limits>
  24. #include <memory>
  25. #include <ostream>
  26. #include <string>
  27. #include <type_traits>
  28. #include <utility>
  29. #include <vector>
  30. #include <fst/log.h>
  31. #include <fst/expanded-fst.h>
  32. #include <fst/fst.h>
  33. #include <fst/generic-register.h>
  34. #include <fst/mutable-fst.h>
  35. #include <fst/properties.h>
  36. #include <fst/symbol-table.h>
  37. #include <fst/util.h>
  38. #include <fst/vector-fst.h>
  39. #include <fst/script/arc-class.h>
  40. #include <fst/script/weight-class.h>
  41. #include <string_view>
  42. // Classes to support "boxing" all existing types of FST arcs in a single
  43. // FstClass which hides the arc types. This allows clients to load
  44. // and work with FSTs without knowing the arc type. These classes are only
  45. // recommended for use in high-level scripting applications. Most users should
  46. // use the lower-level templated versions corresponding to these classes.
  47. namespace fst {
  48. namespace script {
  49. // Abstract base class defining the set of functionalities implemented in all
  50. // impls and passed through by all bases. Below FstClassBase the class
  51. // hierarchy bifurcates; FstClassImplBase serves as the base class for all
  52. // implementations (of which FstClassImpl is currently the only one) and
  53. // FstClass serves as the base class for all interfaces.
  54. class FstClassBase {
  55. public:
  56. virtual const std::string &ArcType() const = 0;
  57. virtual WeightClass Final(int64_t) const = 0;
  58. virtual const std::string &FstType() const = 0;
  59. virtual const SymbolTable *InputSymbols() const = 0;
  60. virtual size_t NumArcs(int64_t) const = 0;
  61. virtual size_t NumInputEpsilons(int64_t) const = 0;
  62. virtual size_t NumOutputEpsilons(int64_t) const = 0;
  63. virtual const SymbolTable *OutputSymbols() const = 0;
  64. virtual uint64_t Properties(uint64_t, bool) const = 0;
  65. virtual int64_t Start() const = 0;
  66. virtual const std::string &WeightType() const = 0;
  67. virtual bool ValidStateId(int64_t) const = 0;
  68. virtual bool Write(const std::string &) const = 0;
  69. virtual bool Write(std::ostream &, const std::string &) const = 0;
  70. virtual ~FstClassBase() = default;
  71. };
  72. // Adds all the MutableFst methods.
  73. class FstClassImplBase : public FstClassBase {
  74. public:
  75. virtual bool AddArc(int64_t, const ArcClass &) = 0;
  76. virtual int64_t AddState() = 0;
  77. virtual void AddStates(size_t) = 0;
  78. virtual FstClassImplBase *Copy() = 0;
  79. virtual bool DeleteArcs(int64_t, size_t) = 0;
  80. virtual bool DeleteArcs(int64_t) = 0;
  81. virtual bool DeleteStates(const std::vector<int64_t> &) = 0;
  82. virtual void DeleteStates() = 0;
  83. virtual SymbolTable *MutableInputSymbols() = 0;
  84. virtual SymbolTable *MutableOutputSymbols() = 0;
  85. virtual int64_t NumStates() const = 0;
  86. virtual bool ReserveArcs(int64_t, size_t) = 0;
  87. virtual void ReserveStates(int64_t) = 0;
  88. virtual void SetInputSymbols(const SymbolTable *) = 0;
  89. virtual bool SetFinal(int64_t, const WeightClass &) = 0;
  90. virtual void SetOutputSymbols(const SymbolTable *) = 0;
  91. virtual void SetProperties(uint64_t, uint64_t) = 0;
  92. virtual bool SetStart(int64_t) = 0;
  93. ~FstClassImplBase() override = default;
  94. };
  95. // Containiner class wrapping an Fst<Arc>, hiding its arc type. Whether this
  96. // Fst<Arc> pointer refers to a special kind of FST (e.g. a MutableFst) is
  97. // known by the type of interface class that owns the pointer to this
  98. // container.
  99. template <class Arc>
  100. class FstClassImpl : public FstClassImplBase {
  101. public:
  102. explicit FstClassImpl(std::unique_ptr<Fst<Arc>> impl)
  103. : impl_(std::move(impl)) {}
  104. explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) {}
  105. // Warning: calling this method casts the FST to a mutable FST.
  106. bool AddArc(int64_t s, const ArcClass &ac) final {
  107. if (!ValidStateId(s)) return false;
  108. // Note that we do not check that the destination state is valid, so users
  109. // can add arcs before they add the corresponding states. Verify can be
  110. // used to determine whether any arc has a nonexisting destination.
  111. Arc arc(ac.ilabel, ac.olabel, *ac.weight.GetWeight<typename Arc::Weight>(),
  112. ac.nextstate);
  113. down_cast<MutableFst<Arc> *>(impl_.get())->AddArc(s, arc);
  114. return true;
  115. }
  116. // Warning: calling this method casts the FST to a mutable FST.
  117. int64_t AddState() final {
  118. return down_cast<MutableFst<Arc> *>(impl_.get())->AddState();
  119. }
  120. // Warning: calling this method casts the FST to a mutable FST.
  121. void AddStates(size_t n) final {
  122. return down_cast<MutableFst<Arc> *>(impl_.get())->AddStates(n);
  123. }
  124. const std::string &ArcType() const final { return Arc::Type(); }
  125. FstClassImpl *Copy() final { return new FstClassImpl<Arc>(*impl_); }
  126. // Warning: calling this method casts the FST to a mutable FST.
  127. bool DeleteArcs(int64_t s, size_t n) final {
  128. if (!ValidStateId(s)) return false;
  129. down_cast<MutableFst<Arc> *>(impl_.get())->DeleteArcs(s, n);
  130. return true;
  131. }
  132. // Warning: calling this method casts the FST to a mutable FST.
  133. bool DeleteArcs(int64_t s) final {
  134. if (!ValidStateId(s)) return false;
  135. down_cast<MutableFst<Arc> *>(impl_.get())->DeleteArcs(s);
  136. return true;
  137. }
  138. // Warning: calling this method casts the FST to a mutable FST.
  139. bool DeleteStates(const std::vector<int64_t> &dstates) final {
  140. for (const auto &state : dstates)
  141. if (!ValidStateId(state)) return false;
  142. // Warning: calling this method with any integers beyond the precision of
  143. // the underlying FST will result in truncation.
  144. std::vector<typename Arc::StateId> typed_dstates(dstates.size());
  145. std::copy(dstates.begin(), dstates.end(), typed_dstates.begin());
  146. down_cast<MutableFst<Arc> *>(impl_.get())->DeleteStates(typed_dstates);
  147. return true;
  148. }
  149. // Warning: calling this method casts the FST to a mutable FST.
  150. void DeleteStates() final {
  151. down_cast<MutableFst<Arc> *>(impl_.get())->DeleteStates();
  152. }
  153. WeightClass Final(int64_t s) const final {
  154. if (!ValidStateId(s)) return WeightClass::NoWeight(WeightType());
  155. WeightClass w(impl_->Final(s));
  156. return w;
  157. }
  158. const std::string &FstType() const final { return impl_->Type(); }
  159. const SymbolTable *InputSymbols() const final {
  160. return impl_->InputSymbols();
  161. }
  162. // Warning: calling this method casts the FST to a mutable FST.
  163. SymbolTable *MutableInputSymbols() final {
  164. return down_cast<MutableFst<Arc> *>(impl_.get())->MutableInputSymbols();
  165. }
  166. // Warning: calling this method casts the FST to a mutable FST.
  167. SymbolTable *MutableOutputSymbols() final {
  168. return down_cast<MutableFst<Arc> *>(impl_.get())->MutableOutputSymbols();
  169. }
  170. // Signals failure by returning size_t max.
  171. size_t NumArcs(int64_t s) const final {
  172. return ValidStateId(s) ? impl_->NumArcs(s)
  173. : std::numeric_limits<size_t>::max();
  174. }
  175. // Signals failure by returning size_t max.
  176. size_t NumInputEpsilons(int64_t s) const final {
  177. return ValidStateId(s) ? impl_->NumInputEpsilons(s)
  178. : std::numeric_limits<size_t>::max();
  179. }
  180. // Signals failure by returning size_t max.
  181. size_t NumOutputEpsilons(int64_t s) const final {
  182. return ValidStateId(s) ? impl_->NumOutputEpsilons(s)
  183. : std::numeric_limits<size_t>::max();
  184. }
  185. // Warning: calling this method casts the FST to a mutable FST.
  186. int64_t NumStates() const final {
  187. return down_cast<MutableFst<Arc> *>(impl_.get())->NumStates();
  188. }
  189. uint64_t Properties(uint64_t mask, bool test) const final {
  190. return impl_->Properties(mask, test);
  191. }
  192. // Warning: calling this method casts the FST to a mutable FST.
  193. bool ReserveArcs(int64_t s, size_t n) final {
  194. if (!ValidStateId(s)) return false;
  195. down_cast<MutableFst<Arc> *>(impl_.get())->ReserveArcs(s, n);
  196. return true;
  197. }
  198. // Warning: calling this method casts the FST to a mutable FST.
  199. void ReserveStates(int64_t n) final {
  200. down_cast<MutableFst<Arc> *>(impl_.get())->ReserveStates(n);
  201. }
  202. const SymbolTable *OutputSymbols() const final {
  203. return impl_->OutputSymbols();
  204. }
  205. // Warning: calling this method casts the FST to a mutable FST.
  206. void SetInputSymbols(const SymbolTable *isyms) final {
  207. down_cast<MutableFst<Arc> *>(impl_.get())->SetInputSymbols(isyms);
  208. }
  209. // Warning: calling this method casts the FST to a mutable FST.
  210. bool SetFinal(int64_t s, const WeightClass &weight) final {
  211. if (!ValidStateId(s)) return false;
  212. down_cast<MutableFst<Arc> *>(impl_.get())
  213. ->SetFinal(s, *weight.GetWeight<typename Arc::Weight>());
  214. return true;
  215. }
  216. // Warning: calling this method casts the FST to a mutable FST.
  217. void SetOutputSymbols(const SymbolTable *osyms) final {
  218. down_cast<MutableFst<Arc> *>(impl_.get())->SetOutputSymbols(osyms);
  219. }
  220. // Warning: calling this method casts the FST to a mutable FST.
  221. void SetProperties(uint64_t props, uint64_t mask) final {
  222. down_cast<MutableFst<Arc> *>(impl_.get())->SetProperties(props, mask);
  223. }
  224. // Warning: calling this method casts the FST to a mutable FST.
  225. bool SetStart(int64_t s) final {
  226. if (!ValidStateId(s)) return false;
  227. down_cast<MutableFst<Arc> *>(impl_.get())->SetStart(s);
  228. return true;
  229. }
  230. int64_t Start() const final { return impl_->Start(); }
  231. bool ValidStateId(int64_t s) const final {
  232. // This cowardly refuses to count states if the FST is not yet expanded.
  233. const auto num_states = impl_->NumStatesIfKnown();
  234. if (!num_states.has_value()) {
  235. FSTERROR() << "Cannot get number of states for unexpanded FST";
  236. return false;
  237. }
  238. if (s < 0 || s >= *num_states) {
  239. FSTERROR() << "State ID " << s << " not valid";
  240. return false;
  241. }
  242. return true;
  243. }
  244. const std::string &WeightType() const final { return Arc::Weight::Type(); }
  245. bool Write(const std::string &source) const final {
  246. return impl_->Write(source);
  247. }
  248. bool Write(std::ostream &ostr, const std::string &source) const final {
  249. const FstWriteOptions opts(source);
  250. return impl_->Write(ostr, opts);
  251. }
  252. ~FstClassImpl() override = default;
  253. Fst<Arc> *GetImpl() const { return impl_.get(); }
  254. private:
  255. std::unique_ptr<Fst<Arc>> impl_;
  256. };
  257. // BASE CLASS DEFINITIONS
  258. class MutableFstClass;
  259. class FstClass : public FstClassBase {
  260. public:
  261. FstClass() : impl_(nullptr) {}
  262. template <class Arc>
  263. explicit FstClass(std::unique_ptr<Fst<Arc>> fst)
  264. : impl_(std::make_unique<FstClassImpl<Arc>>(std::move(fst))) {}
  265. template <class Arc>
  266. explicit FstClass(const Fst<Arc> &fst)
  267. : impl_(std::make_unique<FstClassImpl<Arc>>(fst)) {}
  268. FstClass(const FstClass &other)
  269. : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {}
  270. FstClass &operator=(const FstClass &other) {
  271. impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy());
  272. return *this;
  273. }
  274. WeightClass Final(int64_t s) const final { return impl_->Final(s); }
  275. const std::string &ArcType() const final { return impl_->ArcType(); }
  276. const std::string &FstType() const final { return impl_->FstType(); }
  277. const SymbolTable *InputSymbols() const final {
  278. return impl_->InputSymbols();
  279. }
  280. size_t NumArcs(int64_t s) const final { return impl_->NumArcs(s); }
  281. size_t NumInputEpsilons(int64_t s) const final {
  282. return impl_->NumInputEpsilons(s);
  283. }
  284. size_t NumOutputEpsilons(int64_t s) const final {
  285. return impl_->NumOutputEpsilons(s);
  286. }
  287. const SymbolTable *OutputSymbols() const final {
  288. return impl_->OutputSymbols();
  289. }
  290. uint64_t Properties(uint64_t mask, bool test) const final {
  291. // Special handling for FSTs with a null impl.
  292. if (!impl_) return kError & mask;
  293. return impl_->Properties(mask, test);
  294. }
  295. static std::unique_ptr<FstClass> Read(
  296. const std::string &source);
  297. static std::unique_ptr<FstClass> Read(
  298. std::istream &istrm, const std::string &source);
  299. int64_t Start() const final { return impl_->Start(); }
  300. bool ValidStateId(int64_t s) const final { return impl_->ValidStateId(s); }
  301. const std::string &WeightType() const final { return impl_->WeightType(); }
  302. // Helper that logs an ERROR if the weight type of an FST and a WeightClass
  303. // don't match.
  304. bool WeightTypesMatch(const WeightClass &weight,
  305. std::string_view op_name) const;
  306. bool Write(const std::string &source) const final {
  307. return impl_->Write(source);
  308. }
  309. bool Write(std::ostream &ostr, const std::string &source) const final {
  310. return impl_->Write(ostr, source);
  311. }
  312. ~FstClass() override = default;
  313. // These methods are required by IO registration.
  314. template <class Arc>
  315. static std::unique_ptr<FstClassImplBase> Convert(const FstClass &other) {
  316. FSTERROR() << "Doesn't make sense to convert any class to type FstClass";
  317. return nullptr;
  318. }
  319. template <class Arc>
  320. static std::unique_ptr<FstClassImplBase> Create() {
  321. FSTERROR() << "Doesn't make sense to create an FstClass with a "
  322. << "particular arc type";
  323. return nullptr;
  324. }
  325. template <class Arc>
  326. const Fst<Arc> *GetFst() const {
  327. if (Arc::Type() != ArcType()) {
  328. return nullptr;
  329. } else {
  330. FstClassImpl<Arc> *typed_impl =
  331. down_cast<FstClassImpl<Arc> *>(impl_.get());
  332. return typed_impl->GetImpl();
  333. }
  334. }
  335. template <class Arc>
  336. static std::unique_ptr<FstClass> Read(std::istream &stream,
  337. const FstReadOptions &opts) {
  338. if (!opts.header) {
  339. LOG(ERROR) << "FstClass::Read: Options header not specified";
  340. return nullptr;
  341. }
  342. const FstHeader &hdr = *opts.header;
  343. if (hdr.Properties() & kMutable) {
  344. return ReadTypedFst<MutableFstClass, MutableFst<Arc>>(stream, opts);
  345. } else {
  346. return ReadTypedFst<FstClass, Fst<Arc>>(stream, opts);
  347. }
  348. }
  349. protected:
  350. explicit FstClass(std::unique_ptr<FstClassImplBase> impl)
  351. : impl_(std::move(impl)) {}
  352. const FstClassImplBase *GetImpl() const { return impl_.get(); }
  353. FstClassImplBase *GetImpl() { return impl_.get(); }
  354. // Generic template method for reading an arc-templated FST of type
  355. // UnderlyingT, and returning it wrapped as FstClassT, with appropriate
  356. // error checking. Called from arc-templated Read() static methods.
  357. template <class FstClassT, class UnderlyingT>
  358. static std::unique_ptr<FstClassT> ReadTypedFst(std::istream &stream,
  359. const FstReadOptions &opts) {
  360. std::unique_ptr<UnderlyingT> u(UnderlyingT::Read(stream, opts));
  361. return u ? std::make_unique<FstClassT>(std::move(u)) : nullptr;
  362. }
  363. private:
  364. std::unique_ptr<FstClassImplBase> impl_;
  365. };
  366. // Specific types of FstClass with special properties
  367. class MutableFstClass : public FstClass {
  368. public:
  369. bool AddArc(int64_t s, const ArcClass &ac) {
  370. if (!WeightTypesMatch(ac.weight, "AddArc")) return false;
  371. return GetImpl()->AddArc(s, ac);
  372. }
  373. int64_t AddState() { return GetImpl()->AddState(); }
  374. void AddStates(size_t n) { return GetImpl()->AddStates(n); }
  375. bool DeleteArcs(int64_t s, size_t n) { return GetImpl()->DeleteArcs(s, n); }
  376. bool DeleteArcs(int64_t s) { return GetImpl()->DeleteArcs(s); }
  377. bool DeleteStates(const std::vector<int64_t> &dstates) {
  378. return GetImpl()->DeleteStates(dstates);
  379. }
  380. void DeleteStates() { GetImpl()->DeleteStates(); }
  381. SymbolTable *MutableInputSymbols() {
  382. return GetImpl()->MutableInputSymbols();
  383. }
  384. SymbolTable *MutableOutputSymbols() {
  385. return GetImpl()->MutableOutputSymbols();
  386. }
  387. int64_t NumStates() const { return GetImpl()->NumStates(); }
  388. bool ReserveArcs(int64_t s, size_t n) { return GetImpl()->ReserveArcs(s, n); }
  389. void ReserveStates(int64_t n) { GetImpl()->ReserveStates(n); }
  390. static std::unique_ptr<MutableFstClass> Read(
  391. const std::string &source, bool convert = false);
  392. void SetInputSymbols(const SymbolTable *isyms) {
  393. GetImpl()->SetInputSymbols(isyms);
  394. }
  395. bool SetFinal(int64_t s, const WeightClass &weight) {
  396. if (!WeightTypesMatch(weight, "SetFinal")) return false;
  397. return GetImpl()->SetFinal(s, weight);
  398. }
  399. void SetOutputSymbols(const SymbolTable *osyms) {
  400. GetImpl()->SetOutputSymbols(osyms);
  401. }
  402. void SetProperties(uint64_t props, uint64_t mask) {
  403. GetImpl()->SetProperties(props, mask);
  404. }
  405. bool SetStart(int64_t s) { return GetImpl()->SetStart(s); }
  406. template <class Arc>
  407. explicit MutableFstClass(std::unique_ptr<MutableFst<Arc>> fst)
  408. // NB: The natural cast-less way to do this doesn't compile for some
  409. // arcane reason.
  410. : FstClass(
  411. fst::implicit_cast<std::unique_ptr<Fst<Arc>>>(std::move(fst))) {}
  412. template <class Arc>
  413. explicit MutableFstClass(const MutableFst<Arc> &fst) : FstClass(fst) {}
  414. // These methods are required by IO registration.
  415. template <class Arc>
  416. static std::unique_ptr<FstClassImplBase> Convert(const FstClass &other) {
  417. FSTERROR() << "Doesn't make sense to convert any class to type "
  418. << "MutableFstClass";
  419. return nullptr;
  420. }
  421. template <class Arc>
  422. static std::unique_ptr<FstClassImplBase> Create() {
  423. FSTERROR() << "Doesn't make sense to create a MutableFstClass with a "
  424. << "particular arc type";
  425. return nullptr;
  426. }
  427. template <class Arc>
  428. MutableFst<Arc> *GetMutableFst() {
  429. Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>());
  430. MutableFst<Arc> *mfst = down_cast<MutableFst<Arc> *>(fst);
  431. return mfst;
  432. }
  433. template <class Arc>
  434. static std::unique_ptr<MutableFstClass> Read(std::istream &stream,
  435. const FstReadOptions &opts) {
  436. std::unique_ptr<MutableFst<Arc>> mfst(MutableFst<Arc>::Read(stream, opts));
  437. return mfst ? std::make_unique<MutableFstClass>(std::move(mfst)) : nullptr;
  438. }
  439. protected:
  440. explicit MutableFstClass(std::unique_ptr<FstClassImplBase> impl)
  441. : FstClass(std::move(impl)) {}
  442. };
  443. class VectorFstClass : public MutableFstClass {
  444. public:
  445. explicit VectorFstClass(std::unique_ptr<FstClassImplBase> impl)
  446. : MutableFstClass(std::move(impl)) {}
  447. explicit VectorFstClass(const FstClass &other);
  448. explicit VectorFstClass(std::string_view arc_type);
  449. static std::unique_ptr<VectorFstClass> Read(
  450. const std::string &source);
  451. template <class Arc>
  452. static std::unique_ptr<VectorFstClass> Read(std::istream &stream,
  453. const FstReadOptions &opts) {
  454. std::unique_ptr<VectorFst<Arc>> vfst(VectorFst<Arc>::Read(stream, opts));
  455. return vfst ? std::make_unique<VectorFstClass>(std::move(vfst)) : nullptr;
  456. }
  457. template <class Arc>
  458. explicit VectorFstClass(std::unique_ptr<VectorFst<Arc>> fst)
  459. // NB: The natural cast-less way to do this doesn't compile for some
  460. // arcane reason.
  461. : MutableFstClass(fst::implicit_cast<std::unique_ptr<MutableFst<Arc>>>(
  462. std::move(fst))) {}
  463. template <class Arc>
  464. explicit VectorFstClass(const VectorFst<Arc> &fst) : MutableFstClass(fst) {}
  465. template <class Arc>
  466. static std::unique_ptr<FstClassImplBase> Convert(const FstClass &other) {
  467. return std::make_unique<FstClassImpl<Arc>>(
  468. std::make_unique<VectorFst<Arc>>(*other.GetFst<Arc>()));
  469. }
  470. template <class Arc>
  471. static std::unique_ptr<FstClassImplBase> Create() {
  472. return std::make_unique<FstClassImpl<Arc>>(
  473. std::make_unique<VectorFst<Arc>>());
  474. }
  475. };
  476. // Registration stuff.
  477. // This class definition is to avoid a nested class definition inside the
  478. // FstClassIORegistration struct.
  479. template <class Reader, class Creator, class Converter>
  480. struct FstClassRegEntry {
  481. Reader reader;
  482. Creator creator;
  483. Converter converter;
  484. FstClassRegEntry(Reader r, Creator cr, Converter co)
  485. : reader(r), creator(cr), converter(co) {}
  486. FstClassRegEntry() : reader(nullptr), creator(nullptr), converter(nullptr) {}
  487. };
  488. // Actual FST IO method register.
  489. template <class Reader, class Creator, class Converter>
  490. class FstClassIORegister
  491. : public GenericRegister<std::string,
  492. FstClassRegEntry<Reader, Creator, Converter>,
  493. FstClassIORegister<Reader, Creator, Converter>> {
  494. public:
  495. Reader GetReader(std::string_view arc_type) const {
  496. return this->GetEntry(arc_type).reader;
  497. }
  498. Creator GetCreator(std::string_view arc_type) const {
  499. return this->GetEntry(arc_type).creator;
  500. }
  501. Converter GetConverter(std::string_view arc_type) const {
  502. return this->GetEntry(arc_type).converter;
  503. }
  504. protected:
  505. std::string ConvertKeyToSoFilename(std::string_view key) const final {
  506. std::string legal_type(key);
  507. ConvertToLegalCSymbol(&legal_type);
  508. legal_type.append("-arc.so");
  509. return legal_type;
  510. }
  511. };
  512. // Struct containing everything needed to register a particular type
  513. // of FST class (e.g., a plain FstClass, or a MutableFstClass, etc.).
  514. template <class FstClassType>
  515. struct FstClassIORegistration {
  516. using Reader = std::unique_ptr<FstClassType> (*)(std::istream &stream,
  517. const FstReadOptions &opts);
  518. using Creator = std::unique_ptr<FstClassImplBase> (*)();
  519. using Converter =
  520. std::unique_ptr<FstClassImplBase> (*)(const FstClass &other);
  521. using Entry = FstClassRegEntry<Reader, Creator, Converter>;
  522. // FST class Register.
  523. using Register = FstClassIORegister<Reader, Creator, Converter>;
  524. // FST class Register-er.
  525. using Registerer =
  526. GenericRegisterer<FstClassIORegister<Reader, Creator, Converter>>;
  527. };
  528. // Macros for registering other arc types.
  529. #define REGISTER_FST_CLASS(Class, Arc) \
  530. static FstClassIORegistration<Class>::Registerer Class##_##Arc##_registerer( \
  531. Arc::Type(), \
  532. FstClassIORegistration<Class>::Entry( \
  533. Class::Read<Arc>, Class::Create<Arc>, Class::Convert<Arc>))
  534. #define REGISTER_FST_CLASSES(Arc) \
  535. REGISTER_FST_CLASS(FstClass, Arc); \
  536. REGISTER_FST_CLASS(MutableFstClass, Arc); \
  537. REGISTER_FST_CLASS(VectorFstClass, Arc);
  538. } // namespace script
  539. } // namespace fst
  540. #endif // FST_SCRIPT_FST_CLASS_H_