You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 lines
6.5 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. #ifndef FST_SCRIPT_ARCITERATOR_CLASS_H_
  18. #define FST_SCRIPT_ARCITERATOR_CLASS_H_
  19. #include <cstddef>
  20. #include <cstdint>
  21. #include <memory>
  22. #include <tuple>
  23. #include <utility>
  24. #include <fst/fst.h>
  25. #include <fst/fstlib.h>
  26. #include <fst/mutable-fst.h>
  27. #include <fst/script/arc-class.h>
  28. #include <fst/script/fst-class.h>
  29. // Scripting API support for ArcIterator.
  30. //
  31. // A call to Value() causes the underlying arc to be used to construct the
  32. // associated ArcClass.
  33. namespace fst {
  34. namespace script {
  35. // Non-mutable arc iterators.
  36. // Virtual interface implemented by each concrete ArcIteratorImpl<F>.
  37. class ArcIteratorImplBase {
  38. public:
  39. virtual bool Done() const = 0;
  40. virtual uint8_t Flags() const = 0;
  41. virtual void Next() = 0;
  42. virtual size_t Position() const = 0;
  43. virtual void Reset() = 0;
  44. virtual void Seek(size_t a) = 0;
  45. virtual void SetFlags(uint8_t flags, uint8_t mask) = 0;
  46. virtual ArcClass Value() const = 0;
  47. virtual ~ArcIteratorImplBase() = default;
  48. };
  49. // Templated implementation.
  50. template <class Arc>
  51. class ArcIteratorClassImpl : public ArcIteratorImplBase {
  52. public:
  53. explicit ArcIteratorClassImpl(const Fst<Arc> &fst, int64_t s)
  54. : aiter_(fst, s) {}
  55. bool Done() const final { return aiter_.Done(); }
  56. uint8_t Flags() const final { return aiter_.Flags(); }
  57. void Next() final { aiter_.Next(); }
  58. size_t Position() const final { return aiter_.Position(); }
  59. void Reset() final { aiter_.Reset(); }
  60. void Seek(size_t a) final { aiter_.Seek(a); }
  61. void SetFlags(uint8_t flags, uint8_t mask) final {
  62. aiter_.SetFlags(flags, mask);
  63. }
  64. // This is returned by value because it has not yet been constructed, and
  65. // is likely to participate in return-value optimization.
  66. ArcClass Value() const final { return ArcClass(aiter_.Value()); }
  67. ~ArcIteratorClassImpl() override = default;
  68. private:
  69. ArcIterator<Fst<Arc>> aiter_;
  70. };
  71. class ArcIteratorClass;
  72. using InitArcIteratorClassArgs =
  73. std::tuple<const FstClass &, int64_t, ArcIteratorClass *>;
  74. // Untemplated user-facing class holding a templated pimpl.
  75. class ArcIteratorClass {
  76. public:
  77. ArcIteratorClass(const FstClass &fst, int64_t s);
  78. template <class Arc>
  79. ArcIteratorClass(const Fst<Arc> &fst, int64_t s)
  80. : impl_(std::make_unique<ArcIteratorClassImpl<Arc>>(fst, s)) {}
  81. bool Done() const { return impl_->Done(); }
  82. uint8_t Flags() const { return impl_->Flags(); }
  83. void Next() { impl_->Next(); }
  84. size_t Position() const { return impl_->Position(); }
  85. void Reset() { impl_->Reset(); }
  86. void Seek(size_t a) { impl_->Seek(a); }
  87. void SetFlags(uint8_t flags, uint8_t mask) { impl_->SetFlags(flags, mask); }
  88. ArcClass Value() const { return impl_->Value(); }
  89. template <class Arc>
  90. friend void InitArcIteratorClass(InitArcIteratorClassArgs *args);
  91. private:
  92. std::unique_ptr<ArcIteratorImplBase> impl_;
  93. };
  94. template <class Arc>
  95. void InitArcIteratorClass(InitArcIteratorClassArgs *args) {
  96. const Fst<Arc> &fst = *std::get<0>(*args).GetFst<Arc>();
  97. std::get<2>(*args)->impl_ =
  98. std::make_unique<ArcIteratorClassImpl<Arc>>(fst, std::get<1>(*args));
  99. }
  100. // Mutable arc iterators.
  101. // Virtual interface implemented by each concrete MutableArcIteratorImpl<F>.
  102. class MutableArcIteratorImplBase : public ArcIteratorImplBase {
  103. public:
  104. virtual void SetValue(const ArcClass &) = 0;
  105. ~MutableArcIteratorImplBase() override = default;
  106. };
  107. // Templated implementation.
  108. template <class Arc>
  109. class MutableArcIteratorClassImpl : public MutableArcIteratorImplBase {
  110. public:
  111. explicit MutableArcIteratorClassImpl(MutableFst<Arc> *fst, int64_t s)
  112. : aiter_(fst, s) {}
  113. bool Done() const final { return aiter_.Done(); }
  114. uint8_t Flags() const final { return aiter_.Flags(); }
  115. void Next() final { aiter_.Next(); }
  116. size_t Position() const final { return aiter_.Position(); }
  117. void Reset() final { aiter_.Reset(); }
  118. void Seek(size_t a) final { aiter_.Seek(a); }
  119. void SetFlags(uint8_t flags, uint8_t mask) final {
  120. aiter_.SetFlags(flags, mask);
  121. }
  122. void SetValue(const ArcClass &ac) final { SetValue(ac.GetArc<Arc>()); }
  123. // This is returned by value because it has not yet been constructed, and
  124. // is likely to participate in return-value optimization.
  125. ArcClass Value() const final { return ArcClass(aiter_.Value()); }
  126. ~MutableArcIteratorClassImpl() override = default;
  127. private:
  128. void SetValue(const Arc &arc) { aiter_.SetValue(arc); }
  129. MutableArcIterator<MutableFst<Arc>> aiter_;
  130. };
  131. class MutableArcIteratorClass;
  132. using InitMutableArcIteratorClassArgs =
  133. std::tuple<MutableFstClass *, int64_t, MutableArcIteratorClass *>;
  134. // Untemplated user-facing class holding a templated pimpl.
  135. class MutableArcIteratorClass {
  136. public:
  137. MutableArcIteratorClass(MutableFstClass *fst, int64_t s);
  138. template <class Arc>
  139. MutableArcIteratorClass(MutableFst<Arc> *fst, int64_t s)
  140. : impl_(std::make_unique<MutableArcIteratorClassImpl<Arc>>(fst, s)) {}
  141. bool Done() const { return impl_->Done(); }
  142. uint8_t Flags() const { return impl_->Flags(); }
  143. void Next() { impl_->Next(); }
  144. size_t Position() const { return impl_->Position(); }
  145. void Reset() { impl_->Reset(); }
  146. void Seek(size_t a) { impl_->Seek(a); }
  147. void SetFlags(uint8_t flags, uint8_t mask) { impl_->SetFlags(flags, mask); }
  148. void SetValue(const ArcClass &ac) { impl_->SetValue(ac); }
  149. ArcClass Value() const { return impl_->Value(); }
  150. template <class Arc>
  151. friend void InitMutableArcIteratorClass(
  152. InitMutableArcIteratorClassArgs *args);
  153. private:
  154. std::unique_ptr<MutableArcIteratorImplBase> impl_;
  155. };
  156. template <class Arc>
  157. void InitMutableArcIteratorClass(InitMutableArcIteratorClassArgs *args) {
  158. MutableFst<Arc> *fst = std::get<0>(*args)->GetMutableFst<Arc>();
  159. std::get<2>(*args)->impl_ =
  160. std::make_unique<MutableArcIteratorClassImpl<Arc>>(fst,
  161. std::get<1>(*args));
  162. }
  163. } // namespace script
  164. } // namespace fst
  165. #endif // FST_SCRIPT_ARCITERATOR_CLASS_H_