You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

834 lines
26 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // String weight set and associated semiring operation definitions.
  19. #ifndef FST_STRING_WEIGHT_H_
  20. #define FST_STRING_WEIGHT_H_
  21. #include <cstddef>
  22. #include <cstdint>
  23. #include <ios>
  24. #include <istream>
  25. #include <list>
  26. #include <optional>
  27. #include <ostream>
  28. #include <random>
  29. #include <string>
  30. #include <vector>
  31. #include <fst/log.h>
  32. #include <fst/product-weight.h>
  33. #include <fst/union-weight.h>
  34. #include <fst/util.h>
  35. #include <fst/weight.h>
  36. #include <string_view>
  37. namespace fst {
  38. inline constexpr int kStringInfinity = -1; // Label for the infinite string.
  39. inline constexpr int kStringBad = -2; // Label for a non-string.
  40. inline constexpr char kStringSeparator = '_'; // Label separator in strings.
  41. // Determines whether to use left or right string semiring. Includes a
  42. // 'restricted' version that signals an error if proper prefixes/suffixes
  43. // would otherwise be returned by Plus, useful with various
  44. // algorithms that require functional transducer input with the
  45. // string semirings.
  46. enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1, STRING_RESTRICT = 2 };
  47. constexpr StringType ReverseStringType(StringType s) {
  48. return s == STRING_LEFT ? STRING_RIGHT
  49. : (s == STRING_RIGHT ? STRING_LEFT : STRING_RESTRICT);
  50. }
  51. template <class>
  52. class StringWeightIterator;
  53. template <class>
  54. class StringWeightReverseIterator;
  55. // String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon)
  56. template <typename L, StringType S = STRING_LEFT>
  57. class StringWeight {
  58. public:
  59. using Label = L;
  60. using ReverseWeight = StringWeight<Label, ReverseStringType(S)>;
  61. using Iterator = StringWeightIterator<StringWeight>;
  62. using ReverseIterator = StringWeightReverseIterator<StringWeight>;
  63. friend class StringWeightIterator<StringWeight>;
  64. friend class StringWeightReverseIterator<StringWeight>;
  65. StringWeight() = default;
  66. template <typename Iterator>
  67. StringWeight(const Iterator begin, const Iterator end) {
  68. for (auto iter = begin; iter != end; ++iter) PushBack(*iter);
  69. }
  70. explicit StringWeight(Label label) { PushBack(label); }
  71. static const StringWeight &Zero() {
  72. static const auto *const zero = new StringWeight(Label(kStringInfinity));
  73. return *zero;
  74. }
  75. static const StringWeight &One() {
  76. static const auto *const one = new StringWeight();
  77. return *one;
  78. }
  79. static const StringWeight &NoWeight() {
  80. static const auto *const no_weight = new StringWeight(Label(kStringBad));
  81. return *no_weight;
  82. }
  83. static const std::string &Type() {
  84. static const std::string *const type = new std::string(
  85. S == STRING_LEFT
  86. ? "left_string"
  87. : (S == STRING_RIGHT ? "right_string" : "restricted_string"));
  88. return *type;
  89. }
  90. bool Member() const;
  91. std::istream &Read(std::istream &strm);
  92. std::ostream &Write(std::ostream &strm) const;
  93. size_t Hash() const;
  94. StringWeight Quantize(float delta = kDelta) const { return *this; }
  95. ReverseWeight Reverse() const;
  96. static constexpr uint64_t Properties() {
  97. return kIdempotent |
  98. (S == STRING_LEFT ? kLeftSemiring
  99. : (S == STRING_RIGHT
  100. ? kRightSemiring
  101. : /* S == STRING_RESTRICT */ kLeftSemiring |
  102. kRightSemiring));
  103. }
  104. // These operations combined with the StringWeightIterator and
  105. // StringWeightReverseIterator provide the access and mutation of the string
  106. // internal elements.
  107. // Clear existing StringWeight.
  108. void Clear() {
  109. first_ = 0;
  110. rest_.clear();
  111. }
  112. size_t Size() const { return first_ ? rest_.size() + 1 : 0; }
  113. void PushFront(Label label) {
  114. if (first_) rest_.push_front(first_);
  115. first_ = label;
  116. }
  117. void PushBack(Label label) {
  118. if (!first_) {
  119. first_ = label;
  120. } else {
  121. rest_.push_back(label);
  122. }
  123. }
  124. private:
  125. Label first_ = 0; // First label in string (0 if empty).
  126. std::list<Label> rest_; // Remaining labels in string.
  127. };
  128. // Traverses string in forward direction.
  129. template <class StringWeight_>
  130. class StringWeightIterator {
  131. public:
  132. using Weight = StringWeight_;
  133. using Label = typename Weight::Label;
  134. explicit StringWeightIterator(const Weight &w)
  135. : first_(w.first_), rest_(w.rest_), init_(true), iter_(rest_.begin()) {}
  136. bool Done() const {
  137. if (init_) {
  138. return first_ == 0;
  139. } else {
  140. return iter_ == rest_.end();
  141. }
  142. }
  143. const Label &Value() const { return init_ ? first_ : *iter_; }
  144. void Next() {
  145. if (init_) {
  146. init_ = false;
  147. } else {
  148. ++iter_;
  149. }
  150. }
  151. void Reset() {
  152. init_ = true;
  153. iter_ = rest_.begin();
  154. }
  155. private:
  156. const Label &first_;
  157. const decltype(Weight::rest_) &rest_;
  158. bool init_; // In the initialized state?
  159. typename decltype(Weight::rest_)::const_iterator iter_;
  160. };
  161. // Traverses string in backward direction.
  162. template <class StringWeight_>
  163. class StringWeightReverseIterator {
  164. public:
  165. using Weight = StringWeight_;
  166. using Label = typename Weight::Label;
  167. explicit StringWeightReverseIterator(const Weight &w)
  168. : first_(w.first_),
  169. rest_(w.rest_),
  170. fin_(first_ == Label()),
  171. iter_(rest_.rbegin()) {}
  172. bool Done() const { return fin_; }
  173. const Label &Value() const { return iter_ == rest_.rend() ? first_ : *iter_; }
  174. void Next() {
  175. if (iter_ == rest_.rend()) {
  176. fin_ = true;
  177. } else {
  178. ++iter_;
  179. }
  180. }
  181. void Reset() {
  182. fin_ = false;
  183. iter_ = rest_.rbegin();
  184. }
  185. private:
  186. const Label &first_;
  187. const decltype(Weight::rest_) &rest_;
  188. bool fin_; // In the final state?
  189. typename decltype(Weight::rest_)::const_reverse_iterator iter_;
  190. };
  191. // StringWeight member functions follow that require
  192. // StringWeightIterator or StringWeightReverseIterator.
  193. template <typename Label, StringType S>
  194. inline std::istream &StringWeight<Label, S>::Read(std::istream &strm) {
  195. Clear();
  196. int32_t size;
  197. ReadType(strm, &size);
  198. for (int32_t i = 0; i < size; ++i) {
  199. Label label;
  200. ReadType(strm, &label);
  201. PushBack(label);
  202. }
  203. return strm;
  204. }
  205. template <typename Label, StringType S>
  206. inline std::ostream &StringWeight<Label, S>::Write(std::ostream &strm) const {
  207. const int32_t size = Size();
  208. WriteType(strm, size);
  209. for (Iterator iter(*this); !iter.Done(); iter.Next()) {
  210. WriteType(strm, iter.Value());
  211. }
  212. return strm;
  213. }
  214. template <typename Label, StringType S>
  215. inline bool StringWeight<Label, S>::Member() const {
  216. Iterator iter(*this);
  217. return iter.Value() != Label(kStringBad);
  218. }
  219. template <typename Label, StringType S>
  220. inline typename StringWeight<Label, S>::ReverseWeight
  221. StringWeight<Label, S>::Reverse() const {
  222. ReverseWeight rweight;
  223. for (Iterator iter(*this); !iter.Done(); iter.Next()) {
  224. rweight.PushFront(iter.Value());
  225. }
  226. return rweight;
  227. }
  228. template <typename Label, StringType S>
  229. inline size_t StringWeight<Label, S>::Hash() const {
  230. size_t h = 0;
  231. for (Iterator iter(*this); !iter.Done(); iter.Next()) {
  232. h ^= h << 1 ^ iter.Value();
  233. }
  234. return h;
  235. }
  236. template <typename Label, StringType S>
  237. inline bool operator==(const StringWeight<Label, S> &w1,
  238. const StringWeight<Label, S> &w2) {
  239. if (w1.Size() != w2.Size()) return false;
  240. using Iterator = typename StringWeight<Label, S>::Iterator;
  241. Iterator iter1(w1);
  242. Iterator iter2(w2);
  243. for (; !iter1.Done(); iter1.Next(), iter2.Next()) {
  244. if (iter1.Value() != iter2.Value()) return false;
  245. }
  246. return true;
  247. }
  248. template <typename Label, StringType S>
  249. inline bool operator!=(const StringWeight<Label, S> &w1,
  250. const StringWeight<Label, S> &w2) {
  251. return !(w1 == w2);
  252. }
  253. template <typename Label, StringType S>
  254. inline bool ApproxEqual(const StringWeight<Label, S> &w1,
  255. const StringWeight<Label, S> &w2,
  256. float delta = kDelta) {
  257. return w1 == w2;
  258. }
  259. template <typename Label, StringType S>
  260. inline std::ostream &operator<<(std::ostream &strm,
  261. const StringWeight<Label, S> &weight) {
  262. typename StringWeight<Label, S>::Iterator iter(weight);
  263. if (iter.Done()) {
  264. return strm << "Epsilon";
  265. } else if (iter.Value() == Label(kStringInfinity)) {
  266. return strm << "Infinity";
  267. } else if (iter.Value() == Label(kStringBad)) {
  268. return strm << "BadString";
  269. } else {
  270. for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
  271. if (i > 0) strm << kStringSeparator;
  272. strm << iter.Value();
  273. }
  274. }
  275. return strm;
  276. }
  277. template <typename Label, StringType S>
  278. inline std::istream &operator>>(std::istream &strm,
  279. StringWeight<Label, S> &weight) {
  280. std::string str;
  281. strm >> str;
  282. using Weight = StringWeight<Label, S>;
  283. if (str == "Infinity") {
  284. weight = Weight::Zero();
  285. } else if (str == "Epsilon") {
  286. weight = Weight::One();
  287. } else {
  288. weight.Clear();
  289. for (std::string_view sv : StrSplit(str, kStringSeparator)) {
  290. auto maybe_label = ParseInt64(sv);
  291. if (!maybe_label.has_value()) {
  292. strm.clear(std::ios::badbit);
  293. break;
  294. }
  295. weight.PushBack(*maybe_label);
  296. }
  297. }
  298. return strm;
  299. }
  300. // Default is for the restricted string semiring. String equality is required
  301. // (for non-Zero() input). The restriction is used (e.g., in determinization)
  302. // to ensure the input is functional.
  303. template <typename Label, StringType S>
  304. inline StringWeight<Label, S> Plus(const StringWeight<Label, S> &w1,
  305. const StringWeight<Label, S> &w2) {
  306. using Weight = StringWeight<Label, S>;
  307. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  308. if (w1 == Weight::Zero()) return w2;
  309. if (w2 == Weight::Zero()) return w1;
  310. if (w1 != w2) {
  311. FSTERROR() << "StringWeight::Plus: Unequal arguments "
  312. << "(non-functional FST?)"
  313. << " w1 = " << w1 << " w2 = " << w2;
  314. return Weight::NoWeight();
  315. }
  316. return w1;
  317. }
  318. // Longest common prefix for left string semiring.
  319. template <typename Label>
  320. inline StringWeight<Label, STRING_LEFT> Plus(
  321. const StringWeight<Label, STRING_LEFT> &w1,
  322. const StringWeight<Label, STRING_LEFT> &w2) {
  323. using Weight = StringWeight<Label, STRING_LEFT>;
  324. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  325. if (w1 == Weight::Zero()) return w2;
  326. if (w2 == Weight::Zero()) return w1;
  327. Weight sum;
  328. typename Weight::Iterator iter1(w1);
  329. typename Weight::Iterator iter2(w2);
  330. for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
  331. iter1.Next(), iter2.Next()) {
  332. sum.PushBack(iter1.Value());
  333. }
  334. return sum;
  335. }
  336. // Longest common suffix for right string semiring.
  337. template <typename Label>
  338. inline StringWeight<Label, STRING_RIGHT> Plus(
  339. const StringWeight<Label, STRING_RIGHT> &w1,
  340. const StringWeight<Label, STRING_RIGHT> &w2) {
  341. using Weight = StringWeight<Label, STRING_RIGHT>;
  342. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  343. if (w1 == Weight::Zero()) return w2;
  344. if (w2 == Weight::Zero()) return w1;
  345. Weight sum;
  346. typename Weight::ReverseIterator iter1(w1);
  347. typename Weight::ReverseIterator iter2(w2);
  348. for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
  349. iter1.Next(), iter2.Next()) {
  350. sum.PushFront(iter1.Value());
  351. }
  352. return sum;
  353. }
  354. template <typename Label, StringType S>
  355. inline StringWeight<Label, S> Times(const StringWeight<Label, S> &w1,
  356. const StringWeight<Label, S> &w2) {
  357. using Weight = StringWeight<Label, S>;
  358. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  359. if (w1 == Weight::Zero() || w2 == Weight::Zero()) return Weight::Zero();
  360. Weight product(w1);
  361. for (typename Weight::Iterator iter(w2); !iter.Done(); iter.Next()) {
  362. product.PushBack(iter.Value());
  363. }
  364. return product;
  365. }
  366. // Left division in a left string semiring.
  367. template <typename Label, StringType S>
  368. inline StringWeight<Label, S> DivideLeft(const StringWeight<Label, S> &w1,
  369. const StringWeight<Label, S> &w2) {
  370. using Weight = StringWeight<Label, S>;
  371. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  372. if (w2 == Weight::Zero()) {
  373. return Weight(Label(kStringBad));
  374. } else if (w1 == Weight::Zero()) {
  375. return Weight::Zero();
  376. }
  377. Weight result;
  378. typename Weight::Iterator iter(w1);
  379. size_t i = 0;
  380. for (; !iter.Done() && i < w2.Size(); iter.Next(), ++i) {
  381. }
  382. for (; !iter.Done(); iter.Next()) result.PushBack(iter.Value());
  383. return result;
  384. }
  385. // Right division in a right string semiring.
  386. template <typename Label, StringType S>
  387. inline StringWeight<Label, S> DivideRight(const StringWeight<Label, S> &w1,
  388. const StringWeight<Label, S> &w2) {
  389. using Weight = StringWeight<Label, S>;
  390. if (!w1.Member() || !w2.Member()) return Weight::NoWeight();
  391. if (w2 == Weight::Zero()) {
  392. return Weight(Label(kStringBad));
  393. } else if (w1 == Weight::Zero()) {
  394. return Weight::Zero();
  395. }
  396. Weight result;
  397. typename Weight::ReverseIterator iter(w1);
  398. size_t i = 0;
  399. for (; !iter.Done() && i < w2.Size(); iter.Next(), ++i) {
  400. }
  401. for (; !iter.Done(); iter.Next()) result.PushFront(iter.Value());
  402. return result;
  403. }
  404. // Default is the restricted string semiring.
  405. template <typename Label, StringType S>
  406. inline StringWeight<Label, S> Divide(const StringWeight<Label, S> &w1,
  407. const StringWeight<Label, S> &w2,
  408. DivideType divide_type) {
  409. using Weight = StringWeight<Label, S>;
  410. if (divide_type == DIVIDE_LEFT) {
  411. return DivideLeft(w1, w2);
  412. } else if (divide_type == DIVIDE_RIGHT) {
  413. return DivideRight(w1, w2);
  414. } else {
  415. FSTERROR() << "StringWeight::Divide: "
  416. << "Only explicit left or right division is defined "
  417. << "for the " << Weight::Type() << " semiring";
  418. return Weight::NoWeight();
  419. }
  420. }
  421. // Left division in the left string semiring.
  422. template <typename Label>
  423. inline StringWeight<Label, STRING_LEFT> Divide(
  424. const StringWeight<Label, STRING_LEFT> &w1,
  425. const StringWeight<Label, STRING_LEFT> &w2, DivideType divide_type) {
  426. if (divide_type != DIVIDE_LEFT) {
  427. FSTERROR() << "StringWeight::Divide: Only left division is defined "
  428. << "for the left string semiring";
  429. return StringWeight<Label, STRING_LEFT>::NoWeight();
  430. }
  431. return DivideLeft(w1, w2);
  432. }
  433. // Right division in the right string semiring.
  434. template <typename Label>
  435. inline StringWeight<Label, STRING_RIGHT> Divide(
  436. const StringWeight<Label, STRING_RIGHT> &w1,
  437. const StringWeight<Label, STRING_RIGHT> &w2, DivideType divide_type) {
  438. if (divide_type != DIVIDE_RIGHT) {
  439. FSTERROR() << "StringWeight::Divide: Only right division is defined "
  440. << "for the right string semiring";
  441. return StringWeight<Label, STRING_RIGHT>::NoWeight();
  442. }
  443. return DivideRight(w1, w2);
  444. }
  445. // This function object generates StringWeights that are random integer strings
  446. // from {1, ... , alphabet_size)^{0, max_string_length} U { Zero }. This is
  447. // intended primarily for testing.
  448. template <class Label, StringType S>
  449. class WeightGenerate<StringWeight<Label, S>> {
  450. public:
  451. using Weight = StringWeight<Label, S>;
  452. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  453. bool allow_zero = true,
  454. size_t alphabet_size = kNumRandomWeights,
  455. size_t max_string_length = kNumRandomWeights)
  456. : rand_(seed),
  457. allow_zero_(allow_zero),
  458. alphabet_size_(alphabet_size),
  459. max_string_length_(max_string_length) {}
  460. Weight operator()() const {
  461. const int n = std::uniform_int_distribution<>(
  462. 0, max_string_length_ + allow_zero_)(rand_);
  463. if (allow_zero_ && n == max_string_length_) return Weight::Zero();
  464. std::vector<Label> labels;
  465. labels.reserve(n);
  466. for (int i = 0; i < n; ++i) {
  467. labels.push_back(
  468. std::uniform_int_distribution<>(1, alphabet_size_)(rand_));
  469. }
  470. return Weight(labels.begin(), labels.end());
  471. }
  472. private:
  473. mutable std::mt19937_64 rand_;
  474. const bool allow_zero_;
  475. const size_t alphabet_size_;
  476. const size_t max_string_length_;
  477. };
  478. // Determines whether to use left, right, or (general) gallic semiring. Includes
  479. // a restricted version that signals an error if proper string prefixes or
  480. // suffixes would otherwise be returned by string Plus. This is useful with
  481. // algorithms that require functional transducer input. Also includes min
  482. // version that changes the Plus to keep only the lowest W weight string.
  483. enum GallicType {
  484. GALLIC_LEFT = 0,
  485. GALLIC_RIGHT = 1,
  486. GALLIC_RESTRICT = 2,
  487. GALLIC_MIN = 3,
  488. GALLIC = 4
  489. };
  490. constexpr StringType GallicStringType(GallicType g) {
  491. return g == GALLIC_LEFT
  492. ? STRING_LEFT
  493. : (g == GALLIC_RIGHT ? STRING_RIGHT : STRING_RESTRICT);
  494. }
  495. constexpr GallicType ReverseGallicType(GallicType g) {
  496. return g == GALLIC_LEFT
  497. ? GALLIC_RIGHT
  498. : (g == GALLIC_RIGHT
  499. ? GALLIC_LEFT
  500. : (g == GALLIC_RESTRICT
  501. ? GALLIC_RESTRICT
  502. : (g == GALLIC_MIN ? GALLIC_MIN : GALLIC)));
  503. }
  504. // Product of string weight and an arbitraryy weight.
  505. template <class Label, class W, GallicType G = GALLIC_LEFT>
  506. struct GallicWeight
  507. : public ProductWeight<StringWeight<Label, GallicStringType(G)>, W> {
  508. using ReverseWeight =
  509. GallicWeight<Label, typename W::ReverseWeight, ReverseGallicType(G)>;
  510. using SW = StringWeight<Label, GallicStringType(G)>;
  511. using ProductWeight<SW, W>::Properties;
  512. GallicWeight() = default;
  513. GallicWeight(SW w1, W w2) : ProductWeight<SW, W>(w1, w2) {}
  514. explicit GallicWeight(std::string_view s, int *nread = nullptr)
  515. : ProductWeight<SW, W>(s, nread) {}
  516. explicit GallicWeight(const ProductWeight<SW, W> &w)
  517. : ProductWeight<SW, W>(w) {}
  518. static const GallicWeight &Zero() {
  519. static const GallicWeight zero(ProductWeight<SW, W>::Zero());
  520. return zero;
  521. }
  522. static const GallicWeight &One() {
  523. static const GallicWeight one(ProductWeight<SW, W>::One());
  524. return one;
  525. }
  526. static const GallicWeight &NoWeight() {
  527. static const GallicWeight no_weight(ProductWeight<SW, W>::NoWeight());
  528. return no_weight;
  529. }
  530. static const std::string &Type() {
  531. static const std::string *const type = new std::string(
  532. G == GALLIC_LEFT
  533. ? "left_gallic"
  534. : (G == GALLIC_RIGHT
  535. ? "right_gallic"
  536. : (G == GALLIC_RESTRICT
  537. ? "restricted_gallic"
  538. : (G == GALLIC_MIN ? "min_gallic" : "gallic"))));
  539. return *type;
  540. }
  541. GallicWeight Quantize(float delta = kDelta) const {
  542. return GallicWeight(ProductWeight<SW, W>::Quantize(delta));
  543. }
  544. ReverseWeight Reverse() const {
  545. return ReverseWeight(ProductWeight<SW, W>::Reverse());
  546. }
  547. };
  548. // Default plus.
  549. template <class Label, class W, GallicType G>
  550. inline GallicWeight<Label, W, G> Plus(const GallicWeight<Label, W, G> &w,
  551. const GallicWeight<Label, W, G> &v) {
  552. return GallicWeight<Label, W, G>(Plus(w.Value1(), v.Value1()),
  553. Plus(w.Value2(), v.Value2()));
  554. }
  555. // Min gallic plus.
  556. template <class Label, class W>
  557. inline GallicWeight<Label, W, GALLIC_MIN> Plus(
  558. const GallicWeight<Label, W, GALLIC_MIN> &w1,
  559. const GallicWeight<Label, W, GALLIC_MIN> &w2) {
  560. static const NaturalLess<W> less;
  561. return less(w1.Value2(), w2.Value2()) ? w1 : w2;
  562. }
  563. template <class Label, class W, GallicType G>
  564. inline GallicWeight<Label, W, G> Times(const GallicWeight<Label, W, G> &w,
  565. const GallicWeight<Label, W, G> &v) {
  566. return GallicWeight<Label, W, G>(Times(w.Value1(), v.Value1()),
  567. Times(w.Value2(), v.Value2()));
  568. }
  569. template <class Label, class W, GallicType G>
  570. inline GallicWeight<Label, W, G> Divide(const GallicWeight<Label, W, G> &w,
  571. const GallicWeight<Label, W, G> &v,
  572. DivideType divide_type = DIVIDE_ANY) {
  573. return GallicWeight<Label, W, G>(Divide(w.Value1(), v.Value1(), divide_type),
  574. Divide(w.Value2(), v.Value2(), divide_type));
  575. }
  576. // This function object generates gallic weights by calling an underlying
  577. // product weight generator. This is intended primarily for testing.
  578. template <class Label, class W, GallicType G>
  579. class WeightGenerate<GallicWeight<Label, W, G>>
  580. : public WeightGenerate<
  581. ProductWeight<StringWeight<Label, GallicStringType(G)>, W>> {
  582. public:
  583. using Weight = GallicWeight<Label, W, G>;
  584. using Generate = WeightGenerate<
  585. ProductWeight<StringWeight<Label, GallicStringType(G)>, W>>;
  586. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  587. bool allow_zero = true)
  588. : generate_(seed, allow_zero) {}
  589. Weight operator()() const { return Weight(generate_()); }
  590. private:
  591. const Generate generate_;
  592. };
  593. // Union weight options for (general) GALLIC type.
  594. template <class Label, class W>
  595. struct GallicUnionWeightOptions {
  596. using ReverseOptions = GallicUnionWeightOptions<Label, W>;
  597. using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
  598. using SW = StringWeight<Label, GallicStringType(GALLIC_RESTRICT)>;
  599. using SI = StringWeightIterator<SW>;
  600. // Military order.
  601. struct Compare {
  602. bool operator()(const GW &w1, const GW &w2) const {
  603. const SW &s1 = w1.Value1();
  604. const SW &s2 = w2.Value1();
  605. if (s1.Size() < s2.Size()) return true;
  606. if (s1.Size() > s2.Size()) return false;
  607. SI iter1(s1);
  608. SI iter2(s2);
  609. while (!iter1.Done()) {
  610. const auto l1 = iter1.Value();
  611. const auto l2 = iter2.Value();
  612. if (l1 < l2) return true;
  613. if (l1 > l2) return false;
  614. iter1.Next();
  615. iter2.Next();
  616. }
  617. return false;
  618. }
  619. };
  620. // Adds W weights when string part equal.
  621. struct Merge {
  622. GW operator()(const GW &w1, const GW &w2) const {
  623. return GW(w1.Value1(), Plus(w1.Value2(), w2.Value2()));
  624. }
  625. };
  626. };
  627. // Specialization for the (general) GALLIC type.
  628. template <class Label, class W>
  629. struct GallicWeight<Label, W, GALLIC>
  630. : public UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>,
  631. GallicUnionWeightOptions<Label, W>> {
  632. using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
  633. using SW = StringWeight<Label, GallicStringType(GALLIC_RESTRICT)>;
  634. using SI = StringWeightIterator<SW>;
  635. using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
  636. using UI = UnionWeightIterator<GW, GallicUnionWeightOptions<Label, W>>;
  637. using ReverseWeight = GallicWeight<Label, W, GALLIC>;
  638. using UW::Properties;
  639. GallicWeight() = default;
  640. // Copy constructor.
  641. // NOLINTNEXTLINE(google-explicit-constructor)
  642. GallicWeight(const UW &weight) : UW(weight) {}
  643. // Singleton constructors: create a GALLIC weight containing a single
  644. // GALLIC_RESTRICT weight. Takes as argument (1) a GALLIC_RESTRICT weight or
  645. // (2) the two components of a GALLIC_RESTRICT weight.
  646. explicit GallicWeight(const GW &weight) : UW(weight) {}
  647. GallicWeight(SW w1, W w2) : UW(GW(w1, w2)) {}
  648. explicit GallicWeight(std::string_view str, int *nread = nullptr)
  649. : UW(str, nread) {}
  650. static const GallicWeight<Label, W, GALLIC> &Zero() {
  651. static const GallicWeight<Label, W, GALLIC> zero(UW::Zero());
  652. return zero;
  653. }
  654. static const GallicWeight<Label, W, GALLIC> &One() {
  655. static const GallicWeight<Label, W, GALLIC> one(UW::One());
  656. return one;
  657. }
  658. static const GallicWeight<Label, W, GALLIC> &NoWeight() {
  659. static const GallicWeight<Label, W, GALLIC> no_weight(UW::NoWeight());
  660. return no_weight;
  661. }
  662. static const std::string &Type() {
  663. static const std::string *const type = new std::string("gallic");
  664. return *type;
  665. }
  666. GallicWeight<Label, W, GALLIC> Quantize(float delta = kDelta) const {
  667. return UW::Quantize(delta);
  668. }
  669. ReverseWeight Reverse() const { return UW::Reverse(); }
  670. };
  671. // (General) gallic plus.
  672. template <class Label, class W>
  673. inline GallicWeight<Label, W, GALLIC> Plus(
  674. const GallicWeight<Label, W, GALLIC> &w1,
  675. const GallicWeight<Label, W, GALLIC> &w2) {
  676. using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
  677. using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
  678. return Plus(static_cast<UW>(w1), static_cast<UW>(w2));
  679. }
  680. // (General) gallic times.
  681. template <class Label, class W>
  682. inline GallicWeight<Label, W, GALLIC> Times(
  683. const GallicWeight<Label, W, GALLIC> &w1,
  684. const GallicWeight<Label, W, GALLIC> &w2) {
  685. using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
  686. using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
  687. return Times(static_cast<UW>(w1), static_cast<UW>(w2));
  688. }
  689. // (General) gallic divide.
  690. template <class Label, class W>
  691. inline GallicWeight<Label, W, GALLIC> Divide(
  692. const GallicWeight<Label, W, GALLIC> &w1,
  693. const GallicWeight<Label, W, GALLIC> &w2,
  694. DivideType divide_type = DIVIDE_ANY) {
  695. using GW = GallicWeight<Label, W, GALLIC_RESTRICT>;
  696. using UW = UnionWeight<GW, GallicUnionWeightOptions<Label, W>>;
  697. return Divide(static_cast<UW>(w1), static_cast<UW>(w2), divide_type);
  698. }
  699. // This function object generates gallic weights by calling an underlying
  700. // union weight generator. This is intended primarily for testing.
  701. template <class Label, class W>
  702. class WeightGenerate<GallicWeight<Label, W, GALLIC>>
  703. : public WeightGenerate<UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>,
  704. GallicUnionWeightOptions<Label, W>>> {
  705. public:
  706. using Weight = GallicWeight<Label, W, GALLIC>;
  707. using Generate =
  708. WeightGenerate<UnionWeight<GallicWeight<Label, W, GALLIC_RESTRICT>,
  709. GallicUnionWeightOptions<Label, W>>>;
  710. explicit WeightGenerate(uint64_t seed = std::random_device()(),
  711. bool allow_zero = true)
  712. : generate_(seed, allow_zero) {}
  713. Weight operator()() const { return Weight(generate_()); }
  714. private:
  715. const Generate generate_;
  716. };
  717. } // namespace fst
  718. #endif // FST_STRING_WEIGHT_H_