You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

389 lines
13 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Class to add a matcher to an FST.
  19. #ifndef FST_MATCHER_FST_H_
  20. #define FST_MATCHER_FST_H_
  21. #include <cstdint>
  22. #include <istream>
  23. #include <memory>
  24. #include <ostream>
  25. #include <string>
  26. #include <fst/accumulator.h>
  27. #include <fst/add-on.h>
  28. #include <fst/arc.h>
  29. #include <fst/const-fst.h>
  30. #include <fst/expanded-fst.h>
  31. #include <fst/float-weight.h>
  32. #include <fst/fst.h>
  33. #include <fst/impl-to-fst.h>
  34. #include <fst/lookahead-matcher.h>
  35. #include <fst/matcher.h>
  36. #include <string_view>
  37. namespace fst {
  38. // Writeable matchers have the same interface as Matchers (as defined in
  39. // matcher.h) along with the following additional methods:
  40. //
  41. // template <class F>
  42. // class Matcher {
  43. // public:
  44. // using FST = F;
  45. // ...
  46. // using MatcherData = ...; // Initialization data.
  47. //
  48. // // Constructor with additional argument for external initialization data;
  49. // // matcher increments its reference count on construction and decrements
  50. // // the reference count, and deletes once the reference count has reached
  51. // // zero.
  52. // Matcher(const FST &fst, MatchType type, MatcherData *data);
  53. //
  54. // // Returns pointer to initialization data that can be passed to a Matcher
  55. // // constructor.
  56. // MatcherData *GetData() const;
  57. // };
  58. // The matcher initialization data class must also provide the following
  59. // interface:
  60. //
  61. // class MatcherData {
  62. // public:
  63. // // Required copy constructor.
  64. // MatcherData(const MatcherData &);
  65. //
  66. // // Required I/O methods.
  67. // static MatcherData *Read(std::istream &istrm, const FstReadOptions &opts);
  68. // bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const;
  69. // };
  70. // Trivial (no-op) MatcherFst initializer functor.
  71. template <class M>
  72. class NullMatcherFstInit {
  73. public:
  74. using MatcherData = typename M::MatcherData;
  75. using Data = AddOnPair<MatcherData, MatcherData>;
  76. using Impl = internal::AddOnImpl<typename M::FST, Data>;
  77. explicit NullMatcherFstInit(std::shared_ptr<Impl> *) {}
  78. };
  79. // Class adding a matcher to an FST type. Creates a new FST whose name is given
  80. // by N. An optional functor Init can be used to initialize the FST. The Data
  81. // template parameter allows the user to select the type of the add-on.
  82. template <
  83. class F, class M, const char *Name, class Init = NullMatcherFstInit<M>,
  84. class Data = AddOnPair<typename M::MatcherData, typename M::MatcherData>>
  85. class MatcherFst : public ImplToExpandedFst<internal::AddOnImpl<F, Data>> {
  86. public:
  87. using FST = F;
  88. using Arc = typename FST::Arc;
  89. using StateId = typename Arc::StateId;
  90. using FstMatcher = M;
  91. using MatcherData = typename FstMatcher::MatcherData;
  92. using Impl = internal::AddOnImpl<FST, Data>;
  93. using D = Data;
  94. friend class StateIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
  95. friend class ArcIterator<MatcherFst<FST, FstMatcher, Name, Init, Data>>;
  96. MatcherFst() : ImplToExpandedFst<Impl>(std::make_shared<Impl>(FST(), Name)) {}
  97. // Constructs a MatcherFst from an FST, which is the underlying FST type used
  98. // by this class. Uses the existing Data if present, and runs Init on it.
  99. // Stores fst internally, making a thread-safe copy of it.
  100. explicit MatcherFst(const FST &fst, std::shared_ptr<Data> data = nullptr)
  101. : ImplToExpandedFst<Impl>(data ? CreateImpl(fst, Name, data)
  102. : CreateDataAndImpl(fst, Name)) {}
  103. // Constructs a MatcherFst from an Fst<Arc>, which is *not* the underlying
  104. // FST type used by this class. Uses the existing Data if present, and
  105. // runs Init on it. Stores fst internally, converting Fst<Arc> to FST and
  106. // therefore making a deep copy.
  107. explicit MatcherFst(const Fst<Arc> &fst, std::shared_ptr<Data> data = nullptr)
  108. : ImplToExpandedFst<Impl>(data ? CreateImpl(fst, Name, data)
  109. : CreateDataAndImpl(fst, Name)) {}
  110. // See Fst<>::Copy() for doc.
  111. MatcherFst(const MatcherFst &fst, bool safe = false)
  112. : ImplToExpandedFst<Impl>(fst, safe) {}
  113. // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc.
  114. MatcherFst *Copy(bool safe = false) const override {
  115. return new MatcherFst(*this, safe);
  116. }
  117. // Read a MatcherFst from an input stream; return nullptr on error
  118. static MatcherFst *Read(std::istream &strm, const FstReadOptions &opts) {
  119. auto *impl = Impl::Read(strm, opts);
  120. return impl ? new MatcherFst(std::shared_ptr<Impl>(impl)) : nullptr;
  121. }
  122. // Read a MatcherFst from a file; return nullptr on error
  123. // Empty source reads from standard input
  124. static MatcherFst *Read(std::string_view source) {
  125. auto *impl = ImplToExpandedFst<Impl>::Read(source);
  126. return impl ? new MatcherFst(std::shared_ptr<Impl>(impl)) : nullptr;
  127. }
  128. bool Write(std::ostream &strm, const FstWriteOptions &opts) const override {
  129. return GetImpl()->Write(strm, opts);
  130. }
  131. bool Write(const std::string &source) const override {
  132. return Fst<Arc>::WriteFile(source);
  133. }
  134. void InitStateIterator(StateIteratorData<Arc> *data) const override {
  135. return GetImpl()->InitStateIterator(data);
  136. }
  137. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  138. return GetImpl()->InitArcIterator(s, data);
  139. }
  140. FstMatcher *InitMatcher(MatchType match_type) const override {
  141. return new FstMatcher(&GetFst(), match_type, GetSharedData(match_type));
  142. }
  143. const FST &GetFst() const { return GetImpl()->GetFst(); }
  144. const Data *GetAddOn() const { return GetImpl()->GetAddOn(); }
  145. std::shared_ptr<Data> GetSharedAddOn() const {
  146. return GetImpl()->GetSharedAddOn();
  147. }
  148. const MatcherData *GetData(MatchType match_type) const {
  149. const auto *data = GetAddOn();
  150. return match_type == MATCH_INPUT ? data->First() : data->Second();
  151. }
  152. std::shared_ptr<MatcherData> GetSharedData(MatchType match_type) const {
  153. const auto *data = GetAddOn();
  154. return match_type == MATCH_INPUT ? data->SharedFirst()
  155. : data->SharedSecond();
  156. }
  157. protected:
  158. using ImplToFst<Impl, ExpandedFst<Arc>>::GetImpl;
  159. // Makes a thread-safe copy of fst.
  160. static std::shared_ptr<Impl> CreateDataAndImpl(const FST &fst,
  161. std::string_view name) {
  162. FstMatcher imatcher(fst, MATCH_INPUT);
  163. FstMatcher omatcher(fst, MATCH_OUTPUT);
  164. return CreateImpl(fst, name,
  165. std::make_shared<Data>(imatcher.GetSharedData(),
  166. omatcher.GetSharedData()));
  167. }
  168. // Makes a deep copy of fst.
  169. static std::shared_ptr<Impl> CreateDataAndImpl(const Fst<Arc> &fst,
  170. std::string_view name) {
  171. FST result(fst);
  172. return CreateDataAndImpl(result, name);
  173. }
  174. // Makes a thread-safe copy of fst.
  175. static std::shared_ptr<Impl> CreateImpl(const FST &fst,
  176. std::string_view name,
  177. std::shared_ptr<Data> data) {
  178. auto impl = std::make_shared<Impl>(fst, name);
  179. impl->SetAddOn(data);
  180. Init init(&impl);
  181. return impl;
  182. }
  183. // Makes a deep copy of fst.
  184. static std::shared_ptr<Impl> CreateImpl(const Fst<Arc> &fst,
  185. std::string_view name,
  186. std::shared_ptr<Data> data) {
  187. auto impl = std::make_shared<Impl>(fst, name);
  188. impl->SetAddOn(data);
  189. Init init(&impl);
  190. return impl;
  191. }
  192. explicit MatcherFst(std::shared_ptr<Impl> impl)
  193. : ImplToExpandedFst<Impl>(impl) {}
  194. private:
  195. MatcherFst &operator=(const MatcherFst &) = delete;
  196. };
  197. // Specialization for MatcherFst.
  198. template <class FST, class M, const char *Name, class Init>
  199. class StateIterator<MatcherFst<FST, M, Name, Init>>
  200. : public StateIterator<FST> {
  201. public:
  202. explicit StateIterator(const MatcherFst<FST, M, Name, Init> &fst)
  203. : StateIterator<FST>(fst.GetImpl()->GetFst()) {}
  204. };
  205. // Specialization for MatcherFst.
  206. template <class FST, class M, const char *Name, class Init>
  207. class ArcIterator<MatcherFst<FST, M, Name, Init>> : public ArcIterator<FST> {
  208. public:
  209. using StateId = typename FST::Arc::StateId;
  210. ArcIterator(const MatcherFst<FST, M, Name, Init> &fst,
  211. typename FST::Arc::StateId s)
  212. : ArcIterator<FST>(fst.GetImpl()->GetFst(), s) {}
  213. };
  214. // Specialization for MatcherFst.
  215. template <class F, class M, const char *Name, class Init>
  216. class Matcher<MatcherFst<F, M, Name, Init>> {
  217. public:
  218. using FST = MatcherFst<F, M, Name, Init>;
  219. using Arc = typename F::Arc;
  220. using Label = typename Arc::Label;
  221. using StateId = typename Arc::StateId;
  222. Matcher(const FST &fst, MatchType match_type)
  223. : matcher_(fst.InitMatcher(match_type)) {}
  224. Matcher(const Matcher &matcher) : matcher_(matcher.matcher_->Copy()) {}
  225. Matcher *Copy() const { return new Matcher(*this); }
  226. MatchType Type(bool test) const { return matcher_->Type(test); }
  227. void SetState(StateId s) { matcher_->SetState(s); }
  228. bool Find(Label label) { return matcher_->Find(label); }
  229. bool Done() const { return matcher_->Done(); }
  230. const Arc &Value() const { return matcher_->Value(); }
  231. void Next() { matcher_->Next(); }
  232. uint64_t Properties(uint64_t props) const {
  233. return matcher_->Properties(props);
  234. }
  235. uint32_t Flags() const { return matcher_->Flags(); }
  236. private:
  237. std::unique_ptr<M> matcher_;
  238. };
  239. // Specialization for MatcherFst.
  240. template <class F, class M, const char *Name, class Init>
  241. class LookAheadMatcher<MatcherFst<F, M, Name, Init>> {
  242. public:
  243. using FST = MatcherFst<F, M, Name, Init>;
  244. using Arc = typename F::Arc;
  245. using Label = typename Arc::Label;
  246. using StateId = typename Arc::StateId;
  247. using Weight = typename Arc::Weight;
  248. LookAheadMatcher(const FST &fst, MatchType match_type)
  249. : matcher_(fst.InitMatcher(match_type)) {}
  250. LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false)
  251. : matcher_(matcher.matcher_->Copy(safe)) {}
  252. // General matcher methods.
  253. LookAheadMatcher *Copy(bool safe = false) const {
  254. return new LookAheadMatcher(*this, safe);
  255. }
  256. MatchType Type(bool test) const { return matcher_->Type(test); }
  257. void SetState(StateId s) { matcher_->SetState(s); }
  258. bool Find(Label label) { return matcher_->Find(label); }
  259. bool Done() const { return matcher_->Done(); }
  260. const Arc &Value() const { return matcher_->Value(); }
  261. void Next() { matcher_->Next(); }
  262. const FST &GetFst() const { return matcher_->GetFst(); }
  263. uint64_t Properties(uint64_t props) const {
  264. return matcher_->Properties(props);
  265. }
  266. uint32_t Flags() const { return matcher_->Flags(); }
  267. bool LookAheadLabel(Label label) const {
  268. return matcher_->LookAheadLabel(label);
  269. }
  270. bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
  271. return matcher_->LookAheadFst(fst, s);
  272. }
  273. Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); }
  274. bool LookAheadPrefix(Arc *arc) const {
  275. return matcher_->LookAheadPrefix(arc);
  276. }
  277. void InitLookAheadFst(const Fst<Arc> &fst, bool copy = false) {
  278. matcher_->InitLookAheadFst(fst, copy);
  279. }
  280. private:
  281. std::unique_ptr<M> matcher_;
  282. };
  283. // Useful aliases when using StdArc.
  284. inline constexpr char arc_lookahead_fst_type[] = "arc_lookahead";
  285. using StdArcLookAheadFst =
  286. MatcherFst<ConstFst<StdArc>,
  287. ArcLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>>,
  288. arc_lookahead_fst_type>;
  289. inline constexpr char ilabel_lookahead_fst_type[] = "ilabel_lookahead";
  290. inline constexpr char olabel_lookahead_fst_type[] = "olabel_lookahead";
  291. constexpr auto ilabel_lookahead_flags =
  292. kInputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
  293. kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
  294. constexpr auto olabel_lookahead_flags =
  295. kOutputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix |
  296. kLookAheadEpsilons | kLookAheadNonEpsilonPrefix;
  297. using StdILabelLookAheadFst = MatcherFst<
  298. ConstFst<StdArc>,
  299. LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
  300. ilabel_lookahead_flags, FastLogAccumulator<StdArc>>,
  301. ilabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
  302. using StdOLabelLookAheadFst = MatcherFst<
  303. ConstFst<StdArc>,
  304. LabelLookAheadMatcher<SortedMatcher<ConstFst<StdArc>>,
  305. olabel_lookahead_flags, FastLogAccumulator<StdArc>>,
  306. olabel_lookahead_fst_type, LabelLookAheadRelabeler<StdArc>>;
  307. } // namespace fst
  308. #endif // FST_MATCHER_FST_H_