You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

426 lines
13 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Expanded FST augmented with mutators; interface class definition and
  19. // mutable arc iterator interface.
  20. #ifndef FST_MUTABLE_FST_H_
  21. #define FST_MUTABLE_FST_H_
  22. #include <sys/types.h>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <ios>
  26. #include <iostream>
  27. #include <istream>
  28. #include <memory>
  29. #include <string>
  30. #include <utility>
  31. #include <vector>
  32. #include <fst/log.h>
  33. #include <fst/arc.h>
  34. #include <fst/expanded-fst.h>
  35. #include <fstream>
  36. #include <fst/fst.h>
  37. #include <fst/properties.h>
  38. #include <fst/register.h>
  39. #include <fst/symbol-table.h>
  40. #include <string_view>
  41. namespace fst {
  42. template <class Arc>
  43. struct MutableArcIteratorData;
  44. // Abstract interface for an expanded FST which also supports mutation
  45. // operations. To modify arcs, use MutableArcIterator.
  46. template <class A>
  47. class MutableFst : public ExpandedFst<A> {
  48. public:
  49. using Arc = A;
  50. using StateId = typename Arc::StateId;
  51. using Weight = typename Arc::Weight;
  52. virtual MutableFst<Arc> &operator=(const Fst<Arc> &fst) = 0;
  53. MutableFst &operator=(const MutableFst &fst) {
  54. return operator=(static_cast<const Fst<Arc> &>(fst));
  55. }
  56. // Sets the initial state.
  57. virtual void SetStart(StateId) = 0;
  58. // Sets a state's final weight.
  59. virtual void SetFinal(StateId s, Weight weight = Weight::One()) = 0;
  60. // Sets property bits w.r.t. mask.
  61. virtual void SetProperties(uint64_t props, uint64_t mask) = 0;
  62. // Adds a state and returns its ID.
  63. virtual StateId AddState() = 0;
  64. // Adds multiple states.
  65. virtual void AddStates(size_t) = 0;
  66. // Adds an arc to state.
  67. virtual void AddArc(StateId, const Arc &) = 0;
  68. // Adds an arc (passed by rvalue reference) to state. Allows subclasses
  69. // to optionally implement move semantics. Defaults to lvalue overload.
  70. virtual void AddArc(StateId state, Arc &&arc) { AddArc(state, arc); }
  71. // Deletes some states, preserving original StateId ordering.
  72. virtual void DeleteStates(const std::vector<StateId> &) = 0;
  73. // Delete all states.
  74. virtual void DeleteStates() = 0;
  75. // Delete some arcs at a given state.
  76. virtual void DeleteArcs(StateId, size_t) = 0;
  77. // Delete all arcs at a given state.
  78. virtual void DeleteArcs(StateId) = 0;
  79. // Optional, best effort only.
  80. virtual void ReserveStates(size_t) {}
  81. // Optional, best effort only.
  82. virtual void ReserveArcs(StateId, size_t) {}
  83. // Returns input label symbol table or nullptr if not specified.
  84. const SymbolTable *InputSymbols() const override = 0;
  85. // Returns output label symbol table or nullptr if not specified.
  86. const SymbolTable *OutputSymbols() const override = 0;
  87. // Returns input label symbol table or nullptr if not specified.
  88. virtual SymbolTable *MutableInputSymbols() = 0;
  89. // Returns output label symbol table or nullptr if not specified.
  90. virtual SymbolTable *MutableOutputSymbols() = 0;
  91. // Sets input label symbol table; pass nullptr to delete table.
  92. virtual void SetInputSymbols(const SymbolTable *isyms) = 0;
  93. // Sets output label symbol table; pass nullptr to delete table.
  94. virtual void SetOutputSymbols(const SymbolTable *osyms) = 0;
  95. // Gets a copy of this MutableFst. See Fst<>::Copy() for further doc.
  96. MutableFst *Copy(bool safe = false) const override = 0;
  97. // Reads a MutableFst from an input stream, returning nullptr on error.
  98. static MutableFst *Read(std::istream &strm, const FstReadOptions &opts) {
  99. FstReadOptions ropts(opts);
  100. FstHeader hdr;
  101. if (ropts.header) {
  102. hdr = *opts.header;
  103. } else {
  104. if (!hdr.Read(strm, opts.source)) return nullptr;
  105. ropts.header = &hdr;
  106. }
  107. if (!(hdr.Properties() & kMutable)) {
  108. LOG(ERROR) << "MutableFst::Read: Not a MutableFst: " << ropts.source;
  109. return nullptr;
  110. }
  111. const auto &fst_type = hdr.FstType();
  112. const auto reader = FstRegister<Arc>::GetRegister()->GetReader(fst_type);
  113. if (!reader) {
  114. LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << fst_type
  115. << "\" (arc type = \"" << A::Type() << "\"): " << ropts.source;
  116. return nullptr;
  117. }
  118. auto *fst = reader(strm, ropts);
  119. if (!fst) return nullptr;
  120. return down_cast<MutableFst *>(fst);
  121. }
  122. // Reads a MutableFst from a file; returns nullptr on error. An empty
  123. // source results in reading from standard input. If convert is true,
  124. // convert to a mutable FST subclass (given by convert_type) in the case
  125. // that the input FST is non-mutable.
  126. static MutableFst *Read(const std::string &source, bool convert = false,
  127. std::string_view convert_type = "vector") {
  128. if (convert == false) {
  129. if (!source.empty()) {
  130. std::ifstream strm(source,
  131. std::ios_base::in | std::ios_base::binary);
  132. if (!strm) {
  133. LOG(ERROR) << "MutableFst::Read: Can't open file: " << source;
  134. return nullptr;
  135. }
  136. return Read(strm, FstReadOptions(source));
  137. } else {
  138. return Read(std::cin, FstReadOptions("standard input"));
  139. }
  140. } else { // Converts to 'convert_type' if not mutable.
  141. std::unique_ptr<Fst<Arc>> ifst(Fst<Arc>::Read(source));
  142. if (!ifst) return nullptr;
  143. if (ifst->Properties(kMutable, false)) {
  144. return down_cast<MutableFst *>(ifst.release());
  145. } else {
  146. std::unique_ptr<Fst<Arc>> ofst(Convert(*ifst, convert_type));
  147. ifst.reset();
  148. if (!ofst) return nullptr;
  149. if (!ofst->Properties(kMutable, false)) {
  150. LOG(ERROR) << "MutableFst: Bad convert type: " << convert_type;
  151. }
  152. return down_cast<MutableFst *>(ofst.release());
  153. }
  154. }
  155. }
  156. // For generic mutuble arc iterator construction; not normally called
  157. // directly by users.
  158. virtual void InitMutableArcIterator(StateId s,
  159. MutableArcIteratorData<Arc> *data) = 0;
  160. };
  161. // Mutable arc iterator interface, templated on the Arc definition. This is
  162. // used by mutable arc iterator specializations that are returned by the
  163. // InitMutableArcIterator MutableFst method.
  164. template <class Arc>
  165. class MutableArcIteratorBase : public ArcIteratorBase<Arc> {
  166. public:
  167. // Sets current arc.
  168. virtual void SetValue(const Arc &) = 0;
  169. };
  170. template <class Arc>
  171. struct MutableArcIteratorData {
  172. std::unique_ptr<MutableArcIteratorBase<Arc>> base; // Specific iterator.
  173. };
  174. // Generic mutable arc iterator, templated on the FST definition; a wrapper
  175. // around a pointer to a more specific one.
  176. //
  177. // Here is a typical use:
  178. //
  179. // for (MutableArcIterator<StdFst> aiter(&fst, s);
  180. // !aiter.Done();
  181. // aiter.Next()) {
  182. // StdArc arc = aiter.Value();
  183. // arc.ilabel = 7;
  184. // aiter.SetValue(arc);
  185. // ...
  186. // }
  187. //
  188. // This version requires function calls.
  189. template <class FST>
  190. class MutableArcIterator {
  191. public:
  192. using Arc = typename FST::Arc;
  193. using StateId = typename Arc::StateId;
  194. MutableArcIterator(FST *fst, StateId s) {
  195. fst->InitMutableArcIterator(s, &data_);
  196. }
  197. bool Done() const { return data_.base->Done(); }
  198. const Arc &Value() const { return data_.base->Value(); }
  199. void Next() { data_.base->Next(); }
  200. size_t Position() const { return data_.base->Position(); }
  201. void Reset() { data_.base->Reset(); }
  202. void Seek(size_t a) { data_.base->Seek(a); }
  203. void SetValue(const Arc &arc) { data_.base->SetValue(arc); }
  204. uint8_t Flags() const { return data_.base->Flags(); }
  205. void SetFlags(uint8_t flags, uint8_t mask) {
  206. return data_.base->SetFlags(flags, mask);
  207. }
  208. private:
  209. MutableArcIteratorData<Arc> data_;
  210. MutableArcIterator(const MutableArcIterator &) = delete;
  211. MutableArcIterator &operator=(const MutableArcIterator &) = delete;
  212. };
  213. namespace internal {
  214. // MutableFst<A> case: abstract methods.
  215. template <class Arc>
  216. inline typename Arc::Weight Final(const MutableFst<Arc> &fst,
  217. typename Arc::StateId s) {
  218. return fst.Final(s);
  219. }
  220. template <class Arc>
  221. inline ssize_t NumArcs(const MutableFst<Arc> &fst, typename Arc::StateId s) {
  222. return fst.NumArcs(s);
  223. }
  224. template <class Arc>
  225. inline ssize_t NumInputEpsilons(const MutableFst<Arc> &fst,
  226. typename Arc::StateId s) {
  227. return fst.NumInputEpsilons(s);
  228. }
  229. template <class Arc>
  230. inline ssize_t NumOutputEpsilons(const MutableFst<Arc> &fst,
  231. typename Arc::StateId s) {
  232. return fst.NumOutputEpsilons(s);
  233. }
  234. } // namespace internal
  235. // A useful alias when using StdArc.
  236. using StdMutableFst = MutableFst<StdArc>;
  237. // This is a helper class template useful for attaching a MutableFst interface
  238. // to its implementation, handling reference counting and COW semantics.
  239. template <class Impl, class FST = MutableFst<typename Impl::Arc>>
  240. class ImplToMutableFst : public ImplToExpandedFst<Impl, FST> {
  241. public:
  242. using Arc = typename Impl::Arc;
  243. using StateId = typename Arc::StateId;
  244. using Weight = typename Arc::Weight;
  245. using ImplToExpandedFst<Impl, FST>::operator=;
  246. void SetStart(StateId s) override {
  247. MutateCheck();
  248. GetMutableImpl()->SetStart(s);
  249. }
  250. void SetFinal(StateId s, Weight weight = Weight::One()) override {
  251. MutateCheck();
  252. GetMutableImpl()->SetFinal(s, std::move(weight));
  253. }
  254. void SetProperties(uint64_t props, uint64_t mask) override {
  255. // Can skip mutate check if extrinsic properties don't change,
  256. // since it is then safe to update all (shallow) copies
  257. const auto exprops = kExtrinsicProperties & mask;
  258. if (GetImpl()->Properties(exprops) != (props & exprops)) MutateCheck();
  259. GetMutableImpl()->SetProperties(props, mask);
  260. }
  261. StateId AddState() override {
  262. MutateCheck();
  263. return GetMutableImpl()->AddState();
  264. }
  265. void AddStates(size_t n) override {
  266. MutateCheck();
  267. return GetMutableImpl()->AddStates(n);
  268. }
  269. void AddArc(StateId s, const Arc &arc) override {
  270. MutateCheck();
  271. GetMutableImpl()->AddArc(s, arc);
  272. }
  273. void AddArc(StateId s, Arc &&arc) override {
  274. MutateCheck();
  275. GetMutableImpl()->AddArc(s, std::forward<Arc>(arc));
  276. }
  277. void DeleteStates(const std::vector<StateId> &dstates) override {
  278. MutateCheck();
  279. GetMutableImpl()->DeleteStates(dstates);
  280. }
  281. void DeleteStates() override {
  282. if (!Unique()) {
  283. const auto *isymbols = GetImpl()->InputSymbols();
  284. const auto *osymbols = GetImpl()->OutputSymbols();
  285. SetImpl(std::make_shared<Impl>());
  286. GetMutableImpl()->SetInputSymbols(isymbols);
  287. GetMutableImpl()->SetOutputSymbols(osymbols);
  288. } else {
  289. GetMutableImpl()->DeleteStates();
  290. }
  291. }
  292. void DeleteArcs(StateId s, size_t n) override {
  293. MutateCheck();
  294. GetMutableImpl()->DeleteArcs(s, n);
  295. }
  296. void DeleteArcs(StateId s) override {
  297. MutateCheck();
  298. GetMutableImpl()->DeleteArcs(s);
  299. }
  300. void ReserveStates(size_t n) override {
  301. MutateCheck();
  302. GetMutableImpl()->ReserveStates(n);
  303. }
  304. void ReserveArcs(StateId s, size_t n) override {
  305. MutateCheck();
  306. GetMutableImpl()->ReserveArcs(s, n);
  307. }
  308. const SymbolTable *InputSymbols() const override {
  309. return GetImpl()->InputSymbols();
  310. }
  311. const SymbolTable *OutputSymbols() const override {
  312. return GetImpl()->OutputSymbols();
  313. }
  314. SymbolTable *MutableInputSymbols() override {
  315. MutateCheck();
  316. return GetMutableImpl()->InputSymbols();
  317. }
  318. SymbolTable *MutableOutputSymbols() override {
  319. MutateCheck();
  320. return GetMutableImpl()->OutputSymbols();
  321. }
  322. void SetInputSymbols(const SymbolTable *isyms) override {
  323. MutateCheck();
  324. GetMutableImpl()->SetInputSymbols(isyms);
  325. }
  326. void SetOutputSymbols(const SymbolTable *osyms) override {
  327. MutateCheck();
  328. GetMutableImpl()->SetOutputSymbols(osyms);
  329. }
  330. protected:
  331. using ImplToExpandedFst<Impl, FST>::GetImpl;
  332. using ImplToExpandedFst<Impl, FST>::GetMutableImpl;
  333. using ImplToExpandedFst<Impl, FST>::Unique;
  334. using ImplToExpandedFst<Impl, FST>::SetImpl;
  335. using ImplToExpandedFst<Impl, FST>::InputSymbols;
  336. explicit ImplToMutableFst(std::shared_ptr<Impl> impl)
  337. : ImplToExpandedFst<Impl, FST>(impl) {}
  338. ImplToMutableFst(const ImplToMutableFst &fst, bool safe)
  339. : ImplToExpandedFst<Impl, FST>(fst, safe) {}
  340. void MutateCheck() {
  341. if (!Unique()) SetImpl(std::make_shared<Impl>(*this));
  342. }
  343. };
  344. } // namespace fst
  345. #endif // FST_MUTABLE_FST_H_