|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
#ifndef FST_SCRIPT_ENCODEMAPPER_CLASS_H_
#define FST_SCRIPT_ENCODEMAPPER_CLASS_H_
#include <cstdint>
#include <iostream>
#include <istream>
#include <memory>
#include <ostream>
#include <string>
#include <utility>
#include <fst/encode.h>
#include <fst/generic-register.h>
#include <fst/symbol-table.h>
#include <fst/util.h>
#include <fst/script/arc-class.h>
#include <fst/script/fst-class.h>
#include <string_view>
// Scripting API support for EncodeMapper.
namespace fst { namespace script {
// Virtual interface implemented by each concrete EncodeMapperClassImpl<Arc>.
class EncodeMapperImplBase { public: // Returns an encoded ArcClass.
virtual ArcClass operator()(const ArcClass &) = 0; virtual const std::string &ArcType() const = 0; virtual const std::string &WeightType() const = 0; virtual EncodeMapperImplBase *Copy() const = 0; virtual uint8_t Flags() const = 0; virtual uint64_t Properties(uint64_t) = 0; virtual EncodeType Type() const = 0; virtual bool Write(const std::string &) const = 0; virtual bool Write(std::ostream &, const std::string &) const = 0; virtual const SymbolTable *InputSymbols() const = 0; virtual const SymbolTable *OutputSymbols() const = 0; virtual void SetInputSymbols(const SymbolTable *) = 0; virtual void SetOutputSymbols(const SymbolTable *) = 0; virtual ~EncodeMapperImplBase() = default; };
// Templated implementation.
template <class Arc> class EncodeMapperClassImpl : public EncodeMapperImplBase { public: explicit EncodeMapperClassImpl(const EncodeMapper<Arc> &mapper) : mapper_(mapper) {}
ArcClass operator()(const ArcClass &a) final;
const std::string &ArcType() const final { return Arc::Type(); }
const std::string &WeightType() const final { return Arc::Weight::Type(); }
EncodeMapperClassImpl<Arc> *Copy() const final { return new EncodeMapperClassImpl<Arc>(mapper_); }
uint8_t Flags() const final { return mapper_.Flags(); }
uint64_t Properties(uint64_t inprops) final { return mapper_.Properties(inprops); }
EncodeType Type() const final { return mapper_.Type(); }
bool Write(const std::string &source) const final { return mapper_.Write(source); }
bool Write(std::ostream &strm, const std::string &source) const final { return mapper_.Write(strm, source); }
const SymbolTable *InputSymbols() const final { return mapper_.InputSymbols(); }
const SymbolTable *OutputSymbols() const final { return mapper_.OutputSymbols(); }
void SetInputSymbols(const SymbolTable *syms) final { mapper_.SetInputSymbols(syms); }
void SetOutputSymbols(const SymbolTable *syms) final { mapper_.SetOutputSymbols(syms); }
~EncodeMapperClassImpl() override = default;
const EncodeMapper<Arc> *GetImpl() const { return &mapper_; }
EncodeMapper<Arc> *GetImpl() { return &mapper_; }
private: EncodeMapper<Arc> mapper_; };
template <class Arc> inline ArcClass EncodeMapperClassImpl<Arc>::operator()(const ArcClass &a) { const Arc arc(a.ilabel, a.olabel, *(a.weight.GetWeight<typename Arc::Weight>()), a.nextstate); return ArcClass(mapper_(arc)); }
class EncodeMapperClass { public: EncodeMapperClass() : impl_(nullptr) {}
EncodeMapperClass(std::string_view arc_type, uint8_t flags, EncodeType type = ENCODE);
template <class Arc> explicit EncodeMapperClass(const EncodeMapper<Arc> &mapper) : impl_(std::make_unique<EncodeMapperClassImpl<Arc>>(mapper)) {}
EncodeMapperClass(const EncodeMapperClass &other) : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {}
EncodeMapperClass &operator=(const EncodeMapperClass &other) { impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy()); return *this; }
ArcClass operator()(const ArcClass &arc) { return (*impl_)(arc); }
const std::string &ArcType() const { return impl_->ArcType(); }
const std::string &WeightType() const { return impl_->WeightType(); }
uint8_t Flags() const { return impl_->Flags(); }
uint64_t Properties(uint64_t inprops) { return impl_->Properties(inprops); }
EncodeType Type() const { return impl_->Type(); }
static std::unique_ptr<EncodeMapperClass> Read( const std::string &source);
static std::unique_ptr<EncodeMapperClass> Read( std::istream &strm, const std::string &source);
bool Write(const std::string &source) const { return impl_->Write(source); }
bool Write(std::ostream &strm, const std::string &source) const { return impl_->Write(strm, source); }
const SymbolTable *InputSymbols() const { return impl_->InputSymbols(); }
const SymbolTable *OutputSymbols() const { return impl_->OutputSymbols(); }
void SetInputSymbols(const SymbolTable *syms) { impl_->SetInputSymbols(syms); }
void SetOutputSymbols(const SymbolTable *syms) { impl_->SetOutputSymbols(syms); }
// Implementation stuff.
template <class Arc> EncodeMapper<Arc> *GetEncodeMapper() { if (Arc::Type() != ArcType()) { return nullptr; } else { auto *typed_impl = down_cast<EncodeMapperClassImpl<Arc> *>(impl_.get()); return typed_impl->GetImpl(); } }
template <class Arc> const EncodeMapper<Arc> *GetEncodeMapper() const { if (Arc::Type() != ArcType()) { return nullptr; } else { auto *typed_impl = down_cast<EncodeMapperClassImpl<Arc> *>(impl_.get()); return typed_impl->GetImpl(); } }
// Required for registration.
template <class Arc> static std::unique_ptr<EncodeMapperClass> Read(std::istream &strm, std::string_view source) { std::unique_ptr<EncodeMapper<Arc>> mapper( EncodeMapper<Arc>::Read(strm, source)); return mapper ? std::make_unique<EncodeMapperClass>(*mapper) : nullptr; }
template <class Arc> static std::unique_ptr<EncodeMapperImplBase> Create( uint8_t flags, EncodeType type = ENCODE) { return std::make_unique<EncodeMapperClassImpl<Arc>>( EncodeMapper<Arc>(flags, type)); }
private: explicit EncodeMapperClass(std::unique_ptr<EncodeMapperImplBase> impl) : impl_(std::move(impl)) {}
const EncodeMapperImplBase *GetImpl() const { return impl_.get(); }
EncodeMapperImplBase *GetImpl() { return impl_.get(); }
std::unique_ptr<EncodeMapperImplBase> impl_; };
// Registration for EncodeMapper types.
// This class definition is to avoid a nested class definition inside the
// EncodeMapperIORegistration struct.
template <class Reader, class Creator> struct EncodeMapperClassRegEntry { Reader reader; Creator creator;
EncodeMapperClassRegEntry(Reader reader, Creator creator) : reader(reader), creator(creator) {}
EncodeMapperClassRegEntry() : reader(nullptr), creator(nullptr) {} };
template <class Reader, class Creator> class EncodeMapperClassIORegister : public GenericRegister<std::string, EncodeMapperClassRegEntry<Reader, Creator>, EncodeMapperClassIORegister<Reader, Creator>> { public: Reader GetReader(std::string_view arc_type) const { return this->GetEntry(arc_type).reader; }
Creator GetCreator(std::string_view arc_type) const { return this->GetEntry(arc_type).creator; }
protected: std::string ConvertKeyToSoFilename(std::string_view key) const final { std::string legal_type(key); ConvertToLegalCSymbol(&legal_type); legal_type.append("-arc.so"); return legal_type; } };
// Struct containing everything needed to register a particular type
struct EncodeMapperClassIORegistration { using Reader = std::unique_ptr<EncodeMapperClass> (*)( std::istream &stream, std::string_view source);
using Creator = std::unique_ptr<EncodeMapperImplBase> (*)(uint8_t flags, EncodeType type);
using Entry = EncodeMapperClassRegEntry<Reader, Creator>;
// EncodeMapper register.
using Register = EncodeMapperClassIORegister<Reader, Creator>;
// EncodeMapper register-er.
using Registerer = GenericRegisterer<EncodeMapperClassIORegister<Reader, Creator>>; };
#define REGISTER_ENCODEMAPPER_CLASS(Arc) \
static EncodeMapperClassIORegistration::Registerer \ EncodeMapperClass_##Arc##_registerer( \ Arc::Type(), \ EncodeMapperClassIORegistration::Entry( \ EncodeMapperClass::Read<Arc>, EncodeMapperClass::Create<Arc>));
} // namespace script
} // namespace fst
#endif // FST_SCRIPT_ENCODEMAPPER_CLASS_H_
|