You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

434 lines
13 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Sparse version of tuple-weight, based on tuple-weight.h.
// Internally stores sparse key, value pairs in linked list. The default value
// element is the assumed value of unset keys. Internal singleton
// implementation that stores first key, value pair as a initialized member
// variable to avoid unnecessary allocation on heap. Use
// SparseTupleWeightIterator to iterate through the key,value pairs. Note:
// this does NOT iterate through the default value.
//
// Sparse tuple weight set operation definitions.
#ifndef FST_SPARSE_TUPLE_WEIGHT_H_
#define FST_SPARSE_TUPLE_WEIGHT_H_
#include <algorithm>
#include <cstddef>
#include <functional>
#include <istream>
#include <list>
#include <ostream>
#include <stack>
#include <string>
#include <utility>
#include <fst/util.h>
#include <fst/weight.h>
namespace fst {
template <class W, class K>
class SparseTupleWeightIterator;
// Arbitrary dimension tuple weight, stored as a sorted linked-list.
// W is any weight class, and K is the key value type. kNoKey (-1) is reserved
// for internal use.
template <class W, class K = int>
class SparseTupleWeight {
public:
using ReverseWeight = SparseTupleWeight<typename W::ReverseWeight, K>;
using Iterator = SparseTupleWeightIterator<W, K>;
using Pair = std::pair<K, W>;
using Weight = W;
using Index = K;
static constexpr K kNoKey = -1;
SparseTupleWeight() { Init(); }
~SparseTupleWeight() noexcept = default;
template <class Iterator>
SparseTupleWeight(Iterator begin, Iterator end) {
Init();
// Assumes input iterator is sorted.
for (auto it = begin; it != end; ++it) PushBack(*it);
}
// Initialize component `key` to `weight`, with `default_weight` for all
// other components.
SparseTupleWeight(const K &key, const W &weight, const W &default_weight)
: default_(default_weight),
first_(weight == default_weight ? kNoKey : key, weight) {}
explicit SparseTupleWeight(const W &weight) { Init(weight); }
SparseTupleWeight(const SparseTupleWeight &weight) {
Init(weight.DefaultValue());
SetDefaultValue(weight.DefaultValue());
for (Iterator it(weight); !it.Done(); it.Next()) {
PushBack(it.Value());
}
}
SparseTupleWeight(SparseTupleWeight &&weight) noexcept
// Don't move the default, so weight.default_ is still valid.
: default_(weight.default_), // NOLINT
first_(std::move(weight.first_)),
rest_(std::move(weight.rest_)) {
// move leaves the source in a valid but unspecified state.
// Make sure the source weight is empty.
weight.first_ = Pair(kNoKey, W::NoWeight());
weight.rest_.clear();
}
static const SparseTupleWeight &Zero() {
static const SparseTupleWeight zero(W::Zero());
return zero;
}
static const SparseTupleWeight &One() {
static const SparseTupleWeight one(W::One());
return one;
}
static const SparseTupleWeight &NoWeight() {
static const SparseTupleWeight no_weight(W::NoWeight());
return no_weight;
}
std::istream &Read(std::istream &strm) {
ReadType(strm, &default_);
ReadType(strm, &first_);
return ReadType(strm, &rest_);
}
std::ostream &Write(std::ostream &strm) const {
WriteType(strm, default_);
WriteType(strm, first_);
return WriteType(strm, rest_);
}
SparseTupleWeight &operator=(const SparseTupleWeight &weight) {
if (this == &weight) return *this; // Checks for identity.
Init(weight.DefaultValue());
for (Iterator it(weight); !it.Done(); it.Next()) {
PushBack(it.Value());
}
return *this;
}
SparseTupleWeight &operator=(SparseTupleWeight &&weight) noexcept {
if (this == &weight) return *this; // Checks for identity.
// Don't move the default, so weight.default_ is still valid.
default_ = weight.default_;
first_ = std::move(weight.first_);
rest_ = std::move(weight.rest_);
// move leaves the source in a valid but unspecified state.
// Make sure the source weight is empty.
weight.first_ = Pair(kNoKey, W::NoWeight());
weight.rest_.clear();
return *this;
}
bool Member() const {
if (!DefaultValue().Member()) return false;
for (Iterator it(*this); !it.Done(); it.Next()) {
if (!it.Value().second.Member()) return false;
}
return true;
}
// Assumes H() function exists for the hash of the key value.
size_t Hash() const {
size_t h = 0;
static const std::hash<K> H;
for (Iterator it(*this); !it.Done(); it.Next()) {
h = 5 * h + H(it.Value().first);
h = 13 * h + it.Value().second.Hash();
}
return h;
}
SparseTupleWeight Quantize(float delta = kDelta) const {
SparseTupleWeight weight;
for (Iterator it(*this); !it.Done(); it.Next()) {
weight.PushBack(it.Value().first, it.Value().second.Quantize(delta));
}
return weight;
}
ReverseWeight Reverse() const {
ReverseWeight weight(DefaultValue().Reverse());
for (Iterator it(*this); !it.Done(); it.Next()) {
weight.PushBack(it.Value().first, it.Value().second.Reverse());
}
return weight;
}
void Init(const W &default_value = W::Zero()) {
first_ = Pair(kNoKey, W::NoWeight());
// Initialized to the reserved key value.
default_ = default_value;
rest_.clear();
}
size_t Size() const {
if (first_.first == kNoKey) {
return 0;
} else {
return rest_.size() + 1;
}
}
inline void PushBack(const K &key, const W &weight,
bool default_value_check = true) {
PushBack(std::make_pair(key, weight), default_value_check);
}
inline void PushBack(const Pair &pair, bool default_value_check = true) {
if (default_value_check && pair.second == default_) return;
if (first_.first == kNoKey) {
first_ = pair;
} else {
rest_.push_back(pair);
}
}
// Returns the `key`-th component, or the default value if not set.
const W &Value(const K &key) const {
// TODO(rybach): Consider binary search.
Iterator iter(*this);
for (; !iter.Done() && iter.Value().first < key; iter.Next()) continue;
return !iter.Done() && iter.Value().first == key ? iter.Value().second
: DefaultValue();
}
void SetValue(const K &key, const W &w) {
if (w == DefaultValue()) {
ClearValue(key);
} else {
SetValueToNonDefault(key, w);
}
}
void SetDefaultValue(const W &value) { default_ = value; }
const W &DefaultValue() const { return default_; }
private:
void SetValueToNonDefault(const K &key, const W &w) {
// Don't use SparseTupleWeightIterator, since that's const.
if (first_.first == kNoKey) {
first_ = Pair(key, w);
} else if (key < first_.first) {
rest_.push_front(first_);
first_ = Pair(key, w);
} else if (key == first_.first) {
first_.second = w;
} else {
const auto i =
std::find_if(rest_.begin(), rest_.end(),
[key](const Pair &p) { return p.first >= key; });
if (i != rest_.end() && i->first == key) {
i->second = w;
} else {
rest_.insert(i, Pair(key, w));
}
}
}
// Removes the weight value for `key`, having the effect of setting
// it to `DefaultValue()`.
void ClearValue(const K &key) {
if (key == first_.first) {
if (!rest_.empty()) {
first_ = rest_.front();
rest_.pop_front();
} else {
first_.first = kNoKey;
}
} else if (key > first_.first) {
const auto i =
std::find_if(rest_.begin(), rest_.end(),
[key](const Pair &p) { return p.first >= key; });
if (i != rest_.end() && i->first == key) {
rest_.erase(i);
}
}
}
// Assumed default value of uninitialized keys, by default W::Zero().
W default_;
// Key values pairs are first stored in first_, then fill rest_ this way we
// can avoid dynamic allocation in the common case where the weight is a
// single key/value pair.
Pair first_;
std::list<Pair> rest_;
friend class SparseTupleWeightIterator<W, K>;
};
template <class W, class K>
class SparseTupleWeightIterator {
public:
using Pair = typename SparseTupleWeight<W, K>::Pair;
using const_iterator = typename std::list<Pair>::const_iterator;
using iterator = typename std::list<Pair>::iterator;
explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K> &weight)
: first_(weight.first_),
rest_(weight.rest_),
init_(true),
iter_(rest_.begin()) {}
bool Done() const {
if (init_) {
return first_.first == SparseTupleWeight<W, K>::kNoKey;
} else {
return iter_ == rest_.end();
}
}
const Pair &Value() const { return init_ ? first_ : *iter_; }
void Next() {
if (init_) {
init_ = false;
} else {
++iter_;
}
}
void Reset() {
init_ = true;
iter_ = rest_.begin();
}
private:
const Pair &first_;
const std::list<Pair> &rest_;
bool init_; // In the initialized state?
const_iterator iter_;
};
// M must be callable as a function W(K, W, W).
// K will be kNoKey when mapping the default value.
template <class W, class K, class M>
inline void SparseTupleWeightMap(SparseTupleWeight<W, K> *result,
const SparseTupleWeight<W, K> &w1,
const SparseTupleWeight<W, K> &w2,
const M &operator_mapper) {
SparseTupleWeightIterator<W, K> w1_it(w1);
SparseTupleWeightIterator<W, K> w2_it(w2);
const auto &v1_def = w1.DefaultValue();
const auto &v2_def = w2.DefaultValue();
result->SetDefaultValue(
operator_mapper(SparseTupleWeight<W, K>::kNoKey, v1_def, v2_def));
while (!w1_it.Done() || !w2_it.Done()) {
const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
if (k1 == k2) {
result->PushBack(k1, operator_mapper(k1, v1, v2));
if (!w1_it.Done()) w1_it.Next();
if (!w2_it.Done()) w2_it.Next();
} else if (k1 < k2) {
result->PushBack(k1, operator_mapper(k1, v1, v2_def));
w1_it.Next();
} else {
result->PushBack(k2, operator_mapper(k2, v1_def, v2));
w2_it.Next();
}
}
}
template <class W, class K>
inline bool operator==(const SparseTupleWeight<W, K> &w1,
const SparseTupleWeight<W, K> &w2) {
const auto &v1_def = w1.DefaultValue();
const auto &v2_def = w2.DefaultValue();
if (v1_def != v2_def) return false;
SparseTupleWeightIterator<W, K> w1_it(w1);
SparseTupleWeightIterator<W, K> w2_it(w2);
while (!w1_it.Done() || !w2_it.Done()) {
const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
if (k1 == k2) {
if (v1 != v2) return false;
if (!w1_it.Done()) w1_it.Next();
if (!w2_it.Done()) w2_it.Next();
} else if (k1 < k2) {
if (v1 != v2_def) return false;
w1_it.Next();
} else {
if (v1_def != v2) return false;
w2_it.Next();
}
}
return true;
}
template <class W, class K>
inline bool operator!=(const SparseTupleWeight<W, K> &w1,
const SparseTupleWeight<W, K> &w2) {
return !(w1 == w2);
}
template <class W, class K>
inline std::ostream &operator<<(std::ostream &strm,
const SparseTupleWeight<W, K> &weight) {
CompositeWeightWriter writer(strm);
writer.WriteBegin();
writer.WriteElement(weight.DefaultValue());
for (SparseTupleWeightIterator<W, K> it(weight); !it.Done(); it.Next()) {
writer.WriteElement(it.Value().first);
writer.WriteElement(it.Value().second);
}
writer.WriteEnd();
return strm;
}
template <class W, class K>
inline std::istream &operator>>(std::istream &strm,
SparseTupleWeight<W, K> &weight) {
CompositeWeightReader reader(strm);
reader.ReadBegin();
W def;
bool more = reader.ReadElement(&def);
weight.Init(def);
while (more) {
K key;
reader.ReadElement(&key);
W v;
more = reader.ReadElement(&v);
weight.PushBack(key, v);
}
reader.ReadEnd();
return strm;
}
} // namespace fst
#endif // FST_SPARSE_TUPLE_WEIGHT_H_