You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

477 lines
16 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Simple concrete immutable FST whose states and arcs are each stored in
  19. // single arrays.
  20. #ifndef FST_CONST_FST_H_
  21. #define FST_CONST_FST_H_
  22. #include <climits>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <ios>
  26. #include <istream>
  27. #include <memory>
  28. #include <ostream>
  29. #include <string>
  30. #include <vector>
  31. #include <fst/log.h>
  32. #include <fst/arc.h>
  33. #include <fst/expanded-fst.h>
  34. #include <fst/float-weight.h>
  35. #include <fst/fst-decl.h>
  36. #include <fst/fst.h>
  37. #include <fst/impl-to-fst.h>
  38. #include <fst/mapped-file.h>
  39. #include <fst/properties.h>
  40. #include <fst/test-properties.h>
  41. #include <fst/util.h>
  42. #include <string_view>
  43. namespace fst {
  44. template <class A, class Unsigned>
  45. class ConstFst;
  46. template <class F, class G>
  47. void Cast(const F &, G *);
  48. namespace internal {
  49. // States and arcs each implemented by single arrays, templated on the
  50. // Arc definition. Unsigned is used to represent indices into the arc array.
  51. template <class A, class Unsigned>
  52. class ConstFstImpl : public FstImpl<A> {
  53. public:
  54. using Arc = A;
  55. using StateId = typename Arc::StateId;
  56. using Weight = typename Arc::Weight;
  57. using FstImpl<A>::SetInputSymbols;
  58. using FstImpl<A>::SetOutputSymbols;
  59. using FstImpl<A>::SetType;
  60. using FstImpl<A>::SetProperties;
  61. using FstImpl<A>::Properties;
  62. ConstFstImpl() {
  63. std::string type = "const";
  64. if (sizeof(Unsigned) != sizeof(uint32_t)) {
  65. type += std::to_string(CHAR_BIT * sizeof(Unsigned));
  66. }
  67. SetType(type);
  68. SetProperties(kNullProperties | kStaticProperties);
  69. }
  70. explicit ConstFstImpl(const Fst<Arc> &fst);
  71. StateId Start() const { return start_; }
  72. Weight Final(StateId s) const { return states_[s].final_weight; }
  73. StateId NumStates() const { return nstates_; }
  74. size_t NumArcs(StateId s) const { return states_[s].narcs; }
  75. size_t NumInputEpsilons(StateId s) const { return states_[s].niepsilons; }
  76. size_t NumOutputEpsilons(StateId s) const { return states_[s].noepsilons; }
  77. static ConstFstImpl *Read(std::istream &strm, const FstReadOptions &opts);
  78. const Arc *Arcs(StateId s) const { return arcs_ + states_[s].pos; }
  79. // Provide information needed for generic state iterator.
  80. void InitStateIterator(StateIteratorData<Arc> *data) const {
  81. data->base = nullptr;
  82. data->nstates = nstates_;
  83. }
  84. // Provide information needed for the generic arc iterator.
  85. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
  86. data->base = nullptr;
  87. data->arcs = arcs_ + states_[s].pos;
  88. data->narcs = states_[s].narcs;
  89. data->ref_count = nullptr;
  90. }
  91. private:
  92. // Used to find narcs_ and nstates_ in Write.
  93. friend class ConstFst<Arc, Unsigned>;
  94. // States implemented by array *states_ below, arcs by (single) *arcs_.
  95. struct ConstState {
  96. Weight final_weight; // Final weight.
  97. Unsigned pos; // Start of state's arcs in *arcs_.
  98. Unsigned narcs; // Number of arcs (per state).
  99. Unsigned niepsilons; // Number of input epsilons.
  100. Unsigned noepsilons; // Number of output epsilons.
  101. ConstState() : final_weight(Weight::Zero()) {}
  102. };
  103. // Properties always true of this FST class.
  104. static constexpr uint64_t kStaticProperties = kExpanded;
  105. // Current unaligned file format version. The unaligned version was added and
  106. // made the default since the aligned version does not work on pipes.
  107. static constexpr int kFileVersion = 2;
  108. // Current aligned file format version.
  109. static constexpr int kAlignedFileVersion = 1;
  110. // Minimum file format version supported.
  111. static constexpr int kMinFileVersion = 1;
  112. std::unique_ptr<MappedFile> states_region_; // Mapped file for states.
  113. std::unique_ptr<MappedFile> arcs_region_; // Mapped file for arcs.
  114. ConstState *states_ = nullptr; // States representation.
  115. Arc *arcs_ = nullptr; // Arcs representation.
  116. size_t narcs_ = 0; // Number of arcs.
  117. StateId nstates_ = 0; // Number of states.
  118. StateId start_ = kNoStateId; // Initial state.
  119. ConstFstImpl(const ConstFstImpl &) = delete;
  120. ConstFstImpl &operator=(const ConstFstImpl &) = delete;
  121. };
  122. template <class Arc, class Unsigned>
  123. ConstFstImpl<Arc, Unsigned>::ConstFstImpl(const Fst<Arc> &fst) {
  124. std::string type = "const";
  125. if (sizeof(Unsigned) != sizeof(uint32_t)) {
  126. type += std::to_string(CHAR_BIT * sizeof(Unsigned));
  127. }
  128. SetType(type);
  129. SetInputSymbols(fst.InputSymbols());
  130. SetOutputSymbols(fst.OutputSymbols());
  131. start_ = fst.Start();
  132. // Counts states and arcs.
  133. for (StateIterator<Fst<Arc>> siter(fst); !siter.Done(); siter.Next()) {
  134. ++nstates_;
  135. narcs_ += fst.NumArcs(siter.Value());
  136. }
  137. states_region_.reset(MappedFile::AllocateType<ConstState>(nstates_));
  138. arcs_region_.reset(MappedFile::AllocateType<Arc>(narcs_));
  139. states_ = static_cast<ConstState *>(states_region_->mutable_data());
  140. arcs_ = static_cast<Arc *>(arcs_region_->mutable_data());
  141. size_t pos = 0;
  142. for (StateId s = 0; s < nstates_; ++s) {
  143. states_[s].final_weight = fst.Final(s);
  144. states_[s].pos = pos;
  145. states_[s].narcs = 0;
  146. states_[s].niepsilons = 0;
  147. states_[s].noepsilons = 0;
  148. for (ArcIterator<Fst<Arc>> aiter(fst, s); !aiter.Done(); aiter.Next()) {
  149. const auto &arc = aiter.Value();
  150. ++states_[s].narcs;
  151. if (arc.ilabel == 0) ++states_[s].niepsilons;
  152. if (arc.olabel == 0) ++states_[s].noepsilons;
  153. arcs_[pos] = arc;
  154. ++pos;
  155. }
  156. }
  157. const auto props =
  158. fst.Properties(kMutable, false)
  159. ? fst.Properties(kCopyProperties, true)
  160. : CheckProperties(
  161. fst, kCopyProperties & ~kWeightedCycles & ~kUnweightedCycles,
  162. kCopyProperties);
  163. SetProperties(props | kStaticProperties);
  164. }
  165. template <class Arc, class Unsigned>
  166. ConstFstImpl<Arc, Unsigned> *ConstFstImpl<Arc, Unsigned>::Read(
  167. std::istream &strm, const FstReadOptions &opts) {
  168. auto impl = std::make_unique<ConstFstImpl>();
  169. FstHeader hdr;
  170. if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr;
  171. impl->start_ = hdr.Start();
  172. impl->nstates_ = hdr.NumStates();
  173. impl->narcs_ = hdr.NumArcs();
  174. // Ensures compatibility.
  175. if (hdr.Version() == kAlignedFileVersion) {
  176. hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED);
  177. }
  178. if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
  179. LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source;
  180. return nullptr;
  181. }
  182. size_t b = impl->nstates_ * sizeof(ConstState);
  183. impl->states_region_.reset(
  184. MappedFile::Map(strm, opts.mode == FstReadOptions::MAP, opts.source, b));
  185. if (!strm || !impl->states_region_) {
  186. LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
  187. return nullptr;
  188. }
  189. impl->states_ =
  190. static_cast<ConstState *>(impl->states_region_->mutable_data());
  191. if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) {
  192. LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source;
  193. return nullptr;
  194. }
  195. b = impl->narcs_ * sizeof(Arc);
  196. impl->arcs_region_.reset(
  197. MappedFile::Map(strm, opts.mode == FstReadOptions::MAP, opts.source, b));
  198. if (!strm || !impl->arcs_region_) {
  199. LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source;
  200. return nullptr;
  201. }
  202. impl->arcs_ = static_cast<Arc *>(impl->arcs_region_->mutable_data());
  203. return impl.release();
  204. }
  205. } // namespace internal
  206. // Simple concrete immutable FST. This class attaches interface to
  207. // implementation and handles reference counting, delegating most methods to
  208. // ImplToExpandedFst. The unsigned type U is used to represent indices into the
  209. // arc array (default declared in fst-decl.h).
  210. //
  211. // ConstFst is thread-safe.
  212. template <class A, class Unsigned>
  213. class ConstFst : public ImplToExpandedFst<internal::ConstFstImpl<A, Unsigned>> {
  214. public:
  215. using Arc = A;
  216. using StateId = typename Arc::StateId;
  217. using Impl = internal::ConstFstImpl<A, Unsigned>;
  218. using ConstState = typename Impl::ConstState;
  219. friend class StateIterator<ConstFst<Arc, Unsigned>>;
  220. friend class ArcIterator<ConstFst<Arc, Unsigned>>;
  221. template <class F, class G>
  222. void friend Cast(const F &, G *);
  223. ConstFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>()) {}
  224. explicit ConstFst(const Fst<Arc> &fst)
  225. : ImplToExpandedFst<Impl>(std::make_shared<Impl>(fst)) {}
  226. ConstFst(const ConstFst &fst, bool unused_safe = false)
  227. : ImplToExpandedFst<Impl>(fst.GetSharedImpl()) {}
  228. // Gets a copy of this ConstFst. See Fst<>::Copy() for further doc.
  229. ConstFst *Copy(bool safe = false) const override {
  230. return new ConstFst(*this, safe);
  231. }
  232. // Reads a ConstFst from an input stream, returning nullptr on error.
  233. static ConstFst *Read(std::istream &strm, const FstReadOptions &opts) {
  234. auto *impl = Impl::Read(strm, opts);
  235. return impl ? new ConstFst(std::shared_ptr<Impl>(impl)) : nullptr;
  236. }
  237. // Read a ConstFst from a file; return nullptr on error; empty source reads
  238. // from standard input.
  239. static ConstFst *Read(std::string_view source) {
  240. auto *impl = ImplToExpandedFst<Impl>::Read(source);
  241. return impl ? new ConstFst(std::shared_ptr<Impl>(impl)) : nullptr;
  242. }
  243. bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
  244. return WriteFst(*this, strm, opts);
  245. }
  246. bool Write(const std::string &source) const override {
  247. return Fst<Arc>::WriteFile(source);
  248. }
  249. template <class FST>
  250. static bool WriteFst(const FST &fst, std::ostream &strm,
  251. const FstWriteOptions &opts);
  252. void InitStateIterator(StateIteratorData<Arc> *data) const override {
  253. GetImpl()->InitStateIterator(data);
  254. }
  255. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  256. GetImpl()->InitArcIterator(s, data);
  257. }
  258. private:
  259. explicit ConstFst(std::shared_ptr<Impl> impl)
  260. : ImplToExpandedFst<Impl>(impl) {}
  261. using ImplToFst<Impl, ExpandedFst<Arc>>::GetImpl;
  262. // Uses overloading to extract the type of the argument.
  263. static const Impl *GetImplIfConstFst(const ConstFst &const_fst) {
  264. return const_fst.GetImpl();
  265. }
  266. // NB: this does not give privileged treatment to subtypes of ConstFst.
  267. template <typename FST>
  268. static Impl *GetImplIfConstFst(const FST &fst) {
  269. return nullptr;
  270. }
  271. ConstFst &operator=(const ConstFst &) = delete;
  272. };
  273. // Writes FST in Const format, potentially with a pass over the machine before
  274. // writing to compute number of states and arcs.
  275. template <class Arc, class Unsigned>
  276. template <class FST>
  277. bool ConstFst<Arc, Unsigned>::WriteFst(const FST &fst, std::ostream &strm,
  278. const FstWriteOptions &opts) {
  279. const auto file_version =
  280. opts.align ? internal::ConstFstImpl<Arc, Unsigned>::kAlignedFileVersion
  281. : internal::ConstFstImpl<Arc, Unsigned>::kFileVersion;
  282. size_t num_arcs = 0; // To silence -Wsometimes-uninitialized warnings.
  283. size_t num_states = 0; // Ditto.
  284. std::streamoff start_offset = 0;
  285. bool update_header = true;
  286. if (const auto *impl = GetImplIfConstFst(fst)) {
  287. num_arcs = impl->narcs_;
  288. num_states = impl->nstates_;
  289. update_header = false;
  290. } else if (opts.stream_write || (start_offset = strm.tellp()) == -1) {
  291. // precompute values needed for header when we cannot seek to rewrite it.
  292. num_arcs = 0;
  293. num_states = 0;
  294. for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
  295. num_arcs += fst.NumArcs(siter.Value());
  296. ++num_states;
  297. }
  298. update_header = false;
  299. }
  300. FstHeader hdr;
  301. hdr.SetStart(fst.Start());
  302. hdr.SetNumStates(num_states);
  303. hdr.SetNumArcs(num_arcs);
  304. std::string type = "const";
  305. if (sizeof(Unsigned) != sizeof(uint32_t)) {
  306. type += std::to_string(CHAR_BIT * sizeof(Unsigned));
  307. }
  308. const auto properties =
  309. fst.Properties(kCopyProperties, true) |
  310. internal::ConstFstImpl<Arc, Unsigned>::kStaticProperties;
  311. internal::FstImpl<Arc>::WriteFstHeader(fst, strm, opts, file_version, type,
  312. properties, &hdr);
  313. if (opts.align && !AlignOutput(strm)) {
  314. LOG(ERROR) << "Could not align file during write after header";
  315. return false;
  316. }
  317. size_t pos = 0;
  318. size_t states = 0;
  319. ConstState state;
  320. for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
  321. const auto s = siter.Value();
  322. state.final_weight = fst.Final(s);
  323. state.pos = pos;
  324. state.narcs = fst.NumArcs(s);
  325. state.niepsilons = fst.NumInputEpsilons(s);
  326. state.noepsilons = fst.NumOutputEpsilons(s);
  327. strm.write(reinterpret_cast<const char *>(&state), sizeof(state));
  328. pos += state.narcs;
  329. ++states;
  330. }
  331. hdr.SetNumStates(states);
  332. hdr.SetNumArcs(pos);
  333. if (opts.align && !AlignOutput(strm)) {
  334. LOG(ERROR) << "Could not align file during write after writing states";
  335. }
  336. for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
  337. for (ArcIterator<FST> aiter(fst, siter.Value()); !aiter.Done();
  338. aiter.Next()) {
  339. const auto &arc = aiter.Value();
  340. strm.write(reinterpret_cast<const char *>(&arc), sizeof(arc));
  341. }
  342. }
  343. strm.flush();
  344. if (!strm) {
  345. LOG(ERROR) << "ConstFst::WriteFst: Write failed: " << opts.source;
  346. return false;
  347. }
  348. if (update_header) {
  349. return internal::FstImpl<Arc>::UpdateFstHeader(
  350. fst, strm, opts, file_version, type, properties, &hdr, start_offset);
  351. } else {
  352. if (hdr.NumStates() != num_states) {
  353. LOG(ERROR) << "Inconsistent number of states observed during write";
  354. return false;
  355. }
  356. if (hdr.NumArcs() != num_arcs) {
  357. LOG(ERROR) << "Inconsistent number of arcs observed during write";
  358. return false;
  359. }
  360. }
  361. return true;
  362. }
  363. // Specialization for ConstFst; see generic version in fst.h for sample usage
  364. // (but use the ConstFst type instead). This version should inline.
  365. template <class Arc, class Unsigned>
  366. class StateIterator<ConstFst<Arc, Unsigned>> {
  367. public:
  368. using StateId = typename Arc::StateId;
  369. explicit StateIterator(const ConstFst<Arc, Unsigned> &fst)
  370. : nstates_(fst.GetImpl()->NumStates()), s_(0) {}
  371. bool Done() const { return s_ >= nstates_; }
  372. StateId Value() const { return s_; }
  373. void Next() { ++s_; }
  374. void Reset() { s_ = 0; }
  375. private:
  376. const StateId nstates_;
  377. StateId s_;
  378. };
  379. // Specialization for ConstFst; see generic version in fst.h for sample usage
  380. // (but use the ConstFst type instead). This version should inline.
  381. template <class Arc, class Unsigned>
  382. class ArcIterator<ConstFst<Arc, Unsigned>> {
  383. public:
  384. using StateId = typename Arc::StateId;
  385. ArcIterator(const ConstFst<Arc, Unsigned> &fst, StateId s)
  386. : arcs_(fst.GetImpl()->Arcs(s)),
  387. narcs_(fst.GetImpl()->NumArcs(s)),
  388. i_(0) {}
  389. bool Done() const { return i_ >= narcs_; }
  390. const Arc &Value() const { return arcs_[i_]; }
  391. void Next() { ++i_; }
  392. size_t Position() const { return i_; }
  393. void Reset() { i_ = 0; }
  394. void Seek(size_t a) { i_ = a; }
  395. constexpr uint8_t Flags() const { return kArcValueFlags; }
  396. void SetFlags(uint8_t, uint8_t) {}
  397. private:
  398. const Arc *arcs_;
  399. size_t narcs_;
  400. size_t i_;
  401. };
  402. // A useful alias when using StdArc.
  403. using StdConstFst = ConstFst<StdArc>;
  404. } // namespace fst
  405. #endif // FST_CONST_FST_H_