You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

828 lines
25 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Simple concrete, mutable FST whose states and arcs are stored in STL vectors.
  19. #ifndef FST_VECTOR_FST_H_
  20. #define FST_VECTOR_FST_H_
  21. #include <algorithm>
  22. #include <atomic>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <ios>
  26. #include <iosfwd>
  27. #include <istream>
  28. #include <memory>
  29. #include <new>
  30. #include <ostream>
  31. #include <string>
  32. #include <utility>
  33. #include <vector>
  34. #include <fst/log.h>
  35. #include <fst/arc.h>
  36. #include <fst/expanded-fst.h>
  37. #include <fst/float-weight.h>
  38. #include <fst/fst-decl.h> // For optional argument declarations
  39. #include <fst/fst.h>
  40. #include <fst/mutable-fst.h>
  41. #include <fst/properties.h>
  42. #include <fst/util.h>
  43. #include <string_view>
  44. namespace fst {
  45. template <class A, class S>
  46. class VectorFst;
  47. template <class F, class G>
  48. void Cast(const F &, G *);
  49. // Arcs (of type A) implemented by an STL vector per state. M specifies Arc
  50. // allocator (default declared in fst-decl.h).
  51. template <class A, class M /* = std::allocator<A> */>
  52. class VectorState {
  53. public:
  54. using Arc = A;
  55. using StateId = typename Arc::StateId;
  56. using Weight = typename Arc::Weight;
  57. using ArcAllocator = M;
  58. using StateAllocator = typename std::allocator_traits<
  59. ArcAllocator>::template rebind_alloc<VectorState<Arc, M>>;
  60. // Provide STL allocator for arcs.
  61. explicit VectorState(const ArcAllocator &alloc) : arcs_(alloc) {}
  62. VectorState(const VectorState<A, M> &state, const ArcAllocator &alloc)
  63. : final_weight_(state.Final()),
  64. niepsilons_(state.NumInputEpsilons()),
  65. noepsilons_(state.NumOutputEpsilons()),
  66. arcs_(state.arcs_.begin(), state.arcs_.end(), alloc) {}
  67. void Reset() {
  68. final_weight_ = Weight::Zero();
  69. niepsilons_ = 0;
  70. noepsilons_ = 0;
  71. arcs_.clear();
  72. }
  73. Weight Final() const { return final_weight_; }
  74. size_t NumInputEpsilons() const { return niepsilons_; }
  75. size_t NumOutputEpsilons() const { return noepsilons_; }
  76. size_t NumArcs() const { return arcs_.size(); }
  77. const Arc &GetArc(size_t n) const { return arcs_[n]; }
  78. const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; }
  79. Arc *MutableArcs() { return !arcs_.empty() ? &arcs_[0] : nullptr; }
  80. void ReserveArcs(size_t n) { arcs_.reserve(n); }
  81. void SetFinal(Weight weight) { final_weight_ = std::move(weight); }
  82. void SetNumInputEpsilons(size_t n) { niepsilons_ = n; }
  83. void SetNumOutputEpsilons(size_t n) { noepsilons_ = n; }
  84. void AddArc(const Arc &arc) {
  85. IncrementNumEpsilons(arc);
  86. arcs_.push_back(arc);
  87. }
  88. void AddArc(Arc &&arc) {
  89. IncrementNumEpsilons(arc);
  90. arcs_.push_back(std::move(arc));
  91. }
  92. template <class... T>
  93. void EmplaceArc(T &&...ctor_args) {
  94. arcs_.emplace_back(std::forward<T>(ctor_args)...);
  95. IncrementNumEpsilons(arcs_.back());
  96. }
  97. void SetArc(const Arc &arc, size_t n) {
  98. if (arcs_[n].ilabel == 0) --niepsilons_;
  99. if (arcs_[n].olabel == 0) --noepsilons_;
  100. IncrementNumEpsilons(arc);
  101. arcs_[n] = arc;
  102. }
  103. void DeleteArcs() {
  104. niepsilons_ = 0;
  105. noepsilons_ = 0;
  106. arcs_.clear();
  107. }
  108. void DeleteArcs(size_t n) {
  109. for (size_t i = 0; i < n; ++i) {
  110. if (arcs_.back().ilabel == 0) --niepsilons_;
  111. if (arcs_.back().olabel == 0) --noepsilons_;
  112. arcs_.pop_back();
  113. }
  114. }
  115. // For state class allocation.
  116. void *operator new(size_t size, StateAllocator *alloc) {
  117. return alloc->allocate(1);
  118. }
  119. // For state destruction and memory freeing.
  120. static void Destroy(VectorState<A, M> *state, StateAllocator *alloc) {
  121. if (state) {
  122. state->~VectorState<A, M>();
  123. alloc->deallocate(state, 1);
  124. }
  125. }
  126. private:
  127. // Update the number of epsilons as a result of having added an arc.
  128. void IncrementNumEpsilons(const Arc &arc) {
  129. if (arc.ilabel == 0) ++niepsilons_;
  130. if (arc.olabel == 0) ++noepsilons_;
  131. }
  132. Weight final_weight_ = Weight::Zero(); // Final weight.
  133. size_t niepsilons_ = 0; // # of input epsilons
  134. size_t noepsilons_ = 0; // # of output epsilons
  135. std::vector<A, ArcAllocator> arcs_; // Arc container.
  136. };
  137. namespace internal {
  138. // States are implemented by STL vectors, templated on the
  139. // State definition. This does not manage the Fst properties.
  140. template <class S>
  141. class VectorFstBaseImpl : public FstImpl<typename S::Arc> {
  142. public:
  143. using State = S;
  144. using Arc = typename State::Arc;
  145. using StateId = typename Arc::StateId;
  146. using Weight = typename Arc::Weight;
  147. VectorFstBaseImpl() = default;
  148. ~VectorFstBaseImpl() override {
  149. for (auto *state : states_) State::Destroy(state, &state_alloc_);
  150. }
  151. // Copying is not permitted.
  152. VectorFstBaseImpl(const VectorFstBaseImpl<S> &) = delete;
  153. VectorFstBaseImpl &operator=(const VectorFstBaseImpl &) = delete;
  154. // Moving is permitted.
  155. VectorFstBaseImpl(VectorFstBaseImpl &&impl) noexcept
  156. : FstImpl<typename S::Arc>(),
  157. states_(std::move(impl.states_)),
  158. start_(impl.start_) {
  159. impl.states_.clear();
  160. impl.start_ = kNoStateId;
  161. }
  162. VectorFstBaseImpl &operator=(VectorFstBaseImpl &&impl) noexcept {
  163. for (auto *state : states_) {
  164. State::Destroy(state, &state_alloc_);
  165. }
  166. states_.clear();
  167. std::swap(states_, impl.states_);
  168. start_ = impl.start_;
  169. impl.start_ = kNoStateId;
  170. return *this;
  171. }
  172. StateId Start() const { return start_; }
  173. Weight Final(StateId state) const { return states_[state]->Final(); }
  174. StateId NumStates() const { return states_.size(); }
  175. size_t NumArcs(StateId state) const { return states_[state]->NumArcs(); }
  176. size_t NumInputEpsilons(StateId state) const {
  177. return GetState(state)->NumInputEpsilons();
  178. }
  179. size_t NumOutputEpsilons(StateId state) const {
  180. return GetState(state)->NumOutputEpsilons();
  181. }
  182. void SetStart(StateId state) { start_ = state; }
  183. void SetFinal(StateId state, Weight weight) {
  184. states_[state]->SetFinal(std::move(weight));
  185. }
  186. StateId AddState(State *state) {
  187. states_.push_back(state);
  188. return states_.size() - 1;
  189. }
  190. StateId AddState() { return AddState(CreateState()); }
  191. void AddStates(size_t n) {
  192. const auto curr_num_states = NumStates();
  193. states_.resize(n + curr_num_states);
  194. std::generate(states_.begin() + curr_num_states, states_.end(),
  195. [this] { return CreateState(); });
  196. }
  197. void AddArc(StateId state, const Arc &arc) { states_[state]->AddArc(arc); }
  198. void AddArc(StateId state, Arc &&arc) {
  199. states_[state]->AddArc(std::move(arc));
  200. }
  201. template <class... T>
  202. void EmplaceArc(StateId state, T &&...ctor_args) {
  203. states_[state]->EmplaceArc(std::forward<T>(ctor_args)...);
  204. }
  205. void DeleteStates(const std::vector<StateId> &dstates) {
  206. std::vector<StateId> newid(states_.size(), 0);
  207. for (size_t i = 0; i < dstates.size(); ++i) newid[dstates[i]] = kNoStateId;
  208. StateId nstates = 0;
  209. for (StateId state = 0; state < states_.size(); ++state) {
  210. if (newid[state] != kNoStateId) {
  211. newid[state] = nstates;
  212. if (state != nstates) states_[nstates] = states_[state];
  213. ++nstates;
  214. } else {
  215. State::Destroy(states_[state], &state_alloc_);
  216. }
  217. }
  218. states_.resize(nstates);
  219. for (StateId state = 0; state < states_.size(); ++state) {
  220. auto *arcs = states_[state]->MutableArcs();
  221. size_t narcs = 0;
  222. auto nieps = states_[state]->NumInputEpsilons();
  223. auto noeps = states_[state]->NumOutputEpsilons();
  224. for (size_t i = 0; i < states_[state]->NumArcs(); ++i) {
  225. const auto t = newid[arcs[i].nextstate];
  226. if (t != kNoStateId) {
  227. arcs[i].nextstate = t;
  228. if (i != narcs) arcs[narcs] = arcs[i];
  229. ++narcs;
  230. } else {
  231. if (arcs[i].ilabel == 0) --nieps;
  232. if (arcs[i].olabel == 0) --noeps;
  233. }
  234. }
  235. states_[state]->DeleteArcs(states_[state]->NumArcs() - narcs);
  236. states_[state]->SetNumInputEpsilons(nieps);
  237. states_[state]->SetNumOutputEpsilons(noeps);
  238. }
  239. if (Start() != kNoStateId) SetStart(newid[Start()]);
  240. }
  241. void DeleteStates() {
  242. for (size_t state = 0; state < states_.size(); ++state) {
  243. State::Destroy(states_[state], &state_alloc_);
  244. }
  245. states_.clear();
  246. SetStart(kNoStateId);
  247. }
  248. void DeleteArcs(StateId state, size_t n) { states_[state]->DeleteArcs(n); }
  249. void DeleteArcs(StateId state) { states_[state]->DeleteArcs(); }
  250. State *GetState(StateId state) { return states_[state]; }
  251. const State *GetState(StateId state) const { return states_[state]; }
  252. void SetState(StateId state, State *vstate) { states_[state] = vstate; }
  253. void ReserveStates(size_t n) { states_.reserve(n); }
  254. void ReserveArcs(StateId state, size_t n) { states_[state]->ReserveArcs(n); }
  255. // Provide information needed for generic state iterator.
  256. void InitStateIterator(StateIteratorData<Arc> *data) const {
  257. data->base = nullptr;
  258. data->nstates = states_.size();
  259. }
  260. // Provide information needed for generic arc iterator.
  261. void InitArcIterator(StateId state, ArcIteratorData<Arc> *data) const {
  262. data->base = nullptr;
  263. data->narcs = states_[state]->NumArcs();
  264. data->arcs = states_[state]->Arcs();
  265. data->ref_count = nullptr;
  266. }
  267. private:
  268. State *CreateState() { return new (&state_alloc_) State(arc_alloc_); }
  269. std::vector<State *> states_;
  270. StateId start_ = kNoStateId;
  271. typename State::StateAllocator state_alloc_;
  272. typename State::ArcAllocator arc_alloc_;
  273. };
  274. // This is a VectorFstBaseImpl container that holds VectorStates and manages FST
  275. // properties.
  276. template <class S>
  277. class VectorFstImpl : public VectorFstBaseImpl<S> {
  278. public:
  279. using State = S;
  280. using Arc = typename State::Arc;
  281. using Label = typename Arc::Label;
  282. using StateId = typename Arc::StateId;
  283. using Weight = typename Arc::Weight;
  284. using FstImpl<Arc>::SetInputSymbols;
  285. using FstImpl<Arc>::SetOutputSymbols;
  286. using FstImpl<Arc>::SetType;
  287. using FstImpl<Arc>::SetProperties;
  288. using FstImpl<Arc>::Properties;
  289. using VectorFstBaseImpl<S>::Start;
  290. using VectorFstBaseImpl<S>::NumStates;
  291. using VectorFstBaseImpl<S>::GetState;
  292. using VectorFstBaseImpl<S>::ReserveArcs;
  293. friend class MutableArcIterator<VectorFst<Arc, S>>;
  294. using BaseImpl = VectorFstBaseImpl<S>;
  295. VectorFstImpl() {
  296. SetType("vector");
  297. SetProperties(kNullProperties | kStaticProperties);
  298. }
  299. explicit VectorFstImpl(const Fst<Arc> &fst);
  300. static VectorFstImpl *Read(std::istream &strm, const FstReadOptions &opts);
  301. void SetStart(StateId state) {
  302. BaseImpl::SetStart(state);
  303. SetProperties(SetStartProperties(Properties()));
  304. }
  305. void SetFinal(StateId state, Weight weight) {
  306. const auto old_weight = BaseImpl::Final(state);
  307. const auto properties =
  308. SetFinalProperties(Properties(), old_weight, weight);
  309. BaseImpl::SetFinal(state, std::move(weight));
  310. SetProperties(properties);
  311. }
  312. StateId AddState() {
  313. const auto state = BaseImpl::AddState();
  314. SetProperties(AddStateProperties(Properties()));
  315. return state;
  316. }
  317. void AddStates(size_t n) {
  318. BaseImpl::AddStates(n);
  319. SetProperties(AddStateProperties(Properties()));
  320. }
  321. void AddArc(StateId state, const Arc &arc) {
  322. BaseImpl::AddArc(state, arc);
  323. UpdatePropertiesAfterAddArc(state);
  324. }
  325. void AddArc(StateId state, Arc &&arc) {
  326. BaseImpl::AddArc(state, std::move(arc));
  327. UpdatePropertiesAfterAddArc(state);
  328. }
  329. template <class... T>
  330. void EmplaceArc(StateId state, T &&...ctor_args) {
  331. BaseImpl::EmplaceArc(state, std::forward<T>(ctor_args)...);
  332. UpdatePropertiesAfterAddArc(state);
  333. }
  334. void DeleteStates(const std::vector<StateId> &dstates) {
  335. BaseImpl::DeleteStates(dstates);
  336. SetProperties(DeleteStatesProperties(Properties()));
  337. }
  338. void DeleteStates() {
  339. BaseImpl::DeleteStates();
  340. SetProperties(DeleteAllStatesProperties(Properties(), kStaticProperties));
  341. }
  342. void DeleteArcs(StateId state, size_t n) {
  343. BaseImpl::DeleteArcs(state, n);
  344. SetProperties(DeleteArcsProperties(Properties()));
  345. }
  346. void DeleteArcs(StateId state) {
  347. BaseImpl::DeleteArcs(state);
  348. SetProperties(DeleteArcsProperties(Properties()));
  349. }
  350. // Properties always true of this FST class
  351. static constexpr uint64_t kStaticProperties = kExpanded | kMutable;
  352. private:
  353. void UpdatePropertiesAfterAddArc(StateId state) {
  354. auto *vstate = GetState(state);
  355. const size_t num_arcs{vstate->NumArcs()};
  356. if (num_arcs) {
  357. const auto &arc = vstate->GetArc(num_arcs - 1);
  358. const auto *parc =
  359. (num_arcs < 2) ? nullptr : &(vstate->GetArc(num_arcs - 2));
  360. SetProperties(AddArcProperties(Properties(), state, arc, parc));
  361. }
  362. }
  363. // Minimum file format version supported.
  364. static constexpr int kMinFileVersion = 2;
  365. };
  366. template <class S>
  367. VectorFstImpl<S>::VectorFstImpl(const Fst<Arc> &fst) {
  368. SetType("vector");
  369. SetInputSymbols(fst.InputSymbols());
  370. SetOutputSymbols(fst.OutputSymbols());
  371. BaseImpl::SetStart(fst.Start());
  372. if (std::optional<StateId> num_states = fst.NumStatesIfKnown()) {
  373. BaseImpl::ReserveStates(*num_states);
  374. }
  375. for (StateIterator<Fst<Arc>> siter(fst); !siter.Done(); siter.Next()) {
  376. const auto state = siter.Value();
  377. BaseImpl::AddState();
  378. BaseImpl::SetFinal(state, fst.Final(state));
  379. ReserveArcs(state, fst.NumArcs(state));
  380. for (ArcIterator<Fst<Arc>> aiter(fst, state); !aiter.Done(); aiter.Next()) {
  381. const auto &arc = aiter.Value();
  382. BaseImpl::AddArc(state, arc);
  383. }
  384. }
  385. SetProperties(fst.Properties(kCopyProperties, false) | kStaticProperties);
  386. }
  387. template <class S>
  388. VectorFstImpl<S> *VectorFstImpl<S>::Read(std::istream &strm,
  389. const FstReadOptions &opts) {
  390. auto impl = std::make_unique<VectorFstImpl>();
  391. FstHeader hdr;
  392. if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr;
  393. impl->BaseImpl::SetStart(hdr.Start());
  394. if (hdr.NumStates() != kNoStateId) impl->ReserveStates(hdr.NumStates());
  395. StateId state = 0;
  396. for (; hdr.NumStates() == kNoStateId || state < hdr.NumStates(); ++state) {
  397. Weight weight;
  398. if (!weight.Read(strm)) break;
  399. impl->BaseImpl::AddState();
  400. auto *vstate = impl->GetState(state);
  401. vstate->SetFinal(weight);
  402. int64_t narcs;
  403. ReadType(strm, &narcs);
  404. if (!strm) {
  405. LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source;
  406. return nullptr;
  407. }
  408. impl->ReserveArcs(state, narcs);
  409. for (int64_t i = 0; i < narcs; ++i) {
  410. Arc arc;
  411. ReadType(strm, &arc.ilabel);
  412. ReadType(strm, &arc.olabel);
  413. arc.weight.Read(strm);
  414. ReadType(strm, &arc.nextstate);
  415. if (!strm) {
  416. LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source;
  417. return nullptr;
  418. }
  419. impl->BaseImpl::AddArc(state, std::move(arc));
  420. }
  421. }
  422. if (hdr.NumStates() != kNoStateId && state != hdr.NumStates()) {
  423. LOG(ERROR) << "VectorFst::Read: Unexpected end of file: " << opts.source;
  424. return nullptr;
  425. }
  426. return impl.release();
  427. }
  428. } // namespace internal
  429. // Simple concrete, mutable FST. This class attaches interface to implementation
  430. // and handles reference counting, delegating most methods to ImplToMutableFst.
  431. // Also supports ReserveStates and ReserveArcs methods (cf. STL vector methods).
  432. // The second optional template argument gives the State definition.
  433. //
  434. // VectorFst is thread-compatible.
  435. template <class A, class S /* = VectorState<A> */>
  436. class VectorFst : public ImplToMutableFst<internal::VectorFstImpl<S>> {
  437. public:
  438. using Arc = A;
  439. using StateId = typename Arc::StateId;
  440. using State = S;
  441. using Impl = internal::VectorFstImpl<State>;
  442. friend class StateIterator<VectorFst<Arc, State>>;
  443. friend class ArcIterator<VectorFst<Arc, State>>;
  444. friend class MutableArcIterator<VectorFst<A, S>>;
  445. template <class F, class G>
  446. friend void Cast(const F &, G *);
  447. VectorFst() : ImplToMutableFst<Impl>(std::make_shared<Impl>()) {}
  448. explicit VectorFst(const Fst<Arc> &fst)
  449. : ImplToMutableFst<Impl>(std::make_shared<Impl>(fst)) {}
  450. VectorFst(const VectorFst &fst, bool unused_safe = false)
  451. : ImplToMutableFst<Impl>(fst.GetSharedImpl()) {}
  452. VectorFst(VectorFst &&) noexcept;
  453. // Get a copy of this VectorFst. See Fst<>::Copy() for further doc.
  454. VectorFst *Copy(bool safe = false) const override {
  455. return new VectorFst(*this, safe);
  456. }
  457. VectorFst &operator=(const VectorFst &) = default;
  458. VectorFst &operator=(VectorFst &&) noexcept;
  459. VectorFst &operator=(const Fst<Arc> &fst) override {
  460. if (this != &fst) SetImpl(std::make_shared<Impl>(fst));
  461. return *this;
  462. }
  463. template <class... T>
  464. void EmplaceArc(StateId state, T &&...ctor_args) {
  465. MutateCheck();
  466. GetMutableImpl()->EmplaceArc(state, std::forward<T>(ctor_args)...);
  467. }
  468. // Reads a VectorFst from an input stream, returning nullptr on error.
  469. static VectorFst *Read(std::istream &strm, const FstReadOptions &opts) {
  470. auto *impl = Impl::Read(strm, opts);
  471. return impl ? new VectorFst(std::shared_ptr<Impl>(impl)) : nullptr;
  472. }
  473. // Read a VectorFst from a file, returning nullptr on error; empty source
  474. // reads from standard input.
  475. static VectorFst *Read(std::string_view source) {
  476. auto *impl = ImplToExpandedFst<Impl, MutableFst<Arc>>::Read(source);
  477. return impl ? new VectorFst(std::shared_ptr<Impl>(impl)) : nullptr;
  478. }
  479. bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
  480. return WriteFst(*this, strm, opts);
  481. }
  482. bool Write(const std::string &source) const override {
  483. return Fst<Arc>::WriteFile(source);
  484. }
  485. template <class FST>
  486. static bool WriteFst(const FST &fst, std::ostream &strm,
  487. const FstWriteOptions &opts);
  488. void InitStateIterator(StateIteratorData<Arc> *data) const override {
  489. GetImpl()->InitStateIterator(data);
  490. }
  491. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  492. GetImpl()->InitArcIterator(s, data);
  493. }
  494. inline void InitMutableArcIterator(StateId s,
  495. MutableArcIteratorData<Arc> *) override;
  496. using ImplToMutableFst<Impl, MutableFst<Arc>>::ReserveArcs;
  497. using ImplToMutableFst<Impl, MutableFst<Arc>>::ReserveStates;
  498. private:
  499. using ImplToMutableFst<Impl, MutableFst<Arc>>::GetImpl;
  500. using ImplToMutableFst<Impl, MutableFst<Arc>>::GetMutableImpl;
  501. using ImplToMutableFst<Impl, MutableFst<Arc>>::MutateCheck;
  502. using ImplToMutableFst<Impl, MutableFst<Arc>>::SetImpl;
  503. explicit VectorFst(std::shared_ptr<Impl> impl)
  504. : ImplToMutableFst<Impl>(impl) {}
  505. };
  506. template <class Arc, class State>
  507. inline VectorFst<Arc, State>::VectorFst(VectorFst &&fst) noexcept = default;
  508. template <class Arc, class State>
  509. inline VectorFst<Arc, State> &VectorFst<Arc, State>::operator=(
  510. VectorFst &&fst) noexcept = default;
  511. // Writes FST to file in Vector format, potentially with a pass over the machine
  512. // before writing to compute number of states.
  513. template <class Arc, class State>
  514. template <class FST>
  515. bool VectorFst<Arc, State>::WriteFst(const FST &fst, std::ostream &strm,
  516. const FstWriteOptions &opts) {
  517. static constexpr int file_version = 2;
  518. bool update_header = true;
  519. FstHeader hdr;
  520. hdr.SetStart(fst.Start());
  521. hdr.SetNumStates(kNoStateId);
  522. std::streampos start_offset = 0;
  523. if (fst.Properties(kExpanded, false) || opts.stream_write ||
  524. (start_offset = strm.tellp()) != -1) {
  525. hdr.SetNumStates(CountStates(fst));
  526. update_header = false;
  527. }
  528. const auto properties =
  529. fst.Properties(kCopyProperties, false) | Impl::kStaticProperties;
  530. internal::FstImpl<Arc>::WriteFstHeader(fst, strm, opts, file_version,
  531. "vector", properties, &hdr);
  532. StateId num_states = 0;
  533. for (StateIterator<FST> siter(fst); !siter.Done(); siter.Next()) {
  534. const auto s = siter.Value();
  535. fst.Final(s).Write(strm);
  536. const int64_t narcs = fst.NumArcs(s);
  537. WriteType(strm, narcs);
  538. for (ArcIterator<FST> aiter(fst, s); !aiter.Done(); aiter.Next()) {
  539. const auto &arc = aiter.Value();
  540. WriteType(strm, arc.ilabel);
  541. WriteType(strm, arc.olabel);
  542. arc.weight.Write(strm);
  543. WriteType(strm, arc.nextstate);
  544. }
  545. ++num_states;
  546. }
  547. strm.flush();
  548. if (!strm) {
  549. LOG(ERROR) << "VectorFst::Write: Write failed: " << opts.source;
  550. return false;
  551. }
  552. if (update_header) {
  553. hdr.SetNumStates(num_states);
  554. return internal::FstImpl<Arc>::UpdateFstHeader(
  555. fst, strm, opts, file_version, "vector", properties, &hdr,
  556. start_offset);
  557. } else {
  558. if (num_states != hdr.NumStates()) {
  559. LOG(ERROR) << "Inconsistent number of states observed during write";
  560. return false;
  561. }
  562. }
  563. return true;
  564. }
  565. // Specialization for VectorFst; see generic version in fst.h for sample usage
  566. // (but use the VectorFst type instead). This version should inline.
  567. template <class Arc, class State>
  568. class StateIterator<VectorFst<Arc, State>> {
  569. public:
  570. using StateId = typename Arc::StateId;
  571. explicit StateIterator(const VectorFst<Arc, State> &fst)
  572. : nstates_(fst.GetImpl()->NumStates()) {}
  573. bool Done() const { return s_ >= nstates_; }
  574. StateId Value() const { return s_; }
  575. void Next() { ++s_; }
  576. void Reset() { s_ = 0; }
  577. private:
  578. const StateId nstates_;
  579. StateId s_ = 0;
  580. };
  581. // Specialization for VectorFst; see generic version in fst.h for sample usage
  582. // (but use the VectorFst type instead). This version should inline.
  583. template <class Arc, class State>
  584. class ArcIterator<VectorFst<Arc, State>> {
  585. public:
  586. using StateId = typename Arc::StateId;
  587. ArcIterator(const VectorFst<Arc, State> &fst, StateId s)
  588. : arcs_(fst.GetImpl()->GetState(s)->Arcs()),
  589. narcs_(fst.GetImpl()->GetState(s)->NumArcs()) {}
  590. bool Done() const { return i_ >= narcs_; }
  591. const Arc &Value() const { return arcs_[i_]; }
  592. void Next() { ++i_; }
  593. void Reset() { i_ = 0; }
  594. void Seek(size_t a) { i_ = a; }
  595. size_t Position() const { return i_; }
  596. constexpr uint8_t Flags() const { return kArcValueFlags; }
  597. void SetFlags(uint8_t, uint8_t) {}
  598. private:
  599. const Arc *arcs_;
  600. size_t narcs_;
  601. size_t i_ = 0;
  602. };
  603. // Specialization for VectorFst; see generic version in mutable-fst.h for sample
  604. // usage (but use the VectorFst type instead). This version should inline.
  605. template <class Arc, class State>
  606. class MutableArcIterator<VectorFst<Arc, State>>
  607. : public MutableArcIteratorBase<Arc> {
  608. public:
  609. using StateId = typename Arc::StateId;
  610. using Weight = typename Arc::Weight;
  611. MutableArcIterator(VectorFst<Arc, State> *fst, StateId s) {
  612. fst->MutateCheck();
  613. state_ = fst->GetMutableImpl()->GetState(s);
  614. properties_ = &fst->GetImpl()->properties_;
  615. }
  616. bool Done() const final { return i_ >= state_->NumArcs(); }
  617. const Arc &Value() const final { return state_->GetArc(i_); }
  618. void Next() final { ++i_; }
  619. size_t Position() const final { return i_; }
  620. void Reset() final { i_ = 0; }
  621. void Seek(size_t a) final { i_ = a; }
  622. void SetValue(const Arc &arc) final {
  623. const auto &oarc = state_->GetArc(i_);
  624. uint64_t properties = properties_->load(std::memory_order_relaxed);
  625. if (oarc.ilabel != oarc.olabel) properties &= ~kNotAcceptor;
  626. if (oarc.ilabel == 0) {
  627. properties &= ~kIEpsilons;
  628. if (oarc.olabel == 0) properties &= ~kEpsilons;
  629. }
  630. if (oarc.olabel == 0) properties &= ~kOEpsilons;
  631. if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) {
  632. properties &= ~kWeighted;
  633. }
  634. state_->SetArc(arc, i_);
  635. if (arc.ilabel != arc.olabel) {
  636. properties |= kNotAcceptor;
  637. properties &= ~kAcceptor;
  638. }
  639. if (arc.ilabel == 0) {
  640. properties |= kIEpsilons;
  641. properties &= ~kNoIEpsilons;
  642. if (arc.olabel == 0) {
  643. properties |= kEpsilons;
  644. properties &= ~kNoEpsilons;
  645. }
  646. }
  647. if (arc.olabel == 0) {
  648. properties |= kOEpsilons;
  649. properties &= ~kNoOEpsilons;
  650. }
  651. if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
  652. properties |= kWeighted;
  653. properties &= ~kUnweighted;
  654. }
  655. properties &= kSetArcProperties | kAcceptor | kNotAcceptor | kEpsilons |
  656. kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons |
  657. kNoOEpsilons | kWeighted | kUnweighted;
  658. properties_->store(properties, std::memory_order_relaxed);
  659. }
  660. uint8_t Flags() const final { return kArcValueFlags; }
  661. void SetFlags(uint8_t, uint8_t) final {}
  662. private:
  663. State *state_;
  664. std::atomic<uint64_t> *properties_;
  665. size_t i_ = 0;
  666. };
  667. // Provides information needed for the generic mutable arc iterator.
  668. template <class Arc, class State>
  669. inline void VectorFst<Arc, State>::InitMutableArcIterator(
  670. StateId s, MutableArcIteratorData<Arc> *data) {
  671. data->base =
  672. std::make_unique<MutableArcIterator<VectorFst<Arc, State>>>(this, s);
  673. }
  674. // A useful alias when using StdArc.
  675. using StdVectorFst = VectorFst<StdArc>;
  676. } // namespace fst
  677. #endif // FST_VECTOR_FST_H_