You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

248 lines
7.8 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Represents a generic weight in an FST; that is, represents a specific type
  19. // of weight underneath while hiding that type from a client.
  20. #ifndef FST_SCRIPT_WEIGHT_CLASS_H_
  21. #define FST_SCRIPT_WEIGHT_CLASS_H_
  22. #include <cstddef>
  23. #include <memory>
  24. #include <ostream>
  25. #include <string>
  26. #include <fst/arc.h>
  27. #include <fst/generic-register.h>
  28. #include <fst/util.h>
  29. #include <fst/weight.h>
  30. #include <string_view>
  31. namespace fst {
  32. namespace script {
  33. class WeightImplBase {
  34. public:
  35. virtual WeightImplBase *Copy() const = 0;
  36. virtual void Print(std::ostream *o) const = 0;
  37. virtual const std::string &Type() const = 0;
  38. virtual std::string ToString() const = 0;
  39. virtual bool Member() const = 0;
  40. virtual bool operator==(const WeightImplBase &other) const = 0;
  41. virtual bool operator!=(const WeightImplBase &other) const = 0;
  42. virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0;
  43. virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0;
  44. virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0;
  45. virtual WeightImplBase &PowerEq(size_t n) = 0;
  46. virtual ~WeightImplBase() = default;
  47. };
  48. template <class W>
  49. class WeightClassImpl : public WeightImplBase {
  50. public:
  51. explicit WeightClassImpl(const W &weight) : weight_(weight) {}
  52. WeightClassImpl<W> *Copy() const final {
  53. return new WeightClassImpl<W>(weight_);
  54. }
  55. const std::string &Type() const final { return W::Type(); }
  56. void Print(std::ostream *ostrm) const final { *ostrm << weight_; }
  57. std::string ToString() const final {
  58. return WeightToStr(weight_);
  59. }
  60. bool Member() const final { return weight_.Member(); }
  61. bool operator==(const WeightImplBase &other) const final {
  62. const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
  63. return weight_ == typed_other->weight_;
  64. }
  65. bool operator!=(const WeightImplBase &other) const final {
  66. return !(*this == other);
  67. }
  68. WeightClassImpl<W> &PlusEq(const WeightImplBase &other) final {
  69. const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
  70. weight_ = Plus(weight_, typed_other->weight_);
  71. return *this;
  72. }
  73. WeightClassImpl<W> &TimesEq(const WeightImplBase &other) final {
  74. const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
  75. weight_ = Times(weight_, typed_other->weight_);
  76. return *this;
  77. }
  78. WeightClassImpl<W> &DivideEq(const WeightImplBase &other) final {
  79. const auto *typed_other = down_cast<const WeightClassImpl<W> *>(&other);
  80. weight_ = Divide(weight_, typed_other->weight_);
  81. return *this;
  82. }
  83. WeightClassImpl<W> &PowerEq(size_t n) final {
  84. weight_ = Power(weight_, n);
  85. return *this;
  86. }
  87. W *GetImpl() { return &weight_; }
  88. private:
  89. W weight_;
  90. };
  91. class WeightClass {
  92. public:
  93. WeightClass() = default;
  94. template <class W>
  95. explicit WeightClass(const W &weight)
  96. : impl_(std::make_unique<WeightClassImpl<W>>(weight)) {}
  97. template <class W>
  98. explicit WeightClass(const WeightClassImpl<W> &impl)
  99. : impl_(std::make_unique<WeightClassImpl<W>>(impl)) {}
  100. WeightClass(std::string_view weight_type, std::string_view weight_str);
  101. WeightClass(const WeightClass &other)
  102. : impl_(other.impl_ ? other.impl_->Copy() : nullptr) {}
  103. WeightClass &operator=(const WeightClass &other) {
  104. impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr);
  105. return *this;
  106. }
  107. static constexpr std::string_view __ZERO__ = "__ZERO__"; // NOLINT
  108. static constexpr std::string_view __ONE__ = "__ONE__"; // NOLINT
  109. static constexpr std::string_view __NOWEIGHT__ = "__NOWEIGHT__"; // NOLINT
  110. static WeightClass Zero(std::string_view weight_type);
  111. static WeightClass One(std::string_view weight_type);
  112. static WeightClass NoWeight(std::string_view weight_type);
  113. template <class W>
  114. const W *GetWeight() const {
  115. if (W::Type() != impl_->Type()) {
  116. return nullptr;
  117. } else {
  118. auto *typed_impl = static_cast<WeightClassImpl<W> *>(impl_.get());
  119. return typed_impl->GetImpl();
  120. }
  121. }
  122. std::string ToString() const { return (impl_) ? impl_->ToString() : "none"; }
  123. const std::string &Type() const {
  124. if (impl_) return impl_->Type();
  125. static const std::string *const no_type = new std::string("none");
  126. return *no_type;
  127. }
  128. bool Member() const { return impl_ && impl_->Member(); }
  129. static bool WeightTypesMatch(const WeightClass &lhs, const WeightClass &rhs,
  130. std::string_view op_name);
  131. friend bool operator==(const WeightClass &lhs, const WeightClass &rhs);
  132. friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
  133. friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
  134. friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
  135. friend WeightClass Power(const WeightClass &w, size_t n);
  136. private:
  137. const WeightImplBase *GetImpl() const { return impl_.get(); }
  138. WeightImplBase *GetImpl() { return impl_.get(); }
  139. std::unique_ptr<WeightImplBase> impl_;
  140. friend std::ostream &operator<<(std::ostream &o, const WeightClass &c);
  141. };
  142. bool operator==(const WeightClass &lhs, const WeightClass &rhs);
  143. bool operator!=(const WeightClass &lhs, const WeightClass &rhs);
  144. WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs);
  145. WeightClass Times(const WeightClass &lhs, const WeightClass &rhs);
  146. WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs);
  147. WeightClass Power(const WeightClass &w, size_t n);
  148. std::ostream &operator<<(std::ostream &o, const WeightClass &c);
  149. // Registration for generic weight types.
  150. using StrToWeightImplBaseT =
  151. std::unique_ptr<WeightImplBase> (*)(std::string_view str);
  152. template <class W>
  153. std::unique_ptr<WeightImplBase> StrToWeightImplBase(std::string_view str) {
  154. if (str == WeightClass::__ZERO__) {
  155. return std::make_unique<WeightClassImpl<W>>(W::Zero());
  156. } else if (str == WeightClass::__ONE__) {
  157. return std::make_unique<WeightClassImpl<W>>(W::One());
  158. } else if (str == WeightClass::__NOWEIGHT__) {
  159. return std::make_unique<WeightClassImpl<W>>(W::NoWeight());
  160. }
  161. return std::make_unique<WeightClassImpl<W>>(StrToWeight<W>(str));
  162. }
  163. class WeightClassRegister
  164. : public GenericRegister<std::string, StrToWeightImplBaseT,
  165. WeightClassRegister> {
  166. protected:
  167. std::string ConvertKeyToSoFilename(std::string_view key) const final {
  168. std::string legal_type(key);
  169. ConvertToLegalCSymbol(&legal_type);
  170. legal_type.append(".so");
  171. return legal_type;
  172. }
  173. };
  174. using WeightClassRegisterer = GenericRegisterer<WeightClassRegister>;
  175. // Internal version; needs to be called by wrapper in order for macro args to
  176. // expand.
  177. #define REGISTER_FST_WEIGHT__(Weight, line) \
  178. static WeightClassRegisterer weight_registerer##_##line( \
  179. Weight::Type(), StrToWeightImplBase<Weight>)
  180. // This layer is where __FILE__ and __LINE__ are expanded.
  181. #define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \
  182. REGISTER_FST_WEIGHT__(Weight, line)
  183. // Macro for registering new weight types; clients call this.
  184. #define REGISTER_FST_WEIGHT(Weight) \
  185. REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__)
  186. } // namespace script
  187. } // namespace fst
  188. #endif // FST_SCRIPT_WEIGHT_CLASS_H_