You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

516 lines
15 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // FST utility inline definitions.
  19. #ifndef FST_UTIL_H_
  20. #define FST_UTIL_H_
  21. #include <array>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <ios>
  25. #include <iostream>
  26. #include <istream>
  27. #include <iterator>
  28. #include <list>
  29. #include <map>
  30. #include <optional>
  31. #include <ostream>
  32. #include <set>
  33. #include <sstream>
  34. #include <string>
  35. #include <type_traits>
  36. #include <unordered_map>
  37. #include <unordered_set>
  38. #include <utility>
  39. #include <vector>
  40. #include <fst/compat.h>
  41. #include <fst/flags.h>
  42. #include <fst/log.h>
  43. #include <fstream>
  44. #include <fst/mapped-file.h>
  45. #include <unordered_map>
  46. #include <string_view>
  47. #include <optional>
  48. // Utility for error handling.
  49. DECLARE_bool(fst_error_fatal);
  50. #define FSTERROR() \
  51. (FST_FLAGS_fst_error_fatal ? LOG(FATAL) : LOG(ERROR))
  52. namespace fst {
  53. // Utility for type I/O. For portability of serialized objects across
  54. // architectures, care must be taken so that only fixed-size types (like
  55. // `int32_t`) are used with `WriteType`/`ReadType`, not types that may differ in
  56. // size depending on the architecture, such as `int`. For `enum` types, a
  57. // fixed-size base (like `enum E : int32_t`) should be used. Objects are
  58. // written and read in the host byte order, so will not be portable across
  59. // different endiannesses.
  60. namespace internal {
  61. // Whether the scalar type is supported by `ReadType`/`WriteType`.
  62. template <class T>
  63. inline constexpr bool IsScalarIOTypeV =
  64. std::is_arithmetic_v<T> || std::is_enum_v<T>;
  65. } // namespace internal
  66. // Reads types from an input stream.
  67. // Generic case.
  68. template <class T, typename std::enable_if_t<std::is_class_v<T>, T> * = nullptr>
  69. inline std::istream &ReadType(std::istream &strm, T *t) {
  70. return t->Read(strm);
  71. }
  72. // Numeric (boolean, integral, floating-point) or enum case.
  73. template <class T, typename std::enable_if_t<internal::IsScalarIOTypeV<T>, T>
  74. * = nullptr>
  75. inline std::istream &ReadType(std::istream &strm, T *t) {
  76. return strm.read(reinterpret_cast<char *>(t), sizeof(T));
  77. }
  78. // Numeric (boolean, integral, floating-point) or enum case only.
  79. template <class T>
  80. inline std::istream &ReadType(std::istream &strm, size_t n, T *t) {
  81. static_assert(internal::IsScalarIOTypeV<T>,
  82. "Type not supported for batch read.");
  83. return strm.read(reinterpret_cast<char *>(t), sizeof(T) * n);
  84. }
  85. // String case.
  86. inline std::istream &ReadType(std::istream &strm, std::string *s) {
  87. s->clear();
  88. int32_t ns = 0;
  89. ReadType(strm, &ns);
  90. if (ns <= 0) return strm;
  91. s->resize(ns);
  92. ReadType(strm, ns, s->data());
  93. return strm;
  94. }
  95. // Declares types that can be read from an input stream.
  96. template <class... T>
  97. std::istream &ReadType(std::istream &strm, std::vector<T...> *c);
  98. template <class... T>
  99. std::istream &ReadType(std::istream &strm, std::list<T...> *c);
  100. template <class... T>
  101. std::istream &ReadType(std::istream &strm, std::set<T...> *c);
  102. template <class... T>
  103. std::istream &ReadType(std::istream &strm, std::map<T...> *c);
  104. template <class... T>
  105. std::istream &ReadType(std::istream &strm, std::unordered_map<T...> *c);
  106. template <class... T>
  107. std::istream &ReadType(std::istream &strm, std::unordered_set<T...> *c);
  108. // Pair case.
  109. template <typename S, typename T>
  110. inline std::istream &ReadType(std::istream &strm, std::pair<S, T> *p) {
  111. ReadType(strm, &p->first);
  112. ReadType(strm, &p->second);
  113. return strm;
  114. }
  115. template <typename S, typename T>
  116. inline std::istream &ReadType(std::istream &strm, std::pair<const S, T> *p) {
  117. ReadType(strm, const_cast<S *>(&p->first));
  118. ReadType(strm, &p->second);
  119. return strm;
  120. }
  121. namespace internal {
  122. template <class C, class ReserveFn>
  123. std::istream &ReadContainerType(std::istream &strm, C *c, ReserveFn reserve) {
  124. c->clear();
  125. int64_t n = 0;
  126. ReadType(strm, &n);
  127. reserve(c, n);
  128. auto insert = std::inserter(*c, c->begin());
  129. for (int64_t i = 0; i < n; ++i) {
  130. typename C::value_type value;
  131. ReadType(strm, &value);
  132. *insert = value;
  133. }
  134. return strm;
  135. }
  136. // Generic vector case.
  137. template <typename T, class A,
  138. typename std::enable_if_t<std::is_class_v<T>, T> * = nullptr>
  139. inline std::istream &ReadVectorType(std::istream &strm, std::vector<T, A> *c) {
  140. return internal::ReadContainerType(
  141. strm, c, [](decltype(c) v, int n) { v->reserve(n); });
  142. }
  143. // Vector of numerics (boolean, integral, floating-point, char) or enum case.
  144. template <
  145. typename T, class A,
  146. typename std::enable_if_t<internal::IsScalarIOTypeV<T>, T> * = nullptr>
  147. inline std::istream &ReadVectorType(std::istream &strm, std::vector<T, A> *c) {
  148. c->clear();
  149. int64_t n = 0;
  150. ReadType(strm, &n);
  151. if (n == 0) return strm;
  152. c->resize(n);
  153. ReadType(strm, n, c->data());
  154. return strm;
  155. }
  156. } // namespace internal
  157. template <class T, size_t N>
  158. std::istream &ReadType(std::istream &strm, std::array<T, N> *c) {
  159. if constexpr (internal::IsScalarIOTypeV<T>) {
  160. ReadType(strm, c->size(), c->data());
  161. } else {
  162. for (auto &v : *c) ReadType(strm, &v);
  163. }
  164. return strm;
  165. }
  166. template <class... T>
  167. std::istream &ReadType(std::istream &strm, std::vector<T...> *c) {
  168. return internal::ReadVectorType(strm, c);
  169. }
  170. template <class... T>
  171. std::istream &ReadType(std::istream &strm, std::list<T...> *c) {
  172. return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {});
  173. }
  174. template <class... T>
  175. std::istream &ReadType(std::istream &strm, std::set<T...> *c) {
  176. return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {});
  177. }
  178. template <class... T>
  179. std::istream &ReadType(std::istream &strm, std::map<T...> *c) {
  180. return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {});
  181. }
  182. template <class... T>
  183. std::istream &ReadType(std::istream &strm, std::unordered_set<T...> *c) {
  184. return internal::ReadContainerType(
  185. strm, c, [](decltype(c) v, int n) { v->reserve(n); });
  186. }
  187. template <class... T>
  188. std::istream &ReadType(std::istream &strm, std::unordered_map<T...> *c) {
  189. return internal::ReadContainerType(
  190. strm, c, [](decltype(c) v, int n) { v->reserve(n); });
  191. }
  192. // Writes types to an output stream.
  193. // Generic case.
  194. template <class T, typename std::enable_if<
  195. std::is_class<T>::value &&
  196. // `string_view` is handled separately below.
  197. !std::is_convertible<T, std::string_view>::value,
  198. T>::type * = nullptr>
  199. inline std::ostream &WriteType(std::ostream &strm, const T t) {
  200. t.Write(strm);
  201. return strm;
  202. }
  203. // Numeric (boolean, integral, floating-point) or enum case.
  204. template <class T, typename std::enable_if_t<internal::IsScalarIOTypeV<T>, T>
  205. * = nullptr>
  206. inline std::ostream &WriteType(std::ostream &strm, const T t) {
  207. return strm.write(reinterpret_cast<const char *>(&t), sizeof(T));
  208. }
  209. // Numeric (boolean, integral, floating-point) or enum case only.
  210. template <class T>
  211. inline std::ostream &WriteType(std::ostream &strm, size_t n, const T *t) {
  212. static_assert(internal::IsScalarIOTypeV<T>,
  213. "Type not supported for batch write.");
  214. return strm.write(reinterpret_cast<const char *>(t), sizeof(T) * n);
  215. }
  216. inline std::ostream &WriteType(std::ostream &strm, std::string_view s) {
  217. int32_t ns = s.size();
  218. WriteType(strm, ns);
  219. return strm.write(s.data(), ns);
  220. }
  221. // Declares types that can be written to an output stream.
  222. template <typename... T>
  223. std::ostream &WriteType(std::ostream &strm, const std::vector<T...> &c);
  224. template <typename... T>
  225. std::ostream &WriteType(std::ostream &strm, const std::list<T...> &c);
  226. template <typename... T>
  227. std::ostream &WriteType(std::ostream &strm, const std::set<T...> &c);
  228. template <typename... T>
  229. std::ostream &WriteType(std::ostream &strm, const std::map<T...> &c);
  230. template <typename... T>
  231. std::ostream &WriteType(std::ostream &strm, const std::unordered_map<T...> &c);
  232. template <typename... T>
  233. std::ostream &WriteType(std::ostream &strm, const std::unordered_set<T...> &c);
  234. // Pair case.
  235. template <typename S, typename T>
  236. inline std::ostream &WriteType(std::ostream &strm, const std::pair<S, T> &p) {
  237. WriteType(strm, p.first);
  238. WriteType(strm, p.second);
  239. return strm;
  240. }
  241. namespace internal {
  242. template <class C>
  243. std::ostream &WriteSequence(std::ostream &strm, const C &c) {
  244. for (const auto &e : c) {
  245. WriteType(strm, e);
  246. }
  247. return strm;
  248. }
  249. template <class C>
  250. std::ostream &WriteContainer(std::ostream &strm, const C &c) {
  251. const int64_t n = c.size();
  252. WriteType(strm, n);
  253. WriteSequence(strm, c);
  254. return strm;
  255. }
  256. } // namespace internal
  257. template <class T, size_t N>
  258. std::ostream &WriteType(std::ostream &strm, const std::array<T, N> &c) {
  259. return internal::WriteSequence(strm, c);
  260. }
  261. template <typename... T>
  262. std::ostream &WriteType(std::ostream &strm, const std::vector<T...> &c) {
  263. return internal::WriteContainer(strm, c);
  264. }
  265. template <typename... T>
  266. std::ostream &WriteType(std::ostream &strm, const std::list<T...> &c) {
  267. return internal::WriteContainer(strm, c);
  268. }
  269. template <typename... T>
  270. std::ostream &WriteType(std::ostream &strm, const std::set<T...> &c) {
  271. return internal::WriteContainer(strm, c);
  272. }
  273. template <typename... T>
  274. std::ostream &WriteType(std::ostream &strm, const std::map<T...> &c) {
  275. return internal::WriteContainer(strm, c);
  276. }
  277. template <typename... T>
  278. std::ostream &WriteType(std::ostream &strm, const std::unordered_map<T...> &c) {
  279. return internal::WriteContainer(strm, c);
  280. }
  281. template <typename... T>
  282. std::ostream &WriteType(std::ostream &strm, const std::unordered_set<T...> &c) {
  283. return internal::WriteContainer(strm, c);
  284. }
  285. // Utilities for converting between int64_t or Weight and string.
  286. // Parses a 64-bit signed integer in some base out of an input string. The
  287. // string should consist only of digits (no prefixes such as "0x") and an
  288. // optionally preceding minus. Returns a value iff the entirety of the string is
  289. // consumed during integer parsing, otherwise returns `std::nullopt`.
  290. std::optional<int64_t> ParseInt64(std::string_view s, int base = 10);
  291. int64_t StrToInt64(std::string_view s, std::string_view source, size_t nline,
  292. bool * error = nullptr);
  293. template <typename Weight>
  294. Weight StrToWeight(std::string_view s) {
  295. Weight w;
  296. std::istringstream strm(std::string{s});
  297. strm >> w;
  298. if (!strm) {
  299. FSTERROR() << "StrToWeight: Bad weight: " << s;
  300. return Weight::NoWeight();
  301. }
  302. return w;
  303. }
  304. template <typename Weight>
  305. std::string WeightToStr(Weight w) {
  306. std::ostringstream strm;
  307. strm.precision(9);
  308. strm << w;
  309. return strm.str();
  310. }
  311. // Utilities for reading/writing integer pairs (typically labels).
  312. template <typename I>
  313. bool ReadIntPairs(std::string_view source,
  314. std::vector<std::pair<I, I>> *pairs) {
  315. std::ifstream strm(std::string(source), std::ios_base::in);
  316. if (!strm) {
  317. LOG(ERROR) << "ReadIntPairs: Can't open file: " << source;
  318. return false;
  319. }
  320. const int kLineLen = 8096;
  321. char line[kLineLen];
  322. size_t nline = 0;
  323. pairs->clear();
  324. while (strm.getline(line, kLineLen)) {
  325. ++nline;
  326. std::vector<std::string_view> col =
  327. StrSplit(line, ByAnyChar("\n\t "), SkipEmpty());
  328. // empty line or comment?
  329. if (col.empty() || col[0].empty() || col[0][0] == '#') continue;
  330. if (col.size() != 2) {
  331. LOG(ERROR) << "ReadIntPairs: Bad number of columns, "
  332. << "file = " << source << ", line = " << nline;
  333. return false;
  334. }
  335. bool err;
  336. I i1 = StrToInt64(col[0], source, nline, &err);
  337. if (err) return false;
  338. I i2 = StrToInt64(col[1], source, nline, &err);
  339. if (err) return false;
  340. pairs->emplace_back(i1, i2);
  341. }
  342. return true;
  343. }
  344. template <typename I>
  345. bool WriteIntPairs(std::string_view source,
  346. const std::vector<std::pair<I, I>> &pairs) {
  347. std::ofstream fstrm;
  348. if (!source.empty()) {
  349. fstrm.open(std::string(source));
  350. if (!fstrm) {
  351. LOG(ERROR) << "WriteIntPairs: Can't open file: " << source;
  352. return false;
  353. }
  354. }
  355. std::ostream &ostrm = fstrm.is_open() ? fstrm : std::cout;
  356. for (const auto &pair : pairs) {
  357. ostrm << pair.first << "\t" << pair.second << "\n";
  358. }
  359. return !!ostrm;
  360. }
  361. // Utilities for reading/writing label pairs.
  362. template <typename Label>
  363. bool ReadLabelPairs(std::string_view source,
  364. std::vector<std::pair<Label, Label>> *pairs) {
  365. return ReadIntPairs(source, pairs);
  366. }
  367. template <typename Label>
  368. bool WriteLabelPairs(std::string_view source,
  369. const std::vector<std::pair<Label, Label>> &pairs) {
  370. return WriteIntPairs(source, pairs);
  371. }
  372. // Utilities for converting a type name to a legal C symbol.
  373. void ConvertToLegalCSymbol(std::string *s);
  374. // Utilities for stream I/O.
  375. bool AlignInput(std::istream &strm, size_t align = MappedFile::kArchAlignment);
  376. bool AlignOutput(std::ostream &strm, size_t align = MappedFile::kArchAlignment);
  377. // An associative container for which testing membership is faster than an STL
  378. // set if members are restricted to an interval that excludes most non-members.
  379. // A Key must have ==, !=, and < operators defined. Element NoKey should be a
  380. // key that marks an uninitialized key and is otherwise unused. Find() returns
  381. // an STL const_iterator to the match found, otherwise it equals End().
  382. template <class Key, Key NoKey>
  383. class CompactSet {
  384. public:
  385. using const_iterator = typename std::set<Key>::const_iterator;
  386. CompactSet() : min_key_(NoKey), max_key_(NoKey) {}
  387. CompactSet(const CompactSet &) = default;
  388. void Insert(Key key) {
  389. set_.insert(key);
  390. if (min_key_ == NoKey || key < min_key_) min_key_ = key;
  391. if (max_key_ == NoKey || max_key_ < key) max_key_ = key;
  392. }
  393. void Erase(Key key) {
  394. set_.erase(key);
  395. if (set_.empty()) {
  396. min_key_ = max_key_ = NoKey;
  397. } else if (key == min_key_) {
  398. ++min_key_;
  399. } else if (key == max_key_) {
  400. --max_key_;
  401. }
  402. }
  403. void Clear() {
  404. set_.clear();
  405. min_key_ = max_key_ = NoKey;
  406. }
  407. const_iterator Find(Key key) const {
  408. if (min_key_ == NoKey || key < min_key_ || max_key_ < key) {
  409. return set_.end();
  410. } else {
  411. return set_.find(key);
  412. }
  413. }
  414. bool Member(Key key) const {
  415. if (min_key_ == NoKey || key < min_key_ || max_key_ < key) {
  416. return false; // out of range
  417. } else if (min_key_ != NoKey && max_key_ + 1 == min_key_ + set_.size()) {
  418. return true; // dense range
  419. } else {
  420. return set_.count(key);
  421. }
  422. }
  423. const_iterator Begin() const { return set_.begin(); }
  424. const_iterator End() const { return set_.end(); }
  425. // All stored keys are greater than or equal to this value.
  426. Key LowerBound() const { return min_key_; }
  427. // All stored keys are less than or equal to this value.
  428. Key UpperBound() const { return max_key_; }
  429. private:
  430. std::set<Key> set_;
  431. Key min_key_;
  432. Key max_key_;
  433. void operator=(const CompactSet &) = delete;
  434. };
  435. } // namespace fst
  436. #endif // FST_UTIL_H_