// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. #ifndef FST_SCRIPT_ENCODEMAPPER_CLASS_H_ #define FST_SCRIPT_ENCODEMAPPER_CLASS_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include // Scripting API support for EncodeMapper. namespace fst { namespace script { // Virtual interface implemented by each concrete EncodeMapperClassImpl. class EncodeMapperImplBase { public: // Returns an encoded ArcClass. virtual ArcClass operator()(const ArcClass &) = 0; virtual const std::string &ArcType() const = 0; virtual const std::string &WeightType() const = 0; virtual EncodeMapperImplBase *Copy() const = 0; virtual uint8_t Flags() const = 0; virtual uint64_t Properties(uint64_t) = 0; virtual EncodeType Type() const = 0; virtual bool Write(const std::string &) const = 0; virtual bool Write(std::ostream &, const std::string &) const = 0; virtual const SymbolTable *InputSymbols() const = 0; virtual const SymbolTable *OutputSymbols() const = 0; virtual void SetInputSymbols(const SymbolTable *) = 0; virtual void SetOutputSymbols(const SymbolTable *) = 0; virtual ~EncodeMapperImplBase() = default; }; // Templated implementation. template class EncodeMapperClassImpl : public EncodeMapperImplBase { public: explicit EncodeMapperClassImpl(const EncodeMapper &mapper) : mapper_(mapper) {} ArcClass operator()(const ArcClass &a) final; const std::string &ArcType() const final { return Arc::Type(); } const std::string &WeightType() const final { return Arc::Weight::Type(); } EncodeMapperClassImpl *Copy() const final { return new EncodeMapperClassImpl(mapper_); } uint8_t Flags() const final { return mapper_.Flags(); } uint64_t Properties(uint64_t inprops) final { return mapper_.Properties(inprops); } EncodeType Type() const final { return mapper_.Type(); } bool Write(const std::string &source) const final { return mapper_.Write(source); } bool Write(std::ostream &strm, const std::string &source) const final { return mapper_.Write(strm, source); } const SymbolTable *InputSymbols() const final { return mapper_.InputSymbols(); } const SymbolTable *OutputSymbols() const final { return mapper_.OutputSymbols(); } void SetInputSymbols(const SymbolTable *syms) final { mapper_.SetInputSymbols(syms); } void SetOutputSymbols(const SymbolTable *syms) final { mapper_.SetOutputSymbols(syms); } ~EncodeMapperClassImpl() override = default; const EncodeMapper *GetImpl() const { return &mapper_; } EncodeMapper *GetImpl() { return &mapper_; } private: EncodeMapper mapper_; }; template inline ArcClass EncodeMapperClassImpl::operator()(const ArcClass &a) { const Arc arc(a.ilabel, a.olabel, *(a.weight.GetWeight()), a.nextstate); return ArcClass(mapper_(arc)); } class EncodeMapperClass { public: EncodeMapperClass() : impl_(nullptr) {} EncodeMapperClass(std::string_view arc_type, uint8_t flags, EncodeType type = ENCODE); template explicit EncodeMapperClass(const EncodeMapper &mapper) : impl_(std::make_unique>(mapper)) {} EncodeMapperClass(const EncodeMapperClass &other) : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {} EncodeMapperClass &operator=(const EncodeMapperClass &other) { impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy()); return *this; } ArcClass operator()(const ArcClass &arc) { return (*impl_)(arc); } const std::string &ArcType() const { return impl_->ArcType(); } const std::string &WeightType() const { return impl_->WeightType(); } uint8_t Flags() const { return impl_->Flags(); } uint64_t Properties(uint64_t inprops) { return impl_->Properties(inprops); } EncodeType Type() const { return impl_->Type(); } static std::unique_ptr Read( const std::string &source); static std::unique_ptr Read( std::istream &strm, const std::string &source); bool Write(const std::string &source) const { return impl_->Write(source); } bool Write(std::ostream &strm, const std::string &source) const { return impl_->Write(strm, source); } const SymbolTable *InputSymbols() const { return impl_->InputSymbols(); } const SymbolTable *OutputSymbols() const { return impl_->OutputSymbols(); } void SetInputSymbols(const SymbolTable *syms) { impl_->SetInputSymbols(syms); } void SetOutputSymbols(const SymbolTable *syms) { impl_->SetOutputSymbols(syms); } // Implementation stuff. template EncodeMapper *GetEncodeMapper() { if (Arc::Type() != ArcType()) { return nullptr; } else { auto *typed_impl = down_cast *>(impl_.get()); return typed_impl->GetImpl(); } } template const EncodeMapper *GetEncodeMapper() const { if (Arc::Type() != ArcType()) { return nullptr; } else { auto *typed_impl = down_cast *>(impl_.get()); return typed_impl->GetImpl(); } } // Required for registration. template static std::unique_ptr Read(std::istream &strm, std::string_view source) { std::unique_ptr> mapper( EncodeMapper::Read(strm, source)); return mapper ? std::make_unique(*mapper) : nullptr; } template static std::unique_ptr Create( uint8_t flags, EncodeType type = ENCODE) { return std::make_unique>( EncodeMapper(flags, type)); } private: explicit EncodeMapperClass(std::unique_ptr impl) : impl_(std::move(impl)) {} const EncodeMapperImplBase *GetImpl() const { return impl_.get(); } EncodeMapperImplBase *GetImpl() { return impl_.get(); } std::unique_ptr impl_; }; // Registration for EncodeMapper types. // This class definition is to avoid a nested class definition inside the // EncodeMapperIORegistration struct. template struct EncodeMapperClassRegEntry { Reader reader; Creator creator; EncodeMapperClassRegEntry(Reader reader, Creator creator) : reader(reader), creator(creator) {} EncodeMapperClassRegEntry() : reader(nullptr), creator(nullptr) {} }; template class EncodeMapperClassIORegister : public GenericRegister, EncodeMapperClassIORegister> { public: Reader GetReader(std::string_view arc_type) const { return this->GetEntry(arc_type).reader; } Creator GetCreator(std::string_view arc_type) const { return this->GetEntry(arc_type).creator; } protected: std::string ConvertKeyToSoFilename(std::string_view key) const final { std::string legal_type(key); ConvertToLegalCSymbol(&legal_type); legal_type.append("-arc.so"); return legal_type; } }; // Struct containing everything needed to register a particular type struct EncodeMapperClassIORegistration { using Reader = std::unique_ptr (*)( std::istream &stream, std::string_view source); using Creator = std::unique_ptr (*)(uint8_t flags, EncodeType type); using Entry = EncodeMapperClassRegEntry; // EncodeMapper register. using Register = EncodeMapperClassIORegister; // EncodeMapper register-er. using Registerer = GenericRegisterer>; }; #define REGISTER_ENCODEMAPPER_CLASS(Arc) \ static EncodeMapperClassIORegistration::Registerer \ EncodeMapperClass_##Arc##_registerer( \ Arc::Type(), \ EncodeMapperClassIORegistration::Entry( \ EncodeMapperClass::Read, EncodeMapperClass::Create)); } // namespace script } // namespace fst #endif // FST_SCRIPT_ENCODEMAPPER_CLASS_H_