|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Represents a generic weight in an FST; that is, represents a specific type
// of weight underneath while hiding that type from a client.
#ifndef FST_SCRIPT_WEIGHT_CLASS_H_
#define FST_SCRIPT_WEIGHT_CLASS_H_
#include <cstddef>
#include <memory>
#include <ostream>
#include <string>
#include <fst/arc.h>
#include <fst/generic-register.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <string_view>
namespace fst { namespace script {
class WeightImplBase { public: virtual WeightImplBase *Copy() const = 0; virtual void Print(std::ostream *o) const = 0; virtual const std::string &Type() const = 0; virtual std::string ToString() const = 0; virtual bool Member() const = 0; virtual bool operator==(const WeightImplBase &other) const = 0; virtual bool operator!=(const WeightImplBase &other) const = 0; virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0; virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0; virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0; virtual WeightImplBase &PowerEq(size_t n) = 0; virtual ~WeightImplBase() = default; };
template <class W> class WeightClassImpl : public WeightImplBase { public: explicit WeightClassImpl(const W &weight) : weight_(weight) {}
WeightClassImpl<W> *Copy() const final { return new WeightClassImpl<W>(weight_); }
const std::string &Type() const final { return W::Type(); }
void Print(std::ostream *ostrm) const final { *ostrm << weight_; }
std::string ToString() const final { return WeightToStr(weight_); }
bool Member() const final { return weight_.Member(); }
bool operator==(const WeightImplBase &other) const final { const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other); return weight_ == typed_other->weight_; }
bool operator!=(const WeightImplBase &other) const final { return !(*this == other); }
WeightClassImpl<W> &PlusEq(const WeightImplBase &other) final { const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other); weight_ = Plus(weight_, typed_other->weight_); return *this; }
WeightClassImpl<W> &TimesEq(const WeightImplBase &other) final { const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other); weight_ = Times(weight_, typed_other->weight_); return *this; }
WeightClassImpl<W> &DivideEq(const WeightImplBase &other) final { const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other); weight_ = Divide(weight_, typed_other->weight_); return *this; }
WeightClassImpl<W> &PowerEq(size_t n) final { weight_ = Power(weight_, n); return *this; }
W *GetImpl() { return &weight_; }
private: W weight_; };
class WeightClass { public: WeightClass() = default;
template <class W> explicit WeightClass(const W &weight) : impl_(std::make_unique<WeightClassImpl<W>>(weight)) {}
template <class W> explicit WeightClass(const WeightClassImpl<W> &impl) : impl_(std::make_unique<WeightClassImpl<W>>(impl)) {}
WeightClass(std::string_view weight_type, std::string_view weight_str);
WeightClass(const WeightClass &other) : impl_(other.impl_ ? other.impl_->Copy() : nullptr) {}
WeightClass &operator=(const WeightClass &other) { impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr); return *this; }
static constexpr std::string_view __ZERO__ = "__ZERO__"; // NOLINT
static constexpr std::string_view __ONE__ = "__ONE__"; // NOLINT
static constexpr std::string_view __NOWEIGHT__ = "__NOWEIGHT__"; // NOLINT
static WeightClass Zero(std::string_view weight_type);
static WeightClass One(std::string_view weight_type);
static WeightClass NoWeight(std::string_view weight_type);
template <class W> const W *GetWeight() const { if (W::Type() != impl_->Type()) { return nullptr; } else { auto *typed_impl = static_cast<WeightClassImpl<W> *>(impl_.get()); return typed_impl->GetImpl(); } }
std::string ToString() const { return (impl_) ? impl_->ToString() : "none"; }
const std::string &Type() const { if (impl_) return impl_->Type(); static const std::string *const no_type = new std::string("none"); return *no_type; }
bool Member() const { return impl_ && impl_->Member(); }
static bool WeightTypesMatch(const WeightClass &lhs, const WeightClass &rhs, std::string_view op_name);
friend bool operator==(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Power(const WeightClass &w, size_t n);
private: const WeightImplBase *GetImpl() const { return impl_.get(); }
WeightImplBase *GetImpl() { return impl_.get(); }
std::unique_ptr<WeightImplBase> impl_;
friend std::ostream &operator<<(std::ostream &o, const WeightClass &c); };
bool operator==(const WeightClass &lhs, const WeightClass &rhs);
bool operator!=(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Power(const WeightClass &w, size_t n);
std::ostream &operator<<(std::ostream &o, const WeightClass &c);
// Registration for generic weight types.
using StrToWeightImplBaseT = std::unique_ptr<WeightImplBase> (*)(std::string_view str);
template <class W> std::unique_ptr<WeightImplBase> StrToWeightImplBase(std::string_view str) { if (str == WeightClass::__ZERO__) { return std::make_unique<WeightClassImpl<W>>(W::Zero()); } else if (str == WeightClass::__ONE__) { return std::make_unique<WeightClassImpl<W>>(W::One()); } else if (str == WeightClass::__NOWEIGHT__) { return std::make_unique<WeightClassImpl<W>>(W::NoWeight()); } return std::make_unique<WeightClassImpl<W>>(StrToWeight<W>(str)); }
class WeightClassRegister : public GenericRegister<std::string, StrToWeightImplBaseT, WeightClassRegister> { protected: std::string ConvertKeyToSoFilename(std::string_view key) const final { std::string legal_type(key); ConvertToLegalCSymbol(&legal_type); legal_type.append(".so"); return legal_type; } };
using WeightClassRegisterer = GenericRegisterer<WeightClassRegister>;
// Internal version; needs to be called by wrapper in order for macro args to
// expand.
#define REGISTER_FST_WEIGHT__(Weight, line) \
static WeightClassRegisterer weight_registerer##_##line( \ Weight::Type(), StrToWeightImplBase<Weight>)
// This layer is where __FILE__ and __LINE__ are expanded.
#define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \
REGISTER_FST_WEIGHT__(Weight, line)
// Macro for registering new weight types; clients call this.
#define REGISTER_FST_WEIGHT(Weight) \
REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__)
} // namespace script
} // namespace fst
#endif // FST_SCRIPT_WEIGHT_CLASS_H_
|