// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Tuple weight set operation definitions.
|
|
|
|
#ifndef FST_TUPLE_WEIGHT_H_
|
|
#define FST_TUPLE_WEIGHT_H_
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <functional>
|
|
#include <istream>
|
|
#include <ostream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <fst/flags.h>
|
|
#include <fst/log.h>
|
|
#include <fst/weight.h>
|
|
|
|
namespace fst {
|
|
|
|
// n-tuple weight, element of the n-th Cartesian power of W.
|
|
template <class W, size_t n>
|
|
class TupleWeight {
|
|
public:
|
|
using ReverseWeight = TupleWeight<typename W::ReverseWeight, n>;
|
|
|
|
using Weight = W;
|
|
using Index = size_t;
|
|
|
|
template <class Iterator>
|
|
TupleWeight(Iterator begin, Iterator end) {
|
|
std::copy(begin, end, values_.begin());
|
|
}
|
|
|
|
explicit TupleWeight(const W &weight = W::Zero()) { values_.fill(weight); }
|
|
|
|
// Initialize component `index` to `weight`; initialize all other components
|
|
// to `default_weight`
|
|
TupleWeight(Index index, const W &weight, const W &default_weight)
|
|
: TupleWeight(default_weight) {
|
|
values_[index] = weight;
|
|
}
|
|
|
|
static const TupleWeight<W, n> &Zero() {
|
|
static const TupleWeight<W, n> zero(W::Zero());
|
|
return zero;
|
|
}
|
|
|
|
static const TupleWeight<W, n> &One() {
|
|
static const TupleWeight<W, n> one(W::One());
|
|
return one;
|
|
}
|
|
|
|
static const TupleWeight<W, n> &NoWeight() {
|
|
static const TupleWeight<W, n> no_weight(W::NoWeight());
|
|
return no_weight;
|
|
}
|
|
|
|
constexpr static size_t Length() { return n; }
|
|
|
|
std::istream &Read(std::istream &istrm) {
|
|
for (size_t i = 0; i < n; ++i) values_[i].Read(istrm);
|
|
return istrm;
|
|
}
|
|
|
|
std::ostream &Write(std::ostream &ostrm) const {
|
|
for (size_t i = 0; i < n; ++i) values_[i].Write(ostrm);
|
|
return ostrm;
|
|
}
|
|
|
|
bool Member() const {
|
|
return std::all_of(values_.begin(), values_.end(), std::mem_fn(&W::Member));
|
|
}
|
|
|
|
size_t Hash() const {
|
|
uint64_t hash = 0;
|
|
for (size_t i = 0; i < n; ++i) hash = 5 * hash + values_[i].Hash();
|
|
return size_t(hash);
|
|
}
|
|
|
|
TupleWeight<W, n> Quantize(float delta = kDelta) const {
|
|
TupleWeight<W, n> weight;
|
|
for (size_t i = 0; i < n; ++i) {
|
|
weight.values_[i] = values_[i].Quantize(delta);
|
|
}
|
|
return weight;
|
|
}
|
|
|
|
ReverseWeight Reverse() const {
|
|
TupleWeight<W, n> w;
|
|
for (size_t i = 0; i < n; ++i) w.values_[i] = values_[i].Reverse();
|
|
return w;
|
|
}
|
|
|
|
const W &Value(size_t i) const { return values_[i]; }
|
|
|
|
void SetValue(size_t i, const W &w) { values_[i] = w; }
|
|
|
|
private:
|
|
std::array<W, n> values_;
|
|
};
|
|
|
|
template <class W, size_t n>
|
|
inline bool operator==(const TupleWeight<W, n> &w1,
|
|
const TupleWeight<W, n> &w2) {
|
|
for (size_t i = 0; i < n; ++i) {
|
|
if (w1.Value(i) != w2.Value(i)) return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <class W, size_t n>
|
|
inline bool operator!=(const TupleWeight<W, n> &w1,
|
|
const TupleWeight<W, n> &w2) {
|
|
for (size_t i = 0; i < n; ++i) {
|
|
if (w1.Value(i) != w2.Value(i)) return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <class W, size_t n>
|
|
inline bool ApproxEqual(const TupleWeight<W, n> &w1,
|
|
const TupleWeight<W, n> &w2, float delta = kDelta) {
|
|
for (size_t i = 0; i < n; ++i) {
|
|
if (!ApproxEqual(w1.Value(i), w2.Value(i), delta)) return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <class W, size_t n>
|
|
inline std::ostream &operator<<(std::ostream &strm,
|
|
const TupleWeight<W, n> &w) {
|
|
CompositeWeightWriter writer(strm);
|
|
writer.WriteBegin();
|
|
for (size_t i = 0; i < n; ++i) writer.WriteElement(w.Value(i));
|
|
writer.WriteEnd();
|
|
return strm;
|
|
}
|
|
|
|
template <class W, size_t n>
|
|
inline std::istream &operator>>(std::istream &strm, TupleWeight<W, n> &w) {
|
|
CompositeWeightReader reader(strm);
|
|
reader.ReadBegin();
|
|
W v;
|
|
// Reads first n-1 elements.
|
|
static_assert(n > 0, "Size must be positive.");
|
|
for (size_t i = 0; i < n - 1; ++i) {
|
|
reader.ReadElement(&v);
|
|
w.SetValue(i, v);
|
|
}
|
|
// Reads n-th element.
|
|
reader.ReadElement(&v, true);
|
|
w.SetValue(n - 1, v);
|
|
reader.ReadEnd();
|
|
return strm;
|
|
}
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_TUPLE_WEIGHT_H_
|