|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Sparse version of tuple-weight, based on tuple-weight.h.
// Internally stores sparse key, value pairs in linked list. The default value
// element is the assumed value of unset keys. Internal singleton
// implementation that stores first key, value pair as a initialized member
// variable to avoid unnecessary allocation on heap. Use
// SparseTupleWeightIterator to iterate through the key,value pairs. Note:
// this does NOT iterate through the default value.
//
// Sparse tuple weight set operation definitions.
#ifndef FST_SPARSE_TUPLE_WEIGHT_H_
#define FST_SPARSE_TUPLE_WEIGHT_H_
#include <algorithm>
#include <cstddef>
#include <functional>
#include <istream>
#include <list>
#include <ostream>
#include <stack>
#include <string>
#include <utility>
#include <fst/util.h>
#include <fst/weight.h>
namespace fst {
template <class W, class K> class SparseTupleWeightIterator;
// Arbitrary dimension tuple weight, stored as a sorted linked-list.
// W is any weight class, and K is the key value type. kNoKey (-1) is reserved
// for internal use.
template <class W, class K = int> class SparseTupleWeight { public: using ReverseWeight = SparseTupleWeight<typename W::ReverseWeight, K>;
using Iterator = SparseTupleWeightIterator<W, K>; using Pair = std::pair<K, W>; using Weight = W; using Index = K;
static constexpr K kNoKey = -1;
SparseTupleWeight() { Init(); }
~SparseTupleWeight() noexcept = default;
template <class Iterator> SparseTupleWeight(Iterator begin, Iterator end) { Init(); // Assumes input iterator is sorted.
for (auto it = begin; it != end; ++it) PushBack(*it); }
// Initialize component `key` to `weight`, with `default_weight` for all
// other components.
SparseTupleWeight(const K &key, const W &weight, const W &default_weight) : default_(default_weight), first_(weight == default_weight ? kNoKey : key, weight) {}
explicit SparseTupleWeight(const W &weight) { Init(weight); }
SparseTupleWeight(const SparseTupleWeight &weight) { Init(weight.DefaultValue()); SetDefaultValue(weight.DefaultValue()); for (Iterator it(weight); !it.Done(); it.Next()) { PushBack(it.Value()); } }
SparseTupleWeight(SparseTupleWeight &&weight) noexcept // Don't move the default, so weight.default_ is still valid.
: default_(weight.default_), // NOLINT
first_(std::move(weight.first_)), rest_(std::move(weight.rest_)) { // move leaves the source in a valid but unspecified state.
// Make sure the source weight is empty.
weight.first_ = Pair(kNoKey, W::NoWeight()); weight.rest_.clear(); }
static const SparseTupleWeight &Zero() { static const SparseTupleWeight zero(W::Zero()); return zero; }
static const SparseTupleWeight &One() { static const SparseTupleWeight one(W::One()); return one; }
static const SparseTupleWeight &NoWeight() { static const SparseTupleWeight no_weight(W::NoWeight()); return no_weight; }
std::istream &Read(std::istream &strm) { ReadType(strm, &default_); ReadType(strm, &first_); return ReadType(strm, &rest_); }
std::ostream &Write(std::ostream &strm) const { WriteType(strm, default_); WriteType(strm, first_); return WriteType(strm, rest_); }
SparseTupleWeight &operator=(const SparseTupleWeight &weight) { if (this == &weight) return *this; // Checks for identity.
Init(weight.DefaultValue()); for (Iterator it(weight); !it.Done(); it.Next()) { PushBack(it.Value()); } return *this; }
SparseTupleWeight &operator=(SparseTupleWeight &&weight) noexcept { if (this == &weight) return *this; // Checks for identity.
// Don't move the default, so weight.default_ is still valid.
default_ = weight.default_; first_ = std::move(weight.first_); rest_ = std::move(weight.rest_); // move leaves the source in a valid but unspecified state.
// Make sure the source weight is empty.
weight.first_ = Pair(kNoKey, W::NoWeight()); weight.rest_.clear(); return *this; }
bool Member() const { if (!DefaultValue().Member()) return false; for (Iterator it(*this); !it.Done(); it.Next()) { if (!it.Value().second.Member()) return false; } return true; }
// Assumes H() function exists for the hash of the key value.
size_t Hash() const { size_t h = 0; static const std::hash<K> H; for (Iterator it(*this); !it.Done(); it.Next()) { h = 5 * h + H(it.Value().first); h = 13 * h + it.Value().second.Hash(); } return h; }
SparseTupleWeight Quantize(float delta = kDelta) const { SparseTupleWeight weight; for (Iterator it(*this); !it.Done(); it.Next()) { weight.PushBack(it.Value().first, it.Value().second.Quantize(delta)); } return weight; }
ReverseWeight Reverse() const { ReverseWeight weight(DefaultValue().Reverse()); for (Iterator it(*this); !it.Done(); it.Next()) { weight.PushBack(it.Value().first, it.Value().second.Reverse()); } return weight; }
void Init(const W &default_value = W::Zero()) { first_ = Pair(kNoKey, W::NoWeight()); // Initialized to the reserved key value.
default_ = default_value; rest_.clear(); }
size_t Size() const { if (first_.first == kNoKey) { return 0; } else { return rest_.size() + 1; } }
inline void PushBack(const K &key, const W &weight, bool default_value_check = true) { PushBack(std::make_pair(key, weight), default_value_check); }
inline void PushBack(const Pair &pair, bool default_value_check = true) { if (default_value_check && pair.second == default_) return; if (first_.first == kNoKey) { first_ = pair; } else { rest_.push_back(pair); } }
// Returns the `key`-th component, or the default value if not set.
const W &Value(const K &key) const { // TODO(rybach): Consider binary search.
Iterator iter(*this); for (; !iter.Done() && iter.Value().first < key; iter.Next()) continue; return !iter.Done() && iter.Value().first == key ? iter.Value().second : DefaultValue(); }
void SetValue(const K &key, const W &w) { if (w == DefaultValue()) { ClearValue(key); } else { SetValueToNonDefault(key, w); } }
void SetDefaultValue(const W &value) { default_ = value; }
const W &DefaultValue() const { return default_; }
private: void SetValueToNonDefault(const K &key, const W &w) { // Don't use SparseTupleWeightIterator, since that's const.
if (first_.first == kNoKey) { first_ = Pair(key, w); } else if (key < first_.first) { rest_.push_front(first_); first_ = Pair(key, w); } else if (key == first_.first) { first_.second = w; } else { const auto i = std::find_if(rest_.begin(), rest_.end(), [key](const Pair &p) { return p.first >= key; }); if (i != rest_.end() && i->first == key) { i->second = w; } else { rest_.insert(i, Pair(key, w)); } } }
// Removes the weight value for `key`, having the effect of setting
// it to `DefaultValue()`.
void ClearValue(const K &key) { if (key == first_.first) { if (!rest_.empty()) { first_ = rest_.front(); rest_.pop_front(); } else { first_.first = kNoKey; } } else if (key > first_.first) { const auto i = std::find_if(rest_.begin(), rest_.end(), [key](const Pair &p) { return p.first >= key; }); if (i != rest_.end() && i->first == key) { rest_.erase(i); } } }
// Assumed default value of uninitialized keys, by default W::Zero().
W default_;
// Key values pairs are first stored in first_, then fill rest_ this way we
// can avoid dynamic allocation in the common case where the weight is a
// single key/value pair.
Pair first_; std::list<Pair> rest_;
friend class SparseTupleWeightIterator<W, K>; };
template <class W, class K> class SparseTupleWeightIterator { public: using Pair = typename SparseTupleWeight<W, K>::Pair; using const_iterator = typename std::list<Pair>::const_iterator; using iterator = typename std::list<Pair>::iterator;
explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K> &weight) : first_(weight.first_), rest_(weight.rest_), init_(true), iter_(rest_.begin()) {}
bool Done() const { if (init_) { return first_.first == SparseTupleWeight<W, K>::kNoKey; } else { return iter_ == rest_.end(); } }
const Pair &Value() const { return init_ ? first_ : *iter_; }
void Next() { if (init_) { init_ = false; } else { ++iter_; } }
void Reset() { init_ = true; iter_ = rest_.begin(); }
private: const Pair &first_; const std::list<Pair> &rest_; bool init_; // In the initialized state?
const_iterator iter_; };
// M must be callable as a function W(K, W, W).
// K will be kNoKey when mapping the default value.
template <class W, class K, class M> inline void SparseTupleWeightMap(SparseTupleWeight<W, K> *result, const SparseTupleWeight<W, K> &w1, const SparseTupleWeight<W, K> &w2, const M &operator_mapper) { SparseTupleWeightIterator<W, K> w1_it(w1); SparseTupleWeightIterator<W, K> w2_it(w2); const auto &v1_def = w1.DefaultValue(); const auto &v2_def = w2.DefaultValue(); result->SetDefaultValue( operator_mapper(SparseTupleWeight<W, K>::kNoKey, v1_def, v2_def)); while (!w1_it.Done() || !w2_it.Done()) { const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first; const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first; const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second; const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second; if (k1 == k2) { result->PushBack(k1, operator_mapper(k1, v1, v2)); if (!w1_it.Done()) w1_it.Next(); if (!w2_it.Done()) w2_it.Next(); } else if (k1 < k2) { result->PushBack(k1, operator_mapper(k1, v1, v2_def)); w1_it.Next(); } else { result->PushBack(k2, operator_mapper(k2, v1_def, v2)); w2_it.Next(); } } }
template <class W, class K> inline bool operator==(const SparseTupleWeight<W, K> &w1, const SparseTupleWeight<W, K> &w2) { const auto &v1_def = w1.DefaultValue(); const auto &v2_def = w2.DefaultValue(); if (v1_def != v2_def) return false; SparseTupleWeightIterator<W, K> w1_it(w1); SparseTupleWeightIterator<W, K> w2_it(w2); while (!w1_it.Done() || !w2_it.Done()) { const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first; const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first; const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second; const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second; if (k1 == k2) { if (v1 != v2) return false; if (!w1_it.Done()) w1_it.Next(); if (!w2_it.Done()) w2_it.Next(); } else if (k1 < k2) { if (v1 != v2_def) return false; w1_it.Next(); } else { if (v1_def != v2) return false; w2_it.Next(); } } return true; }
template <class W, class K> inline bool operator!=(const SparseTupleWeight<W, K> &w1, const SparseTupleWeight<W, K> &w2) { return !(w1 == w2); }
template <class W, class K> inline std::ostream &operator<<(std::ostream &strm, const SparseTupleWeight<W, K> &weight) { CompositeWeightWriter writer(strm); writer.WriteBegin(); writer.WriteElement(weight.DefaultValue()); for (SparseTupleWeightIterator<W, K> it(weight); !it.Done(); it.Next()) { writer.WriteElement(it.Value().first); writer.WriteElement(it.Value().second); } writer.WriteEnd(); return strm; }
template <class W, class K> inline std::istream &operator>>(std::istream &strm, SparseTupleWeight<W, K> &weight) { CompositeWeightReader reader(strm); reader.ReadBegin(); W def; bool more = reader.ReadElement(&def); weight.Init(def); while (more) { K key; reader.ReadElement(&key); W v; more = reader.ReadElement(&v); weight.PushBack(key, v); } reader.ReadEnd(); return strm; }
} // namespace fst
#endif // FST_SPARSE_TUPLE_WEIGHT_H_
|