You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

978 lines
32 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // FST abstract base class definition, state and arc iterator interface, and
  19. // suggested base implementation.
  20. #ifndef FST_FST_H_
  21. #define FST_FST_H_
  22. #include <sys/types.h>
  23. #include <atomic>
  24. #include <cmath>
  25. #include <cstddef>
  26. #include <cstdint>
  27. #include <ios>
  28. #include <iostream>
  29. #include <istream>
  30. #include <memory>
  31. #include <optional>
  32. #include <ostream>
  33. #include <sstream>
  34. #include <string>
  35. #include <utility>
  36. #include <fst/compat.h>
  37. #include <fst/flags.h>
  38. #include <fst/log.h>
  39. #include <fst/arc.h>
  40. #include <fstream>
  41. #include <fst/memory.h>
  42. #include <fst/properties.h>
  43. #include <fst/register.h>
  44. #include <fst/symbol-table.h>
  45. #include <fst/util.h>
  46. #include <string_view>
  47. DECLARE_bool(fst_align);
  48. namespace fst {
  49. // Identifies stream data as an FST (and its endianity).
  50. inline constexpr int32_t kFstMagicNumber = 2125659606;
  51. class FstHeader;
  52. template <class Arc>
  53. class MatcherBase;
  54. template <class Arc>
  55. struct ArcIteratorData;
  56. template <class Arc>
  57. struct StateIteratorData;
  58. struct FstReadOptions {
  59. // FileReadMode(s) are advisory, there are many conditions than prevent a
  60. // file from being mapped, READ mode will be selected in these cases with
  61. // a warning indicating why it was chosen.
  62. enum FileReadMode { READ, MAP };
  63. std::string source; // Where you're reading from.
  64. const FstHeader *header; // Pointer to FST header; if non-zero, use
  65. // this info (don't read a stream header).
  66. const SymbolTable *isymbols; // Pointer to input symbols; if non-zero, use
  67. // this info (read and skip stream isymbols)
  68. const SymbolTable *osymbols; // Pointer to output symbols; if non-zero, use
  69. // this info (read and skip stream osymbols)
  70. FileReadMode mode; // Read or map files (advisory, if possible)
  71. bool read_isymbols; // Read isymbols, if any (default: true).
  72. bool read_osymbols; // Read osymbols, if any (default: true).
  73. explicit FstReadOptions(
  74. const std::string_view source = "<unspecified>",
  75. const FstHeader * header = nullptr,
  76. const SymbolTable * isymbols = nullptr,
  77. const SymbolTable * osymbols = nullptr);
  78. explicit FstReadOptions(const std::string_view source,
  79. const SymbolTable *isymbols,
  80. const SymbolTable *osymbols = nullptr);
  81. // Helper function to convert strings FileReadModes into their enum value.
  82. static FileReadMode ReadMode(std::string_view mode);
  83. // Outputs a debug string for the FstReadOptions object.
  84. std::string DebugString() const;
  85. };
  86. struct FstWriteOptions {
  87. std::string source; // Where you're writing to.
  88. bool write_header; // Write the header?
  89. bool write_isymbols; // Write input symbols?
  90. bool write_osymbols; // Write output symbols?
  91. bool align; // Write data aligned (may fail on pipes)?
  92. bool stream_write; // Avoid seek operations in writing.
  93. explicit FstWriteOptions(std::string_view source = "<unspecified>",
  94. bool write_header = true, bool write_isymbols = true,
  95. bool write_osymbols = true,
  96. bool align = FST_FLAGS_fst_align,
  97. bool stream_write = false)
  98. : source(source),
  99. write_header(write_header),
  100. write_isymbols(write_isymbols),
  101. write_osymbols(write_osymbols),
  102. align(align),
  103. stream_write(stream_write) {}
  104. };
  105. // Header class.
  106. //
  107. // This is the recommended file header representation.
  108. class FstHeader {
  109. public:
  110. enum Flags {
  111. HAS_ISYMBOLS = 0x1, // Has input symbol table.
  112. HAS_OSYMBOLS = 0x2, // Has output symbol table.
  113. IS_ALIGNED = 0x4, // Memory-aligned (where appropriate).
  114. };
  115. FstHeader() = default;
  116. const std::string &FstType() const { return fsttype_; }
  117. const std::string &ArcType() const { return arctype_; }
  118. int32_t Version() const { return version_; }
  119. uint32_t GetFlags() const { return flags_; }
  120. uint64_t Properties() const { return properties_; }
  121. int64_t Start() const { return start_; }
  122. int64_t NumStates() const { return numstates_; }
  123. int64_t NumArcs() const { return numarcs_; }
  124. void SetFstType(std::string_view type) { fsttype_ = std::string(type); }
  125. void SetArcType(std::string_view type) { arctype_ = std::string(type); }
  126. void SetVersion(int32_t version) { version_ = version; }
  127. void SetFlags(uint32_t flags) { flags_ = flags; }
  128. void SetProperties(uint64_t properties) { properties_ = properties; }
  129. void SetStart(int64_t start) { start_ = start; }
  130. void SetNumStates(int64_t numstates) { numstates_ = numstates; }
  131. void SetNumArcs(int64_t numarcs) { numarcs_ = numarcs; }
  132. bool Read(std::istream &strm, const std::string &source, bool rewind = false);
  133. bool Write(std::ostream &strm, std::string_view source) const;
  134. // Outputs a debug string for the FstHeader object.
  135. std::string DebugString() const;
  136. private:
  137. std::string fsttype_; // E.g. "vector".
  138. std::string arctype_; // E.g. "standard".
  139. int32_t version_ = 0; // Type version number.
  140. uint32_t flags_ = 0; // File format bits.
  141. uint64_t properties_ = 0; // FST property bits.
  142. int64_t start_ = -1; // Start state.
  143. int64_t numstates_ = 0; // # of states.
  144. int64_t numarcs_ = 0; // # of arcs.
  145. };
  146. // Specifies matcher action.
  147. enum MatchType {
  148. MATCH_INPUT = 1, // Match input label.
  149. MATCH_OUTPUT = 2, // Match output label.
  150. MATCH_BOTH = 3, // Match input or output label.
  151. MATCH_NONE = 4, // Match nothing.
  152. MATCH_UNKNOWN = 5
  153. }; // Otherwise, match type unknown.
  154. inline constexpr int kNoLabel = -1; // Not a valid label.
  155. inline constexpr int kNoStateId = -1; // Not a valid state ID.
  156. // A generic FST, templated on the arc definition, with common-demoninator
  157. // methods (use StateIterator and ArcIterator to iterate over its states and
  158. // arcs). Derived classes should be assumed to be thread-unsafe unless
  159. // otherwise specified.
  160. template <class A>
  161. class Fst {
  162. public:
  163. using Arc = A;
  164. using StateId = typename Arc::StateId;
  165. using Weight = typename Arc::Weight;
  166. virtual ~Fst() = default;
  167. // Initial state.
  168. virtual StateId Start() const = 0;
  169. // State's final weight.
  170. virtual Weight Final(StateId) const = 0;
  171. // State's arc count.
  172. virtual size_t NumArcs(StateId) const = 0;
  173. // State's input epsilon count.
  174. virtual size_t NumInputEpsilons(StateId) const = 0;
  175. // State's output epsilon count.
  176. virtual size_t NumOutputEpsilons(StateId) const = 0;
  177. // Returns the number of states if it is finite and can be computed in O(1)
  178. // time. Otherwise returns nullopt.
  179. virtual std::optional<StateId> NumStatesIfKnown() const {
  180. return std::nullopt;
  181. }
  182. // Property bits. If test = false, return stored properties bits for mask
  183. // (some possibly unknown); if test = true, return property bits for mask
  184. // (computing o.w. unknown).
  185. virtual uint64_t Properties(uint64_t mask, bool test) const = 0;
  186. // FST type name.
  187. virtual const std::string &Type() const = 0;
  188. // Gets a copy of this Fst. The copying behaves as follows:
  189. //
  190. // (1) The copying is constant time if safe = false or if safe = true
  191. // and is on an otherwise unaccessed FST.
  192. //
  193. // (2) If safe = true, the copy is thread-safe in that the original
  194. // and copy can be safely accessed (but not necessarily mutated) by
  195. // separate threads. For some FST types, 'Copy(true)' should only be
  196. // called on an FST that has not otherwise been accessed. Behavior is
  197. // otherwise undefined.
  198. //
  199. // (3) If a MutableFst is copied and then mutated, then the original is
  200. // unmodified and vice versa (often by a copy-on-write on the initial
  201. // mutation, which may not be constant time).
  202. virtual Fst *Copy(bool safe = false) const = 0;
  203. // Reads an FST from an input stream; returns nullptr on error.
  204. static Fst *Read(std::istream &strm, const FstReadOptions &opts) {
  205. FstReadOptions ropts(opts);
  206. FstHeader hdr;
  207. if (ropts.header) {
  208. hdr = *opts.header;
  209. } else {
  210. if (!hdr.Read(strm, opts.source)) return nullptr;
  211. ropts.header = &hdr;
  212. }
  213. const auto &fst_type = hdr.FstType();
  214. const auto reader = FstRegister<Arc>::GetRegister()->GetReader(fst_type);
  215. if (!reader) {
  216. LOG(ERROR) << "Fst::Read: Unknown FST type " << fst_type
  217. << " (arc type = " << Arc::Type() << "): " << ropts.source;
  218. return nullptr;
  219. }
  220. return reader(strm, ropts);
  221. }
  222. // Reads an FST from a file; returns nullptr on error. An empty source
  223. // results in reading from standard input.
  224. static Fst *Read(const std::string &source) {
  225. if (!source.empty()) {
  226. std::ifstream strm(source,
  227. std::ios_base::in | std::ios_base::binary);
  228. if (!strm) {
  229. LOG(ERROR) << "Fst::Read: Can't open file: " << source;
  230. return nullptr;
  231. }
  232. return Read(strm, FstReadOptions(source));
  233. } else {
  234. return Read(std::cin, FstReadOptions("standard input"));
  235. }
  236. }
  237. // Writes an FST to an output stream; returns false on error.
  238. virtual bool Write(std::ostream &strm, const FstWriteOptions &opts) const {
  239. LOG(ERROR) << "Fst::Write: No write stream method for " << Type()
  240. << " FST type";
  241. return false;
  242. }
  243. // Writes an FST to a file; returns false on error; an empty source
  244. // results in writing to standard output.
  245. virtual bool Write(const std::string &source) const {
  246. LOG(ERROR) << "Fst::Write: No write source method for " << Type()
  247. << " FST type";
  248. return false;
  249. }
  250. // Some Fst implementations support
  251. // template <class Fst2>
  252. // static bool Fst1::WriteFst(const Fst2 &fst2, ...);
  253. // which is equivalent to Fst1(fst2).Write(...), but uses less memory.
  254. // WriteFst is not part of the general Fst interface.
  255. // Returns input label symbol table; return nullptr if not specified.
  256. virtual const SymbolTable *InputSymbols() const = 0;
  257. // Return output label symbol table; return nullptr if not specified.
  258. virtual const SymbolTable *OutputSymbols() const = 0;
  259. // For generic state iterator construction (not normally called directly by
  260. // users). Does not copy the FST.
  261. virtual void InitStateIterator(StateIteratorData<Arc> *data) const = 0;
  262. // For generic arc iterator construction (not normally called directly by
  263. // users). Does not copy the FST.
  264. virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const = 0;
  265. // For generic matcher construction (not normally called directly by users).
  266. // Does not copy the FST.
  267. virtual MatcherBase<Arc> *InitMatcher(MatchType match_type) const;
  268. protected:
  269. bool WriteFile(const std::string &source) const {
  270. if (!source.empty()) {
  271. std::ofstream strm(source,
  272. std::ios_base::out | std::ios_base::binary);
  273. if (!strm) {
  274. LOG(ERROR) << "Fst::WriteFile: Can't open file: " << source;
  275. return false;
  276. }
  277. if (!Write(strm, FstWriteOptions(source))) {
  278. LOG(ERROR) << "Fst::WriteFile: Write failed: " << source;
  279. return false;
  280. }
  281. return true;
  282. } else {
  283. return Write(std::cout, FstWriteOptions("standard output"));
  284. }
  285. }
  286. };
  287. // A useful alias when using StdArc.
  288. using StdFst = Fst<StdArc>;
  289. // State and arc iterator definitions.
  290. //
  291. // State iterator interface templated on the Arc definition; used for
  292. // StateIterator specializations returned by the InitStateIterator FST method.
  293. template <class Arc>
  294. class StateIteratorBase {
  295. public:
  296. using StateId = typename Arc::StateId;
  297. virtual ~StateIteratorBase() = default;
  298. // End of iterator?
  299. virtual bool Done() const = 0;
  300. // Returns current state (when !Done()).
  301. virtual StateId Value() const = 0;
  302. // Advances to next state (when !Done()).
  303. virtual void Next() = 0;
  304. // Resets to initial condition.
  305. virtual void Reset() = 0;
  306. };
  307. // StateIterator initialization data.
  308. template <class Arc>
  309. struct StateIteratorData {
  310. using StateId = typename Arc::StateId;
  311. // Specialized iterator if non-null.
  312. std::unique_ptr<StateIteratorBase<Arc>> base;
  313. // Otherwise, the total number of states.
  314. StateId nstates = 0;
  315. StateIteratorData() = default;
  316. StateIteratorData(const StateIteratorData &) = delete;
  317. StateIteratorData &operator=(const StateIteratorData &) = delete;
  318. };
  319. // Generic state iterator, templated on the FST definition (a wrapper
  320. // around a pointer to a specific one). Here is a typical use:
  321. //
  322. // for (StateIterator<StdFst> siter(fst);
  323. // !siter.Done();
  324. // siter.Next()) {
  325. // StateId s = siter.Value();
  326. // ...
  327. // }
  328. // There is no copying of the FST.
  329. //
  330. // Specializations may exist for some FST types.
  331. // StateIterators are thread-unsafe unless otherwise specified.
  332. template <class FST>
  333. class StateIterator {
  334. public:
  335. using Arc = typename FST::Arc;
  336. using StateId = typename Arc::StateId;
  337. explicit StateIterator(const FST &fst) {
  338. fst.InitStateIterator(&data_);
  339. }
  340. bool Done() const {
  341. return data_.base ? data_.base->Done() : s_ >= data_.nstates;
  342. }
  343. StateId Value() const { return data_.base ? data_.base->Value() : s_; }
  344. void Next() {
  345. if (data_.base) {
  346. data_.base->Next();
  347. } else {
  348. ++s_;
  349. }
  350. }
  351. void Reset() {
  352. if (data_.base) {
  353. data_.base->Reset();
  354. } else {
  355. s_ = 0;
  356. }
  357. }
  358. private:
  359. StateIteratorData<Arc> data_;
  360. StateId s_ = 0;
  361. };
  362. // Flags to control the behavior on an arc iterator via SetFlags().
  363. // Value() gives valid ilabel.
  364. inline constexpr uint8_t kArcILabelValue = 0x01;
  365. // Value() call gives valid olabel.
  366. inline constexpr uint8_t kArcOLabelValue = 0x02;
  367. // Value() call gives valid weight.
  368. inline constexpr uint8_t kArcWeightValue = 0x04;
  369. // Value() call gives valid nextstate.
  370. inline constexpr uint8_t kArcNextStateValue = 0x08;
  371. // Arcs need not be cached.
  372. inline constexpr uint8_t kArcNoCache = 0x10;
  373. inline constexpr uint8_t kArcValueFlags =
  374. kArcILabelValue | kArcOLabelValue | kArcWeightValue | kArcNextStateValue;
  375. inline constexpr uint8_t kArcFlags = kArcValueFlags | kArcNoCache;
  376. // Arc iterator interface, templated on the arc definition; used for arc
  377. // iterator specializations that are returned by the InitArcIterator FST method.
  378. template <class Arc>
  379. class ArcIteratorBase {
  380. public:
  381. using StateId = typename Arc::StateId;
  382. virtual ~ArcIteratorBase() = default;
  383. // End of iterator?
  384. virtual bool Done() const = 0;
  385. // Returns current arc (when !Done()).
  386. virtual const Arc &Value() const = 0;
  387. // Advances to next arc (when !Done()).
  388. virtual void Next() = 0;
  389. // Returns current position.
  390. virtual size_t Position() const = 0;
  391. // Returns to initial condition.
  392. virtual void Reset() = 0;
  393. // Advances to arbitrary arc by position.
  394. virtual void Seek(size_t) = 0;
  395. // Returns current behavorial flags, a bitmask of kArcFlags.
  396. virtual uint8_t Flags() const = 0;
  397. // Sets behavorial flags, a bitmask of kArcFlags.
  398. virtual void SetFlags(uint8_t, uint8_t) = 0;
  399. };
  400. // ArcIterator initialization data.
  401. template <class Arc>
  402. struct ArcIteratorData {
  403. ArcIteratorData() = default;
  404. ArcIteratorData(const ArcIteratorData &) = delete;
  405. ArcIteratorData &operator=(const ArcIteratorData &) = delete;
  406. std::unique_ptr<ArcIteratorBase<Arc>>
  407. base; // Specialized iterator if non-null.
  408. const Arc *arcs = nullptr; // O.w. arcs pointer
  409. size_t narcs = 0; // ... and arc count.
  410. int *ref_count = nullptr; // ... and a reference count of the
  411. // `narcs`-length `arcs` array if non-null.
  412. };
  413. // Generic arc iterator, templated on the FST definition (a wrapper around a
  414. // pointer to a specific one). Here is a typical use:
  415. //
  416. // for (ArcIterator<StdFst> aiter(fst, s);
  417. // !aiter.Done();
  418. // aiter.Next()) {
  419. // StdArc &arc = aiter.Value();
  420. // ...
  421. // }
  422. // There is no copying of the FST.
  423. //
  424. // Specializations may exist for some FST types.
  425. // ArcIterators are thread-unsafe unless otherwise specified.
  426. template <class FST>
  427. class ArcIterator {
  428. public:
  429. using Arc = typename FST::Arc;
  430. using StateId = typename Arc::StateId;
  431. ArcIterator(const FST &fst, StateId s) {
  432. fst.InitArcIterator(s, &data_);
  433. }
  434. explicit ArcIterator(const ArcIteratorData<Arc> &data) = delete;
  435. ~ArcIterator() {
  436. if (data_.ref_count) {
  437. --(*data_.ref_count);
  438. }
  439. }
  440. bool Done() const {
  441. return data_.base ? data_.base->Done() : i_ >= data_.narcs;
  442. }
  443. const Arc &Value() const {
  444. return data_.base ? data_.base->Value() : data_.arcs[i_];
  445. }
  446. void Next() {
  447. if (data_.base) {
  448. data_.base->Next();
  449. } else {
  450. ++i_;
  451. }
  452. }
  453. void Reset() {
  454. if (data_.base) {
  455. data_.base->Reset();
  456. } else {
  457. i_ = 0;
  458. }
  459. }
  460. void Seek(size_t a) {
  461. if (data_.base) {
  462. data_.base->Seek(a);
  463. } else {
  464. i_ = a;
  465. }
  466. }
  467. size_t Position() const { return data_.base ? data_.base->Position() : i_; }
  468. uint8_t Flags() const {
  469. return data_.base ? data_.base->Flags() : kArcValueFlags;
  470. }
  471. void SetFlags(uint8_t flags, uint8_t mask) {
  472. if (data_.base) data_.base->SetFlags(flags, mask);
  473. }
  474. private:
  475. ArcIteratorData<Arc> data_;
  476. size_t i_ = 0;
  477. };
  478. } // namespace fst
  479. // ArcIterator placement operator new and destroy function; new needs to be in
  480. // the global namespace.
  481. template <class FST>
  482. void *operator new(size_t size,
  483. fst::MemoryPool<fst::ArcIterator<FST>> *pool) {
  484. return pool->Allocate();
  485. }
  486. namespace fst {
  487. template <class FST>
  488. void Destroy(ArcIterator<FST> *aiter, MemoryPool<ArcIterator<FST>> *pool) {
  489. if (aiter) {
  490. aiter->~ArcIterator<FST>();
  491. pool->Free(aiter);
  492. }
  493. }
  494. // Matcher definitions.
  495. template <class Arc>
  496. MatcherBase<Arc> *Fst<Arc>::InitMatcher(MatchType match_type) const {
  497. return nullptr; // One should just use the default matcher.
  498. }
  499. // FST accessors, useful in high-performance applications.
  500. namespace internal {
  501. // General case, requires non-abstract, 'final' methods. Use for inlining.
  502. template <class F>
  503. inline typename F::Arc::Weight Final(const F &fst, typename F::Arc::StateId s) {
  504. return fst.F::Final(s);
  505. }
  506. template <class F>
  507. inline ssize_t NumArcs(const F &fst, typename F::Arc::StateId s) {
  508. return fst.F::NumArcs(s);
  509. }
  510. template <class F>
  511. inline ssize_t NumInputEpsilons(const F &fst, typename F::Arc::StateId s) {
  512. return fst.F::NumInputEpsilons(s);
  513. }
  514. template <class F>
  515. inline ssize_t NumOutputEpsilons(const F &fst, typename F::Arc::StateId s) {
  516. return fst.F::NumOutputEpsilons(s);
  517. }
  518. // Fst<Arc> case, abstract methods.
  519. template <class Arc>
  520. inline typename Arc::Weight Final(const Fst<Arc> &fst,
  521. typename Arc::StateId s) {
  522. return fst.Final(s);
  523. }
  524. template <class Arc>
  525. inline size_t NumArcs(const Fst<Arc> &fst, typename Arc::StateId s) {
  526. return fst.NumArcs(s);
  527. }
  528. template <class Arc>
  529. inline size_t NumInputEpsilons(const Fst<Arc> &fst, typename Arc::StateId s) {
  530. return fst.NumInputEpsilons(s);
  531. }
  532. template <class Arc>
  533. inline size_t NumOutputEpsilons(const Fst<Arc> &fst, typename Arc::StateId s) {
  534. return fst.NumOutputEpsilons(s);
  535. }
  536. // FST implementation base.
  537. //
  538. // This is the recommended FST implementation base class. It will handle
  539. // reference counts, property bits, type information and symbols.
  540. //
  541. // Users are discouraged, but not prohibited, from subclassing this outside the
  542. // FST library.
  543. //
  544. // This class is thread-compatible except for the const SetProperties
  545. // overload. Derived classes should be assumed to be thread-unsafe unless
  546. // otherwise specified. Derived-class copy constructors must produce a
  547. // thread-safe copy.
  548. template <class Arc>
  549. class FstImpl {
  550. public:
  551. using StateId = typename Arc::StateId;
  552. using Weight = typename Arc::Weight;
  553. FstImpl() = default;
  554. FstImpl(const FstImpl<Arc> &impl)
  555. : properties_(impl.properties_.load(std::memory_order_relaxed)),
  556. type_(impl.type_),
  557. isymbols_(impl.isymbols_ ? impl.isymbols_->Copy() : nullptr),
  558. osymbols_(impl.osymbols_ ? impl.osymbols_->Copy() : nullptr) {}
  559. FstImpl(FstImpl<Arc> &&impl) noexcept;
  560. virtual ~FstImpl() = default;
  561. FstImpl &operator=(const FstImpl &impl) {
  562. properties_.store(impl.properties_.load(std::memory_order_relaxed),
  563. std::memory_order_relaxed);
  564. type_ = impl.type_;
  565. isymbols_ = impl.isymbols_ ? impl.isymbols_->Copy() : nullptr;
  566. osymbols_ = impl.osymbols_ ? impl.osymbols_->Copy() : nullptr;
  567. return *this;
  568. }
  569. FstImpl &operator=(FstImpl &&impl) noexcept;
  570. const std::string &Type() const { return type_; }
  571. void SetType(std::string_view type) { type_ = std::string(type); }
  572. virtual uint64_t Properties() const {
  573. return properties_.load(std::memory_order_relaxed);
  574. }
  575. virtual uint64_t Properties(uint64_t mask) const {
  576. return properties_.load(std::memory_order_relaxed) & mask;
  577. }
  578. void SetProperties(uint64_t props) {
  579. uint64_t properties = properties_.load(std::memory_order_relaxed);
  580. properties &= kError; // kError can't be cleared.
  581. properties |= props;
  582. properties_.store(properties, std::memory_order_relaxed);
  583. }
  584. void SetProperties(uint64_t props, uint64_t mask) {
  585. // Unlike UpdateProperties, does not require compatibility between props
  586. // and properties_, since it may be used to update properties after
  587. // a mutation.
  588. uint64_t properties = properties_.load(std::memory_order_relaxed);
  589. properties &= ~mask | kError; // kError can't be cleared.
  590. properties |= props & mask;
  591. properties_.store(properties, std::memory_order_relaxed);
  592. }
  593. // Allows (only) setting error bit on const FST implementations.
  594. void SetProperties(uint64_t props, uint64_t mask) const {
  595. if (mask != kError) {
  596. FSTERROR() << "FstImpl::SetProperties() const: Can only set kError";
  597. }
  598. properties_.fetch_or(kError, std::memory_order_relaxed);
  599. }
  600. // Sets the subset of the properties that have changed, in a thread-safe
  601. // manner via atomic bitwise-or..
  602. void UpdateProperties(uint64_t props, uint64_t mask) {
  603. // If properties_ and props are compatible (for example kAcceptor and
  604. // kNoAcceptor cannot both be set), the props can be or-ed in.
  605. // Compatibility is ensured if props comes from ComputeProperties
  606. // and properties_ is set correctly initially. However
  607. // relying on properties to be set correctly is too large an
  608. // assumption, as many places set them incorrectly.
  609. // Therefore, we or in only the newly discovered properties.
  610. // These cannot become inconsistent, but this means that
  611. // incorrectly set properties will remain incorrect.
  612. const uint64_t properties = properties_.load(std::memory_order_relaxed);
  613. DCHECK(internal::CompatProperties(properties, props));
  614. const uint64_t old_props = properties & mask;
  615. const uint64_t old_mask = internal::KnownProperties(old_props);
  616. const uint64_t discovered_mask = mask & ~old_mask;
  617. const uint64_t discovered_props = props & discovered_mask;
  618. // It is always correct to or these bits in, but do this only when
  619. // necessary to avoid extra stores and possible cache flushes.
  620. if (discovered_props != 0) {
  621. properties_.fetch_or(discovered_props, std::memory_order_relaxed);
  622. }
  623. }
  624. const SymbolTable *InputSymbols() const { return isymbols_.get(); }
  625. const SymbolTable *OutputSymbols() const { return osymbols_.get(); }
  626. SymbolTable *InputSymbols() { return isymbols_.get(); }
  627. SymbolTable *OutputSymbols() { return osymbols_.get(); }
  628. void SetInputSymbols(const SymbolTable *isyms) {
  629. isymbols_.reset(isyms ? isyms->Copy() : nullptr);
  630. }
  631. void SetOutputSymbols(const SymbolTable *osyms) {
  632. osymbols_.reset(osyms ? osyms->Copy() : nullptr);
  633. }
  634. // Reads header and symbols from input stream, initializes FST, and returns
  635. // the header. If opts.header is non-null, skips reading and uses the option
  636. // value instead. If opts.[io]symbols is non-null, reads in (if present), but
  637. // uses the option value.
  638. bool ReadHeader(std::istream &strm, const FstReadOptions &opts,
  639. int min_version, FstHeader *hdr);
  640. // Writes header and symbols to output stream. If opts.header is false, skips
  641. // writing header. If opts.[io]symbols is false, skips writing those symbols.
  642. // This method is needed for implementations that implement Write methods.
  643. void WriteHeader(std::ostream &strm, const FstWriteOptions &opts, int version,
  644. FstHeader *hdr) const {
  645. if (opts.write_header) {
  646. hdr->SetFstType(type_);
  647. hdr->SetArcType(Arc::Type());
  648. hdr->SetVersion(version);
  649. hdr->SetProperties(properties_.load(std::memory_order_relaxed));
  650. int32_t file_flags = 0;
  651. if (isymbols_ && opts.write_isymbols) {
  652. file_flags |= FstHeader::HAS_ISYMBOLS;
  653. }
  654. if (osymbols_ && opts.write_osymbols) {
  655. file_flags |= FstHeader::HAS_OSYMBOLS;
  656. }
  657. if (opts.align) file_flags |= FstHeader::IS_ALIGNED;
  658. hdr->SetFlags(file_flags);
  659. hdr->Write(strm, opts.source);
  660. }
  661. if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm);
  662. if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm);
  663. }
  664. // Writes out header and symbols to output stream. If opts.header is false,
  665. // skips writing header. If opts.[io]symbols is false, skips writing those
  666. // symbols. `type` is the FST type being written. This method is used in the
  667. // cross-type serialization methods Fst::WriteFst.
  668. static void WriteFstHeader(const Fst<Arc> &fst, std::ostream &strm,
  669. const FstWriteOptions &opts, int version,
  670. std::string_view type, uint64_t properties,
  671. FstHeader *hdr) {
  672. if (opts.write_header) {
  673. hdr->SetFstType(type);
  674. hdr->SetArcType(Arc::Type());
  675. hdr->SetVersion(version);
  676. hdr->SetProperties(properties);
  677. int32_t file_flags = 0;
  678. if (fst.InputSymbols() && opts.write_isymbols) {
  679. file_flags |= FstHeader::HAS_ISYMBOLS;
  680. }
  681. if (fst.OutputSymbols() && opts.write_osymbols) {
  682. file_flags |= FstHeader::HAS_OSYMBOLS;
  683. }
  684. if (opts.align) file_flags |= FstHeader::IS_ALIGNED;
  685. hdr->SetFlags(file_flags);
  686. hdr->Write(strm, opts.source);
  687. }
  688. if (fst.InputSymbols() && opts.write_isymbols) {
  689. fst.InputSymbols()->Write(strm);
  690. }
  691. if (fst.OutputSymbols() && opts.write_osymbols) {
  692. fst.OutputSymbols()->Write(strm);
  693. }
  694. }
  695. // In serialization routines where the header cannot be written until after
  696. // the machine has been serialized, this routine can be called to seek to the
  697. // beginning of the file an rewrite the header with updated fields. It
  698. // repositions the file pointer back at the end of the file. Returns true on
  699. // success, false on failure.
  700. static bool UpdateFstHeader(const Fst<Arc> &fst, std::ostream &strm,
  701. const FstWriteOptions &opts, int version,
  702. std::string_view type, uint64_t properties,
  703. FstHeader *hdr, size_t header_offset) {
  704. strm.seekp(header_offset);
  705. if (!strm) {
  706. LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source;
  707. return false;
  708. }
  709. WriteFstHeader(fst, strm, opts, version, type, properties, hdr);
  710. if (!strm) {
  711. LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source;
  712. return false;
  713. }
  714. strm.seekp(0, std::ios_base::end);
  715. if (!strm) {
  716. LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source;
  717. return false;
  718. }
  719. return true;
  720. }
  721. protected:
  722. // Use atomic so that UpdateProperties() can be thread-safe.
  723. // This is always used with memory_order_relaxed because it's only used
  724. // as a cache and not used to synchronize other operations.
  725. mutable std::atomic<uint64_t> properties_ = 0; // Property bits.
  726. private:
  727. std::string type_ = "null"; // Unique name of FST class.
  728. std::unique_ptr<SymbolTable> isymbols_;
  729. std::unique_ptr<SymbolTable> osymbols_;
  730. };
  731. template <class Arc>
  732. inline FstImpl<Arc>::FstImpl(FstImpl<Arc> &&) noexcept = default;
  733. template <class Arc>
  734. inline FstImpl<Arc> &FstImpl<Arc>::operator=(FstImpl<Arc> &&) noexcept =
  735. default;
  736. template <class Arc>
  737. bool FstImpl<Arc>::ReadHeader(std::istream &strm, const FstReadOptions &opts,
  738. int min_version, FstHeader *hdr) {
  739. if (opts.header) {
  740. *hdr = *opts.header;
  741. } else if (!hdr->Read(strm, opts.source)) {
  742. return false;
  743. }
  744. VLOG(2) << "FstImpl::ReadHeader: source: " << opts.source
  745. << ", fst_type: " << hdr->FstType() << ", arc_type: " << Arc::Type()
  746. << ", version: " << hdr->Version() << ", flags: " << hdr->GetFlags();
  747. if (hdr->FstType() != type_) {
  748. LOG(ERROR) << "FstImpl::ReadHeader: FST not of type " << type_ << ", found "
  749. << hdr->FstType() << ": " << opts.source;
  750. return false;
  751. }
  752. if (hdr->ArcType() != Arc::Type()) {
  753. LOG(ERROR) << "FstImpl::ReadHeader: Arc not of type " << Arc::Type()
  754. << ", found " << hdr->ArcType() << ": " << opts.source;
  755. return false;
  756. }
  757. if (hdr->Version() < min_version) {
  758. LOG(ERROR) << "FstImpl::ReadHeader: Obsolete " << type_ << " FST version "
  759. << hdr->Version() << ", min_version=" << min_version << ": "
  760. << opts.source;
  761. return false;
  762. }
  763. properties_.store(hdr->Properties(), std::memory_order_relaxed);
  764. if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS) {
  765. isymbols_.reset(SymbolTable::Read(strm, opts.source));
  766. }
  767. // Deletes input symbol table.
  768. if (!opts.read_isymbols) SetInputSymbols(nullptr);
  769. if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS) {
  770. osymbols_.reset(SymbolTable::Read(strm, opts.source));
  771. }
  772. // Deletes output symbol table.
  773. if (!opts.read_osymbols) SetOutputSymbols(nullptr);
  774. if (opts.isymbols) {
  775. isymbols_.reset(opts.isymbols->Copy());
  776. }
  777. if (opts.osymbols) {
  778. osymbols_.reset(opts.osymbols->Copy());
  779. }
  780. return true;
  781. }
  782. } // namespace internal
  783. // Converts FSTs by casting their implementations, where this makes sense
  784. // (which excludes implementations with weight-dependent virtual methods).
  785. // Must be a friend of the FST classes involved (currently the concrete FSTs:
  786. // ConstFst, CompactFst, and VectorFst). This can only be safely used for arc
  787. // types that have identical storage characteristics. As with an FST
  788. // copy constructor and Copy() method, this is a constant time operation
  789. // (but subject to copy-on-write if it is a MutableFst and modified).
  790. template <class IFST, class OFST>
  791. void Cast(const IFST &ifst, OFST *ofst) {
  792. using OImpl = typename OFST::Impl;
  793. ofst->impl_ = std::shared_ptr<OImpl>(
  794. ifst.impl_, reinterpret_cast<OImpl *>(ifst.impl_.get()));
  795. }
  796. // FST serialization.
  797. template <class Arc>
  798. std::string FstToString(
  799. const Fst<Arc> &fst,
  800. const FstWriteOptions &options = FstWriteOptions("FstToString")) {
  801. std::ostringstream ostrm;
  802. fst.Write(ostrm, options);
  803. return ostrm.str();
  804. }
  805. template <class Arc>
  806. void FstToString(const Fst<Arc> &fst, std::string *result) {
  807. *result = FstToString(fst);
  808. }
  809. template <class Arc>
  810. void FstToString(const Fst<Arc> &fst, std::string *result,
  811. const FstWriteOptions &options) {
  812. *result = FstToString(fst, options);
  813. }
  814. template <class Arc>
  815. Fst<Arc> *StringToFst(std::string_view s) {
  816. std::istringstream istrm((std::string(s)));
  817. return Fst<Arc>::Read(istrm, FstReadOptions("StringToFst"));
  818. }
  819. } // namespace fst
  820. #endif // FST_FST_H_