You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1452 lines
42 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Regression test for various FST algorithms.
  19. #ifndef FST_TEST_ALGO_TEST_H_
  20. #define FST_TEST_ALGO_TEST_H_
  21. #include <sys/types.h>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <memory>
  25. #include <random>
  26. #include <string>
  27. #include <utility>
  28. #include <vector>
  29. #include <fst/log.h>
  30. #include <fst/arc-map.h>
  31. #include <fst/arc.h>
  32. #include <fst/arcfilter.h>
  33. #include <fst/arcsort.h>
  34. #include <fst/cache.h>
  35. #include <fst/closure.h>
  36. #include <fst/compose-filter.h>
  37. #include <fst/compose.h>
  38. #include <fst/concat.h>
  39. #include <fst/connect.h>
  40. #include <fst/determinize.h>
  41. #include <fst/dfs-visit.h>
  42. #include <fst/difference.h>
  43. #include <fst/disambiguate.h>
  44. #include <fst/encode.h>
  45. #include <fst/equivalent.h>
  46. #include <fst/float-weight.h>
  47. #include <fst/fst.h>
  48. #include <fst/fstlib.h>
  49. #include <fst/intersect.h>
  50. #include <fst/invert.h>
  51. #include <fst/lookahead-matcher.h>
  52. #include <fst/matcher-fst.h>
  53. #include <fst/matcher.h>
  54. #include <fst/minimize.h>
  55. #include <fst/mutable-fst.h>
  56. #include <fst/pair-weight.h>
  57. #include <fst/project.h>
  58. #include <fst/properties.h>
  59. #include <fst/prune.h>
  60. #include <fst/push.h>
  61. #include <fst/randequivalent.h>
  62. #include <fst/randgen.h>
  63. #include <fst/rational.h>
  64. #include <fst/relabel.h>
  65. #include <fst/reverse.h>
  66. #include <fst/reweight.h>
  67. #include <fst/rmepsilon.h>
  68. #include <fst/shortest-distance.h>
  69. #include <fst/shortest-path.h>
  70. #include <fst/string-weight.h>
  71. #include <fst/synchronize.h>
  72. #include <fst/topsort.h>
  73. #include <fst/union-weight.h>
  74. #include <fst/union.h>
  75. #include <fst/vector-fst.h>
  76. #include <fst/verify.h>
  77. #include <fst/weight.h>
  78. #include <fst/test/rand-fst.h>
  79. DECLARE_int32(repeat); // defined in ./algo_test.cc
  80. namespace fst {
  81. // Mapper to change input and output label of every transition into
  82. // epsilons.
  83. template <class A>
  84. class EpsMapper {
  85. public:
  86. EpsMapper() = default;
  87. A operator()(const A &arc) const {
  88. return A(0, 0, arc.weight, arc.nextstate);
  89. }
  90. uint64_t Properties(uint64_t props) const {
  91. props &= ~kNotAcceptor;
  92. props |= kAcceptor;
  93. props &= ~kNoIEpsilons & ~kNoOEpsilons & ~kNoEpsilons;
  94. props |= kIEpsilons | kOEpsilons | kEpsilons;
  95. props &= ~kNotILabelSorted & ~kNotOLabelSorted;
  96. props |= kILabelSorted | kOLabelSorted;
  97. return props;
  98. }
  99. MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
  100. MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
  101. MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
  102. };
  103. // Generic - no lookahead.
  104. template <class Arc>
  105. void LookAheadCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
  106. MutableFst<Arc> *ofst) {
  107. Compose(ifst1, ifst2, ofst);
  108. }
  109. // Specialized and epsilon olabel acyclic - lookahead.
  110. inline void LookAheadCompose(const Fst<StdArc> &ifst1, const Fst<StdArc> &ifst2,
  111. MutableFst<StdArc> *ofst) {
  112. std::vector<StdArc::StateId> order;
  113. bool acyclic;
  114. TopOrderVisitor<StdArc> visitor(&order, &acyclic);
  115. DfsVisit(ifst1, &visitor, OutputEpsilonArcFilter<StdArc>());
  116. if (acyclic) { // no ifst1 output epsilon cycles?
  117. StdOLabelLookAheadFst lfst1(ifst1);
  118. StdVectorFst lfst2(ifst2);
  119. LabelLookAheadRelabeler<StdArc>::Relabel(&lfst2, lfst1, true);
  120. Compose(lfst1, lfst2, ofst);
  121. } else {
  122. Compose(ifst1, ifst2, ofst);
  123. }
  124. }
  125. // This class tests a variety of identities and properties that must
  126. // hold for various algorithms on weighted FSTs.
  127. template <class Arc>
  128. class WeightedTester {
  129. public:
  130. using Label = typename Arc::Label;
  131. using StateId = typename Arc::StateId;
  132. using Weight = typename Arc::Weight;
  133. using WeightGenerator = WeightGenerate<Weight>;
  134. WeightedTester(uint64_t seed, const Fst<Arc> &zero_fst,
  135. const Fst<Arc> &one_fst, const Fst<Arc> &univ_fst,
  136. WeightGenerator weight_generator)
  137. : seed_(seed),
  138. rand_(seed),
  139. zero_fst_(zero_fst),
  140. one_fst_(one_fst),
  141. univ_fst_(univ_fst),
  142. generate_(std::move(weight_generator)) {}
  143. void Test(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
  144. TestRational(T1, T2, T3);
  145. TestMap(T1);
  146. TestCompose(T1, T2, T3);
  147. TestSort(T1);
  148. TestOptimize(T1);
  149. TestSearch(T1);
  150. }
  151. private:
  152. // Tests rational operations with identities
  153. void TestRational(const Fst<Arc> &T1, const Fst<Arc> &T2,
  154. const Fst<Arc> &T3) {
  155. {
  156. VLOG(1) << "Check destructive and delayed union are equivalent.";
  157. VectorFst<Arc> U1(T1);
  158. Union(&U1, T2);
  159. UnionFst<Arc> U2(T1, T2);
  160. CHECK(Equiv(U1, U2));
  161. }
  162. {
  163. VLOG(1) << "Check destructive and delayed concatenation are equivalent.";
  164. VectorFst<Arc> C1(T1);
  165. Concat(&C1, T2);
  166. ConcatFst<Arc> C2(T1, T2);
  167. CHECK(Equiv(C1, C2));
  168. VectorFst<Arc> C3(T2);
  169. Concat(T1, &C3);
  170. CHECK(Equiv(C3, C2));
  171. }
  172. {
  173. VLOG(1) << "Check destructive and delayed closure* are equivalent.";
  174. VectorFst<Arc> C1(T1);
  175. Closure(&C1, CLOSURE_STAR);
  176. ClosureFst<Arc> C2(T1, CLOSURE_STAR);
  177. CHECK(Equiv(C1, C2));
  178. }
  179. {
  180. VLOG(1) << "Check destructive and delayed closure+ are equivalent.";
  181. VectorFst<Arc> C1(T1);
  182. Closure(&C1, CLOSURE_PLUS);
  183. ClosureFst<Arc> C2(T1, CLOSURE_PLUS);
  184. CHECK(Equiv(C1, C2));
  185. }
  186. {
  187. VLOG(1) << "Check union is associative (destructive).";
  188. VectorFst<Arc> U1(T1);
  189. Union(&U1, T2);
  190. Union(&U1, T3);
  191. VectorFst<Arc> U3(T2);
  192. Union(&U3, T3);
  193. VectorFst<Arc> U4(T1);
  194. Union(&U4, U3);
  195. CHECK(Equiv(U1, U4));
  196. }
  197. {
  198. VLOG(1) << "Check union is associative (delayed).";
  199. UnionFst<Arc> U1(T1, T2);
  200. UnionFst<Arc> U2(U1, T3);
  201. UnionFst<Arc> U3(T2, T3);
  202. UnionFst<Arc> U4(T1, U3);
  203. CHECK(Equiv(U2, U4));
  204. }
  205. {
  206. VLOG(1) << "Check union is associative (destructive delayed).";
  207. UnionFst<Arc> U1(T1, T2);
  208. Union(&U1, T3);
  209. UnionFst<Arc> U3(T2, T3);
  210. UnionFst<Arc> U4(T1, U3);
  211. CHECK(Equiv(U1, U4));
  212. }
  213. {
  214. VLOG(1) << "Check concatenation is associative (destructive).";
  215. VectorFst<Arc> C1(T1);
  216. Concat(&C1, T2);
  217. Concat(&C1, T3);
  218. VectorFst<Arc> C3(T2);
  219. Concat(&C3, T3);
  220. VectorFst<Arc> C4(T1);
  221. Concat(&C4, C3);
  222. CHECK(Equiv(C1, C4));
  223. }
  224. {
  225. VLOG(1) << "Check concatenation is associative (delayed).";
  226. ConcatFst<Arc> C1(T1, T2);
  227. ConcatFst<Arc> C2(C1, T3);
  228. ConcatFst<Arc> C3(T2, T3);
  229. ConcatFst<Arc> C4(T1, C3);
  230. CHECK(Equiv(C2, C4));
  231. }
  232. {
  233. VLOG(1) << "Check concatenation is associative (destructive delayed).";
  234. ConcatFst<Arc> C1(T1, T2);
  235. Concat(&C1, T3);
  236. ConcatFst<Arc> C3(T2, T3);
  237. ConcatFst<Arc> C4(T1, C3);
  238. CHECK(Equiv(C1, C4));
  239. }
  240. if (Weight::Properties() & kLeftSemiring) {
  241. VLOG(1) << "Check concatenation left distributes"
  242. << " over union (destructive).";
  243. VectorFst<Arc> U1(T1);
  244. Union(&U1, T2);
  245. VectorFst<Arc> C1(T3);
  246. Concat(&C1, U1);
  247. VectorFst<Arc> C2(T3);
  248. Concat(&C2, T1);
  249. VectorFst<Arc> C3(T3);
  250. Concat(&C3, T2);
  251. VectorFst<Arc> U2(C2);
  252. Union(&U2, C3);
  253. CHECK(Equiv(C1, U2));
  254. }
  255. if (Weight::Properties() & kRightSemiring) {
  256. VLOG(1) << "Check concatenation right distributes"
  257. << " over union (destructive).";
  258. VectorFst<Arc> U1(T1);
  259. Union(&U1, T2);
  260. VectorFst<Arc> C1(U1);
  261. Concat(&C1, T3);
  262. VectorFst<Arc> C2(T1);
  263. Concat(&C2, T3);
  264. VectorFst<Arc> C3(T2);
  265. Concat(&C3, T3);
  266. VectorFst<Arc> U2(C2);
  267. Union(&U2, C3);
  268. CHECK(Equiv(C1, U2));
  269. }
  270. if (Weight::Properties() & kLeftSemiring) {
  271. VLOG(1) << "Check concatenation left distributes over union (delayed).";
  272. UnionFst<Arc> U1(T1, T2);
  273. ConcatFst<Arc> C1(T3, U1);
  274. ConcatFst<Arc> C2(T3, T1);
  275. ConcatFst<Arc> C3(T3, T2);
  276. UnionFst<Arc> U2(C2, C3);
  277. CHECK(Equiv(C1, U2));
  278. }
  279. if (Weight::Properties() & kRightSemiring) {
  280. VLOG(1) << "Check concatenation right distributes over union (delayed).";
  281. UnionFst<Arc> U1(T1, T2);
  282. ConcatFst<Arc> C1(U1, T3);
  283. ConcatFst<Arc> C2(T1, T3);
  284. ConcatFst<Arc> C3(T2, T3);
  285. UnionFst<Arc> U2(C2, C3);
  286. CHECK(Equiv(C1, U2));
  287. }
  288. if (Weight::Properties() & kLeftSemiring) {
  289. VLOG(1) << "Check T T* == T+ (destructive).";
  290. VectorFst<Arc> S(T1);
  291. Closure(&S, CLOSURE_STAR);
  292. VectorFst<Arc> C(T1);
  293. Concat(&C, S);
  294. VectorFst<Arc> P(T1);
  295. Closure(&P, CLOSURE_PLUS);
  296. CHECK(Equiv(C, P));
  297. }
  298. if (Weight::Properties() & kRightSemiring) {
  299. VLOG(1) << "Check T* T == T+ (destructive).";
  300. VectorFst<Arc> S(T1);
  301. Closure(&S, CLOSURE_STAR);
  302. VectorFst<Arc> C(S);
  303. Concat(&C, T1);
  304. VectorFst<Arc> P(T1);
  305. Closure(&P, CLOSURE_PLUS);
  306. CHECK(Equiv(C, P));
  307. }
  308. if (Weight::Properties() & kLeftSemiring) {
  309. VLOG(1) << "Check T T* == T+ (delayed).";
  310. ClosureFst<Arc> S(T1, CLOSURE_STAR);
  311. ConcatFst<Arc> C(T1, S);
  312. ClosureFst<Arc> P(T1, CLOSURE_PLUS);
  313. CHECK(Equiv(C, P));
  314. }
  315. if (Weight::Properties() & kRightSemiring) {
  316. VLOG(1) << "Check T* T == T+ (delayed).";
  317. ClosureFst<Arc> S(T1, CLOSURE_STAR);
  318. ConcatFst<Arc> C(S, T1);
  319. ClosureFst<Arc> P(T1, CLOSURE_PLUS);
  320. CHECK(Equiv(C, P));
  321. }
  322. }
  323. // Tests map-based operations.
  324. void TestMap(const Fst<Arc> &T) {
  325. {
  326. VLOG(1) << "Check destructive and delayed projection are equivalent.";
  327. VectorFst<Arc> P1(T);
  328. Project(&P1, ProjectType::INPUT);
  329. ProjectFst<Arc> P2(T, ProjectType::INPUT);
  330. CHECK(Equiv(P1, P2));
  331. }
  332. {
  333. VLOG(1) << "Check destructive and delayed inversion are equivalent.";
  334. VectorFst<Arc> I1(T);
  335. Invert(&I1);
  336. InvertFst<Arc> I2(T);
  337. CHECK(Equiv(I1, I2));
  338. }
  339. {
  340. VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive).";
  341. VectorFst<Arc> P1(T);
  342. VectorFst<Arc> I1(T);
  343. Project(&P1, ProjectType::INPUT);
  344. Invert(&I1);
  345. Project(&I1, ProjectType::OUTPUT);
  346. CHECK(Equiv(P1, I1));
  347. }
  348. {
  349. VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive).";
  350. VectorFst<Arc> P1(T);
  351. VectorFst<Arc> I1(T);
  352. Project(&P1, ProjectType::OUTPUT);
  353. Invert(&I1);
  354. Project(&I1, ProjectType::INPUT);
  355. CHECK(Equiv(P1, I1));
  356. }
  357. {
  358. VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed).";
  359. ProjectFst<Arc> P1(T, ProjectType::INPUT);
  360. InvertFst<Arc> I1(T);
  361. ProjectFst<Arc> P2(I1, ProjectType::OUTPUT);
  362. CHECK(Equiv(P1, P2));
  363. }
  364. {
  365. VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed).";
  366. ProjectFst<Arc> P1(T, ProjectType::OUTPUT);
  367. InvertFst<Arc> I1(T);
  368. ProjectFst<Arc> P2(I1, ProjectType::INPUT);
  369. CHECK(Equiv(P1, P2));
  370. }
  371. {
  372. VLOG(1) << "Check destructive relabeling";
  373. static const int kNumLabels = 10;
  374. // set up relabeling pairs
  375. std::vector<Label> labelset(kNumLabels);
  376. for (size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
  377. for (size_t i = 0; i < kNumLabels; ++i) {
  378. using std::swap;
  379. const auto index =
  380. std::uniform_int_distribution<>(0, kNumLabels - 1)(rand_);
  381. swap(labelset[i], labelset[index]);
  382. }
  383. std::vector<std::pair<Label, Label>> ipairs1(kNumLabels);
  384. std::vector<std::pair<Label, Label>> opairs1(kNumLabels);
  385. for (size_t i = 0; i < kNumLabels; ++i) {
  386. ipairs1[i] = std::make_pair(i, labelset[i]);
  387. opairs1[i] = std::make_pair(labelset[i], i);
  388. }
  389. VectorFst<Arc> R(T);
  390. Relabel(&R, ipairs1, opairs1);
  391. std::vector<std::pair<Label, Label>> ipairs2(kNumLabels);
  392. std::vector<std::pair<Label, Label>> opairs2(kNumLabels);
  393. for (size_t i = 0; i < kNumLabels; ++i) {
  394. ipairs2[i] = std::make_pair(labelset[i], i);
  395. opairs2[i] = std::make_pair(i, labelset[i]);
  396. }
  397. Relabel(&R, ipairs2, opairs2);
  398. CHECK(Equiv(R, T));
  399. VLOG(1) << "Check on-the-fly relabeling";
  400. RelabelFst<Arc> Rdelay(T, ipairs1, opairs1);
  401. RelabelFst<Arc> RRdelay(Rdelay, ipairs2, opairs2);
  402. CHECK(Equiv(RRdelay, T));
  403. }
  404. {
  405. VLOG(1) << "Check encoding/decoding (destructive).";
  406. VectorFst<Arc> D(T);
  407. uint8_t encode_props = 0;
  408. if (std::bernoulli_distribution(.5)(rand_)) {
  409. encode_props |= kEncodeLabels;
  410. }
  411. if (std::bernoulli_distribution(.5)(rand_)) {
  412. encode_props |= kEncodeWeights;
  413. }
  414. EncodeMapper<Arc> encoder(encode_props, ENCODE);
  415. Encode(&D, &encoder);
  416. Decode(&D, encoder);
  417. CHECK(Equiv(D, T));
  418. }
  419. {
  420. VLOG(1) << "Check encoding/decoding (delayed).";
  421. uint8_t encode_props = 0;
  422. if (std::bernoulli_distribution(.5)(rand_)) {
  423. encode_props |= kEncodeLabels;
  424. }
  425. if (std::bernoulli_distribution(.5)(rand_)) {
  426. encode_props |= kEncodeWeights;
  427. }
  428. EncodeMapper<Arc> encoder(encode_props, ENCODE);
  429. EncodeFst<Arc> E(T, &encoder);
  430. VectorFst<Arc> Encoded(E);
  431. DecodeFst<Arc> D(Encoded, encoder);
  432. CHECK(Equiv(D, T));
  433. }
  434. {
  435. VLOG(1) << "Check gallic mappers (constructive).";
  436. ToGallicMapper<Arc> to_mapper;
  437. FromGallicMapper<Arc> from_mapper;
  438. VectorFst<GallicArc<Arc>> G;
  439. VectorFst<Arc> F;
  440. ArcMap(T, &G, to_mapper);
  441. ArcMap(G, &F, from_mapper);
  442. CHECK(Equiv(T, F));
  443. }
  444. {
  445. VLOG(1) << "Check gallic mappers (delayed).";
  446. ArcMapFst G(T, ToGallicMapper<Arc>());
  447. ArcMapFst F(G, FromGallicMapper<Arc>());
  448. CHECK(Equiv(T, F));
  449. }
  450. }
  451. // Tests compose-based operations.
  452. void TestCompose(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
  453. if (!(Weight::Properties() & kCommutative)) return;
  454. VectorFst<Arc> S1(T1);
  455. VectorFst<Arc> S2(T2);
  456. VectorFst<Arc> S3(T3);
  457. ILabelCompare<Arc> icomp;
  458. OLabelCompare<Arc> ocomp;
  459. ArcSort(&S1, ocomp);
  460. ArcSort(&S2, ocomp);
  461. ArcSort(&S3, icomp);
  462. {
  463. VLOG(1) << "Check composition is associative.";
  464. ComposeFst<Arc> C1(S1, S2);
  465. ComposeFst<Arc> C2(C1, S3);
  466. ComposeFst<Arc> C3(S2, S3);
  467. ComposeFst<Arc> C4(S1, C3);
  468. CHECK(Equiv(C2, C4));
  469. }
  470. {
  471. VLOG(1) << "Check composition left distributes over union.";
  472. UnionFst<Arc> U1(S2, S3);
  473. ComposeFst<Arc> C1(S1, U1);
  474. ComposeFst<Arc> C2(S1, S2);
  475. ComposeFst<Arc> C3(S1, S3);
  476. UnionFst<Arc> U2(C2, C3);
  477. CHECK(Equiv(C1, U2));
  478. }
  479. {
  480. VLOG(1) << "Check composition right distributes over union.";
  481. UnionFst<Arc> U1(S1, S2);
  482. ComposeFst<Arc> C1(U1, S3);
  483. ComposeFst<Arc> C2(S1, S3);
  484. ComposeFst<Arc> C3(S2, S3);
  485. UnionFst<Arc> U2(C2, C3);
  486. CHECK(Equiv(C1, U2));
  487. }
  488. VectorFst<Arc> A1(S1);
  489. VectorFst<Arc> A2(S2);
  490. VectorFst<Arc> A3(S3);
  491. Project(&A1, ProjectType::OUTPUT);
  492. Project(&A2, ProjectType::INPUT);
  493. Project(&A3, ProjectType::INPUT);
  494. {
  495. VLOG(1) << "Check intersection is commutative.";
  496. IntersectFst<Arc> I1(A1, A2);
  497. IntersectFst<Arc> I2(A2, A1);
  498. CHECK(Equiv(I1, I2));
  499. }
  500. {
  501. VLOG(1) << "Check all epsilon filters leads to equivalent results.";
  502. using M = Matcher<Fst<Arc>>;
  503. ComposeFst<Arc> C1(S1, S2);
  504. ComposeFst<Arc> C2(
  505. S1, S2, ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>>());
  506. ComposeFst<Arc> C3(S1, S2,
  507. ComposeFstOptions<Arc, M, MatchComposeFilter<M>>());
  508. CHECK(Equiv(C1, C2));
  509. CHECK(Equiv(C1, C3));
  510. if ((Weight::Properties() & kIdempotent) ||
  511. S1.Properties(kNoOEpsilons, false) ||
  512. S2.Properties(kNoIEpsilons, false)) {
  513. ComposeFst<Arc> C4(
  514. S1, S2, ComposeFstOptions<Arc, M, TrivialComposeFilter<M>>());
  515. CHECK(Equiv(C1, C4));
  516. ComposeFst<Arc> C5(
  517. S1, S2, ComposeFstOptions<Arc, M, NoMatchComposeFilter<M>>());
  518. CHECK(Equiv(C1, C5));
  519. }
  520. if (S1.Properties(kNoOEpsilons, false) &&
  521. S2.Properties(kNoIEpsilons, false)) {
  522. ComposeFst<Arc> C6(S1, S2,
  523. ComposeFstOptions<Arc, M, NullComposeFilter<M>>());
  524. CHECK(Equiv(C1, C6));
  525. }
  526. }
  527. {
  528. VLOG(1) << "Check look-ahead filters lead to equivalent results.";
  529. VectorFst<Arc> C1, C2;
  530. Compose(S1, S2, &C1);
  531. LookAheadCompose(S1, S2, &C2);
  532. CHECK(Equiv(C1, C2));
  533. }
  534. }
  535. // Tests sorting operations
  536. void TestSort(const Fst<Arc> &T) {
  537. ILabelCompare<Arc> icomp;
  538. OLabelCompare<Arc> ocomp;
  539. {
  540. VLOG(1) << "Check arc sorted Fst is equivalent to its input.";
  541. VectorFst<Arc> S1(T);
  542. ArcSort(&S1, icomp);
  543. CHECK(Equiv(T, S1));
  544. }
  545. {
  546. VLOG(1) << "Check destructive and delayed arcsort are equivalent.";
  547. VectorFst<Arc> S1(T);
  548. ArcSort(&S1, icomp);
  549. ArcSortFst<Arc, ILabelCompare<Arc>> S2(T, icomp);
  550. CHECK(Equiv(S1, S2));
  551. }
  552. {
  553. VLOG(1) << "Check ilabel sorting vs. olabel sorting with inversions.";
  554. VectorFst<Arc> S1(T);
  555. VectorFst<Arc> S2(T);
  556. ArcSort(&S1, icomp);
  557. Invert(&S2);
  558. ArcSort(&S2, ocomp);
  559. Invert(&S2);
  560. CHECK(Equiv(S1, S2));
  561. }
  562. {
  563. VLOG(1) << "Check topologically sorted Fst is equivalent to its input.";
  564. VectorFst<Arc> S1(T);
  565. TopSort(&S1);
  566. CHECK(Equiv(T, S1));
  567. }
  568. {
  569. VLOG(1) << "Check reverse(reverse(T)) = T";
  570. for (int i = 0; i < 2; ++i) {
  571. VectorFst<ReverseArc<Arc>> R1;
  572. VectorFst<Arc> R2;
  573. bool require_superinitial = i == 1;
  574. Reverse(T, &R1, require_superinitial);
  575. Reverse(R1, &R2, require_superinitial);
  576. CHECK(Equiv(T, R2));
  577. }
  578. }
  579. }
  580. // Tests optimization operations
  581. void TestOptimize(const Fst<Arc> &T) {
  582. uint64_t tprops = T.Properties(kFstProperties, true);
  583. uint64_t wprops = Weight::Properties();
  584. VectorFst<Arc> A(T);
  585. Project(&A, ProjectType::INPUT);
  586. {
  587. VLOG(1) << "Check connected FST is equivalent to its input.";
  588. VectorFst<Arc> C1(T);
  589. Connect(&C1);
  590. CHECK(Equiv(T, C1));
  591. }
  592. if ((wprops & kSemiring) == kSemiring &&
  593. (tprops & kAcyclic || wprops & kIdempotent)) {
  594. VLOG(1) << "Check epsilon-removed FST is equivalent to its input.";
  595. VectorFst<Arc> R1(T);
  596. RmEpsilon(&R1);
  597. CHECK(Equiv(T, R1));
  598. VLOG(1) << "Check destructive and delayed epsilon removal"
  599. << "are equivalent.";
  600. RmEpsilonFst<Arc> R2(T);
  601. CHECK(Equiv(R1, R2));
  602. VLOG(1) << "Check an FST with a large proportion"
  603. << " of epsilon transitions:";
  604. // Maps all transitions of T to epsilon-transitions and append
  605. // a non-epsilon transition.
  606. VectorFst<Arc> U;
  607. ArcMap(T, &U, EpsMapper<Arc>());
  608. VectorFst<Arc> V;
  609. V.SetStart(V.AddState());
  610. Arc arc(1, 1, Weight::One(), V.AddState());
  611. V.AddArc(V.Start(), arc);
  612. V.SetFinal(arc.nextstate, Weight::One());
  613. Concat(&U, V);
  614. // Check that epsilon-removal preserves the shortest-distance
  615. // from the initial state to the final states.
  616. std::vector<Weight> d;
  617. ShortestDistance(U, &d, true);
  618. Weight w = U.Start() < d.size() ? d[U.Start()] : Weight::Zero();
  619. VectorFst<Arc> U1(U);
  620. RmEpsilon(&U1);
  621. ShortestDistance(U1, &d, true);
  622. Weight w1 = U1.Start() < d.size() ? d[U1.Start()] : Weight::Zero();
  623. CHECK(ApproxEqual(w, w1, kTestDelta));
  624. RmEpsilonFst<Arc> U2(U);
  625. ShortestDistance(U2, &d, true);
  626. Weight w2 = U2.Start() < d.size() ? d[U2.Start()] : Weight::Zero();
  627. CHECK(ApproxEqual(w, w2, kTestDelta));
  628. }
  629. if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
  630. VLOG(1) << "Check determinized FSA is equivalent to its input.";
  631. DeterminizeFst<Arc> D(A);
  632. CHECK(Equiv(A, D));
  633. {
  634. VLOG(1) << "Check determinized FST is equivalent to its input.";
  635. DeterminizeFstOptions<Arc> opts;
  636. opts.type = DETERMINIZE_NONFUNCTIONAL;
  637. DeterminizeFst<Arc> DT(T, opts);
  638. CHECK(Equiv(T, DT));
  639. }
  640. if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
  641. VLOG(1) << "Check pruning in determinization";
  642. VectorFst<Arc> P;
  643. const Weight threshold = generate_();
  644. DeterminizeOptions<Arc> opts;
  645. opts.weight_threshold = threshold;
  646. Determinize(A, &P, opts);
  647. CHECK(P.Properties(kIDeterministic, true));
  648. CHECK(PruneEquiv(A, P, threshold));
  649. }
  650. if ((wprops & kPath) == kPath) {
  651. VLOG(1) << "Check min-determinization";
  652. // Ensures no input epsilons
  653. VectorFst<Arc> R(T);
  654. std::vector<std::pair<Label, Label>> ipairs, opairs;
  655. ipairs.push_back(std::pair<Label, Label>(0, 1));
  656. Relabel(&R, ipairs, opairs);
  657. VectorFst<Arc> M;
  658. DeterminizeOptions<Arc> opts;
  659. opts.type = DETERMINIZE_DISAMBIGUATE;
  660. Determinize(R, &M, opts);
  661. CHECK(M.Properties(kIDeterministic, true));
  662. CHECK(MinRelated(M, R));
  663. }
  664. int n;
  665. {
  666. VLOG(1) << "Check size(min(det(A))) <= size(det(A))"
  667. << " and min(det(A)) equiv det(A)";
  668. VectorFst<Arc> M(D);
  669. n = M.NumStates();
  670. Minimize(&M, static_cast<MutableFst<Arc> *>(nullptr), kDelta);
  671. CHECK(Equiv(D, M));
  672. CHECK(M.NumStates() <= n);
  673. n = M.NumStates();
  674. }
  675. if (n && (wprops & kIdempotent) == kIdempotent &&
  676. A.Properties(kNoEpsilons, true)) {
  677. VLOG(1) << "Check that Revuz's algorithm leads to the"
  678. << " same number of states as Brozozowski's algorithm";
  679. // Skip test if A is the empty machine or contains epsilons or
  680. // if the semiring is not idempotent (to avoid floating point
  681. // errors)
  682. VectorFst<ReverseArc<Arc>> R;
  683. Reverse(A, &R);
  684. RmEpsilon(&R);
  685. DeterminizeFst<ReverseArc<Arc>> DR(R);
  686. VectorFst<Arc> RD;
  687. Reverse(DR, &RD);
  688. DeterminizeFst<Arc> DRD(RD);
  689. VectorFst<Arc> M(DRD);
  690. CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
  691. // to the initial state
  692. }
  693. }
  694. if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
  695. VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
  696. VectorFst<Arc> R(A), D;
  697. RmEpsilon(&R);
  698. Disambiguate(R, &D);
  699. CHECK(Equiv(R, D));
  700. VLOG(1) << "Check disambiguated FSA is unambiguous";
  701. CHECK(Unambiguous(D));
  702. /* TODO(riley): find out why this fails
  703. if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
  704. VLOG(1) << "Check pruning in disambiguation";
  705. VectorFst<Arc> P;
  706. const Weight threshold = generate_();
  707. DisambiguateOptions<Arc> opts;
  708. opts.weight_threshold = threshold;
  709. Disambiguate(R, &P, opts);
  710. CHECK(Unambiguous(P));
  711. CHECK(PruneEquiv(A, P, threshold));
  712. }
  713. */
  714. }
  715. if (Arc::Type() == LogArc::Type() || Arc::Type() == StdArc::Type()) {
  716. VLOG(1) << "Check reweight(T) equiv T";
  717. std::vector<Weight> potential;
  718. VectorFst<Arc> RI(T);
  719. VectorFst<Arc> RF(T);
  720. while (potential.size() < RI.NumStates()) {
  721. potential.push_back(generate_());
  722. }
  723. Reweight(&RI, potential, REWEIGHT_TO_INITIAL);
  724. CHECK(Equiv(T, RI));
  725. Reweight(&RF, potential, REWEIGHT_TO_FINAL);
  726. CHECK(Equiv(T, RF));
  727. }
  728. if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
  729. VLOG(1) << "Check pushed FST is equivalent to input FST.";
  730. // Pushing towards the final state.
  731. if (wprops & kRightSemiring) {
  732. VectorFst<Arc> P1;
  733. Push<Arc, REWEIGHT_TO_FINAL>(T, &P1, kPushLabels);
  734. CHECK(Equiv(T, P1));
  735. VectorFst<Arc> P2;
  736. Push<Arc, REWEIGHT_TO_FINAL>(T, &P2, kPushWeights);
  737. CHECK(Equiv(T, P2));
  738. VectorFst<Arc> P3;
  739. Push<Arc, REWEIGHT_TO_FINAL>(T, &P3, kPushLabels | kPushWeights);
  740. CHECK(Equiv(T, P3));
  741. }
  742. // Pushing towards the initial state.
  743. if (wprops & kLeftSemiring) {
  744. VectorFst<Arc> P1;
  745. Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1, kPushLabels);
  746. CHECK(Equiv(T, P1));
  747. VectorFst<Arc> P2;
  748. Push<Arc, REWEIGHT_TO_INITIAL>(T, &P2, kPushWeights);
  749. CHECK(Equiv(T, P2));
  750. VectorFst<Arc> P3;
  751. Push<Arc, REWEIGHT_TO_INITIAL>(T, &P3, kPushLabels | kPushWeights);
  752. CHECK(Equiv(T, P3));
  753. }
  754. }
  755. if constexpr (IsPath<Weight>::value) {
  756. if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
  757. VLOG(1) << "Check pruning algorithm";
  758. {
  759. VLOG(1) << "Check equiv. of constructive and destructive algorithms";
  760. const Weight threshold = generate_();
  761. VectorFst<Arc> P1(T);
  762. Prune(&P1, threshold);
  763. VectorFst<Arc> P2;
  764. Prune(T, &P2, threshold);
  765. CHECK(Equiv(P1, P2));
  766. }
  767. {
  768. VLOG(1) << "Check prune(reverse) equiv reverse(prune)";
  769. const Weight threshold = generate_();
  770. VectorFst<ReverseArc<Arc>> R;
  771. VectorFst<Arc> P1(T);
  772. VectorFst<Arc> P2;
  773. Prune(&P1, threshold);
  774. Reverse(T, &R);
  775. Prune(&R, threshold.Reverse());
  776. Reverse(R, &P2);
  777. CHECK(Equiv(P1, P2));
  778. }
  779. {
  780. VLOG(1) << "Check: ShortestDistance(A - prune(A))"
  781. << " > ShortestDistance(A) times Threshold";
  782. const Weight threshold = generate_();
  783. VectorFst<Arc> P;
  784. Prune(A, &P, threshold);
  785. CHECK(PruneEquiv(A, P, threshold));
  786. }
  787. }
  788. }
  789. if (tprops & kAcyclic) {
  790. VLOG(1) << "Check synchronize(T) equiv T";
  791. SynchronizeFst<Arc> S(T);
  792. CHECK(Equiv(T, S));
  793. }
  794. }
  795. // Tests search operations
  796. void TestSearch(const Fst<Arc> &T) {
  797. if constexpr (IsPath<Weight>::value) {
  798. uint64_t wprops = Weight::Properties();
  799. VectorFst<Arc> A(T);
  800. Project(&A, ProjectType::INPUT);
  801. if ((wprops & (kPath | kRightSemiring)) == (kPath | kRightSemiring)) {
  802. VLOG(1) << "Check 1-best weight.";
  803. VectorFst<Arc> path;
  804. ShortestPath(T, &path);
  805. Weight tsum = ShortestDistance(T);
  806. Weight psum = ShortestDistance(path);
  807. CHECK(ApproxEqual(tsum, psum, kTestDelta));
  808. }
  809. if ((wprops & (kPath | kSemiring)) == (kPath | kSemiring)) {
  810. VLOG(1) << "Check n-best weights";
  811. VectorFst<Arc> R(A);
  812. RmEpsilon(&R, /*connect=*/true, Arc::Weight::Zero(), kNoStateId,
  813. kDelta);
  814. const int nshortest = std::uniform_int_distribution<>(
  815. 0, kNumRandomShortestPaths + 1)(rand_);
  816. VectorFst<Arc> paths;
  817. ShortestPath(R, &paths, nshortest, /*unique=*/true,
  818. /*first_path=*/false, Weight::Zero(), kNumShortestStates,
  819. kDelta);
  820. std::vector<Weight> distance;
  821. ShortestDistance(paths, &distance, true, kDelta);
  822. StateId pstart = paths.Start();
  823. if (pstart != kNoStateId) {
  824. ArcIterator<Fst<Arc>> piter(paths, pstart);
  825. for (; !piter.Done(); piter.Next()) {
  826. StateId s = piter.Value().nextstate;
  827. Weight nsum = s < distance.size()
  828. ? Times(piter.Value().weight, distance[s])
  829. : Weight::Zero();
  830. VectorFst<Arc> path;
  831. ShortestPath(R, &path, 1, false, false, Weight::Zero(), kNoStateId,
  832. kDelta);
  833. Weight dsum = ShortestDistance(path, kDelta);
  834. CHECK(ApproxEqual(nsum, dsum, kTestDelta));
  835. ArcMap(&path, RmWeightMapper<Arc>());
  836. VectorFst<Arc> S;
  837. Difference(R, path, &S);
  838. R = S;
  839. }
  840. }
  841. }
  842. }
  843. }
  844. // Tests if two FSTS are equivalent by checking if random
  845. // strings from one FST are transduced the same by both FSTs.
  846. template <class A>
  847. bool Equiv(const Fst<A> &fst1, const Fst<A> &fst2) {
  848. VLOG(1) << "Check FSTs for sanity (including property bits).";
  849. CHECK(Verify(fst1));
  850. CHECK(Verify(fst2));
  851. // Ensures seed used once per instantiation.
  852. static const UniformArcSelector<A> uniform_selector(seed_);
  853. const RandGenOptions<UniformArcSelector<A>> opts(uniform_selector,
  854. kRandomPathLength);
  855. return RandEquivalent(fst1, fst2, kNumRandomPaths, opts, kTestDelta, seed_);
  856. }
  857. // Tests FSA is unambiguous.
  858. bool Unambiguous(const Fst<Arc> &fst) {
  859. VectorFst<StdArc> sfst, dfst;
  860. VectorFst<LogArc> lfst1, lfst2;
  861. ArcMap(fst, &sfst, RmWeightMapper<Arc, StdArc>());
  862. Determinize(sfst, &dfst);
  863. ArcMap(fst, &lfst1, RmWeightMapper<Arc, LogArc>());
  864. ArcMap(dfst, &lfst2, RmWeightMapper<StdArc, LogArc>());
  865. return Equiv(lfst1, lfst2);
  866. }
  867. // Ensures input-epsilon free transducers fst1 and fst2 have the
  868. // same domain and that for each string pair '(is, os)' in fst1,
  869. // '(is, os)' is the minimum weight match to 'is' in fst2.
  870. template <class A>
  871. bool MinRelated(const Fst<A> &fst1, const Fst<A> &fst2) {
  872. // Same domain
  873. VectorFst<Arc> P1(fst1), P2(fst2);
  874. Project(&P1, ProjectType::INPUT);
  875. Project(&P2, ProjectType::INPUT);
  876. if (!Equiv(P1, P2)) {
  877. LOG(ERROR) << "Inputs not equivalent";
  878. return false;
  879. }
  880. // Ensures seed used once per instantiation.
  881. static const UniformArcSelector<A> uniform_selector(seed_);
  882. const RandGenOptions<UniformArcSelector<A>> opts(uniform_selector,
  883. kRandomPathLength);
  884. VectorFst<Arc> path, paths1, paths2;
  885. for (ssize_t n = 0; n < kNumRandomPaths; ++n) {
  886. RandGen(fst1, &path, opts);
  887. Invert(&path);
  888. ArcMap(&path, RmWeightMapper<Arc>());
  889. Compose(path, fst2, &paths1);
  890. Weight sum1 = ShortestDistance(paths1);
  891. Compose(paths1, path, &paths2);
  892. Weight sum2 = ShortestDistance(paths2);
  893. if (!ApproxEqual(Plus(sum1, sum2), sum2, kTestDelta)) {
  894. LOG(ERROR) << "Sums not equivalent: " << sum1 << " " << sum2;
  895. return false;
  896. }
  897. }
  898. return true;
  899. }
  900. // Tests ShortestDistance(A - P) >= ShortestDistance(A) times Threshold.
  901. template <class A>
  902. bool PruneEquiv(const Fst<A> &fst, const Fst<A> &pfst, Weight threshold) {
  903. VLOG(1) << "Check FSTs for sanity (including property bits).";
  904. CHECK(Verify(fst));
  905. CHECK(Verify(pfst));
  906. DifferenceFst<Arc> D(fst, DeterminizeFst<Arc>(RmEpsilonFst<Arc>(
  907. ArcMapFst(pfst, RmWeightMapper<Arc>()))));
  908. const Weight sum1 = Times(ShortestDistance(fst), threshold);
  909. const Weight sum2 = ShortestDistance(D);
  910. return ApproxEqual(Plus(sum1, sum2), sum1, kTestDelta);
  911. }
  912. // Random seed.
  913. uint64_t seed_;
  914. // Random state (for randomness in this class).
  915. std::mt19937_64 rand_;
  916. // FST with no states
  917. VectorFst<Arc> zero_fst_;
  918. // FST with one state that accepts epsilon.
  919. VectorFst<Arc> one_fst_;
  920. // FST with one state that accepts all strings.
  921. VectorFst<Arc> univ_fst_;
  922. // Generates weights used in testing.
  923. WeightGenerator generate_;
  924. // Maximum random path length.
  925. static constexpr int kRandomPathLength = 25;
  926. // Number of random paths to explore.
  927. static constexpr int kNumRandomPaths = 100;
  928. // Maximum number of nshortest paths.
  929. static constexpr int kNumRandomShortestPaths = 100;
  930. // Maximum number of nshortest states.
  931. static constexpr int kNumShortestStates = 10000;
  932. // Delta for equivalence tests.
  933. static constexpr float kTestDelta = .05;
  934. WeightedTester(const WeightedTester &) = delete;
  935. WeightedTester &operator=(const WeightedTester &) = delete;
  936. };
  937. // This class tests a variety of identities and properties that must
  938. // hold for various algorithms on unweighted FSAs and that are not tested
  939. // by WeightedTester. Only the specialization does anything interesting.
  940. template <class Arc>
  941. class UnweightedTester {
  942. public:
  943. UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
  944. const Fst<Arc> &univ_fsa, uint64_t seed) {}
  945. void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {}
  946. };
  947. // Specialization for StdArc. This should work for any commutative,
  948. // idempotent semiring when restricted to the unweighted case
  949. // (being isomorphic to the boolean semiring).
  950. template <>
  951. class UnweightedTester<StdArc> {
  952. public:
  953. using Arc = StdArc;
  954. using Label = Arc::Label;
  955. using StateId = Arc::StateId;
  956. using Weight = Arc::Weight;
  957. UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
  958. const Fst<Arc> &univ_fsa, uint64_t seed)
  959. : zero_fsa_(zero_fsa),
  960. one_fsa_(one_fsa),
  961. univ_fsa_(univ_fsa),
  962. rand_(seed) {}
  963. void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {
  964. TestRational(A1, A2, A3);
  965. TestIntersect(A1, A2, A3);
  966. TestOptimize(A1);
  967. }
  968. private:
  969. // Tests rational operations with identities.
  970. void TestRational(const Fst<Arc> &A1, const Fst<Arc> &A2,
  971. const Fst<Arc> &A3) {
  972. {
  973. VLOG(1) << "Check the union contains its arguments (destructive).";
  974. VectorFst<Arc> U(A1);
  975. Union(&U, A2);
  976. CHECK(Subset(A1, U));
  977. CHECK(Subset(A2, U));
  978. }
  979. {
  980. VLOG(1) << "Check the union contains its arguments (delayed).";
  981. UnionFst<Arc> U(A1, A2);
  982. CHECK(Subset(A1, U));
  983. CHECK(Subset(A2, U));
  984. }
  985. {
  986. VLOG(1) << "Check if A^n c A* (destructive).";
  987. VectorFst<Arc> C(one_fsa_);
  988. const int n = std::uniform_int_distribution<>(0, 4)(rand_);
  989. for (int i = 0; i < n; ++i) Concat(&C, A1);
  990. VectorFst<Arc> S(A1);
  991. Closure(&S, CLOSURE_STAR);
  992. CHECK(Subset(C, S));
  993. }
  994. {
  995. VLOG(1) << "Check if A^n c A* (delayed).";
  996. const int n = std::uniform_int_distribution<>(0, 4)(rand_);
  997. std::unique_ptr<Fst<Arc>> C = std::make_unique<VectorFst<Arc>>(one_fsa_);
  998. for (int i = 0; i < n; ++i) {
  999. C = std::make_unique<ConcatFst<Arc>>(*C, A1);
  1000. }
  1001. ClosureFst<Arc> S(A1, CLOSURE_STAR);
  1002. CHECK(Subset(*C, S));
  1003. }
  1004. }
  1005. // Tests intersect-based operations.
  1006. void TestIntersect(const Fst<Arc> &A1, const Fst<Arc> &A2,
  1007. const Fst<Arc> &A3) {
  1008. VectorFst<Arc> S1(A1);
  1009. VectorFst<Arc> S2(A2);
  1010. VectorFst<Arc> S3(A3);
  1011. ILabelCompare<Arc> comp;
  1012. ArcSort(&S1, comp);
  1013. ArcSort(&S2, comp);
  1014. ArcSort(&S3, comp);
  1015. {
  1016. VLOG(1) << "Check the intersection is contained in its arguments.";
  1017. IntersectFst<Arc> I1(S1, S2);
  1018. CHECK(Subset(I1, S1));
  1019. CHECK(Subset(I1, S2));
  1020. }
  1021. {
  1022. VLOG(1) << "Check union distributes over intersection.";
  1023. IntersectFst<Arc> I1(S1, S2);
  1024. UnionFst<Arc> U1(I1, S3);
  1025. UnionFst<Arc> U2(S1, S3);
  1026. UnionFst<Arc> U3(S2, S3);
  1027. ArcSortFst<Arc, ILabelCompare<Arc>> S4(U3, comp);
  1028. IntersectFst<Arc> I2(U2, S4);
  1029. CHECK(Equiv(U1, I2));
  1030. }
  1031. VectorFst<Arc> C1;
  1032. VectorFst<Arc> C2;
  1033. Complement(S1, &C1);
  1034. Complement(S2, &C2);
  1035. ArcSort(&C1, comp);
  1036. ArcSort(&C2, comp);
  1037. {
  1038. VLOG(1) << "Check S U S' = Sigma*";
  1039. UnionFst<Arc> U(S1, C1);
  1040. CHECK(Equiv(U, univ_fsa_));
  1041. }
  1042. {
  1043. VLOG(1) << "Check S n S' = {}";
  1044. IntersectFst<Arc> I(S1, C1);
  1045. CHECK(Equiv(I, zero_fsa_));
  1046. }
  1047. {
  1048. VLOG(1) << "Check (S1' U S2') == (S1 n S2)'";
  1049. UnionFst<Arc> U(C1, C2);
  1050. IntersectFst<Arc> I(S1, S2);
  1051. VectorFst<Arc> C3;
  1052. Complement(I, &C3);
  1053. CHECK(Equiv(U, C3));
  1054. }
  1055. {
  1056. VLOG(1) << "Check (S1' n S2') == (S1 U S2)'";
  1057. IntersectFst<Arc> I(C1, C2);
  1058. UnionFst<Arc> U(S1, S2);
  1059. VectorFst<Arc> C3;
  1060. Complement(U, &C3);
  1061. CHECK(Equiv(I, C3));
  1062. }
  1063. }
  1064. // Tests optimization operations.
  1065. void TestOptimize(const Fst<Arc> &A) {
  1066. {
  1067. VLOG(1) << "Check determinized FSA is equivalent to its input.";
  1068. DeterminizeFst<Arc> D(A);
  1069. CHECK(Equiv(A, D));
  1070. }
  1071. {
  1072. VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
  1073. VectorFst<Arc> R(A), D;
  1074. RmEpsilon(&R);
  1075. Disambiguate(R, &D);
  1076. CHECK(Equiv(R, D));
  1077. }
  1078. {
  1079. VLOG(1) << "Check minimized FSA is equivalent to its input.";
  1080. int n;
  1081. {
  1082. RmEpsilonFst<Arc> R(A);
  1083. DeterminizeFst<Arc> D(R);
  1084. VectorFst<Arc> M(D);
  1085. Minimize(&M, static_cast<MutableFst<Arc> *>(nullptr), kDelta);
  1086. CHECK(Equiv(A, M));
  1087. n = M.NumStates();
  1088. }
  1089. if (n) { // Skips test if A is the empty machine.
  1090. VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the"
  1091. << " same number of states as Brozozowski's algorithm";
  1092. VectorFst<Arc> R;
  1093. Reverse(A, &R);
  1094. RmEpsilon(&R);
  1095. DeterminizeFst<Arc> DR(R);
  1096. VectorFst<Arc> RD;
  1097. Reverse(DR, &RD);
  1098. DeterminizeFst<Arc> DRD(RD);
  1099. VectorFst<Arc> M(DRD);
  1100. CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
  1101. // to the initial state.
  1102. }
  1103. }
  1104. }
  1105. // Tests if two FSAS are equivalent.
  1106. bool Equiv(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
  1107. VLOG(1) << "Check FSAs for sanity (including property bits).";
  1108. CHECK(Verify(fsa1));
  1109. CHECK(Verify(fsa2));
  1110. VectorFst<Arc> vfsa1(fsa1);
  1111. VectorFst<Arc> vfsa2(fsa2);
  1112. RmEpsilon(&vfsa1);
  1113. RmEpsilon(&vfsa2);
  1114. DeterminizeFst<Arc> dfa1(vfsa1);
  1115. DeterminizeFst<Arc> dfa2(vfsa2);
  1116. // Test equivalence using union-find algorithm
  1117. bool equiv1 = Equivalent(dfa1, dfa2);
  1118. // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty
  1119. ILabelCompare<Arc> comp;
  1120. VectorFst<Arc> sdfa1(dfa1);
  1121. ArcSort(&sdfa1, comp);
  1122. VectorFst<Arc> sdfa2(dfa2);
  1123. ArcSort(&sdfa2, comp);
  1124. DifferenceFst<Arc> dfsa1(sdfa1, sdfa2);
  1125. DifferenceFst<Arc> dfsa2(sdfa2, sdfa1);
  1126. VectorFst<Arc> ufsa(dfsa1);
  1127. Union(&ufsa, dfsa2);
  1128. Connect(&ufsa);
  1129. bool equiv2 = ufsa.NumStates() == 0;
  1130. // Checks both equivalence tests match.
  1131. CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
  1132. return equiv1;
  1133. }
  1134. // Tests if FSA1 is a subset of FSA2 (disregarding weights).
  1135. bool Subset(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
  1136. VLOG(1) << "Check FSAs (incl. property bits) for sanity";
  1137. CHECK(Verify(fsa1));
  1138. CHECK(Verify(fsa2));
  1139. VectorFst<StdArc> vfsa1;
  1140. VectorFst<StdArc> vfsa2;
  1141. RmEpsilon(&vfsa1);
  1142. RmEpsilon(&vfsa2);
  1143. ILabelCompare<StdArc> comp;
  1144. ArcSort(&vfsa1, comp);
  1145. ArcSort(&vfsa2, comp);
  1146. IntersectFst<StdArc> ifsa(vfsa1, vfsa2);
  1147. DeterminizeFst<StdArc> dfa1(vfsa1);
  1148. DeterminizeFst<StdArc> dfa2(ifsa);
  1149. return Equivalent(dfa1, dfa2);
  1150. }
  1151. // Returns complement FSA.
  1152. void Complement(const Fst<Arc> &ifsa, MutableFst<Arc> *ofsa) {
  1153. RmEpsilonFst<Arc> rfsa(ifsa);
  1154. DeterminizeFst<Arc> dfa(rfsa);
  1155. DifferenceFst<Arc> cfsa(univ_fsa_, dfa);
  1156. *ofsa = cfsa;
  1157. }
  1158. // FSA with no states.
  1159. VectorFst<Arc> zero_fsa_;
  1160. // FSA with one state that accepts epsilon.
  1161. VectorFst<Arc> one_fsa_;
  1162. // FSA with one state that accepts all strings.
  1163. VectorFst<Arc> univ_fsa_;
  1164. // Random state.
  1165. std::mt19937_64 rand_;
  1166. };
  1167. // This class tests a variety of identities and properties that must
  1168. // hold for various FST algorithms. It randomly generates FSTs, using
  1169. // function object 'weight_generator' to select weights. 'WeightTester'
  1170. // and 'UnweightedTester' are then called.
  1171. template <class Arc>
  1172. class AlgoTester {
  1173. public:
  1174. using Label = typename Arc::Label;
  1175. using StateId = typename Arc::StateId;
  1176. using Weight = typename Arc::Weight;
  1177. using WeightGenerator = WeightGenerate<Weight>;
  1178. AlgoTester(WeightGenerator generator, uint64_t seed)
  1179. : generate_(std::move(generator)), rand_(seed) {
  1180. one_fst_.AddState();
  1181. one_fst_.SetStart(0);
  1182. one_fst_.SetFinal(0);
  1183. univ_fst_.AddState();
  1184. univ_fst_.SetStart(0);
  1185. univ_fst_.SetFinal(0);
  1186. for (int i = 0; i < kNumRandomLabels; ++i) univ_fst_.EmplaceArc(0, i, i, 0);
  1187. weighted_tester_.reset(new WeightedTester<Arc>(seed, zero_fst_, one_fst_,
  1188. univ_fst_, generate_));
  1189. unweighted_tester_.reset(
  1190. new UnweightedTester<Arc>(zero_fst_, one_fst_, univ_fst_, seed));
  1191. }
  1192. void MakeRandFst(MutableFst<Arc> *fst) {
  1193. RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs,
  1194. kNumRandomLabels, kAcyclicProb, generate_,
  1195. rand_(), fst);
  1196. }
  1197. void Test() {
  1198. VLOG(1) << "weight type = " << Weight::Type();
  1199. for (int i = 0; i < FST_FLAGS_repeat; ++i) {
  1200. // Random transducers
  1201. VectorFst<Arc> T1;
  1202. VectorFst<Arc> T2;
  1203. VectorFst<Arc> T3;
  1204. MakeRandFst(&T1);
  1205. MakeRandFst(&T2);
  1206. MakeRandFst(&T3);
  1207. weighted_tester_->Test(T1, T2, T3);
  1208. VectorFst<Arc> A1(T1);
  1209. VectorFst<Arc> A2(T2);
  1210. VectorFst<Arc> A3(T3);
  1211. Project(&A1, ProjectType::OUTPUT);
  1212. Project(&A2, ProjectType::INPUT);
  1213. Project(&A3, ProjectType::INPUT);
  1214. ArcMap(&A1, rm_weight_mapper_);
  1215. ArcMap(&A2, rm_weight_mapper_);
  1216. ArcMap(&A3, rm_weight_mapper_);
  1217. unweighted_tester_->Test(A1, A2, A3);
  1218. }
  1219. }
  1220. private:
  1221. // Generates weights used in testing.
  1222. WeightGenerator generate_;
  1223. // Random state used to seed RandFst.
  1224. std::mt19937_64 rand_;
  1225. // FST with no states
  1226. VectorFst<Arc> zero_fst_;
  1227. // FST with one state that accepts epsilon.
  1228. VectorFst<Arc> one_fst_;
  1229. // FST with one state that accepts all strings.
  1230. VectorFst<Arc> univ_fst_;
  1231. // Tests weighted FSTs
  1232. std::unique_ptr<WeightedTester<Arc>> weighted_tester_;
  1233. // Tests unweighted FSTs
  1234. std::unique_ptr<UnweightedTester<Arc>> unweighted_tester_;
  1235. // Mapper to remove weights from an Fst
  1236. RmWeightMapper<Arc> rm_weight_mapper_;
  1237. // Maximum number of states in random test Fst.
  1238. static constexpr int kNumRandomStates = 10;
  1239. // Maximum number of arcs in random test Fst.
  1240. static constexpr int kNumRandomArcs = 25;
  1241. // Number of alternative random labels.
  1242. static constexpr int kNumRandomLabels = 5;
  1243. // Probability to force an acyclic Fst
  1244. static constexpr float kAcyclicProb = .25;
  1245. // Maximum random path length.
  1246. static constexpr int kRandomPathLength = 25;
  1247. // Number of random paths to explore.
  1248. static constexpr int kNumRandomPaths = 100;
  1249. AlgoTester(const AlgoTester &) = delete;
  1250. AlgoTester &operator=(const AlgoTester &) = delete;
  1251. };
  1252. } // namespace fst
  1253. #endif // FST_TEST_ALGO_TEST_H_