You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

409 lines
13 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // General weight set and associated semiring operation definitions.
  19. #ifndef FST_WEIGHT_H_
  20. #define FST_WEIGHT_H_
  21. #include <cctype>
  22. #include <cmath>
  23. #include <cstddef>
  24. #include <cstdint>
  25. #include <ios>
  26. #include <iostream>
  27. #include <istream>
  28. #include <ostream>
  29. #include <sstream>
  30. #include <string>
  31. #include <type_traits>
  32. #include <utility>
  33. #include <fst/compat.h>
  34. #include <fst/log.h>
  35. #include <fst/util.h>
  36. DECLARE_string(fst_weight_parentheses);
  37. DECLARE_string(fst_weight_separator);
  38. namespace fst {
  39. // A semiring is specified by two binary operations Plus and Times and two
  40. // designated elements Zero and One with the following properties:
  41. //
  42. // Plus: associative, commutative, and has Zero as its identity.
  43. //
  44. // Times: associative and has identity One, distributes w.r.t. Plus, and
  45. // has Zero as an annihilator:
  46. // Times(Zero(), a) == Times(a, Zero()) = Zero().
  47. //
  48. // A left semiring distributes on the left; a right semiring is similarly
  49. // defined.
  50. //
  51. // A Weight class must have binary functions Plus and Times and static member
  52. // functions Zero() and One() and these must form (at least) a left or right
  53. // semiring.
  54. //
  55. // In addition, the following should be defined for a Weight:
  56. //
  57. // Member: predicate on set membership.
  58. //
  59. // NoWeight: static member function that returns an element that is
  60. // not a set member; used to signal an error.
  61. //
  62. // >>: reads textual representation of a weight.
  63. //
  64. // <<: prints textual representation of a weight.
  65. //
  66. // Read(istream &istrm): reads binary representation of a weight.
  67. //
  68. // Write(ostream &ostrm): writes binary representation of a weight.
  69. //
  70. // Hash: maps weight to size_t.
  71. //
  72. // ApproxEqual: approximate equality (for inexact weights)
  73. //
  74. // Quantize: quantizes w.r.t delta (for inexact weights)
  75. //
  76. // Divide:
  77. // - In a left semiring, for all a, b, b', c:
  78. // if Times(a, b) = c, Divide(c, a, DIVIDE_LEFT) = b' and b'.Member(),
  79. // then Times(a, b') = c.
  80. // - In a right semiring, for all a, a', b, c:
  81. // if Times(a, b) = c, Divide(c, b, DIVIDE_RIGHT) = a' and a'.Member(),
  82. // then Times(a', b) = c.
  83. // - In a commutative semiring,
  84. // * for all a, c:
  85. // Divide(c, a, DIVIDE_ANY) = Divide(c, a, DIVIDE_LEFT)
  86. // = Divide(c, a, DIVIDE_RIGHT)
  87. // * for all a, b, b', c:
  88. // if Times(a, b) = c, Divide(c, a, DIVIDE_ANY) = b' and b'.Member(),
  89. // then Times(a, b') = c
  90. // - In the case where there exist no b such that c = Times(a, b), the
  91. // return value of Divide(c, a, DIVIDE_LEFT) is unspecified. Returning
  92. // Weight::NoWeight() is recommemded but not required in order to
  93. // allow the most efficient implementation.
  94. // - All algorithms in this library only call Divide(c, a) when it is
  95. // guaranteed that there exists a b such that c = Times(a, b).
  96. //
  97. // ReverseWeight: the type of the corresponding reverse weight.
  98. //
  99. // Typically the same type as Weight for a (both left and right) semiring.
  100. // For the left string semiring, it is the right string semiring.
  101. //
  102. // Reverse: a mapping from Weight to ReverseWeight s.t.
  103. //
  104. // --> Reverse(Reverse(a)) = a
  105. // --> Reverse(Plus(a, b)) = Plus(Reverse(a), Reverse(b))
  106. // --> Reverse(Times(a, b)) = Times(Reverse(b), Reverse(a))
  107. // Typically the identity mapping in a (both left and right) semiring.
  108. // In the left string semiring, it maps to the reverse string in the right
  109. // string semiring.
  110. //
  111. // Properties: specifies additional properties that hold:
  112. // LeftSemiring: indicates weights form a left semiring.
  113. // RightSemiring: indicates weights form a right semiring.
  114. // Commutative: for all a, b: Times(a,b) == Times(b, a)
  115. // Idempotent: for all a: Plus(a, a) == a.
  116. // Path: for all a, b: Plus(a, b) == a or Plus(a, b) == b.
  117. //
  118. // User-defined weights and their corresponding operations SHOULD be
  119. // defined in the same namespace, but SHOULD NOT defined in the fst
  120. // namespace. Defining them in fst would make the user code fragile
  121. // to additions in fst. They will be found in another namespace
  122. // via argument-dependent lookup.
  123. // CONSTANT DEFINITIONS
  124. // A representable float near .001.
  125. inline constexpr float kDelta = 1.0F / 1024.0F;
  126. // For all a, b, c: Times(c, Plus(a, b)) = Plus(Times(c, a), Times(c, b)).
  127. inline constexpr uint64_t kLeftSemiring = 0x0000000000000001ULL;
  128. // For all a, b, c: Times(Plus(a, b), c) = Plus(Times(a, c), Times(b, c)).
  129. inline constexpr uint64_t kRightSemiring = 0x0000000000000002ULL;
  130. inline constexpr uint64_t kSemiring = kLeftSemiring | kRightSemiring;
  131. // For all a, b: Times(a, b) = Times(b, a).
  132. inline constexpr uint64_t kCommutative = 0x0000000000000004ULL;
  133. // For all a: Plus(a, a) = a.
  134. inline constexpr uint64_t kIdempotent = 0x0000000000000008ULL;
  135. // For all a, b: Plus(a, b) = a or Plus(a, b) = b.
  136. inline constexpr uint64_t kPath = 0x0000000000000010ULL;
  137. // For random weight generation: default number of distinct weights.
  138. // This is also used for a few other weight generation defaults.
  139. inline constexpr size_t kNumRandomWeights = 5;
  140. // Weight property boolean constants needed for SFINAE.
  141. template <class W>
  142. using IsIdempotent = std::bool_constant<(W::Properties() & kIdempotent) != 0>;
  143. template <class W>
  144. using IsPath = std::bool_constant<(W::Properties() & kPath) != 0>;
  145. // Determines direction of division.
  146. enum DivideType {
  147. DIVIDE_LEFT, // left division
  148. DIVIDE_RIGHT, // right division
  149. DIVIDE_ANY
  150. }; // division in a commutative semiring
  151. // NATURAL ORDER
  152. //
  153. // By definition:
  154. //
  155. // a <= b iff a + b = a
  156. //
  157. // The natural order is a negative partial order iff the semiring is
  158. // idempotent. It is trivially monotonic for plus. It is left
  159. // (resp. right) monotonic for times iff the semiring is left
  160. // (resp. right) distributive. It is a total order iff the semiring
  161. // has the path property.
  162. //
  163. // For more information, see:
  164. //
  165. // Mohri, M. 2002. Semiring framework and algorithms for shortest-distance
  166. // problems, Journal of Automata, Languages and
  167. // Combinatorics 7(3): 321-350, 2002.
  168. //
  169. // We define the strict version of this order below.
  170. // Requires W is idempotent.
  171. template <class W>
  172. struct NaturalLess {
  173. using Weight = W;
  174. static_assert(IsIdempotent<W>::value, "W must be idempotent.");
  175. bool operator()(const Weight &w1, const Weight &w2) const {
  176. return w1 != w2 && Plus(w1, w2) == w1;
  177. }
  178. };
  179. // Power is the iterated product for arbitrary semirings such that Power(w, 0)
  180. // is One() for the semiring, and Power(w, n) = Times(Power(w, n - 1), w).
  181. template <class Weight>
  182. Weight Power(const Weight &weight, size_t n) {
  183. auto result = Weight::One();
  184. for (size_t i = 0; i < n; ++i) result = Times(result, weight);
  185. return result;
  186. }
  187. // Simple default adder class. Specializations might be more complex.
  188. template <class Weight>
  189. class Adder {
  190. public:
  191. Adder() : sum_(Weight::Zero()) {}
  192. explicit Adder(Weight w) : sum_(std::move(w)) {}
  193. Weight Add(const Weight &w) {
  194. sum_ = Plus(sum_, w);
  195. return sum_;
  196. }
  197. Weight Sum() const { return sum_; }
  198. void Reset(Weight w = Weight::Zero()) { sum_ = std::move(w); }
  199. private:
  200. Weight sum_;
  201. };
  202. // General weight converter: raises error.
  203. template <class W1, class W2>
  204. struct WeightConvert {
  205. W2 operator()(W1 w1) const {
  206. FSTERROR() << "WeightConvert: Can't convert weight from " << W1::Type()
  207. << " to " << W2::Type();
  208. return W2::NoWeight();
  209. }
  210. };
  211. // Specialized weight converter to self.
  212. template <class W>
  213. struct WeightConvert<W, W> {
  214. constexpr W operator()(W weight) const { return weight; }
  215. };
  216. // General random weight generator: raises error.
  217. //
  218. // The standard interface is roughly:
  219. //
  220. // class WeightGenerate<MyWeight> {
  221. // public:
  222. // explicit WeightGenerate(uint64_t seed = std::random_device()(),
  223. // bool allow_zero = true,
  224. // ...);
  225. //
  226. // MyWeight operator()() const;
  227. // };
  228. //
  229. // Many weight generators also take trailing constructor arguments specifying
  230. // the number of random (unique) weights, the length of weights (e.g., for
  231. // string-based weights), etc. with sensible defaults
  232. template <class W>
  233. struct WeightGenerate {
  234. W operator()() const {
  235. FSTERROR() << "WeightGenerate: No random generator for " << W::Type();
  236. return W::NoWeight();
  237. }
  238. };
  239. namespace internal {
  240. class CompositeWeightIO {
  241. public:
  242. CompositeWeightIO();
  243. CompositeWeightIO(char separator, std::pair<char, char> parentheses);
  244. std::pair<char, char> parentheses() const {
  245. return {open_paren_, close_paren_};
  246. }
  247. char separator() const { return separator_; }
  248. bool error() const { return error_; }
  249. protected:
  250. const char separator_;
  251. const char open_paren_;
  252. const char close_paren_;
  253. private:
  254. bool error_;
  255. };
  256. } // namespace internal
  257. // Helper class for writing textual composite weights.
  258. class CompositeWeightWriter : public internal::CompositeWeightIO {
  259. public:
  260. // Uses configuration from flags (FST_FLAGS_fst_weight_separator,
  261. // FST_FLAGS_fst_weight_parentheses).
  262. explicit CompositeWeightWriter(std::ostream &ostrm);
  263. // parentheses defines the opening and closing parenthesis characters.
  264. // Set parentheses = {0, 0} to disable writing parenthesis.
  265. CompositeWeightWriter(std::ostream &ostrm, char separator,
  266. std::pair<char, char> parentheses);
  267. CompositeWeightWriter(const CompositeWeightWriter &) = delete;
  268. CompositeWeightWriter &operator=(const CompositeWeightWriter &) = delete;
  269. // Writes open parenthesis to a stream if option selected.
  270. void WriteBegin();
  271. // Writes element to a stream.
  272. template <class T>
  273. void WriteElement(const T &comp) {
  274. if (i_++ > 0) ostrm_ << separator_;
  275. ostrm_ << comp;
  276. }
  277. // Writes close parenthesis to a stream if option selected.
  278. void WriteEnd();
  279. private:
  280. std::ostream &ostrm_;
  281. int i_ = 0; // Element position.
  282. };
  283. // Helper class for reading textual composite weights. Elements are separated by
  284. // a separator character. There must be at least one element per textual
  285. // representation. Parentheses characters should be set if the composite
  286. // weights themselves contain composite weights to ensure proper parsing.
  287. class CompositeWeightReader : public internal::CompositeWeightIO {
  288. public:
  289. // Uses configuration from flags (FST_FLAGS_fst_weight_separator,
  290. // FST_FLAGS_fst_weight_parentheses).
  291. explicit CompositeWeightReader(std::istream &istrm);
  292. // parentheses defines the opening and closing parenthesis characters.
  293. // Set parentheses = {0, 0} to disable reading parenthesis.
  294. CompositeWeightReader(std::istream &istrm, char separator,
  295. std::pair<char, char> parentheses);
  296. CompositeWeightReader(const CompositeWeightReader &) = delete;
  297. CompositeWeightReader &operator=(const CompositeWeightReader &) = delete;
  298. // Reads open parenthesis from a stream if option selected.
  299. void ReadBegin();
  300. // Reads element from a stream. The second argument, when true, indicates that
  301. // this will be the last element (allowing more forgiving formatting of the
  302. // last element). Returns false when last element is read.
  303. template <class T>
  304. bool ReadElement(T *comp, bool last = false);
  305. // Finalizes reading.
  306. void ReadEnd();
  307. private:
  308. std::istream &istrm_; // Input stream.
  309. int c_ = 0; // Last character read, or EOF.
  310. int depth_ = 0; // Weight parentheses depth.
  311. };
  312. template <class T>
  313. inline bool CompositeWeightReader::ReadElement(T *comp, bool last) {
  314. std::string s;
  315. const bool has_parens = open_paren_ != 0;
  316. while ((c_ != std::istream::traits_type::eof()) && !std::isspace(c_) &&
  317. (c_ != separator_ || depth_ > 1 || last) &&
  318. (c_ != close_paren_ || depth_ != 1)) {
  319. s += c_;
  320. // If parentheses encountered before separator, they must be matched.
  321. if (has_parens && c_ == open_paren_) {
  322. ++depth_;
  323. } else if (has_parens && c_ == close_paren_) {
  324. // Failure on unmatched parentheses.
  325. if (depth_ == 0) {
  326. FSTERROR() << "CompositeWeightReader: Unmatched close paren: "
  327. << "Is the fst_weight_parentheses flag set correctly?";
  328. istrm_.clear(std::ios::badbit);
  329. return false;
  330. }
  331. --depth_;
  332. }
  333. c_ = istrm_.get();
  334. }
  335. if (s.empty()) {
  336. FSTERROR() << "CompositeWeightReader: Empty element: "
  337. << "Is the fst_weight_parentheses flag set correctly?";
  338. istrm_.clear(std::ios::badbit);
  339. return false;
  340. }
  341. std::istringstream istrm(s);
  342. istrm >> *comp;
  343. // Skips separator/close parenthesis.
  344. if (c_ != std::istream::traits_type::eof() && !std::isspace(c_)) {
  345. c_ = istrm_.get();
  346. }
  347. const bool is_eof = c_ == std::istream::traits_type::eof();
  348. // Clears fail bit if just EOF.
  349. if (is_eof && !istrm_.bad()) istrm_.clear(std::ios::eofbit);
  350. return !is_eof && !std::isspace(c_);
  351. }
  352. } // namespace fst
  353. #endif // FST_WEIGHT_H_