You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

638 lines
19 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Weights consisting of sets (of integral Labels) and
  19. // associated semiring operation definitions using intersect
  20. // and union.
  21. #ifndef FST_SET_WEIGHT_H_
  22. #define FST_SET_WEIGHT_H_
  23. #include <algorithm>
  24. #include <cstddef>
  25. #include <cstdint>
  26. #include <ios>
  27. #include <istream>
  28. #include <list>
  29. #include <optional>
  30. #include <ostream>
  31. #include <random>
  32. #include <string>
  33. #include <utility>
  34. #include <vector>
  35. #include <fst/log.h>
  36. #include <fst/union-weight.h>
  37. #include <fst/util.h>
  38. #include <fst/weight.h>
  39. #include <string_view>
  40. namespace fst {
  41. inline constexpr int kSetEmpty = 0; // Label for the empty set.
  42. inline constexpr int kSetUniv = -1; // Label for the universal set.
  43. inline constexpr int kSetBad = -2; // Label for a non-set.
  44. inline constexpr char kSetSeparator = '_'; // Label separator in sets.
  45. // Determines whether to use (intersect, union) or (union, intersect)
  46. // as (+, *) for the semiring. SET_INTERSECT_UNION_RESTRICTED is a
  47. // restricted version of (intersect, union) that requires summed
  48. // arguments to be equal (or an error is signalled), useful for
  49. // algorithms that require a unique labelled path weight. SET_BOOLEAN
  50. // treats all non-Zero() elements as equivalent (with Zero() ==
  51. // UnivSet()), useful for algorithms that don't really depend on the
  52. // detailed sets.
  53. enum SetType {
  54. SET_INTERSECT_UNION = 0,
  55. SET_UNION_INTERSECT = 1,
  56. SET_INTERSECT_UNION_RESTRICT = 2,
  57. SET_BOOLEAN = 3
  58. };
  59. template <class>
  60. class SetWeightIterator;
  61. // Set semiring of integral labels.
  62. template <typename L, SetType S = SET_INTERSECT_UNION>
  63. class SetWeight {
  64. public:
  65. using Label = L;
  66. using ReverseWeight = SetWeight<Label, S>;
  67. using Iterator = SetWeightIterator<SetWeight>;
  68. friend class SetWeightIterator<SetWeight>;
  69. // Allow type-converting copy and move constructors private access.
  70. template <typename L2, SetType S2>
  71. friend class SetWeight;
  72. SetWeight() = default;
  73. // Input should be positive, sorted and unique.
  74. template <typename Iterator>
  75. SetWeight(const Iterator begin, const Iterator end) {
  76. for (auto iter = begin; iter != end; ++iter) PushBack(*iter);
  77. }
  78. // Input should be positive. (Non-positive value has
  79. // special internal meaning w.r.t. integral constants above.)
  80. explicit SetWeight(Label label) { PushBack(label); }
  81. template <SetType S2>
  82. explicit SetWeight(const SetWeight<Label, S2> &w)
  83. : first_(w.first_), rest_(w.rest_) {}
  84. template <SetType S2>
  85. explicit SetWeight(SetWeight<Label, S2> &&w)
  86. : first_(w.first_), rest_(std::move(w.rest_)) {
  87. w.Clear();
  88. }
  89. template <SetType S2>
  90. SetWeight &operator=(const SetWeight<Label, S2> &w) {
  91. first_ = w.first_;
  92. rest_ = w.rest_;
  93. return *this;
  94. }
  95. template <SetType S2>
  96. SetWeight &operator=(SetWeight<Label, S2> &&w) {
  97. first_ = w.first_;
  98. rest_ = std::move(w.rest_);
  99. w.Clear();
  100. return *this;
  101. }
  102. static const SetWeight &Zero() {
  103. return S == SET_UNION_INTERSECT ? EmptySet() : UnivSet();
  104. }
  105. static const SetWeight &One() {
  106. return S == SET_UNION_INTERSECT ? UnivSet() : EmptySet();
  107. }
  108. static const SetWeight &NoWeight() {
  109. static const auto *const no_weight = new SetWeight(Label(kSetBad));
  110. return *no_weight;
  111. }
  112. static const std::string &Type() {
  113. static const std::string *const type =
  114. new std::string(S == SET_UNION_INTERSECT
  115. ? "union_intersect_set"
  116. : (S == SET_INTERSECT_UNION
  117. ? "intersect_union_set"
  118. : (S == SET_INTERSECT_UNION_RESTRICT
  119. ? "restricted_set_intersect_union"
  120. : "boolean_set")));
  121. return *type;
  122. }
  123. bool Member() const;
  124. std::istream &Read(std::istream &strm);
  125. std::ostream &Write(std::ostream &strm) const;
  126. size_t Hash() const;
  127. SetWeight Quantize(float delta = kDelta) const { return *this; }
  128. ReverseWeight Reverse() const;
  129. static constexpr uint64_t Properties() {
  130. return kIdempotent | kLeftSemiring | kRightSemiring | kCommutative;
  131. }
  132. // These operations combined with the SetWeightIterator
  133. // provide the access and mutation of the set internal elements.
  134. // The empty set.
  135. static const SetWeight &EmptySet() {
  136. static const auto *const empty = new SetWeight(Label(kSetEmpty));
  137. return *empty;
  138. }
  139. // The univeral set.
  140. static const SetWeight &UnivSet() {
  141. static const auto *const univ = new SetWeight(Label(kSetUniv));
  142. return *univ;
  143. }
  144. // Clear existing SetWeight.
  145. void Clear() {
  146. first_ = kSetEmpty;
  147. rest_.clear();
  148. }
  149. size_t Size() const { return first_ == kSetEmpty ? 0 : rest_.size() + 1; }
  150. Label Back() {
  151. if (rest_.empty()) {
  152. return first_;
  153. } else {
  154. return rest_.back();
  155. }
  156. }
  157. // Caller must add in sort order and be unique (or error signalled).
  158. // Input should also be positive. Non-positive value (for the first
  159. // push) has special internal meaning w.r.t. integral constants above.
  160. void PushBack(Label label) {
  161. if (first_ == kSetEmpty) {
  162. first_ = label;
  163. } else {
  164. if (label <= Back() || label <= 0) {
  165. FSTERROR() << "SetWeight: labels must be positive, added"
  166. << " in sort order and be unique.";
  167. rest_.push_back(Label(kSetBad));
  168. }
  169. rest_.push_back(label);
  170. }
  171. }
  172. private:
  173. Label first_ = kSetEmpty; // First label in set (kSetEmpty if empty).
  174. std::list<Label> rest_; // Remaining labels in set.
  175. };
  176. // Traverses set in forward direction.
  177. template <class SetWeight_>
  178. class SetWeightIterator {
  179. public:
  180. using Weight = SetWeight_;
  181. using Label = typename Weight::Label;
  182. explicit SetWeightIterator(const Weight &w)
  183. : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
  184. bool Done() const {
  185. if (init_) {
  186. return first_ == kSetEmpty;
  187. } else {
  188. return iter_ == rest_.end();
  189. }
  190. }
  191. const Label &Value() const { return init_ ? first_ : *iter_; }
  192. void Next() {
  193. if (init_) {
  194. init_ = false;
  195. } else {
  196. ++iter_;
  197. }
  198. }
  199. void Reset() {
  200. init_ = true;
  201. iter_ = rest_.begin();
  202. }
  203. private:
  204. const Label &first_;
  205. const decltype(Weight::rest_) &rest_;
  206. bool init_; // In the initialized state?
  207. typename decltype(Weight::rest_)::const_iterator iter_;
  208. };
  209. // SetWeight member functions follow that require SetWeightIterator
  210. template <typename Label, SetType S>
  211. inline std::istream &SetWeight<Label, S>::Read(std::istream &strm) {
  212. Clear();
  213. int32_t size;
  214. ReadType(strm, &size);
  215. for (int32_t i = 0; i < size; ++i) {
  216. Label label;
  217. ReadType(strm, &label);
  218. PushBack(label);
  219. }
  220. return strm;
  221. }
  222. template <typename Label, SetType S>
  223. inline std::ostream &SetWeight<Label, S>::Write(std::ostream &strm) const {
  224. const int32_t size = Size();
  225. WriteType(strm, size);
  226. for (Iterator iter(*this); !iter.Done(); iter.Next()) {
  227. WriteType(strm, iter.Value());
  228. }
  229. return strm;
  230. }
  231. template <typename Label, SetType S>
  232. inline bool SetWeight<Label, S>::Member() const {
  233. Iterator iter(*this);
  234. return iter.Value() != Label(kSetBad);
  235. }
  236. template <typename Label, SetType S>
  237. inline typename SetWeight<Label, S>::ReverseWeight
  238. SetWeight<Label, S>::Reverse() const {
  239. return *this;
  240. }
  241. template <typename Label, SetType S>
  242. inline size_t SetWeight<Label, S>::Hash() const {
  243. using Weight = SetWeight<Label, S>;
  244. if (S == SET_BOOLEAN) {
  245. return *this == Weight::Zero() ? 0 : 1;
  246. } else {
  247. size_t h = 0;
  248. for (Iterator iter(*this); !iter.Done(); iter.Next()) {
  249. h ^= h << 1 ^ iter.Value();
  250. }
  251. return h;
  252. }
  253. }
  254. // Default ==
  255. template <typename Label, SetType S>
  256. inline bool operator==(const SetWeight<Label, S> &w1,
  257. const SetWeight<Label, S> &w2) {
  258. if (w1.Size() != w2.Size()) return false;
  259. using Iterator = typename SetWeight<Label, S>::Iterator;
  260. Iterator iter1(w1);
  261. Iterator iter2(w2);
  262. for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
  263. if (iter1.Value() != iter2.Value()) return false;
  264. }
  265. return true;
  266. }
  267. // Boolean ==
  268. template <typename Label>
  269. inline bool operator==(const SetWeight<Label, SET_BOOLEAN> &w1,
  270. const SetWeight<Label, SET_BOOLEAN> &w2) {
  271. // x == kSetEmpty if x \nin {kUnivSet, kSetBad}
  272. if (!w1.Member() || !w2.Member()) return false;
  273. using Iterator = typename SetWeight<Label, SET_BOOLEAN>::Iterator;
  274. Iterator iter1(w1);
  275. Iterator iter2(w2);
  276. Label label1 = iter1.Done() ? kSetEmpty : iter1.Value();
  277. Label label2 = iter2.Done() ? kSetEmpty : iter2.Value();
  278. if (label1 == kSetUniv) return label2 == kSetUniv;
  279. if (label2 == kSetUniv) return label1 == kSetUniv;
  280. return true;
  281. }
  282. template <typename Label, SetType S>
  283. inline bool operator!=(const SetWeight<Label, S> &w1,
  284. const SetWeight<Label, S> &w2) {
  285. return !(w1 == w2);
  286. }
  287. template <typename Label, SetType S>
  288. inline bool ApproxEqual(const SetWeight<Label, S> &w1,
  289. const SetWeight<Label, S> &w2, float delta = kDelta) {
  290. return w1 == w2;
  291. }
  292. template <typename Label, SetType S>
  293. inline std::ostream &operator<<(std::ostream &strm,
  294. const SetWeight<Label, S> &weight) {
  295. typename SetWeight<Label, S>::Iterator iter(weight);
  296. if (iter.Done()) {
  297. return strm << "EmptySet";
  298. } else if (iter.Value() == Label(kSetUniv)) {
  299. return strm << "UnivSet";
  300. } else if (iter.Value() == Label(kSetBad)) {
  301. return strm << "BadSet";
  302. } else {
  303. for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
  304. if (i > 0) strm << kSetSeparator;
  305. strm << iter.Value();
  306. }
  307. }
  308. return strm;
  309. }
  310. template <typename Label, SetType S>
  311. inline std::istream &operator>>(std::istream &strm,
  312. SetWeight<Label, S> &weight) {
  313. std::string str;
  314. strm >> str;
  315. using Weight = SetWeight<Label, S>;
  316. if (str == "EmptySet") {
  317. weight = Weight(Label(kSetEmpty));
  318. } else if (str == "UnivSet") {
  319. weight = Weight(Label(kSetUniv));
  320. } else {
  321. weight.Clear();
  322. for (std::string_view sv : StrSplit(str, kSetSeparator)) {
  323. auto maybe_label = ParseInt64(sv);
  324. if (!maybe_label.has_value()) {
  325. strm.clear(std::ios::badbit);
  326. break;
  327. }
  328. weight.PushBack(*maybe_label);
  329. }
  330. }
  331. return strm;
  332. }
  333. template <typename Label, SetType S>
  334. inline SetWeight<Label, S> Union(const SetWeight<Label, S> &w1,
  335. const SetWeight<Label, S> &w2) {
  336. using Weight = SetWeight<Label, S>;
  337. using Iterator = typename SetWeight<Label, S>::Iterator;
  338. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  339. if (w1 == Weight::EmptySet()) return w2;
  340. if (w2 == Weight::EmptySet()) return w1;
  341. if (w1 == Weight::UnivSet()) return w1;
  342. if (w2 == Weight::UnivSet()) return w2;
  343. Iterator it1(w1);
  344. Iterator it2(w2);
  345. Weight result;
  346. while (!it1.Done() && !it2.Done()) {
  347. const auto v1 = it1.Value();
  348. const auto v2 = it2.Value();
  349. if (v1 < v2) {
  350. result.PushBack(v1);
  351. it1.Next();
  352. } else if (v1 > v2) {
  353. result.PushBack(v2);
  354. it2.Next();
  355. } else {
  356. result.PushBack(v1);
  357. it1.Next();
  358. it2.Next();
  359. }
  360. }
  361. for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
  362. for (; !it2.Done(); it2.Next()) result.PushBack(it2.Value());
  363. return result;
  364. }
  365. template <typename Label, SetType S>
  366. inline SetWeight<Label, S> Intersect(const SetWeight<Label, S> &w1,
  367. const SetWeight<Label, S> &w2) {
  368. using Weight = SetWeight<Label, S>;
  369. using Iterator = typename SetWeight<Label, S>::Iterator;
  370. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  371. if (w1 == Weight::EmptySet()) return w1;
  372. if (w2 == Weight::EmptySet()) return w2;
  373. if (w1 == Weight::UnivSet()) return w2;
  374. if (w2 == Weight::UnivSet()) return w1;
  375. Iterator it1(w1);
  376. Iterator it2(w2);
  377. Weight result;
  378. while (!it1.Done() && !it2.Done()) {
  379. const auto v1 = it1.Value();
  380. const auto v2 = it2.Value();
  381. if (v1 < v2) {
  382. it1.Next();
  383. } else if (v1 > v2) {
  384. it2.Next();
  385. } else {
  386. result.PushBack(v1);
  387. it1.Next();
  388. it2.Next();
  389. }
  390. }
  391. return result;
  392. }
  393. template <typename Label, SetType S>
  394. inline SetWeight<Label, S> Difference(const SetWeight<Label, S> &w1,
  395. const SetWeight<Label, S> &w2) {
  396. using Weight = SetWeight<Label, S>;
  397. using Iterator = typename SetWeight<Label, S>::Iterator;
  398. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  399. if (w1 == Weight::EmptySet()) return w1;
  400. if (w2 == Weight::EmptySet()) return w1;
  401. if (w2 == Weight::UnivSet()) return Weight::EmptySet();
  402. Iterator it1(w1);
  403. Iterator it2(w2);
  404. Weight result;
  405. while (!it1.Done() && !it2.Done()) {
  406. const auto v1 = it1.Value();
  407. const auto v2 = it2.Value();
  408. if (v1 < v2) {
  409. result.PushBack(v1);
  410. it1.Next();
  411. } else if (v1 > v2) {
  412. it2.Next();
  413. } else {
  414. it1.Next();
  415. it2.Next();
  416. }
  417. }
  418. for (; !it1.Done(); it1.Next()) result.PushBack(it1.Value());
  419. return result;
  420. }
  421. // Default: Plus = Intersect.
  422. template <typename Label, SetType S>
  423. inline SetWeight<Label, S> Plus(const SetWeight<Label, S> &w1,
  424. const SetWeight<Label, S> &w2) {
  425. return Intersect(w1, w2);
  426. }
  427. // Plus = Union.
  428. template <typename Label>
  429. inline SetWeight<Label, SET_UNION_INTERSECT> Plus(
  430. const SetWeight<Label, SET_UNION_INTERSECT> &w1,
  431. const SetWeight<Label, SET_UNION_INTERSECT> &w2) {
  432. return Union(w1, w2);
  433. }
  434. // Plus = Set equality is required (for non-Zero() input). The
  435. // restriction is useful (e.g., in determinization) to ensure the input
  436. // has a unique labelled path weight.
  437. template <typename Label>
  438. inline SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> Plus(
  439. const SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> &w1,
  440. const SetWeight<Label, SET_INTERSECT_UNION_RESTRICT> &w2) {
  441. using Weight = SetWeight<Label, SET_INTERSECT_UNION_RESTRICT>;
  442. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  443. if (w1 == Weight::Zero()) return w2;
  444. if (w2 == Weight::Zero()) return w1;
  445. if (w1 != w2) {
  446. FSTERROR() << "SetWeight::Plus: Unequal arguments "
  447. << "(non-unique labelled path weights?)"
  448. << " w1 = " << w1 << " w2 = " << w2;
  449. return Weight::NoWeight();
  450. }
  451. return w1;
  452. }
  453. // Plus = Or.
  454. template <typename Label>
  455. inline SetWeight<Label, SET_BOOLEAN> Plus(
  456. const SetWeight<Label, SET_BOOLEAN> &w1,
  457. const SetWeight<Label, SET_BOOLEAN> &w2) {
  458. using Weight = SetWeight<Label, SET_BOOLEAN>;
  459. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  460. if (w1 == Weight::One()) return w1;
  461. if (w2 == Weight::One()) return w2;
  462. return Weight::Zero();
  463. }
  464. // Default: Times = Union.
  465. template <typename Label, SetType S>
  466. inline SetWeight<Label, S> Times(const SetWeight<Label, S> &w1,
  467. const SetWeight<Label, S> &w2) {
  468. return Union(w1, w2);
  469. }
  470. // Times = Intersect.
  471. template <typename Label>
  472. inline SetWeight<Label, SET_UNION_INTERSECT> Times(
  473. const SetWeight<Label, SET_UNION_INTERSECT> &w1,
  474. const SetWeight<Label, SET_UNION_INTERSECT> &w2) {
  475. return Intersect(w1, w2);
  476. }
  477. // Times = And.
  478. template <typename Label>
  479. inline SetWeight<Label, SET_BOOLEAN> Times(
  480. const SetWeight<Label, SET_BOOLEAN> &w1,
  481. const SetWeight<Label, SET_BOOLEAN> &w2) {
  482. using Weight = SetWeight<Label, SET_BOOLEAN>;
  483. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  484. if (w1 == Weight::One()) return w2;
  485. return w1;
  486. }
  487. // Divide = Difference.
  488. template <typename Label, SetType S>
  489. inline SetWeight<Label, S> Divide(const SetWeight<Label, S> &w1,
  490. const SetWeight<Label, S> &w2,
  491. DivideType divide_type = DIVIDE_ANY) {
  492. return Difference(w1, w2);
  493. }
  494. // Divide = dividend (or the universal set if the
  495. // dividend == divisor).
  496. template <typename Label>
  497. inline SetWeight<Label, SET_UNION_INTERSECT> Divide(
  498. const SetWeight<Label, SET_UNION_INTERSECT> &w1,
  499. const SetWeight<Label, SET_UNION_INTERSECT> &w2,
  500. DivideType divide_type = DIVIDE_ANY) {
  501. using Weight = SetWeight<Label, SET_UNION_INTERSECT>;
  502. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  503. if (w1 == w2) return Weight::UnivSet();
  504. return w1;
  505. }
  506. // Divide = Or Not.
  507. template <typename Label>
  508. inline SetWeight<Label, SET_BOOLEAN> Divide(
  509. const SetWeight<Label, SET_BOOLEAN> &w1,
  510. const SetWeight<Label, SET_BOOLEAN> &w2,
  511. DivideType divide_type = DIVIDE_ANY) {
  512. using Weight = SetWeight<Label, SET_BOOLEAN>;
  513. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  514. if (w1 == Weight::One()) return w1;
  515. if (w2 == Weight::Zero()) return Weight::One();
  516. return Weight::Zero();
  517. }
  518. // Converts between different set types.
  519. template <typename Label, SetType S1, SetType S2>
  520. struct WeightConvert<SetWeight<Label, S1>, SetWeight<Label, S2>> {
  521. SetWeight<Label, S2> operator()(const SetWeight<Label, S1> &w1) const {
  522. using Iterator = SetWeightIterator<SetWeight<Label, S1>>;
  523. SetWeight<Label, S2> w2;
  524. for (Iterator iter(w1); !iter.Done(); iter.Next())
  525. w2.PushBack(iter.Value());
  526. return w2;
  527. }
  528. };
  529. // This function object generates SetWeights that are random integer sets
  530. // from {1, ... , alphabet_size}^{0, max_set_length} U { Zero }. This is
  531. // intended primarily for testing.
  532. template <class Label, SetType S>
  533. class WeightGenerate<SetWeight<Label, S>> {
  534. public:
  535. using Weight = SetWeight<Label, S>;
  536. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  537. bool allow_zero = true,
  538. size_t alphabet_size = kNumRandomWeights,
  539. size_t max_set_length = kNumRandomWeights)
  540. : allow_zero_(allow_zero),
  541. alphabet_size_(alphabet_size),
  542. max_set_length_(max_set_length) {}
  543. Weight operator()() const {
  544. const int n = std::uniform_int_distribution<>(
  545. 0, max_set_length_ + allow_zero_ - 1)(rand_);
  546. if (allow_zero_ && n == max_set_length_) return Weight::Zero();
  547. std::vector<Label> labels;
  548. labels.reserve(n);
  549. for (int i = 0; i < n; ++i) {
  550. labels.push_back(
  551. std::uniform_int_distribution<>(0, alphabet_size_)(rand_));
  552. }
  553. std::sort(labels.begin(), labels.end());
  554. const auto labels_end = std::unique(labels.begin(), labels.end());
  555. labels.resize(labels_end - labels.begin());
  556. return Weight(labels.begin(), labels.end());
  557. }
  558. private:
  559. mutable std::mt19937_64 rand_;
  560. const bool allow_zero_;
  561. const size_t alphabet_size_;
  562. const size_t max_set_length_;
  563. };
  564. } // namespace fst
  565. #endif // FST_SET_WEIGHT_H_