// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Product weight set and associated semiring operation definitions. #ifndef FST_PRODUCT_WEIGHT_H_ #define FST_PRODUCT_WEIGHT_H_ #include #include #include #include #include #include namespace fst { // Product semiring: W1 * W2. template class ProductWeight : public PairWeight { public: using ReverseWeight = ProductWeight; ProductWeight() = default; explicit ProductWeight(const PairWeight &weight) : PairWeight(weight) {} ProductWeight(W1 w1, W2 w2) : PairWeight(std::move(w1), std::move(w2)) {} static const ProductWeight &Zero() { static const ProductWeight zero(PairWeight::Zero()); return zero; } static const ProductWeight &One() { static const ProductWeight one(PairWeight::One()); return one; } static const ProductWeight &NoWeight() { static const ProductWeight no_weight(PairWeight::NoWeight()); return no_weight; } static const std::string &Type() { static const std::string *const type = new std::string(W1::Type() + "_X_" + W2::Type()); return *type; } static constexpr uint64_t Properties() { return W1::Properties() & W2::Properties() & (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); } ProductWeight Quantize(float delta = kDelta) const { return ProductWeight(PairWeight::Quantize(delta)); } ReverseWeight Reverse() const { return ReverseWeight(PairWeight::Reverse()); } }; template inline ProductWeight Plus(const ProductWeight &w1, const ProductWeight &w2) { return ProductWeight(Plus(w1.Value1(), w2.Value1()), Plus(w1.Value2(), w2.Value2())); } template inline ProductWeight Times(const ProductWeight &w1, const ProductWeight &w2) { return ProductWeight(Times(w1.Value1(), w2.Value1()), Times(w1.Value2(), w2.Value2())); } template inline ProductWeight Divide(const ProductWeight &w1, const ProductWeight &w2, DivideType typ = DIVIDE_ANY) { return ProductWeight(Divide(w1.Value1(), w2.Value1(), typ), Divide(w1.Value2(), w2.Value2(), typ)); } // Specialization for product weight template class Adder> { public: using Weight = ProductWeight; Adder() = default; explicit Adder(Weight w) : adder1_(w.Value1()), adder2_(w.Value2()) {} Weight Add(const Weight &w) { adder1_.Add(w.Value1()); adder2_.Add(w.Value2()); return Sum(); } Weight Sum() const { return Weight(adder1_.Sum(), adder2_.Sum()); } void Reset(Weight w = Weight::Zero()) { adder1_.Reset(w.Value1()); adder2_.Reset(w.Value2()); } private: Adder adder1_; Adder adder2_; }; // This function object generates weights by calling the underlying generators // for the template weight types, like all other pair weight types. This is // intended primarily for testing. template class WeightGenerate> { public: using Weight = ProductWeight; using Generate = WeightGenerate>; explicit WeightGenerate(uint64_t seed = std::random_device()(), bool allow_zero = true) : generate_(seed, allow_zero) {} Weight operator()() const { return Weight(generate_()); } private: const Generate generate_; }; } // namespace fst #endif // FST_PRODUCT_WEIGHT_H_