You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

419 lines
14 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Synchronize an FST with bounded delay.
  19. #ifndef FST_SYNCHRONIZE_H_
  20. #define FST_SYNCHRONIZE_H_
  21. #include <algorithm>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <functional>
  25. #include <memory>
  26. #include <string>
  27. #include <string_view>
  28. #include <utility>
  29. #include <vector>
  30. #include <fst/cache.h>
  31. #include <fst/fst.h>
  32. #include <fst/impl-to-fst.h>
  33. #include <fst/mutable-fst.h>
  34. #include <fst/properties.h>
  35. #include <unordered_map>
  36. #include <unordered_set>
  37. namespace fst {
  38. using SynchronizeFstOptions = CacheOptions;
  39. namespace internal {
  40. // Implementation class for SynchronizeFst.
  41. // TODO(kbg,sorenj): Refactor to guarantee thread-safety.
  42. template <class Arc>
  43. class SynchronizeFstImpl : public CacheImpl<Arc> {
  44. public:
  45. using Label = typename Arc::Label;
  46. using StateId = typename Arc::StateId;
  47. using Weight = typename Arc::Weight;
  48. using FstImpl<Arc>::SetType;
  49. using FstImpl<Arc>::SetProperties;
  50. using FstImpl<Arc>::SetInputSymbols;
  51. using FstImpl<Arc>::SetOutputSymbols;
  52. using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
  53. using CacheBaseImpl<CacheState<Arc>>::HasArcs;
  54. using CacheBaseImpl<CacheState<Arc>>::HasFinal;
  55. using CacheBaseImpl<CacheState<Arc>>::HasStart;
  56. using CacheBaseImpl<CacheState<Arc>>::SetArcs;
  57. using CacheBaseImpl<CacheState<Arc>>::SetFinal;
  58. using CacheBaseImpl<CacheState<Arc>>::SetStart;
  59. // To avoid using `std::char_traits<Label>`, which is not guaranteed to exist,
  60. // use `char32_t` for the backing strings instead of `Label`. We should
  61. // probably use our own traits type instead.
  62. static_assert(sizeof(Label) <= sizeof(char32_t),
  63. "Label must fit in 32 bits. This is a hack.");
  64. using String = std::basic_string<char32_t>;
  65. using StringView = std::basic_string_view<char32_t>;
  66. struct Element {
  67. Element() = default;
  68. Element(StateId state_, StringView i, StringView o)
  69. : state(state_), istring(i), ostring(o) {}
  70. StateId state; // Input state ID.
  71. StringView istring; // Residual input labels.
  72. StringView ostring; // Residual output labels.
  73. // Residual strings are represented by std::basic_string_view<Label> whose
  74. // values are owned by the hash set string_set_.
  75. };
  76. SynchronizeFstImpl(const Fst<Arc> &fst, const SynchronizeFstOptions &opts)
  77. : CacheImpl<Arc>(opts), fst_(fst.Copy()) {
  78. SetType("synchronize");
  79. const auto props = fst.Properties(kFstProperties, false);
  80. SetProperties(SynchronizeProperties(props), kCopyProperties);
  81. SetInputSymbols(fst.InputSymbols());
  82. SetOutputSymbols(fst.OutputSymbols());
  83. }
  84. SynchronizeFstImpl(const SynchronizeFstImpl &impl)
  85. : CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)) {
  86. SetType("synchronize");
  87. SetProperties(impl.Properties(), kCopyProperties);
  88. SetInputSymbols(impl.InputSymbols());
  89. SetOutputSymbols(impl.OutputSymbols());
  90. }
  91. StateId Start() {
  92. if (!HasStart()) {
  93. auto start = fst_->Start();
  94. if (start == kNoStateId) return kNoStateId;
  95. const StringView empty = FindString(String());
  96. start = FindState(Element(fst_->Start(), empty, empty));
  97. SetStart(start);
  98. }
  99. return CacheImpl<Arc>::Start();
  100. }
  101. Weight Final(StateId s) {
  102. if (!HasFinal(s)) {
  103. const auto &element = elements_[s];
  104. const auto weight = element.state == kNoStateId
  105. ? Weight::One()
  106. : fst_->Final(element.state);
  107. if ((weight != Weight::Zero()) && element.istring.empty() &&
  108. element.ostring.empty()) {
  109. SetFinal(s, weight);
  110. } else {
  111. SetFinal(s, Weight::Zero());
  112. }
  113. }
  114. return CacheImpl<Arc>::Final(s);
  115. }
  116. size_t NumArcs(StateId s) {
  117. if (!HasArcs(s)) Expand(s);
  118. return CacheImpl<Arc>::NumArcs(s);
  119. }
  120. size_t NumInputEpsilons(StateId s) {
  121. if (!HasArcs(s)) Expand(s);
  122. return CacheImpl<Arc>::NumInputEpsilons(s);
  123. }
  124. size_t NumOutputEpsilons(StateId s) {
  125. if (!HasArcs(s)) Expand(s);
  126. return CacheImpl<Arc>::NumOutputEpsilons(s);
  127. }
  128. uint64_t Properties() const override { return Properties(kFstProperties); }
  129. // Sets error if found, returning other FST impl properties.
  130. uint64_t Properties(uint64_t mask) const override {
  131. if ((mask & kError) && fst_->Properties(kError, false)) {
  132. SetProperties(kError, kError);
  133. }
  134. return FstImpl<Arc>::Properties(mask);
  135. }
  136. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
  137. if (!HasArcs(s)) Expand(s);
  138. CacheImpl<Arc>::InitArcIterator(s, data);
  139. }
  140. // Returns the first character of the string obtained by concatenating the
  141. // string and the label.
  142. Label Car(StringView str, Label label = 0) const {
  143. if (!str.empty()) {
  144. return str[0];
  145. } else {
  146. return label;
  147. }
  148. }
  149. // Computes the residual string obtained by removing the first
  150. // character in the concatenation of the string and the label.
  151. StringView Cdr(StringView str, Label label = 0) {
  152. if (str.empty()) return FindString(String());
  153. return Concat(str.substr(1), label);
  154. }
  155. // Computes the concatenation of the string and the label.
  156. StringView Concat(StringView str, Label label = 0) {
  157. String r(str);
  158. if (label) r.push_back(label);
  159. return FindString(std::move(r));
  160. }
  161. // Tests if the concatenation of the string and label is empty.
  162. bool Empty(StringView str, Label label = 0) const {
  163. if (str.empty()) {
  164. return label == 0;
  165. } else {
  166. return false;
  167. }
  168. }
  169. StringView FindString(String &&str) {
  170. const auto [str_it, unused] = string_set_.insert(std::forward<String>(str));
  171. return *str_it;
  172. }
  173. // Finds state corresponding to an element. Creates new state if element
  174. // is not found.
  175. StateId FindState(const Element &element) {
  176. const auto &[iter, inserted] =
  177. element_map_.emplace(element, elements_.size());
  178. if (inserted) {
  179. elements_.push_back(element);
  180. }
  181. return iter->second;
  182. }
  183. // Computes the outgoing transitions from a state, creating new destination
  184. // states as needed.
  185. void Expand(StateId s) {
  186. const auto element = elements_[s];
  187. if (element.state != kNoStateId) {
  188. for (ArcIterator<Fst<Arc>> aiter(*fst_, element.state); !aiter.Done();
  189. aiter.Next()) {
  190. const auto &arc = aiter.Value();
  191. if (!Empty(element.istring, arc.ilabel) &&
  192. !Empty(element.ostring, arc.olabel)) {
  193. StringView istring = Cdr(element.istring, arc.ilabel);
  194. StringView ostring = Cdr(element.ostring, arc.olabel);
  195. EmplaceArc(s, Car(element.istring, arc.ilabel),
  196. Car(element.ostring, arc.olabel), arc.weight,
  197. FindState(Element(arc.nextstate, istring, ostring)));
  198. } else {
  199. StringView istring = Concat(element.istring, arc.ilabel);
  200. StringView ostring = Concat(element.ostring, arc.olabel);
  201. EmplaceArc(s, 0, 0, arc.weight,
  202. FindState(Element(arc.nextstate, istring, ostring)));
  203. }
  204. }
  205. }
  206. const auto weight = element.state == kNoStateId
  207. ? Weight::One()
  208. : fst_->Final(element.state);
  209. if ((weight != Weight::Zero()) &&
  210. (element.istring.size() + element.ostring.size() > 0)) {
  211. StringView istring = Cdr(element.istring);
  212. StringView ostring = Cdr(element.ostring);
  213. EmplaceArc(s, Car(element.istring), Car(element.ostring), weight,
  214. FindState(Element(kNoStateId, istring, ostring)));
  215. }
  216. SetArcs(s);
  217. }
  218. private:
  219. // Equality function for Elements; assumes strings have been hashed.
  220. class ElementEqual {
  221. public:
  222. bool operator()(const Element &x, const Element &y) const {
  223. return x.state == y.state && x.istring.data() == y.istring.data() &&
  224. x.ostring.data() == y.ostring.data();
  225. }
  226. };
  227. // Hash function for Elements to FST states.
  228. class ElementKey {
  229. public:
  230. size_t operator()(const Element &x) const {
  231. size_t key = x.state;
  232. key = (key << 1) ^ x.istring.size();
  233. for (Label label : x.istring) {
  234. key = (key << 1) ^ label;
  235. }
  236. key = (key << 1) ^ x.ostring.size();
  237. for (Label label : x.ostring) {
  238. key = (key << 1) ^ label;
  239. }
  240. return key;
  241. }
  242. };
  243. // Hash function for set of strings. This only has to be specified since
  244. // `std::hash<std::basic_string<T>>` is only guaranteed to be defined for
  245. // certain values of `T`. Not defining this works fine on clang, but fails
  246. // under GCC.
  247. class StringKey {
  248. public:
  249. size_t operator()(StringView x) const {
  250. size_t key = x.size();
  251. for (Label label : x) key = (key << 1) ^ label;
  252. return key;
  253. }
  254. };
  255. using ElementMap =
  256. std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
  257. using StringSet = std::unordered_set<String, StringKey>;
  258. std::unique_ptr<const Fst<Arc>> fst_;
  259. std::vector<Element> elements_; // Maps FST state to Elements.
  260. ElementMap element_map_; // Maps Elements to FST state.
  261. StringSet string_set_;
  262. };
  263. } // namespace internal
  264. // Synchronizes a transducer. This version is a delayed FST. The result is an
  265. // equivalent FST that has the property that during the traversal of a path,
  266. // the delay is either zero or strictly increasing, where the delay is the
  267. // difference between the number of non-epsilon output labels and input labels
  268. // along the path.
  269. //
  270. // For the algorithm to terminate, the input transducer must have bounded
  271. // delay, i.e., the delay of every cycle must be zero.
  272. //
  273. // Complexity:
  274. //
  275. // - A has bounded delay: exponential.
  276. // - A does not have bounded delay: does not terminate.
  277. //
  278. // For more information, see:
  279. //
  280. // Mohri, M. 2003. Edit-distance of weighted automata: General definitions and
  281. // algorithms. International Journal of Computer Science 14(6): 957-982.
  282. //
  283. // This class attaches interface to implementation and handles reference
  284. // counting, delegating most methods to ImplToFst.
  285. template <class A>
  286. class SynchronizeFst : public ImplToFst<internal::SynchronizeFstImpl<A>> {
  287. public:
  288. using Arc = A;
  289. using StateId = typename Arc::StateId;
  290. using Weight = typename Arc::Weight;
  291. using Store = DefaultCacheStore<Arc>;
  292. using State = typename Store::State;
  293. using Impl = internal::SynchronizeFstImpl<A>;
  294. friend class ArcIterator<SynchronizeFst<A>>;
  295. friend class StateIterator<SynchronizeFst<A>>;
  296. explicit SynchronizeFst(const Fst<A> &fst, const SynchronizeFstOptions &opts =
  297. SynchronizeFstOptions())
  298. : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
  299. // See Fst<>::Copy() for doc.
  300. SynchronizeFst(const SynchronizeFst &fst, bool safe = false)
  301. : ImplToFst<Impl>(fst, safe) {}
  302. // Gets a copy of this SynchronizeFst. See Fst<>::Copy() for further doc.
  303. SynchronizeFst *Copy(bool safe = false) const override {
  304. return new SynchronizeFst(*this, safe);
  305. }
  306. inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
  307. void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
  308. GetMutableImpl()->InitArcIterator(s, data);
  309. }
  310. private:
  311. using ImplToFst<Impl>::GetImpl;
  312. using ImplToFst<Impl>::GetMutableImpl;
  313. SynchronizeFst &operator=(const SynchronizeFst &) = delete;
  314. };
  315. // Specialization for SynchronizeFst.
  316. template <class Arc>
  317. class StateIterator<SynchronizeFst<Arc>>
  318. : public CacheStateIterator<SynchronizeFst<Arc>> {
  319. public:
  320. explicit StateIterator(const SynchronizeFst<Arc> &fst)
  321. : CacheStateIterator<SynchronizeFst<Arc>>(fst, fst.GetMutableImpl()) {}
  322. };
  323. // Specialization for SynchronizeFst.
  324. template <class Arc>
  325. class ArcIterator<SynchronizeFst<Arc>>
  326. : public CacheArcIterator<SynchronizeFst<Arc>> {
  327. public:
  328. using StateId = typename Arc::StateId;
  329. ArcIterator(const SynchronizeFst<Arc> &fst, StateId s)
  330. : CacheArcIterator<SynchronizeFst<Arc>>(fst.GetMutableImpl(), s) {
  331. if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
  332. }
  333. };
  334. template <class Arc>
  335. inline void SynchronizeFst<Arc>::InitStateIterator(
  336. StateIteratorData<Arc> *data) const {
  337. data->base = std::make_unique<StateIterator<SynchronizeFst<Arc>>>(*this);
  338. }
  339. // Synchronizes a transducer. This version writes the synchronized result to a
  340. // MutableFst. The result will be an equivalent FST that has the property that
  341. // during the traversal of a path, the delay is either zero or strictly
  342. // increasing, where the delay is the difference between the number of
  343. // non-epsilon output labels and input labels along the path.
  344. //
  345. // For the algorithm to terminate, the input transducer must have bounded
  346. // delay, i.e., the delay of every cycle must be zero.
  347. //
  348. // Complexity:
  349. //
  350. // - A has bounded delay: exponential.
  351. // - A does not have bounded delay: does not terminate.
  352. //
  353. // For more information, see:
  354. //
  355. // Mohri, M. 2003. Edit-distance of weighted automata: General definitions and
  356. // algorithms. International Journal of Computer Science 14(6): 957-982.
  357. template <class Arc>
  358. void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
  359. // Caches only the last state for fastest copy.
  360. const SynchronizeFstOptions opts(FST_FLAGS_fst_default_cache_gc,
  361. 0);
  362. *ofst = SynchronizeFst<Arc>(ifst, opts);
  363. }
  364. } // namespace fst
  365. #endif // FST_SYNCHRONIZE_H_