You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

248 lines
7.8 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Represents a generic weight in an FST; that is, represents a specific type
// of weight underneath while hiding that type from a client.
#ifndef FST_SCRIPT_WEIGHT_CLASS_H_
#define FST_SCRIPT_WEIGHT_CLASS_H_
#include <cstddef>
#include <memory>
#include <ostream>
#include <string>
#include <fst/arc.h>
#include <fst/generic-register.h>
#include <fst/util.h>
#include <fst/weight.h>
#include <string_view>
namespace fst {
namespace script {
class WeightImplBase {
public:
virtual WeightImplBase *Copy() const = 0;
virtual void Print(std::ostream *o) const = 0;
virtual const std::string &Type() const = 0;
virtual std::string ToString() const = 0;
virtual bool Member() const = 0;
virtual bool operator==(const WeightImplBase &other) const = 0;
virtual bool operator!=(const WeightImplBase &other) const = 0;
virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0;
virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0;
virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0;
virtual WeightImplBase &PowerEq(size_t n) = 0;
virtual ~WeightImplBase() = default;
};
template <class W>
class WeightClassImpl : public WeightImplBase {
public:
explicit WeightClassImpl(const W &weight) : weight_(weight) {}
WeightClassImpl<W> *Copy() const final {
return new WeightClassImpl<W>(weight_);
}
const std::string &Type() const final { return W::Type(); }
void Print(std::ostream *ostrm) const final { *ostrm << weight_; }
std::string ToString() const final {
return WeightToStr(weight_);
}
bool Member() const final { return weight_.Member(); }
bool operator==(const WeightImplBase &other) const final {
const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
return weight_ == typed_other->weight_;
}
bool operator!=(const WeightImplBase &other) const final {
return !(*this == other);
}
WeightClassImpl<W> &PlusEq(const WeightImplBase &other) final {
const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
weight_ = Plus(weight_, typed_other->weight_);
return *this;
}
WeightClassImpl<W> &TimesEq(const WeightImplBase &other) final {
const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
weight_ = Times(weight_, typed_other->weight_);
return *this;
}
WeightClassImpl<W> &DivideEq(const WeightImplBase &other) final {
const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
weight_ = Divide(weight_, typed_other->weight_);
return *this;
}
WeightClassImpl<W> &PowerEq(size_t n) final {
weight_ = Power(weight_, n);
return *this;
}
W *GetImpl() { return &weight_; }
private:
W weight_;
};
class WeightClass {
public:
WeightClass() = default;
template <class W>
explicit WeightClass(const W &weight)
: impl_(std::make_unique<WeightClassImpl<W>>(weight)) {}
template <class W>
explicit WeightClass(const WeightClassImpl<W> &impl)
: impl_(std::make_unique<WeightClassImpl<W>>(impl)) {}
WeightClass(std::string_view weight_type, std::string_view weight_str);
WeightClass(const WeightClass &other)
: impl_(other.impl_ ? other.impl_->Copy() : nullptr) {}
WeightClass &operator=(const WeightClass &other) {
impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr);
return *this;
}
static constexpr std::string_view __ZERO__ = "__ZERO__"; // NOLINT
static constexpr std::string_view __ONE__ = "__ONE__"; // NOLINT
static constexpr std::string_view __NOWEIGHT__ = "__NOWEIGHT__"; // NOLINT
static WeightClass Zero(std::string_view weight_type);
static WeightClass One(std::string_view weight_type);
static WeightClass NoWeight(std::string_view weight_type);
template <class W>
const W *GetWeight() const {
if (W::Type() != impl_->Type()) {
return nullptr;
} else {
auto *typed_impl = static_cast<WeightClassImpl<W> *>(impl_.get());
return typed_impl->GetImpl();
}
}
std::string ToString() const { return (impl_) ? impl_->ToString() : "none"; }
const std::string &Type() const {
if (impl_) return impl_->Type();
static const std::string *const no_type = new std::string("none");
return *no_type;
}
bool Member() const { return impl_ && impl_->Member(); }
static bool WeightTypesMatch(const WeightClass &lhs, const WeightClass &rhs,
std::string_view op_name);
friend bool operator==(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
friend WeightClass Power(const WeightClass &w, size_t n);
private:
const WeightImplBase *GetImpl() const { return impl_.get(); }
WeightImplBase *GetImpl() { return impl_.get(); }
std::unique_ptr<WeightImplBase> impl_;
friend std::ostream &operator<<(std::ostream &o, const WeightClass &c);
};
bool operator==(const WeightClass &lhs, const WeightClass &rhs);
bool operator!=(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
WeightClass Power(const WeightClass &w, size_t n);
std::ostream &operator<<(std::ostream &o, const WeightClass &c);
// Registration for generic weight types.
using StrToWeightImplBaseT =
std::unique_ptr<WeightImplBase> (*)(std::string_view str);
template <class W>
std::unique_ptr<WeightImplBase> StrToWeightImplBase(std::string_view str) {
if (str == WeightClass::__ZERO__) {
return std::make_unique<WeightClassImpl<W>>(W::Zero());
} else if (str == WeightClass::__ONE__) {
return std::make_unique<WeightClassImpl<W>>(W::One());
} else if (str == WeightClass::__NOWEIGHT__) {
return std::make_unique<WeightClassImpl<W>>(W::NoWeight());
}
return std::make_unique<WeightClassImpl<W>>(StrToWeight<W>(str));
}
class WeightClassRegister
: public GenericRegister<std::string, StrToWeightImplBaseT,
WeightClassRegister> {
protected:
std::string ConvertKeyToSoFilename(std::string_view key) const final {
std::string legal_type(key);
ConvertToLegalCSymbol(&legal_type);
legal_type.append(".so");
return legal_type;
}
};
using WeightClassRegisterer = GenericRegisterer<WeightClassRegister>;
// Internal version; needs to be called by wrapper in order for macro args to
// expand.
#define REGISTER_FST_WEIGHT__(Weight, line) \
static WeightClassRegisterer weight_registerer##_##line( \
Weight::Type(), StrToWeightImplBase<Weight>)
// This layer is where __FILE__ and __LINE__ are expanded.
#define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \
REGISTER_FST_WEIGHT__(Weight, line)
// Macro for registering new weight types; clients call this.
#define REGISTER_FST_WEIGHT(Weight) \
REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__)
} // namespace script
} // namespace fst
#endif // FST_SCRIPT_WEIGHT_CLASS_H_