You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

524 lines
15 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Union weight set and associated semiring operation definitions.
  19. //
  20. // TODO(riley): add in normalizer functor.
  21. #ifndef FST_UNION_WEIGHT_H_
  22. #define FST_UNION_WEIGHT_H_
  23. #include <climits>
  24. #include <cstddef>
  25. #include <cstdint>
  26. #include <iostream>
  27. #include <istream>
  28. #include <list>
  29. #include <ostream>
  30. #include <random>
  31. #include <sstream>
  32. #include <string>
  33. #include <utility>
  34. #include <fst/util.h>
  35. #include <fst/weight.h>
  36. namespace fst {
  37. // Example UnionWeightOptions for UnionWeight template below. The Merge
  38. // operation is used to collapse elements of the set and the Compare function
  39. // to efficiently implement the merge. In the simplest case, merge would just
  40. // apply with equality of set elements so the result is a set (and not a
  41. // multiset). More generally, this can be used to maintain the multiplicity or
  42. // other such weight associated with the set elements (cf. Gallic weights).
  43. // template <class W>
  44. // struct UnionWeightOptions {
  45. // // Comparison function C is a total order on W that is monotonic w.r.t. to
  46. // // Times: for all a, b,c != Zero(): C(a, b) => C(ca, cb) and is
  47. // // anti-monotonic w.r.rt to Divide: C(a, b) => C(c/b, c/a).
  48. // //
  49. // // For all a, b: only one of C(a, b), C(b, a) or a ~ b must true where
  50. // // ~ is an equivalence relation on W. Also we require a ~ b iff
  51. // // a.Reverse() ~ b.Reverse().
  52. // using Compare = NaturalLess<W>;
  53. //
  54. // // How to combine two weights if a ~ b as above. For all a, b: a ~ b =>
  55. // // merge(a, b) ~ a, Merge must define a semiring endomorphism from the
  56. // // unmerged weight sets to the merged weight sets.
  57. // struct Merge {
  58. // W operator()(const W &w1, const W &w2) const { return w1; }
  59. // };
  60. //
  61. // // For ReverseWeight.
  62. // using ReverseOptions = UnionWeightOptions<ReverseWeight>;
  63. // };
  64. template <class W, class O>
  65. class UnionWeight;
  66. template <class W, class O>
  67. class UnionWeightIterator;
  68. template <class W, class O>
  69. class UnionWeightReverseIterator;
  70. template <class W, class O>
  71. bool operator==(const UnionWeight<W, O> &, const UnionWeight<W, O> &);
  72. // Semiring that uses Times() and One() from W and union and the empty set
  73. // for Plus() and Zero(), respectively. Template argument O specifies the union
  74. // weight options as above.
  75. template <class W, class O>
  76. class UnionWeight {
  77. public:
  78. using Weight = W;
  79. using Compare = typename O::Compare;
  80. using Merge = typename O::Merge;
  81. using ReverseWeight =
  82. UnionWeight<typename W::ReverseWeight, typename O::ReverseOptions>;
  83. friend class UnionWeightIterator<W, O>;
  84. friend class UnionWeightReverseIterator<W, O>;
  85. // Sets represented as first_ weight + rest_ weights. Uses first_ as
  86. // NoWeight() to indicate the union weight Zero() as the empty set. Uses
  87. // rest_ containing NoWeight() to indicate the union weight NoWeight().
  88. UnionWeight() : first_(W::NoWeight()) {}
  89. explicit UnionWeight(W weight) : first_(weight) {
  90. if (!weight.Member()) rest_.push_back(W::NoWeight());
  91. }
  92. static const UnionWeight &Zero() {
  93. static const auto *const zero = new UnionWeight;
  94. return *zero;
  95. }
  96. static const UnionWeight &One() {
  97. static const auto *const one = new UnionWeight(W::One());
  98. return *one;
  99. }
  100. static const UnionWeight &NoWeight() {
  101. static const auto *const no_weight =
  102. new UnionWeight(W::Zero(), W::NoWeight());
  103. return *no_weight;
  104. }
  105. static const std::string &Type() {
  106. static const std::string *const type =
  107. new std::string(W::Type() + "_union");
  108. return *type;
  109. }
  110. static constexpr uint64_t Properties() {
  111. return W::Properties() &
  112. (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent);
  113. }
  114. bool Member() const;
  115. std::istream &Read(std::istream &strm);
  116. std::ostream &Write(std::ostream &strm) const;
  117. size_t Hash() const;
  118. UnionWeight Quantize(float delta = kDelta) const;
  119. ReverseWeight Reverse() const;
  120. // These operations combined with the UnionWeightIterator and
  121. // UnionWeightReverseIterator provide the access and mutation of the union
  122. // weight internal elements.
  123. // Common initializer among constructors; clears existing UnionWeight.
  124. void Clear() {
  125. first_ = W::NoWeight();
  126. rest_.clear();
  127. }
  128. size_t Size() const { return first_.Member() ? rest_.size() + 1 : 0; }
  129. const W &Back() const { return rest_.empty() ? first_ : rest_.back(); }
  130. // When srt is true, assumes elements added sorted w.r.t Compare and merging
  131. // of weights performed as needed. Otherwise, just ensures first_ is the
  132. // least element wrt Compare.
  133. void PushBack(W weight, bool srt);
  134. // Sorts the elements of the set. Assumes that first_, if present, is the
  135. // least element.
  136. void Sort() { rest_.sort(comp_); }
  137. private:
  138. W &Back() {
  139. if (rest_.empty()) {
  140. return first_;
  141. } else {
  142. return rest_.back();
  143. }
  144. }
  145. UnionWeight(W w1, W w2) : first_(std::move(w1)), rest_(1, std::move(w2)) {}
  146. W first_; // First weight in set.
  147. std::list<W> rest_; // Remaining weights in set.
  148. Compare comp_;
  149. Merge merge_;
  150. };
  151. template <class W, class O>
  152. void UnionWeight<W, O>::PushBack(W weight, bool srt) {
  153. if (!weight.Member()) {
  154. rest_.push_back(std::move(weight));
  155. } else if (!first_.Member()) {
  156. first_ = std::move(weight);
  157. } else if (srt) {
  158. auto &back = Back();
  159. if (comp_(back, weight)) {
  160. rest_.push_back(std::move(weight));
  161. } else {
  162. back = merge_(back, std::move(weight));
  163. }
  164. } else {
  165. if (comp_(first_, weight)) {
  166. rest_.push_back(std::move(weight));
  167. } else {
  168. rest_.push_back(first_);
  169. first_ = std::move(weight);
  170. }
  171. }
  172. }
  173. // Traverses union weight in the forward direction.
  174. template <class W, class O>
  175. class UnionWeightIterator {
  176. public:
  177. explicit UnionWeightIterator(const UnionWeight<W, O> &weight)
  178. : first_(weight.first_),
  179. rest_(weight.rest_),
  180. init_(true),
  181. it_(rest_.begin()) {}
  182. bool Done() const { return init_ ? !first_.Member() : it_ == rest_.end(); }
  183. const W &Value() const { return init_ ? first_ : *it_; }
  184. void Next() {
  185. if (init_) {
  186. init_ = false;
  187. } else {
  188. ++it_;
  189. }
  190. }
  191. void Reset() {
  192. init_ = true;
  193. it_ = rest_.begin();
  194. }
  195. private:
  196. const W &first_;
  197. const std::list<W> &rest_;
  198. bool init_; // in the initialized state?
  199. typename std::list<W>::const_iterator it_;
  200. };
  201. // Traverses union weight in backward direction.
  202. template <typename L, class O>
  203. class UnionWeightReverseIterator {
  204. public:
  205. explicit UnionWeightReverseIterator(const UnionWeight<L, O> &weight)
  206. : first_(weight.first_),
  207. rest_(weight.rest_),
  208. fin_(!first_.Member()),
  209. it_(rest_.rbegin()) {}
  210. bool Done() const { return fin_; }
  211. const L &Value() const { return it_ == rest_.rend() ? first_ : *it_; }
  212. void Next() {
  213. if (it_ == rest_.rend()) {
  214. fin_ = true;
  215. } else {
  216. ++it_;
  217. }
  218. }
  219. void Reset() {
  220. fin_ = !first_.Member();
  221. it_ = rest_.rbegin();
  222. }
  223. private:
  224. const L &first_;
  225. const std::list<L> &rest_;
  226. bool fin_; // in the final state?
  227. typename std::list<L>::const_reverse_iterator it_;
  228. };
  229. // UnionWeight member functions follow that require UnionWeightIterator.
  230. template <class W, class O>
  231. inline std::istream &UnionWeight<W, O>::Read(std::istream &istrm) {
  232. Clear();
  233. int32_t size;
  234. ReadType(istrm, &size);
  235. for (int i = 0; i < size; ++i) {
  236. W weight;
  237. ReadType(istrm, &weight);
  238. PushBack(weight, true);
  239. }
  240. return istrm;
  241. }
  242. template <class W, class O>
  243. inline std::ostream &UnionWeight<W, O>::Write(std::ostream &ostrm) const {
  244. const int32_t size = Size();
  245. WriteType(ostrm, size);
  246. for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
  247. WriteType(ostrm, it.Value());
  248. }
  249. return ostrm;
  250. }
  251. template <class W, class O>
  252. inline bool UnionWeight<W, O>::Member() const {
  253. if (Size() <= 1) return true;
  254. for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
  255. if (!it.Value().Member()) return false;
  256. }
  257. return true;
  258. }
  259. template <class W, class O>
  260. inline UnionWeight<W, O> UnionWeight<W, O>::Quantize(float delta) const {
  261. UnionWeight weight;
  262. for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
  263. weight.PushBack(it.Value().Quantize(delta), true);
  264. }
  265. return weight;
  266. }
  267. template <class W, class O>
  268. inline typename UnionWeight<W, O>::ReverseWeight UnionWeight<W, O>::Reverse()
  269. const {
  270. ReverseWeight weight;
  271. for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
  272. weight.PushBack(it.Value().Reverse(), false);
  273. }
  274. weight.Sort();
  275. return weight;
  276. }
  277. template <class W, class O>
  278. inline size_t UnionWeight<W, O>::Hash() const {
  279. size_t h = 0;
  280. static constexpr int lshift = 5;
  281. static constexpr int rshift = CHAR_BIT * sizeof(size_t) - lshift;
  282. for (UnionWeightIterator<W, O> it(*this); !it.Done(); it.Next()) {
  283. h = h << lshift ^ h >> rshift ^ it.Value().Hash();
  284. }
  285. return h;
  286. }
  287. // Requires union weight has been canonicalized.
  288. template <class W, class O>
  289. inline bool operator==(const UnionWeight<W, O> &w1,
  290. const UnionWeight<W, O> &w2) {
  291. if (w1.Size() != w2.Size()) return false;
  292. UnionWeightIterator<W, O> it1(w1);
  293. UnionWeightIterator<W, O> it2(w2);
  294. for (; !it1.Done(); it1.Next(), it2.Next()) {
  295. if (it1.Value() != it2.Value()) return false;
  296. }
  297. return true;
  298. }
  299. // Requires union weight has been canonicalized.
  300. template <class W, class O>
  301. inline bool operator!=(const UnionWeight<W, O> &w1,
  302. const UnionWeight<W, O> &w2) {
  303. return !(w1 == w2);
  304. }
  305. // Requires union weight has been canonicalized.
  306. template <class W, class O>
  307. inline bool ApproxEqual(const UnionWeight<W, O> &w1,
  308. const UnionWeight<W, O> &w2, float delta = kDelta) {
  309. if (w1.Size() != w2.Size()) return false;
  310. UnionWeightIterator<W, O> it1(w1);
  311. UnionWeightIterator<W, O> it2(w2);
  312. for (; !it1.Done(); it1.Next(), it2.Next()) {
  313. if (!ApproxEqual(it1.Value(), it2.Value(), delta)) return false;
  314. }
  315. return true;
  316. }
  317. template <class W, class O>
  318. inline std::ostream &operator<<(std::ostream &ostrm,
  319. const UnionWeight<W, O> &weight) {
  320. UnionWeightIterator<W, O> it(weight);
  321. if (it.Done()) {
  322. return ostrm << "EmptySet";
  323. } else if (!weight.Member()) {
  324. return ostrm << "BadSet";
  325. } else {
  326. CompositeWeightWriter writer(ostrm);
  327. writer.WriteBegin();
  328. for (; !it.Done(); it.Next()) writer.WriteElement(it.Value());
  329. writer.WriteEnd();
  330. }
  331. return ostrm;
  332. }
  333. template <class W, class O>
  334. inline std::istream &operator>>(std::istream &istrm,
  335. UnionWeight<W, O> &weight) {
  336. std::string s;
  337. istrm >> s;
  338. if (s == "EmptySet") {
  339. weight = UnionWeight<W, O>::Zero();
  340. } else if (s == "BadSet") {
  341. weight = UnionWeight<W, O>::NoWeight();
  342. } else {
  343. weight = UnionWeight<W, O>::Zero();
  344. std::istringstream sstrm(s);
  345. CompositeWeightReader reader(sstrm);
  346. reader.ReadBegin();
  347. bool more = true;
  348. while (more) {
  349. W v;
  350. more = reader.ReadElement(&v);
  351. weight.PushBack(v, true);
  352. }
  353. reader.ReadEnd();
  354. }
  355. return istrm;
  356. }
  357. template <class W, class O>
  358. inline UnionWeight<W, O> Plus(const UnionWeight<W, O> &w1,
  359. const UnionWeight<W, O> &w2) {
  360. if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
  361. if (w1 == UnionWeight<W, O>::Zero()) return w2;
  362. if (w2 == UnionWeight<W, O>::Zero()) return w1;
  363. UnionWeightIterator<W, O> it1(w1);
  364. UnionWeightIterator<W, O> it2(w2);
  365. UnionWeight<W, O> sum;
  366. typename O::Compare comp;
  367. while (!it1.Done() && !it2.Done()) {
  368. const auto v1 = it1.Value();
  369. const auto v2 = it2.Value();
  370. if (comp(v1, v2)) {
  371. sum.PushBack(v1, true);
  372. it1.Next();
  373. } else {
  374. sum.PushBack(v2, true);
  375. it2.Next();
  376. }
  377. }
  378. for (; !it1.Done(); it1.Next()) sum.PushBack(it1.Value(), true);
  379. for (; !it2.Done(); it2.Next()) sum.PushBack(it2.Value(), true);
  380. return sum;
  381. }
  382. template <class W, class O>
  383. inline UnionWeight<W, O> Times(const UnionWeight<W, O> &w1,
  384. const UnionWeight<W, O> &w2) {
  385. if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
  386. if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
  387. return UnionWeight<W, O>::Zero();
  388. }
  389. UnionWeightIterator<W, O> it1(w1);
  390. UnionWeightIterator<W, O> it2(w2);
  391. UnionWeight<W, O> prod1;
  392. for (; !it1.Done(); it1.Next()) {
  393. UnionWeight<W, O> prod2;
  394. for (; !it2.Done(); it2.Next()) {
  395. prod2.PushBack(Times(it1.Value(), it2.Value()), true);
  396. }
  397. prod1 = Plus(prod1, prod2);
  398. it2.Reset();
  399. }
  400. return prod1;
  401. }
  402. template <class W, class O>
  403. inline UnionWeight<W, O> Divide(const UnionWeight<W, O> &w1,
  404. const UnionWeight<W, O> &w2, DivideType typ) {
  405. if (!w1.Member() || !w2.Member()) return UnionWeight<W, O>::NoWeight();
  406. if (w1 == UnionWeight<W, O>::Zero() || w2 == UnionWeight<W, O>::Zero()) {
  407. return UnionWeight<W, O>::Zero();
  408. }
  409. UnionWeightIterator<W, O> it1(w1);
  410. UnionWeightReverseIterator<W, O> it2(w2);
  411. UnionWeight<W, O> quot;
  412. if (w1.Size() == 1) {
  413. for (; !it2.Done(); it2.Next()) {
  414. quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
  415. }
  416. } else if (w2.Size() == 1) {
  417. for (; !it1.Done(); it1.Next()) {
  418. quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true);
  419. }
  420. } else {
  421. quot = UnionWeight<W, O>::NoWeight();
  422. }
  423. return quot;
  424. }
  425. // This function object generates weights over the union of weights for the
  426. // underlying generators for the template weight types. This is intended
  427. // primarily for testing.
  428. template <class W, class O>
  429. class WeightGenerate<UnionWeight<W, O>> {
  430. public:
  431. using Weight = UnionWeight<W, O>;
  432. using Generate = WeightGenerate<W>;
  433. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  434. bool allow_zero = true,
  435. size_t num_random_weights = kNumRandomWeights)
  436. : rand_(seed),
  437. allow_zero_(allow_zero),
  438. num_random_weights_(num_random_weights),
  439. generate_(seed, false) {}
  440. Weight operator()() const {
  441. const int sample = std::uniform_int_distribution<>(
  442. 0, num_random_weights_ + allow_zero_ - 1)(rand_);
  443. if (allow_zero_ && sample == num_random_weights_) {
  444. return Weight::Zero();
  445. } else if (std::bernoulli_distribution(.5)(rand_)) {
  446. return Weight(generate_());
  447. } else {
  448. return Plus(Weight(generate_()), Weight(generate_()));
  449. }
  450. }
  451. private:
  452. mutable std::mt19937_64 rand_;
  453. const bool allow_zero_;
  454. const size_t num_random_weights_;
  455. const Generate generate_;
  456. };
  457. } // namespace fst
  458. #endif // FST_UNION_WEIGHT_H_