|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// An FST implementation and base interface for delayed unions, concatenations,
// and closures.
#ifndef FST_RATIONAL_H_
#define FST_RATIONAL_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <fst/cache.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/replace.h>
#include <fst/vector-fst.h>
namespace fst {
using RationalFstOptions = CacheOptions;
// This specifies whether to add the empty string.
enum ClosureType { CLOSURE_STAR = 0, // Add the empty string.
CLOSURE_PLUS = 1 // Don't add the empty string.
};
template <class Arc> class RationalFst;
template <class Arc> void Union(RationalFst<Arc> *fst1, const Fst<Arc> &fst2);
template <class Arc> void Concat(RationalFst<Arc> *fst1, const Fst<Arc> &fst2);
template <class Arc> void Concat(const Fst<Arc> &fst1, RationalFst<Arc> *fst2);
template <class Arc> void Closure(RationalFst<Arc> *fst, ClosureType closure_type);
namespace internal {
// Implementation class for delayed unions, concatenations and closures.
template <class A> class RationalFstImpl : public FstImpl<A> { public: using Arc = A; using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using FstImpl<Arc>::SetType; using FstImpl<Arc>::SetProperties; using FstImpl<Arc>::WriteHeader; using FstImpl<Arc>::SetInputSymbols; using FstImpl<Arc>::SetOutputSymbols;
explicit RationalFstImpl(const RationalFstOptions &opts) : nonterminals_(0), replace_options_(opts, 0) { SetType("rational"); fst_tuples_.emplace_back(0, nullptr); }
RationalFstImpl(const RationalFstImpl<Arc> &impl) : rfst_(impl.rfst_), nonterminals_(impl.nonterminals_), replace_(impl.replace_ ? impl.replace_->Copy(true) : nullptr), replace_options_(impl.replace_options_) { SetType("rational"); fst_tuples_.reserve(impl.fst_tuples_.size()); for (const auto &pair : impl.fst_tuples_) { fst_tuples_.emplace_back(pair.first, pair.second ? pair.second->Copy(true) : nullptr); } }
~RationalFstImpl() override { for (auto &tuple : fst_tuples_) delete tuple.second; }
StateId Start() { return Replace()->Start(); }
Weight Final(StateId s) { return Replace()->Final(s); }
size_t NumArcs(StateId s) { return Replace()->NumArcs(s); }
size_t NumInputEpsilons(StateId s) { return Replace()->NumInputEpsilons(s); }
size_t NumOutputEpsilons(StateId s) { return Replace()->NumOutputEpsilons(s); }
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override { if ((mask & kError) && Replace()->Properties(kError, false)) { SetProperties(kError, kError); } return FstImpl<Arc>::Properties(mask); }
// Implementation of UnionFst(fst1, fst2).
void InitUnion(const Fst<Arc> &fst1, const Fst<Arc> &fst2) { replace_.reset(); const auto props1 = fst1.Properties(kFstProperties, false); const auto props2 = fst2.Properties(kFstProperties, false); SetInputSymbols(fst1.InputSymbols()); SetOutputSymbols(fst1.OutputSymbols()); rfst_.AddState(); rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(1); rfst_.SetInputSymbols(fst1.InputSymbols()); rfst_.SetOutputSymbols(fst1.OutputSymbols()); nonterminals_ = 2; rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); rfst_.EmplaceArc(0, 0, -2, Weight::One(), 1); fst_tuples_.emplace_back(-1, fst1.Copy()); fst_tuples_.emplace_back(-2, fst2.Copy()); SetProperties(UnionProperties(props1, props2, true), kCopyProperties); }
// Implementation of ConcatFst(fst1, fst2).
void InitConcat(const Fst<Arc> &fst1, const Fst<Arc> &fst2) { replace_.reset(); const auto props1 = fst1.Properties(kFstProperties, false); const auto props2 = fst2.Properties(kFstProperties, false); SetInputSymbols(fst1.InputSymbols()); SetOutputSymbols(fst1.OutputSymbols()); rfst_.AddState(); rfst_.AddState(); rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(2); rfst_.SetInputSymbols(fst1.InputSymbols()); rfst_.SetOutputSymbols(fst1.OutputSymbols()); nonterminals_ = 2; rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); rfst_.EmplaceArc(1, 0, -2, Weight::One(), 2); fst_tuples_.emplace_back(-1, fst1.Copy()); fst_tuples_.emplace_back(-2, fst2.Copy()); SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); }
// Implementation of ClosureFst(fst, closure_type).
void InitClosure(const Fst<Arc> &fst, ClosureType closure_type) { replace_.reset(); const auto props = fst.Properties(kFstProperties, false); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); if (closure_type == CLOSURE_STAR) { rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(0); rfst_.EmplaceArc(0, 0, -1, Weight::One(), 0); } else { rfst_.AddState(); rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(1); rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); rfst_.EmplaceArc(1, 0, 0, Weight::One(), 0); } rfst_.SetInputSymbols(fst.InputSymbols()); rfst_.SetOutputSymbols(fst.OutputSymbols()); fst_tuples_.emplace_back(-1, fst.Copy()); nonterminals_ = 1; SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), kCopyProperties); }
// Implementation of Union(Fst &, RationalFst *).
void AddUnion(const Fst<Arc> &fst) { replace_.reset(); const auto props1 = FstImpl<A>::Properties(); const auto props2 = fst.Properties(kFstProperties, false); VectorFst<Arc> afst; afst.AddState(); afst.AddState(); afst.SetStart(0); afst.SetFinal(1); ++nonterminals_; afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1); Union(&rfst_, afst); fst_tuples_.emplace_back(-nonterminals_, fst.Copy()); SetProperties(UnionProperties(props1, props2, true), kCopyProperties); }
// Implementation of Concat(Fst &, RationalFst *).
void AddConcat(const Fst<Arc> &fst, bool append) { replace_.reset(); const auto props1 = FstImpl<A>::Properties(); const auto props2 = fst.Properties(kFstProperties, false); VectorFst<Arc> afst; afst.AddState(); afst.AddState(); afst.SetStart(0); afst.SetFinal(1); ++nonterminals_; afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1); if (append) { Concat(&rfst_, afst); } else { Concat(afst, &rfst_); } fst_tuples_.emplace_back(-nonterminals_, fst.Copy()); SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); }
// Implementation of Closure(RationalFst *, closure_type).
void AddClosure(ClosureType closure_type) { replace_.reset(); const auto props = FstImpl<A>::Properties(); Closure(&rfst_, closure_type); SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), kCopyProperties); }
// Returns the underlying ReplaceFst, preserving ownership of the underlying
// object.
ReplaceFst<Arc> *Replace() const { if (!replace_) { fst_tuples_[0].second = rfst_.Copy(); replace_ = std::make_unique<ReplaceFst<Arc>>(fst_tuples_, replace_options_); } return replace_.get(); }
private: // Rational topology machine, using negative non-terminals.
VectorFst<Arc> rfst_; // Number of nonterminals used.
Label nonterminals_; // Contains the nonterminals and their corresponding FSTs.
mutable std::vector<std::pair<Label, const Fst<Arc> *>> fst_tuples_; // Underlying ReplaceFst.
mutable std::unique_ptr<ReplaceFst<Arc>> replace_; const ReplaceFstOptions<Arc> replace_options_; };
} // namespace internal
// Parent class for the delayed rational operations (union, concatenation, and
// closure). This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A> class RationalFst : public ImplToFst<internal::RationalFstImpl<A>> { public: using Arc = A; using StateId = typename Arc::StateId;
using Impl = internal::RationalFstImpl<Arc>;
friend class StateIterator<RationalFst<Arc>>; friend class ArcIterator<RationalFst<Arc>>; friend void Union<>(RationalFst<Arc> *fst1, const Fst<Arc> &fst2); friend void Concat<>(RationalFst<Arc> *fst1, const Fst<Arc> &fst2); friend void Concat<>(const Fst<Arc> &fst1, RationalFst<Arc> *fst2); friend void Closure<>(RationalFst<Arc> *fst, ClosureType closure_type);
void InitStateIterator(StateIteratorData<Arc> *data) const override { GetImpl()->Replace()->InitStateIterator(data); }
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override { GetImpl()->Replace()->InitArcIterator(s, data); }
protected: using ImplToFst<Impl>::GetImpl;
explicit RationalFst(const RationalFstOptions &opts = RationalFstOptions()) : ImplToFst<Impl>(std::make_shared<Impl>(opts)) {}
// See Fst<>::Copy() for doc.
RationalFst(const RationalFst &fst, bool safe = false) : ImplToFst<Impl>(fst, safe) {}
private: RationalFst &operator=(const RationalFst &) = delete; };
// Specialization for RationalFst.
template <class Arc> class StateIterator<RationalFst<Arc>> : public StateIterator<ReplaceFst<Arc>> { public: explicit StateIterator(const RationalFst<Arc> &fst) : StateIterator<ReplaceFst<Arc>>(*(fst.GetImpl()->Replace())) {} };
// Specialization for RationalFst.
template <class Arc> class ArcIterator<RationalFst<Arc>> : public CacheArcIterator<ReplaceFst<Arc>> { public: using StateId = typename Arc::StateId;
ArcIterator(const RationalFst<Arc> &fst, StateId s) : ArcIterator<ReplaceFst<Arc>>(*(fst.GetImpl()->Replace()), s) {} };
} // namespace fst
#endif // FST_RATIONAL_H_
|