// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. #ifndef FST_SCRIPT_FST_CLASS_H_ #define FST_SCRIPT_FST_CLASS_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // Classes to support "boxing" all existing types of FST arcs in a single // FstClass which hides the arc types. This allows clients to load // and work with FSTs without knowing the arc type. These classes are only // recommended for use in high-level scripting applications. Most users should // use the lower-level templated versions corresponding to these classes. namespace fst { namespace script { // Abstract base class defining the set of functionalities implemented in all // impls and passed through by all bases. Below FstClassBase the class // hierarchy bifurcates; FstClassImplBase serves as the base class for all // implementations (of which FstClassImpl is currently the only one) and // FstClass serves as the base class for all interfaces. class FstClassBase { public: virtual const std::string &ArcType() const = 0; virtual WeightClass Final(int64_t) const = 0; virtual const std::string &FstType() const = 0; virtual const SymbolTable *InputSymbols() const = 0; virtual size_t NumArcs(int64_t) const = 0; virtual size_t NumInputEpsilons(int64_t) const = 0; virtual size_t NumOutputEpsilons(int64_t) const = 0; virtual const SymbolTable *OutputSymbols() const = 0; virtual uint64_t Properties(uint64_t, bool) const = 0; virtual int64_t Start() const = 0; virtual const std::string &WeightType() const = 0; virtual bool ValidStateId(int64_t) const = 0; virtual bool Write(const std::string &) const = 0; virtual bool Write(std::ostream &, const std::string &) const = 0; virtual ~FstClassBase() = default; }; // Adds all the MutableFst methods. class FstClassImplBase : public FstClassBase { public: virtual bool AddArc(int64_t, const ArcClass &) = 0; virtual int64_t AddState() = 0; virtual void AddStates(size_t) = 0; virtual FstClassImplBase *Copy() = 0; virtual bool DeleteArcs(int64_t, size_t) = 0; virtual bool DeleteArcs(int64_t) = 0; virtual bool DeleteStates(const std::vector &) = 0; virtual void DeleteStates() = 0; virtual SymbolTable *MutableInputSymbols() = 0; virtual SymbolTable *MutableOutputSymbols() = 0; virtual int64_t NumStates() const = 0; virtual bool ReserveArcs(int64_t, size_t) = 0; virtual void ReserveStates(int64_t) = 0; virtual void SetInputSymbols(const SymbolTable *) = 0; virtual bool SetFinal(int64_t, const WeightClass &) = 0; virtual void SetOutputSymbols(const SymbolTable *) = 0; virtual void SetProperties(uint64_t, uint64_t) = 0; virtual bool SetStart(int64_t) = 0; ~FstClassImplBase() override = default; }; // Containiner class wrapping an Fst, hiding its arc type. Whether this // Fst pointer refers to a special kind of FST (e.g. a MutableFst) is // known by the type of interface class that owns the pointer to this // container. template class FstClassImpl : public FstClassImplBase { public: explicit FstClassImpl(std::unique_ptr> impl) : impl_(std::move(impl)) {} explicit FstClassImpl(const Fst &impl) : impl_(impl.Copy()) {} // Warning: calling this method casts the FST to a mutable FST. bool AddArc(int64_t s, const ArcClass &ac) final { if (!ValidStateId(s)) return false; // Note that we do not check that the destination state is valid, so users // can add arcs before they add the corresponding states. Verify can be // used to determine whether any arc has a nonexisting destination. Arc arc(ac.ilabel, ac.olabel, *ac.weight.GetWeight(), ac.nextstate); down_cast *>(impl_.get())->AddArc(s, arc); return true; } // Warning: calling this method casts the FST to a mutable FST. int64_t AddState() final { return down_cast *>(impl_.get())->AddState(); } // Warning: calling this method casts the FST to a mutable FST. void AddStates(size_t n) final { return down_cast *>(impl_.get())->AddStates(n); } const std::string &ArcType() const final { return Arc::Type(); } FstClassImpl *Copy() final { return new FstClassImpl(*impl_); } // Warning: calling this method casts the FST to a mutable FST. bool DeleteArcs(int64_t s, size_t n) final { if (!ValidStateId(s)) return false; down_cast *>(impl_.get())->DeleteArcs(s, n); return true; } // Warning: calling this method casts the FST to a mutable FST. bool DeleteArcs(int64_t s) final { if (!ValidStateId(s)) return false; down_cast *>(impl_.get())->DeleteArcs(s); return true; } // Warning: calling this method casts the FST to a mutable FST. bool DeleteStates(const std::vector &dstates) final { for (const auto &state : dstates) if (!ValidStateId(state)) return false; // Warning: calling this method with any integers beyond the precision of // the underlying FST will result in truncation. std::vector typed_dstates(dstates.size()); std::copy(dstates.begin(), dstates.end(), typed_dstates.begin()); down_cast *>(impl_.get())->DeleteStates(typed_dstates); return true; } // Warning: calling this method casts the FST to a mutable FST. void DeleteStates() final { down_cast *>(impl_.get())->DeleteStates(); } WeightClass Final(int64_t s) const final { if (!ValidStateId(s)) return WeightClass::NoWeight(WeightType()); WeightClass w(impl_->Final(s)); return w; } const std::string &FstType() const final { return impl_->Type(); } const SymbolTable *InputSymbols() const final { return impl_->InputSymbols(); } // Warning: calling this method casts the FST to a mutable FST. SymbolTable *MutableInputSymbols() final { return down_cast *>(impl_.get())->MutableInputSymbols(); } // Warning: calling this method casts the FST to a mutable FST. SymbolTable *MutableOutputSymbols() final { return down_cast *>(impl_.get())->MutableOutputSymbols(); } // Signals failure by returning size_t max. size_t NumArcs(int64_t s) const final { return ValidStateId(s) ? impl_->NumArcs(s) : std::numeric_limits::max(); } // Signals failure by returning size_t max. size_t NumInputEpsilons(int64_t s) const final { return ValidStateId(s) ? impl_->NumInputEpsilons(s) : std::numeric_limits::max(); } // Signals failure by returning size_t max. size_t NumOutputEpsilons(int64_t s) const final { return ValidStateId(s) ? impl_->NumOutputEpsilons(s) : std::numeric_limits::max(); } // Warning: calling this method casts the FST to a mutable FST. int64_t NumStates() const final { return down_cast *>(impl_.get())->NumStates(); } uint64_t Properties(uint64_t mask, bool test) const final { return impl_->Properties(mask, test); } // Warning: calling this method casts the FST to a mutable FST. bool ReserveArcs(int64_t s, size_t n) final { if (!ValidStateId(s)) return false; down_cast *>(impl_.get())->ReserveArcs(s, n); return true; } // Warning: calling this method casts the FST to a mutable FST. void ReserveStates(int64_t n) final { down_cast *>(impl_.get())->ReserveStates(n); } const SymbolTable *OutputSymbols() const final { return impl_->OutputSymbols(); } // Warning: calling this method casts the FST to a mutable FST. void SetInputSymbols(const SymbolTable *isyms) final { down_cast *>(impl_.get())->SetInputSymbols(isyms); } // Warning: calling this method casts the FST to a mutable FST. bool SetFinal(int64_t s, const WeightClass &weight) final { if (!ValidStateId(s)) return false; down_cast *>(impl_.get()) ->SetFinal(s, *weight.GetWeight()); return true; } // Warning: calling this method casts the FST to a mutable FST. void SetOutputSymbols(const SymbolTable *osyms) final { down_cast *>(impl_.get())->SetOutputSymbols(osyms); } // Warning: calling this method casts the FST to a mutable FST. void SetProperties(uint64_t props, uint64_t mask) final { down_cast *>(impl_.get())->SetProperties(props, mask); } // Warning: calling this method casts the FST to a mutable FST. bool SetStart(int64_t s) final { if (!ValidStateId(s)) return false; down_cast *>(impl_.get())->SetStart(s); return true; } int64_t Start() const final { return impl_->Start(); } bool ValidStateId(int64_t s) const final { // This cowardly refuses to count states if the FST is not yet expanded. const auto num_states = impl_->NumStatesIfKnown(); if (!num_states.has_value()) { FSTERROR() << "Cannot get number of states for unexpanded FST"; return false; } if (s < 0 || s >= *num_states) { FSTERROR() << "State ID " << s << " not valid"; return false; } return true; } const std::string &WeightType() const final { return Arc::Weight::Type(); } bool Write(const std::string &source) const final { return impl_->Write(source); } bool Write(std::ostream &ostr, const std::string &source) const final { const FstWriteOptions opts(source); return impl_->Write(ostr, opts); } ~FstClassImpl() override = default; Fst *GetImpl() const { return impl_.get(); } private: std::unique_ptr> impl_; }; // BASE CLASS DEFINITIONS class MutableFstClass; class FstClass : public FstClassBase { public: FstClass() : impl_(nullptr) {} template explicit FstClass(std::unique_ptr> fst) : impl_(std::make_unique>(std::move(fst))) {} template explicit FstClass(const Fst &fst) : impl_(std::make_unique>(fst)) {} FstClass(const FstClass &other) : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {} FstClass &operator=(const FstClass &other) { impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy()); return *this; } WeightClass Final(int64_t s) const final { return impl_->Final(s); } const std::string &ArcType() const final { return impl_->ArcType(); } const std::string &FstType() const final { return impl_->FstType(); } const SymbolTable *InputSymbols() const final { return impl_->InputSymbols(); } size_t NumArcs(int64_t s) const final { return impl_->NumArcs(s); } size_t NumInputEpsilons(int64_t s) const final { return impl_->NumInputEpsilons(s); } size_t NumOutputEpsilons(int64_t s) const final { return impl_->NumOutputEpsilons(s); } const SymbolTable *OutputSymbols() const final { return impl_->OutputSymbols(); } uint64_t Properties(uint64_t mask, bool test) const final { // Special handling for FSTs with a null impl. if (!impl_) return kError & mask; return impl_->Properties(mask, test); } static std::unique_ptr Read( const std::string &source); static std::unique_ptr Read( std::istream &istrm, const std::string &source); int64_t Start() const final { return impl_->Start(); } bool ValidStateId(int64_t s) const final { return impl_->ValidStateId(s); } const std::string &WeightType() const final { return impl_->WeightType(); } // Helper that logs an ERROR if the weight type of an FST and a WeightClass // don't match. bool WeightTypesMatch(const WeightClass &weight, std::string_view op_name) const; bool Write(const std::string &source) const final { return impl_->Write(source); } bool Write(std::ostream &ostr, const std::string &source) const final { return impl_->Write(ostr, source); } ~FstClass() override = default; // These methods are required by IO registration. template static std::unique_ptr Convert(const FstClass &other) { FSTERROR() << "Doesn't make sense to convert any class to type FstClass"; return nullptr; } template static std::unique_ptr Create() { FSTERROR() << "Doesn't make sense to create an FstClass with a " << "particular arc type"; return nullptr; } template const Fst *GetFst() const { if (Arc::Type() != ArcType()) { return nullptr; } else { FstClassImpl *typed_impl = down_cast *>(impl_.get()); return typed_impl->GetImpl(); } } template static std::unique_ptr Read(std::istream &stream, const FstReadOptions &opts) { if (!opts.header) { LOG(ERROR) << "FstClass::Read: Options header not specified"; return nullptr; } const FstHeader &hdr = *opts.header; if (hdr.Properties() & kMutable) { return ReadTypedFst>(stream, opts); } else { return ReadTypedFst>(stream, opts); } } protected: explicit FstClass(std::unique_ptr impl) : impl_(std::move(impl)) {} const FstClassImplBase *GetImpl() const { return impl_.get(); } FstClassImplBase *GetImpl() { return impl_.get(); } // Generic template method for reading an arc-templated FST of type // UnderlyingT, and returning it wrapped as FstClassT, with appropriate // error checking. Called from arc-templated Read() static methods. template static std::unique_ptr ReadTypedFst(std::istream &stream, const FstReadOptions &opts) { std::unique_ptr u(UnderlyingT::Read(stream, opts)); return u ? std::make_unique(std::move(u)) : nullptr; } private: std::unique_ptr impl_; }; // Specific types of FstClass with special properties class MutableFstClass : public FstClass { public: bool AddArc(int64_t s, const ArcClass &ac) { if (!WeightTypesMatch(ac.weight, "AddArc")) return false; return GetImpl()->AddArc(s, ac); } int64_t AddState() { return GetImpl()->AddState(); } void AddStates(size_t n) { return GetImpl()->AddStates(n); } bool DeleteArcs(int64_t s, size_t n) { return GetImpl()->DeleteArcs(s, n); } bool DeleteArcs(int64_t s) { return GetImpl()->DeleteArcs(s); } bool DeleteStates(const std::vector &dstates) { return GetImpl()->DeleteStates(dstates); } void DeleteStates() { GetImpl()->DeleteStates(); } SymbolTable *MutableInputSymbols() { return GetImpl()->MutableInputSymbols(); } SymbolTable *MutableOutputSymbols() { return GetImpl()->MutableOutputSymbols(); } int64_t NumStates() const { return GetImpl()->NumStates(); } bool ReserveArcs(int64_t s, size_t n) { return GetImpl()->ReserveArcs(s, n); } void ReserveStates(int64_t n) { GetImpl()->ReserveStates(n); } static std::unique_ptr Read( const std::string &source, bool convert = false); void SetInputSymbols(const SymbolTable *isyms) { GetImpl()->SetInputSymbols(isyms); } bool SetFinal(int64_t s, const WeightClass &weight) { if (!WeightTypesMatch(weight, "SetFinal")) return false; return GetImpl()->SetFinal(s, weight); } void SetOutputSymbols(const SymbolTable *osyms) { GetImpl()->SetOutputSymbols(osyms); } void SetProperties(uint64_t props, uint64_t mask) { GetImpl()->SetProperties(props, mask); } bool SetStart(int64_t s) { return GetImpl()->SetStart(s); } template explicit MutableFstClass(std::unique_ptr> fst) // NB: The natural cast-less way to do this doesn't compile for some // arcane reason. : FstClass( fst::implicit_cast>>(std::move(fst))) {} template explicit MutableFstClass(const MutableFst &fst) : FstClass(fst) {} // These methods are required by IO registration. template static std::unique_ptr Convert(const FstClass &other) { FSTERROR() << "Doesn't make sense to convert any class to type " << "MutableFstClass"; return nullptr; } template static std::unique_ptr Create() { FSTERROR() << "Doesn't make sense to create a MutableFstClass with a " << "particular arc type"; return nullptr; } template MutableFst *GetMutableFst() { Fst *fst = const_cast *>(this->GetFst()); MutableFst *mfst = down_cast *>(fst); return mfst; } template static std::unique_ptr Read(std::istream &stream, const FstReadOptions &opts) { std::unique_ptr> mfst(MutableFst::Read(stream, opts)); return mfst ? std::make_unique(std::move(mfst)) : nullptr; } protected: explicit MutableFstClass(std::unique_ptr impl) : FstClass(std::move(impl)) {} }; class VectorFstClass : public MutableFstClass { public: explicit VectorFstClass(std::unique_ptr impl) : MutableFstClass(std::move(impl)) {} explicit VectorFstClass(const FstClass &other); explicit VectorFstClass(std::string_view arc_type); static std::unique_ptr Read( const std::string &source); template static std::unique_ptr Read(std::istream &stream, const FstReadOptions &opts) { std::unique_ptr> vfst(VectorFst::Read(stream, opts)); return vfst ? std::make_unique(std::move(vfst)) : nullptr; } template explicit VectorFstClass(std::unique_ptr> fst) // NB: The natural cast-less way to do this doesn't compile for some // arcane reason. : MutableFstClass(fst::implicit_cast>>( std::move(fst))) {} template explicit VectorFstClass(const VectorFst &fst) : MutableFstClass(fst) {} template static std::unique_ptr Convert(const FstClass &other) { return std::make_unique>( std::make_unique>(*other.GetFst())); } template static std::unique_ptr Create() { return std::make_unique>( std::make_unique>()); } }; // Registration stuff. // This class definition is to avoid a nested class definition inside the // FstClassIORegistration struct. template struct FstClassRegEntry { Reader reader; Creator creator; Converter converter; FstClassRegEntry(Reader r, Creator cr, Converter co) : reader(r), creator(cr), converter(co) {} FstClassRegEntry() : reader(nullptr), creator(nullptr), converter(nullptr) {} }; // Actual FST IO method register. template class FstClassIORegister : public GenericRegister, FstClassIORegister> { public: Reader GetReader(std::string_view arc_type) const { return this->GetEntry(arc_type).reader; } Creator GetCreator(std::string_view arc_type) const { return this->GetEntry(arc_type).creator; } Converter GetConverter(std::string_view arc_type) const { return this->GetEntry(arc_type).converter; } protected: std::string ConvertKeyToSoFilename(std::string_view key) const final { std::string legal_type(key); ConvertToLegalCSymbol(&legal_type); legal_type.append("-arc.so"); return legal_type; } }; // Struct containing everything needed to register a particular type // of FST class (e.g., a plain FstClass, or a MutableFstClass, etc.). template struct FstClassIORegistration { using Reader = std::unique_ptr (*)(std::istream &stream, const FstReadOptions &opts); using Creator = std::unique_ptr (*)(); using Converter = std::unique_ptr (*)(const FstClass &other); using Entry = FstClassRegEntry; // FST class Register. using Register = FstClassIORegister; // FST class Register-er. using Registerer = GenericRegisterer>; }; // Macros for registering other arc types. #define REGISTER_FST_CLASS(Class, Arc) \ static FstClassIORegistration::Registerer Class##_##Arc##_registerer( \ Arc::Type(), \ FstClassIORegistration::Entry( \ Class::Read, Class::Create, Class::Convert)) #define REGISTER_FST_CLASSES(Arc) \ REGISTER_FST_CLASS(FstClass, Arc); \ REGISTER_FST_CLASS(MutableFstClass, Arc); \ REGISTER_FST_CLASS(VectorFstClass, Arc); } // namespace script } // namespace fst #endif // FST_SCRIPT_FST_CLASS_H_