You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
7.8 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // FST implementation class to attach an arbitrary object with a read/write
  19. // method to an FST and its file representation. The FST is given a new type
  20. // name.
  21. #ifndef FST_ADD_ON_H_
  22. #define FST_ADD_ON_H_
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <istream>
  26. #include <memory>
  27. #include <ostream>
  28. #include <string>
  29. #include <utility>
  30. #include <fst/log.h>
  31. #include <fst/fst.h>
  32. #include <fst/properties.h>
  33. #include <fst/util.h>
  34. #include <string_view>
  35. namespace fst {
  36. // Identifies stream data as an add-on FST.
  37. inline constexpr int32_t kAddOnMagicNumber = 446681434;
  38. // Nothing to save.
  39. class NullAddOn {
  40. public:
  41. NullAddOn() = default;
  42. static NullAddOn *Read(std::istream &strm, const FstReadOptions &opts) {
  43. return new NullAddOn();
  44. }
  45. bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const {
  46. return true;
  47. }
  48. };
  49. // Create a new add-on from a pair of add-ons.
  50. template <class A1, class A2>
  51. class AddOnPair {
  52. public:
  53. // Argument reference count incremented.
  54. AddOnPair(std::shared_ptr<A1> a1, std::shared_ptr<A2> a2)
  55. : a1_(std::move(a1)), a2_(std::move(a2)) {}
  56. const A1 *First() const { return a1_.get(); }
  57. const A2 *Second() const { return a2_.get(); }
  58. std::shared_ptr<A1> SharedFirst() const { return a1_; }
  59. std::shared_ptr<A2> SharedSecond() const { return a2_; }
  60. static AddOnPair *Read(std::istream &istrm, const FstReadOptions &opts) {
  61. bool have_addon1 = false;
  62. ReadType(istrm, &have_addon1);
  63. std::unique_ptr<A1> a1;
  64. if (have_addon1) a1 = fst::WrapUnique(A1::Read(istrm, opts));
  65. bool have_addon2 = false;
  66. ReadType(istrm, &have_addon2);
  67. std::unique_ptr<A1> a2;
  68. if (have_addon2) a2 = fst::WrapUnique(A2::Read(istrm, opts));
  69. return new AddOnPair(std::move(a1), std::move(a2));
  70. }
  71. bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const {
  72. bool have_addon1 = a1_ != nullptr;
  73. WriteType(ostrm, have_addon1);
  74. if (have_addon1) a1_->Write(ostrm, opts);
  75. bool have_addon2 = a2_ != nullptr;
  76. WriteType(ostrm, have_addon2);
  77. if (have_addon2) a2_->Write(ostrm, opts);
  78. return true;
  79. }
  80. private:
  81. std::shared_ptr<A1> a1_;
  82. std::shared_ptr<A2> a2_;
  83. };
  84. namespace internal {
  85. // Adds an object of type T to an FST. T must support:
  86. //
  87. // T* Read(std::istream &);
  88. // bool Write(std::ostream &);
  89. //
  90. // The resulting type is a new FST implementation.
  91. template <class FST, class T>
  92. class AddOnImpl : public FstImpl<typename FST::Arc> {
  93. public:
  94. using FstType = FST;
  95. using Arc = typename FST::Arc;
  96. using Label = typename Arc::Label;
  97. using StateId = typename Arc::StateId;
  98. using Weight = typename Arc::Weight;
  99. using FstImpl<Arc>::SetType;
  100. using FstImpl<Arc>::SetInputSymbols;
  101. using FstImpl<Arc>::SetOutputSymbols;
  102. using FstImpl<Arc>::SetProperties;
  103. using FstImpl<Arc>::WriteHeader;
  104. // We make a thread-safe copy of the FST by default since an FST
  105. // implementation is expected to not share mutable data between objects.
  106. AddOnImpl(const FST &fst, std::string_view type,
  107. std::shared_ptr<T> t = nullptr)
  108. : fst_(fst, true), t_(std::move(t)) {
  109. SetType(type);
  110. SetProperties(fst_.Properties(kFstProperties, false));
  111. SetInputSymbols(fst_.InputSymbols());
  112. SetOutputSymbols(fst_.OutputSymbols());
  113. }
  114. // Conversion from const Fst<Arc> & to F always copies the underlying
  115. // implementation.
  116. AddOnImpl(const Fst<Arc> &fst, std::string_view type,
  117. std::shared_ptr<T> t = nullptr)
  118. : fst_(fst), t_(std::move(t)) {
  119. SetType(type);
  120. SetProperties(fst_.Properties(kFstProperties, false));
  121. SetInputSymbols(fst_.InputSymbols());
  122. SetOutputSymbols(fst_.OutputSymbols());
  123. }
  124. // We make a thread-safe copy of the FST by default since an FST
  125. // implementation is expected to not share mutable data between objects.
  126. AddOnImpl(const AddOnImpl &impl) : fst_(impl.fst_, true), t_(impl.t_) {
  127. SetType(impl.Type());
  128. SetProperties(fst_.Properties(kCopyProperties, false));
  129. SetInputSymbols(fst_.InputSymbols());
  130. SetOutputSymbols(fst_.OutputSymbols());
  131. }
  132. StateId Start() const { return fst_.Start(); }
  133. Weight Final(StateId s) const { return fst_.Final(s); }
  134. size_t NumArcs(StateId s) const { return fst_.NumArcs(s); }
  135. size_t NumInputEpsilons(StateId s) const { return fst_.NumInputEpsilons(s); }
  136. size_t NumOutputEpsilons(StateId s) const {
  137. return fst_.NumOutputEpsilons(s);
  138. }
  139. size_t NumStates() const { return fst_.NumStates(); }
  140. static AddOnImpl *Read(std::istream &strm, const FstReadOptions &opts) {
  141. FstReadOptions nopts(opts);
  142. FstHeader hdr;
  143. if (!nopts.header) {
  144. hdr.Read(strm, nopts.source);
  145. nopts.header = &hdr;
  146. }
  147. // Using `new` to access private constructor for `AddOnImpl`.
  148. auto impl = fst::WrapUnique(new AddOnImpl(nopts.header->FstType()));
  149. if (!impl->ReadHeader(strm, nopts, kMinFileVersion, &hdr)) return nullptr;
  150. impl.reset();
  151. int32_t magic_number = 0;
  152. ReadType(strm, &magic_number); // Ensures this is an add-on FST.
  153. if (magic_number != kAddOnMagicNumber) {
  154. LOG(ERROR) << "AddOnImpl::Read: Bad add-on header: " << nopts.source;
  155. return nullptr;
  156. }
  157. FstReadOptions fopts(opts);
  158. fopts.header = nullptr; // Contained header was written out.
  159. std::unique_ptr<FST> fst(FST::Read(strm, fopts));
  160. if (!fst) return nullptr;
  161. std::shared_ptr<T> t;
  162. bool have_addon = false;
  163. ReadType(strm, &have_addon);
  164. if (have_addon) { // Reads add-on object if present.
  165. t = std::shared_ptr<T>(T::Read(strm, fopts));
  166. if (!t) return nullptr;
  167. }
  168. return new AddOnImpl(*fst, nopts.header->FstType(), t);
  169. }
  170. bool Write(std::ostream &strm, const FstWriteOptions &opts) const {
  171. FstHeader hdr;
  172. FstWriteOptions nopts(opts);
  173. nopts.write_isymbols = false; // Allows contained FST to hold any symbols.
  174. nopts.write_osymbols = false;
  175. WriteHeader(strm, nopts, kFileVersion, &hdr);
  176. WriteType(strm, kAddOnMagicNumber); // Ensures this is an add-on FST.
  177. FstWriteOptions fopts(opts);
  178. fopts.write_header = true; // Forces writing contained header.
  179. if (!fst_.Write(strm, fopts)) return false;
  180. bool have_addon = !!t_;
  181. WriteType(strm, have_addon);
  182. // Writes add-on object if present.
  183. if (have_addon) t_->Write(strm, opts);
  184. return true;
  185. }
  186. void InitStateIterator(StateIteratorData<Arc> *data) const {
  187. fst_.InitStateIterator(data);
  188. }
  189. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
  190. fst_.InitArcIterator(s, data);
  191. }
  192. FST &GetFst() { return fst_; }
  193. const FST &GetFst() const { return fst_; }
  194. const T *GetAddOn() const { return t_.get(); }
  195. std::shared_ptr<T> GetSharedAddOn() const { return t_; }
  196. void SetAddOn(std::shared_ptr<T> t) { t_ = t; }
  197. private:
  198. explicit AddOnImpl(std::string_view type) : t_() {
  199. SetType(type);
  200. SetProperties(kExpanded);
  201. }
  202. // Current file format version.
  203. static constexpr int kFileVersion = 1;
  204. // Minimum file format version supported.
  205. static constexpr int kMinFileVersion = 1;
  206. FST fst_;
  207. std::shared_ptr<T> t_;
  208. AddOnImpl &operator=(const AddOnImpl &) = delete;
  209. };
  210. } // namespace internal
  211. } // namespace fst
  212. #endif // FST_ADD_ON_H_