You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

434 lines
13 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Sparse version of tuple-weight, based on tuple-weight.h.
  19. // Internally stores sparse key, value pairs in linked list. The default value
  20. // element is the assumed value of unset keys. Internal singleton
  21. // implementation that stores first key, value pair as a initialized member
  22. // variable to avoid unnecessary allocation on heap. Use
  23. // SparseTupleWeightIterator to iterate through the key,value pairs. Note:
  24. // this does NOT iterate through the default value.
  25. //
  26. // Sparse tuple weight set operation definitions.
  27. #ifndef FST_SPARSE_TUPLE_WEIGHT_H_
  28. #define FST_SPARSE_TUPLE_WEIGHT_H_
  29. #include <algorithm>
  30. #include <cstddef>
  31. #include <functional>
  32. #include <istream>
  33. #include <list>
  34. #include <ostream>
  35. #include <stack>
  36. #include <string>
  37. #include <utility>
  38. #include <fst/util.h>
  39. #include <fst/weight.h>
  40. namespace fst {
  41. template <class W, class K>
  42. class SparseTupleWeightIterator;
  43. // Arbitrary dimension tuple weight, stored as a sorted linked-list.
  44. // W is any weight class, and K is the key value type. kNoKey (-1) is reserved
  45. // for internal use.
  46. template <class W, class K = int>
  47. class SparseTupleWeight {
  48. public:
  49. using ReverseWeight = SparseTupleWeight<typename W::ReverseWeight, K>;
  50. using Iterator = SparseTupleWeightIterator<W, K>;
  51. using Pair = std::pair<K, W>;
  52. using Weight = W;
  53. using Index = K;
  54. static constexpr K kNoKey = -1;
  55. SparseTupleWeight() { Init(); }
  56. ~SparseTupleWeight() noexcept = default;
  57. template <class Iterator>
  58. SparseTupleWeight(Iterator begin, Iterator end) {
  59. Init();
  60. // Assumes input iterator is sorted.
  61. for (auto it = begin; it != end; ++it) PushBack(*it);
  62. }
  63. // Initialize component `key` to `weight`, with `default_weight` for all
  64. // other components.
  65. SparseTupleWeight(const K &key, const W &weight, const W &default_weight)
  66. : default_(default_weight),
  67. first_(weight == default_weight ? kNoKey : key, weight) {}
  68. explicit SparseTupleWeight(const W &weight) { Init(weight); }
  69. SparseTupleWeight(const SparseTupleWeight &weight) {
  70. Init(weight.DefaultValue());
  71. SetDefaultValue(weight.DefaultValue());
  72. for (Iterator it(weight); !it.Done(); it.Next()) {
  73. PushBack(it.Value());
  74. }
  75. }
  76. SparseTupleWeight(SparseTupleWeight &&weight) noexcept
  77. // Don't move the default, so weight.default_ is still valid.
  78. : default_(weight.default_), // NOLINT
  79. first_(std::move(weight.first_)),
  80. rest_(std::move(weight.rest_)) {
  81. // move leaves the source in a valid but unspecified state.
  82. // Make sure the source weight is empty.
  83. weight.first_ = Pair(kNoKey, W::NoWeight());
  84. weight.rest_.clear();
  85. }
  86. static const SparseTupleWeight &Zero() {
  87. static const SparseTupleWeight zero(W::Zero());
  88. return zero;
  89. }
  90. static const SparseTupleWeight &One() {
  91. static const SparseTupleWeight one(W::One());
  92. return one;
  93. }
  94. static const SparseTupleWeight &NoWeight() {
  95. static const SparseTupleWeight no_weight(W::NoWeight());
  96. return no_weight;
  97. }
  98. std::istream &Read(std::istream &strm) {
  99. ReadType(strm, &default_);
  100. ReadType(strm, &first_);
  101. return ReadType(strm, &rest_);
  102. }
  103. std::ostream &Write(std::ostream &strm) const {
  104. WriteType(strm, default_);
  105. WriteType(strm, first_);
  106. return WriteType(strm, rest_);
  107. }
  108. SparseTupleWeight &operator=(const SparseTupleWeight &weight) {
  109. if (this == &weight) return *this; // Checks for identity.
  110. Init(weight.DefaultValue());
  111. for (Iterator it(weight); !it.Done(); it.Next()) {
  112. PushBack(it.Value());
  113. }
  114. return *this;
  115. }
  116. SparseTupleWeight &operator=(SparseTupleWeight &&weight) noexcept {
  117. if (this == &weight) return *this; // Checks for identity.
  118. // Don't move the default, so weight.default_ is still valid.
  119. default_ = weight.default_;
  120. first_ = std::move(weight.first_);
  121. rest_ = std::move(weight.rest_);
  122. // move leaves the source in a valid but unspecified state.
  123. // Make sure the source weight is empty.
  124. weight.first_ = Pair(kNoKey, W::NoWeight());
  125. weight.rest_.clear();
  126. return *this;
  127. }
  128. bool Member() const {
  129. if (!DefaultValue().Member()) return false;
  130. for (Iterator it(*this); !it.Done(); it.Next()) {
  131. if (!it.Value().second.Member()) return false;
  132. }
  133. return true;
  134. }
  135. // Assumes H() function exists for the hash of the key value.
  136. size_t Hash() const {
  137. size_t h = 0;
  138. static const std::hash<K> H;
  139. for (Iterator it(*this); !it.Done(); it.Next()) {
  140. h = 5 * h + H(it.Value().first);
  141. h = 13 * h + it.Value().second.Hash();
  142. }
  143. return h;
  144. }
  145. SparseTupleWeight Quantize(float delta = kDelta) const {
  146. SparseTupleWeight weight;
  147. for (Iterator it(*this); !it.Done(); it.Next()) {
  148. weight.PushBack(it.Value().first, it.Value().second.Quantize(delta));
  149. }
  150. return weight;
  151. }
  152. ReverseWeight Reverse() const {
  153. ReverseWeight weight(DefaultValue().Reverse());
  154. for (Iterator it(*this); !it.Done(); it.Next()) {
  155. weight.PushBack(it.Value().first, it.Value().second.Reverse());
  156. }
  157. return weight;
  158. }
  159. void Init(const W &default_value = W::Zero()) {
  160. first_ = Pair(kNoKey, W::NoWeight());
  161. // Initialized to the reserved key value.
  162. default_ = default_value;
  163. rest_.clear();
  164. }
  165. size_t Size() const {
  166. if (first_.first == kNoKey) {
  167. return 0;
  168. } else {
  169. return rest_.size() + 1;
  170. }
  171. }
  172. inline void PushBack(const K &key, const W &weight,
  173. bool default_value_check = true) {
  174. PushBack(std::make_pair(key, weight), default_value_check);
  175. }
  176. inline void PushBack(const Pair &pair, bool default_value_check = true) {
  177. if (default_value_check && pair.second == default_) return;
  178. if (first_.first == kNoKey) {
  179. first_ = pair;
  180. } else {
  181. rest_.push_back(pair);
  182. }
  183. }
  184. // Returns the `key`-th component, or the default value if not set.
  185. const W &Value(const K &key) const {
  186. // TODO(rybach): Consider binary search.
  187. Iterator iter(*this);
  188. for (; !iter.Done() && iter.Value().first < key; iter.Next()) continue;
  189. return !iter.Done() && iter.Value().first == key ? iter.Value().second
  190. : DefaultValue();
  191. }
  192. void SetValue(const K &key, const W &w) {
  193. if (w == DefaultValue()) {
  194. ClearValue(key);
  195. } else {
  196. SetValueToNonDefault(key, w);
  197. }
  198. }
  199. void SetDefaultValue(const W &value) { default_ = value; }
  200. const W &DefaultValue() const { return default_; }
  201. private:
  202. void SetValueToNonDefault(const K &key, const W &w) {
  203. // Don't use SparseTupleWeightIterator, since that's const.
  204. if (first_.first == kNoKey) {
  205. first_ = Pair(key, w);
  206. } else if (key < first_.first) {
  207. rest_.push_front(first_);
  208. first_ = Pair(key, w);
  209. } else if (key == first_.first) {
  210. first_.second = w;
  211. } else {
  212. const auto i =
  213. std::find_if(rest_.begin(), rest_.end(),
  214. [key](const Pair &p) { return p.first >= key; });
  215. if (i != rest_.end() && i->first == key) {
  216. i->second = w;
  217. } else {
  218. rest_.insert(i, Pair(key, w));
  219. }
  220. }
  221. }
  222. // Removes the weight value for `key`, having the effect of setting
  223. // it to `DefaultValue()`.
  224. void ClearValue(const K &key) {
  225. if (key == first_.first) {
  226. if (!rest_.empty()) {
  227. first_ = rest_.front();
  228. rest_.pop_front();
  229. } else {
  230. first_.first = kNoKey;
  231. }
  232. } else if (key > first_.first) {
  233. const auto i =
  234. std::find_if(rest_.begin(), rest_.end(),
  235. [key](const Pair &p) { return p.first >= key; });
  236. if (i != rest_.end() && i->first == key) {
  237. rest_.erase(i);
  238. }
  239. }
  240. }
  241. // Assumed default value of uninitialized keys, by default W::Zero().
  242. W default_;
  243. // Key values pairs are first stored in first_, then fill rest_ this way we
  244. // can avoid dynamic allocation in the common case where the weight is a
  245. // single key/value pair.
  246. Pair first_;
  247. std::list<Pair> rest_;
  248. friend class SparseTupleWeightIterator<W, K>;
  249. };
  250. template <class W, class K>
  251. class SparseTupleWeightIterator {
  252. public:
  253. using Pair = typename SparseTupleWeight<W, K>::Pair;
  254. using const_iterator = typename std::list<Pair>::const_iterator;
  255. using iterator = typename std::list<Pair>::iterator;
  256. explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K> &weight)
  257. : first_(weight.first_),
  258. rest_(weight.rest_),
  259. init_(true),
  260. iter_(rest_.begin()) {}
  261. bool Done() const {
  262. if (init_) {
  263. return first_.first == SparseTupleWeight<W, K>::kNoKey;
  264. } else {
  265. return iter_ == rest_.end();
  266. }
  267. }
  268. const Pair &Value() const { return init_ ? first_ : *iter_; }
  269. void Next() {
  270. if (init_) {
  271. init_ = false;
  272. } else {
  273. ++iter_;
  274. }
  275. }
  276. void Reset() {
  277. init_ = true;
  278. iter_ = rest_.begin();
  279. }
  280. private:
  281. const Pair &first_;
  282. const std::list<Pair> &rest_;
  283. bool init_; // In the initialized state?
  284. const_iterator iter_;
  285. };
  286. // M must be callable as a function W(K, W, W).
  287. // K will be kNoKey when mapping the default value.
  288. template <class W, class K, class M>
  289. inline void SparseTupleWeightMap(SparseTupleWeight<W, K> *result,
  290. const SparseTupleWeight<W, K> &w1,
  291. const SparseTupleWeight<W, K> &w2,
  292. const M &operator_mapper) {
  293. SparseTupleWeightIterator<W, K> w1_it(w1);
  294. SparseTupleWeightIterator<W, K> w2_it(w2);
  295. const auto &v1_def = w1.DefaultValue();
  296. const auto &v2_def = w2.DefaultValue();
  297. result->SetDefaultValue(
  298. operator_mapper(SparseTupleWeight<W, K>::kNoKey, v1_def, v2_def));
  299. while (!w1_it.Done() || !w2_it.Done()) {
  300. const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
  301. const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
  302. const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
  303. const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
  304. if (k1 == k2) {
  305. result->PushBack(k1, operator_mapper(k1, v1, v2));
  306. if (!w1_it.Done()) w1_it.Next();
  307. if (!w2_it.Done()) w2_it.Next();
  308. } else if (k1 < k2) {
  309. result->PushBack(k1, operator_mapper(k1, v1, v2_def));
  310. w1_it.Next();
  311. } else {
  312. result->PushBack(k2, operator_mapper(k2, v1_def, v2));
  313. w2_it.Next();
  314. }
  315. }
  316. }
  317. template <class W, class K>
  318. inline bool operator==(const SparseTupleWeight<W, K> &w1,
  319. const SparseTupleWeight<W, K> &w2) {
  320. const auto &v1_def = w1.DefaultValue();
  321. const auto &v2_def = w2.DefaultValue();
  322. if (v1_def != v2_def) return false;
  323. SparseTupleWeightIterator<W, K> w1_it(w1);
  324. SparseTupleWeightIterator<W, K> w2_it(w2);
  325. while (!w1_it.Done() || !w2_it.Done()) {
  326. const auto &k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
  327. const auto &k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
  328. const auto &v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
  329. const auto &v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
  330. if (k1 == k2) {
  331. if (v1 != v2) return false;
  332. if (!w1_it.Done()) w1_it.Next();
  333. if (!w2_it.Done()) w2_it.Next();
  334. } else if (k1 < k2) {
  335. if (v1 != v2_def) return false;
  336. w1_it.Next();
  337. } else {
  338. if (v1_def != v2) return false;
  339. w2_it.Next();
  340. }
  341. }
  342. return true;
  343. }
  344. template <class W, class K>
  345. inline bool operator!=(const SparseTupleWeight<W, K> &w1,
  346. const SparseTupleWeight<W, K> &w2) {
  347. return !(w1 == w2);
  348. }
  349. template <class W, class K>
  350. inline std::ostream &operator<<(std::ostream &strm,
  351. const SparseTupleWeight<W, K> &weight) {
  352. CompositeWeightWriter writer(strm);
  353. writer.WriteBegin();
  354. writer.WriteElement(weight.DefaultValue());
  355. for (SparseTupleWeightIterator<W, K> it(weight); !it.Done(); it.Next()) {
  356. writer.WriteElement(it.Value().first);
  357. writer.WriteElement(it.Value().second);
  358. }
  359. writer.WriteEnd();
  360. return strm;
  361. }
  362. template <class W, class K>
  363. inline std::istream &operator>>(std::istream &strm,
  364. SparseTupleWeight<W, K> &weight) {
  365. CompositeWeightReader reader(strm);
  366. reader.ReadBegin();
  367. W def;
  368. bool more = reader.ReadElement(&def);
  369. weight.Init(def);
  370. while (more) {
  371. K key;
  372. reader.ReadElement(&key);
  373. W v;
  374. more = reader.ReadElement(&v);
  375. weight.PushBack(key, v);
  376. }
  377. reader.ReadEnd();
  378. return strm;
  379. }
  380. } // namespace fst
  381. #endif // FST_SPARSE_TUPLE_WEIGHT_H_