|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Synchronize an FST with bounded delay.
#ifndef FST_SYNCHRONIZE_H_
#define FST_SYNCHRONIZE_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include <fst/cache.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <unordered_map>
#include <unordered_set>
namespace fst {
using SynchronizeFstOptions = CacheOptions;
namespace internal {
// Implementation class for SynchronizeFst.
// TODO(kbg,sorenj): Refactor to guarantee thread-safety.
template <class Arc> class SynchronizeFstImpl : public CacheImpl<Arc> { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using FstImpl<Arc>::SetType; using FstImpl<Arc>::SetProperties; using FstImpl<Arc>::SetInputSymbols; using FstImpl<Arc>::SetOutputSymbols;
using CacheBaseImpl<CacheState<Arc>>::EmplaceArc; using CacheBaseImpl<CacheState<Arc>>::HasArcs; using CacheBaseImpl<CacheState<Arc>>::HasFinal; using CacheBaseImpl<CacheState<Arc>>::HasStart; using CacheBaseImpl<CacheState<Arc>>::SetArcs; using CacheBaseImpl<CacheState<Arc>>::SetFinal; using CacheBaseImpl<CacheState<Arc>>::SetStart;
// To avoid using `std::char_traits<Label>`, which is not guaranteed to exist,
// use `char32_t` for the backing strings instead of `Label`. We should
// probably use our own traits type instead.
static_assert(sizeof(Label) <= sizeof(char32_t), "Label must fit in 32 bits. This is a hack."); using String = std::basic_string<char32_t>; using StringView = std::basic_string_view<char32_t>;
struct Element { Element() = default;
Element(StateId state_, StringView i, StringView o) : state(state_), istring(i), ostring(o) {}
StateId state; // Input state ID.
StringView istring; // Residual input labels.
StringView ostring; // Residual output labels.
// Residual strings are represented by std::basic_string_view<Label> whose
// values are owned by the hash set string_set_.
};
SynchronizeFstImpl(const Fst<Arc> &fst, const SynchronizeFstOptions &opts) : CacheImpl<Arc>(opts), fst_(fst.Copy()) { SetType("synchronize"); const auto props = fst.Properties(kFstProperties, false); SetProperties(SynchronizeProperties(props), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); }
SynchronizeFstImpl(const SynchronizeFstImpl &impl) : CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)) { SetType("synchronize"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); }
StateId Start() { if (!HasStart()) { auto start = fst_->Start(); if (start == kNoStateId) return kNoStateId; const StringView empty = FindString(String()); start = FindState(Element(fst_->Start(), empty, empty)); SetStart(start); } return CacheImpl<Arc>::Start(); }
Weight Final(StateId s) { if (!HasFinal(s)) { const auto &element = elements_[s]; const auto weight = element.state == kNoStateId ? Weight::One() : fst_->Final(element.state); if ((weight != Weight::Zero()) && element.istring.empty() && element.ostring.empty()) { SetFinal(s, weight); } else { SetFinal(s, Weight::Zero()); } } return CacheImpl<Arc>::Final(s); }
size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<Arc>::NumArcs(s); }
size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<Arc>::NumInputEpsilons(s); }
size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<Arc>::NumOutputEpsilons(s); }
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, returning other FST impl properties.
uint64_t Properties(uint64_t mask) const override { if ((mask & kError) && fst_->Properties(kError, false)) { SetProperties(kError, kError); } return FstImpl<Arc>::Properties(mask); }
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) { if (!HasArcs(s)) Expand(s); CacheImpl<Arc>::InitArcIterator(s, data); }
// Returns the first character of the string obtained by concatenating the
// string and the label.
Label Car(StringView str, Label label = 0) const { if (!str.empty()) { return str[0]; } else { return label; } }
// Computes the residual string obtained by removing the first
// character in the concatenation of the string and the label.
StringView Cdr(StringView str, Label label = 0) { if (str.empty()) return FindString(String()); return Concat(str.substr(1), label); }
// Computes the concatenation of the string and the label.
StringView Concat(StringView str, Label label = 0) { String r(str); if (label) r.push_back(label); return FindString(std::move(r)); }
// Tests if the concatenation of the string and label is empty.
bool Empty(StringView str, Label label = 0) const { if (str.empty()) { return label == 0; } else { return false; } }
StringView FindString(String &&str) { const auto [str_it, unused] = string_set_.insert(std::forward<String>(str)); return *str_it; }
// Finds state corresponding to an element. Creates new state if element
// is not found.
StateId FindState(const Element &element) { const auto &[iter, inserted] = element_map_.emplace(element, elements_.size()); if (inserted) { elements_.push_back(element); } return iter->second; }
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) { const auto element = elements_[s]; if (element.state != kNoStateId) { for (ArcIterator<Fst<Arc>> aiter(*fst_, element.state); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); if (!Empty(element.istring, arc.ilabel) && !Empty(element.ostring, arc.olabel)) { StringView istring = Cdr(element.istring, arc.ilabel); StringView ostring = Cdr(element.ostring, arc.olabel); EmplaceArc(s, Car(element.istring, arc.ilabel), Car(element.ostring, arc.olabel), arc.weight, FindState(Element(arc.nextstate, istring, ostring))); } else { StringView istring = Concat(element.istring, arc.ilabel); StringView ostring = Concat(element.ostring, arc.olabel); EmplaceArc(s, 0, 0, arc.weight, FindState(Element(arc.nextstate, istring, ostring))); } } } const auto weight = element.state == kNoStateId ? Weight::One() : fst_->Final(element.state); if ((weight != Weight::Zero()) && (element.istring.size() + element.ostring.size() > 0)) { StringView istring = Cdr(element.istring); StringView ostring = Cdr(element.ostring); EmplaceArc(s, Car(element.istring), Car(element.ostring), weight, FindState(Element(kNoStateId, istring, ostring))); } SetArcs(s); }
private: // Equality function for Elements; assumes strings have been hashed.
class ElementEqual { public: bool operator()(const Element &x, const Element &y) const { return x.state == y.state && x.istring.data() == y.istring.data() && x.ostring.data() == y.ostring.data(); } };
// Hash function for Elements to FST states.
class ElementKey { public: size_t operator()(const Element &x) const { size_t key = x.state; key = (key << 1) ^ x.istring.size(); for (Label label : x.istring) { key = (key << 1) ^ label; } key = (key << 1) ^ x.ostring.size(); for (Label label : x.ostring) { key = (key << 1) ^ label; } return key; } };
// Hash function for set of strings. This only has to be specified since
// `std::hash<std::basic_string<T>>` is only guaranteed to be defined for
// certain values of `T`. Not defining this works fine on clang, but fails
// under GCC.
class StringKey { public: size_t operator()(StringView x) const { size_t key = x.size(); for (Label label : x) key = (key << 1) ^ label; return key; } };
using ElementMap = std::unordered_map<Element, StateId, ElementKey, ElementEqual>; using StringSet = std::unordered_set<String, StringKey>;
std::unique_ptr<const Fst<Arc>> fst_; std::vector<Element> elements_; // Maps FST state to Elements.
ElementMap element_map_; // Maps Elements to FST state.
StringSet string_set_; };
} // namespace internal
// Synchronizes a transducer. This version is a delayed FST. The result is an
// equivalent FST that has the property that during the traversal of a path,
// the delay is either zero or strictly increasing, where the delay is the
// difference between the number of non-epsilon output labels and input labels
// along the path.
//
// For the algorithm to terminate, the input transducer must have bounded
// delay, i.e., the delay of every cycle must be zero.
//
// Complexity:
//
// - A has bounded delay: exponential.
// - A does not have bounded delay: does not terminate.
//
// For more information, see:
//
// Mohri, M. 2003. Edit-distance of weighted automata: General definitions and
// algorithms. International Journal of Computer Science 14(6): 957-982.
//
// This class attaches interface to implementation and handles reference
// counting, delegating most methods to ImplToFst.
template <class A> class SynchronizeFst : public ImplToFst<internal::SynchronizeFstImpl<A>> { public: using Arc = A; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>; using State = typename Store::State; using Impl = internal::SynchronizeFstImpl<A>;
friend class ArcIterator<SynchronizeFst<A>>; friend class StateIterator<SynchronizeFst<A>>;
explicit SynchronizeFst(const Fst<A> &fst, const SynchronizeFstOptions &opts = SynchronizeFstOptions()) : ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
// See Fst<>::Copy() for doc.
SynchronizeFst(const SynchronizeFst &fst, bool safe = false) : ImplToFst<Impl>(fst, safe) {}
// Gets a copy of this SynchronizeFst. See Fst<>::Copy() for further doc.
SynchronizeFst *Copy(bool safe = false) const override { return new SynchronizeFst(*this, safe); }
inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override { GetMutableImpl()->InitArcIterator(s, data); }
private: using ImplToFst<Impl>::GetImpl; using ImplToFst<Impl>::GetMutableImpl;
SynchronizeFst &operator=(const SynchronizeFst &) = delete; };
// Specialization for SynchronizeFst.
template <class Arc> class StateIterator<SynchronizeFst<Arc>> : public CacheStateIterator<SynchronizeFst<Arc>> { public: explicit StateIterator(const SynchronizeFst<Arc> &fst) : CacheStateIterator<SynchronizeFst<Arc>>(fst, fst.GetMutableImpl()) {} };
// Specialization for SynchronizeFst.
template <class Arc> class ArcIterator<SynchronizeFst<Arc>> : public CacheArcIterator<SynchronizeFst<Arc>> { public: using StateId = typename Arc::StateId;
ArcIterator(const SynchronizeFst<Arc> &fst, StateId s) : CacheArcIterator<SynchronizeFst<Arc>>(fst.GetMutableImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); } };
template <class Arc> inline void SynchronizeFst<Arc>::InitStateIterator( StateIteratorData<Arc> *data) const { data->base = std::make_unique<StateIterator<SynchronizeFst<Arc>>>(*this); }
// Synchronizes a transducer. This version writes the synchronized result to a
// MutableFst. The result will be an equivalent FST that has the property that
// during the traversal of a path, the delay is either zero or strictly
// increasing, where the delay is the difference between the number of
// non-epsilon output labels and input labels along the path.
//
// For the algorithm to terminate, the input transducer must have bounded
// delay, i.e., the delay of every cycle must be zero.
//
// Complexity:
//
// - A has bounded delay: exponential.
// - A does not have bounded delay: does not terminate.
//
// For more information, see:
//
// Mohri, M. 2003. Edit-distance of weighted automata: General definitions and
// algorithms. International Journal of Computer Science 14(6): 957-982.
template <class Arc> void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) { // Caches only the last state for fastest copy.
const SynchronizeFstOptions opts(FST_FLAGS_fst_default_cache_gc, 0); *ofst = SynchronizeFst<Arc>(ifst, opts); }
} // namespace fst
#endif // FST_SYNCHRONIZE_H_
|