You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

419 lines
14 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Synchronize an FST with bounded delay.
#ifndef FST_SYNCHRONIZE_H_
#define FST_SYNCHRONIZE_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include <fst/cache.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <unordered_map>
#include <unordered_set>
namespace fst {
using SynchronizeFstOptions = CacheOptions;
namespace internal {
// Implementation class for SynchronizeFst.
// TODO(kbg,sorenj): Refactor to guarantee thread-safety.
template <class Arc>
class SynchronizeFstImpl : public CacheImpl<Arc> {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using FstImpl<Arc>::SetType;
using FstImpl<Arc>::SetProperties;
using FstImpl<Arc>::SetInputSymbols;
using FstImpl<Arc>::SetOutputSymbols;
using CacheBaseImpl<CacheState<Arc>>::EmplaceArc;
using CacheBaseImpl<CacheState<Arc>>::HasArcs;
using CacheBaseImpl<CacheState<Arc>>::HasFinal;
using CacheBaseImpl<CacheState<Arc>>::HasStart;
using CacheBaseImpl<CacheState<Arc>>::SetArcs;
using CacheBaseImpl<CacheState<Arc>>::SetFinal;
using CacheBaseImpl<CacheState<Arc>>::SetStart;
// To avoid using `std::char_traits<Label>`, which is not guaranteed to exist,
// use `char32_t` for the backing strings instead of `Label`. We should
// probably use our own traits type instead.
static_assert(sizeof(Label) <= sizeof(char32_t),
"Label must fit in 32 bits. This is a hack.");
using String = std::basic_string<char32_t>;
using StringView = std::basic_string_view<char32_t>;
struct Element {
Element() = default;
Element(StateId state_, StringView i, StringView o)
: state(state_), istring(i), ostring(o) {}
StateId state; // Input state ID.
StringView istring; // Residual input labels.
StringView ostring; // Residual output labels.
// Residual strings are represented by std::basic_string_view<Label> whose
// values are owned by the hash set string_set_.
};
SynchronizeFstImpl(const Fst<Arc> &fst, const SynchronizeFstOptions &opts)
: CacheImpl<Arc>(opts), fst_(fst.Copy()) {
SetType("synchronize");
const auto props = fst.Properties(kFstProperties, false);
SetProperties(SynchronizeProperties(props), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
SynchronizeFstImpl(const SynchronizeFstImpl &impl)
: CacheImpl<Arc>(impl), fst_(impl.fst_->Copy(true)) {
SetType("synchronize");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
StateId Start() {
if (!HasStart()) {
auto start = fst_->Start();
if (start == kNoStateId) return kNoStateId;
const StringView empty = FindString(String());
start = FindState(Element(fst_->Start(), empty, empty));
SetStart(start);
}
return CacheImpl<Arc>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
const auto &element = elements_[s];
const auto weight = element.state == kNoStateId
? Weight::One()
: fst_->Final(element.state);
if ((weight != Weight::Zero()) && element.istring.empty() &&
element.ostring.empty()) {
SetFinal(s, weight);
} else {
SetFinal(s, Weight::Zero());
}
}
return CacheImpl<Arc>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumOutputEpsilons(s);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, returning other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) && fst_->Properties(kError, false)) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
if (!HasArcs(s)) Expand(s);
CacheImpl<Arc>::InitArcIterator(s, data);
}
// Returns the first character of the string obtained by concatenating the
// string and the label.
Label Car(StringView str, Label label = 0) const {
if (!str.empty()) {
return str[0];
} else {
return label;
}
}
// Computes the residual string obtained by removing the first
// character in the concatenation of the string and the label.
StringView Cdr(StringView str, Label label = 0) {
if (str.empty()) return FindString(String());
return Concat(str.substr(1), label);
}
// Computes the concatenation of the string and the label.
StringView Concat(StringView str, Label label = 0) {
String r(str);
if (label) r.push_back(label);
return FindString(std::move(r));
}
// Tests if the concatenation of the string and label is empty.
bool Empty(StringView str, Label label = 0) const {
if (str.empty()) {
return label == 0;
} else {
return false;
}
}
StringView FindString(String &&str) {
const auto [str_it, unused] = string_set_.insert(std::forward<String>(str));
return *str_it;
}
// Finds state corresponding to an element. Creates new state if element
// is not found.
StateId FindState(const Element &element) {
const auto &[iter, inserted] =
element_map_.emplace(element, elements_.size());
if (inserted) {
elements_.push_back(element);
}
return iter->second;
}
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) {
const auto element = elements_[s];
if (element.state != kNoStateId) {
for (ArcIterator<Fst<Arc>> aiter(*fst_, element.state); !aiter.Done();
aiter.Next()) {
const auto &arc = aiter.Value();
if (!Empty(element.istring, arc.ilabel) &&
!Empty(element.ostring, arc.olabel)) {
StringView istring = Cdr(element.istring, arc.ilabel);
StringView ostring = Cdr(element.ostring, arc.olabel);
EmplaceArc(s, Car(element.istring, arc.ilabel),
Car(element.ostring, arc.olabel), arc.weight,
FindState(Element(arc.nextstate, istring, ostring)));
} else {
StringView istring = Concat(element.istring, arc.ilabel);
StringView ostring = Concat(element.ostring, arc.olabel);
EmplaceArc(s, 0, 0, arc.weight,
FindState(Element(arc.nextstate, istring, ostring)));
}
}
}
const auto weight = element.state == kNoStateId
? Weight::One()
: fst_->Final(element.state);
if ((weight != Weight::Zero()) &&
(element.istring.size() + element.ostring.size() > 0)) {
StringView istring = Cdr(element.istring);
StringView ostring = Cdr(element.ostring);
EmplaceArc(s, Car(element.istring), Car(element.ostring), weight,
FindState(Element(kNoStateId, istring, ostring)));
}
SetArcs(s);
}
private:
// Equality function for Elements; assumes strings have been hashed.
class ElementEqual {
public:
bool operator()(const Element &x, const Element &y) const {
return x.state == y.state && x.istring.data() == y.istring.data() &&
x.ostring.data() == y.ostring.data();
}
};
// Hash function for Elements to FST states.
class ElementKey {
public:
size_t operator()(const Element &x) const {
size_t key = x.state;
key = (key << 1) ^ x.istring.size();
for (Label label : x.istring) {
key = (key << 1) ^ label;
}
key = (key << 1) ^ x.ostring.size();
for (Label label : x.ostring) {
key = (key << 1) ^ label;
}
return key;
}
};
// Hash function for set of strings. This only has to be specified since
// `std::hash<std::basic_string<T>>` is only guaranteed to be defined for
// certain values of `T`. Not defining this works fine on clang, but fails
// under GCC.
class StringKey {
public:
size_t operator()(StringView x) const {
size_t key = x.size();
for (Label label : x) key = (key << 1) ^ label;
return key;
}
};
using ElementMap =
std::unordered_map<Element, StateId, ElementKey, ElementEqual>;
using StringSet = std::unordered_set<String, StringKey>;
std::unique_ptr<const Fst<Arc>> fst_;
std::vector<Element> elements_; // Maps FST state to Elements.
ElementMap element_map_; // Maps Elements to FST state.
StringSet string_set_;
};
} // namespace internal
// Synchronizes a transducer. This version is a delayed FST. The result is an
// equivalent FST that has the property that during the traversal of a path,
// the delay is either zero or strictly increasing, where the delay is the
// difference between the number of non-epsilon output labels and input labels
// along the path.
//
// For the algorithm to terminate, the input transducer must have bounded
// delay, i.e., the delay of every cycle must be zero.
//
// Complexity:
//
// - A has bounded delay: exponential.
// - A does not have bounded delay: does not terminate.
//
// For more information, see:
//
// Mohri, M. 2003. Edit-distance of weighted automata: General definitions and
// algorithms. International Journal of Computer Science 14(6): 957-982.
//
// This class attaches interface to implementation and handles reference
// counting, delegating most methods to ImplToFst.
template <class A>
class SynchronizeFst : public ImplToFst<internal::SynchronizeFstImpl<A>> {
public:
using Arc = A;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using Impl = internal::SynchronizeFstImpl<A>;
friend class ArcIterator<SynchronizeFst<A>>;
friend class StateIterator<SynchronizeFst<A>>;
explicit SynchronizeFst(const Fst<A> &fst, const SynchronizeFstOptions &opts =
SynchronizeFstOptions())
: ImplToFst<Impl>(std::make_shared<Impl>(fst, opts)) {}
// See Fst<>::Copy() for doc.
SynchronizeFst(const SynchronizeFst &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Gets a copy of this SynchronizeFst. See Fst<>::Copy() for further doc.
SynchronizeFst *Copy(bool safe = false) const override {
return new SynchronizeFst(*this, safe);
}
inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
GetMutableImpl()->InitArcIterator(s, data);
}
private:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
SynchronizeFst &operator=(const SynchronizeFst &) = delete;
};
// Specialization for SynchronizeFst.
template <class Arc>
class StateIterator<SynchronizeFst<Arc>>
: public CacheStateIterator<SynchronizeFst<Arc>> {
public:
explicit StateIterator(const SynchronizeFst<Arc> &fst)
: CacheStateIterator<SynchronizeFst<Arc>>(fst, fst.GetMutableImpl()) {}
};
// Specialization for SynchronizeFst.
template <class Arc>
class ArcIterator<SynchronizeFst<Arc>>
: public CacheArcIterator<SynchronizeFst<Arc>> {
public:
using StateId = typename Arc::StateId;
ArcIterator(const SynchronizeFst<Arc> &fst, StateId s)
: CacheArcIterator<SynchronizeFst<Arc>>(fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class Arc>
inline void SynchronizeFst<Arc>::InitStateIterator(
StateIteratorData<Arc> *data) const {
data->base = std::make_unique<StateIterator<SynchronizeFst<Arc>>>(*this);
}
// Synchronizes a transducer. This version writes the synchronized result to a
// MutableFst. The result will be an equivalent FST that has the property that
// during the traversal of a path, the delay is either zero or strictly
// increasing, where the delay is the difference between the number of
// non-epsilon output labels and input labels along the path.
//
// For the algorithm to terminate, the input transducer must have bounded
// delay, i.e., the delay of every cycle must be zero.
//
// Complexity:
//
// - A has bounded delay: exponential.
// - A does not have bounded delay: does not terminate.
//
// For more information, see:
//
// Mohri, M. 2003. Edit-distance of weighted automata: General definitions and
// algorithms. International Journal of Computer Science 14(6): 957-982.
template <class Arc>
void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
// Caches only the last state for fastest copy.
const SynchronizeFstOptions opts(FST_FLAGS_fst_default_cache_gc,
0);
*ofst = SynchronizeFst<Arc>(ifst, opts);
}
} // namespace fst
#endif // FST_SYNCHRONIZE_H_