You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

348 lines
10 KiB

  1. // Copyright 2005-2024 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the 'License');
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an 'AS IS' BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // See www.openfst.org for extensive documentation on this weighted
  16. // finite-state transducer library.
  17. //
  18. // Regression test for FST classes.
  19. #ifndef FST_TEST_FST_TEST_H_
  20. #define FST_TEST_FST_TEST_H_
  21. #include <cstddef>
  22. #include <memory>
  23. #include <string>
  24. #include <fst/log.h>
  25. #include <fst/equal.h>
  26. #include <fst/expanded-fst.h>
  27. #include <fstream>
  28. #include <fst/fst.h>
  29. #include <fst/matcher.h>
  30. #include <fst/mutable-fst.h>
  31. #include <fst/properties.h>
  32. #include <fst/vector-fst.h>
  33. #include <fst/verify.h>
  34. namespace fst {
  35. // This tests an Fst F that is assumed to have a copy method from an
  36. // arbitrary Fst. Some test functions make further assumptions mostly
  37. // obvious from their name. These tests are written as member temple
  38. // functions that take a test fst as its argument so that different
  39. // Fsts in the interface hierarchy can be tested separately and so
  40. // that we can instantiate only those tests that make sense for a
  41. // particular Fst.
  42. template <class F>
  43. class FstTester {
  44. public:
  45. using Arc = typename F::Arc;
  46. using StateId = typename Arc::StateId;
  47. using Weight = typename Arc::Weight;
  48. using Label = typename Arc::Label;
  49. explicit FstTester(size_t num_states = 128, bool weighted = true)
  50. : num_states_(num_states), weighted_(weighted) {
  51. VectorFst<Arc> vfst;
  52. InitFst(&vfst, num_states);
  53. testfst_ = std::make_unique<F>(vfst);
  54. }
  55. // This verifies the contents described in InitFst() using
  56. // methods defined in a generic Fst.
  57. template <class G>
  58. void TestBase(const G &fst) const {
  59. StateId ns = 0;
  60. StateIterator<G> siter(fst);
  61. Matcher<G> matcher(fst, MATCH_INPUT);
  62. MatchType match_type = matcher.Type(true);
  63. bool has_states = false;
  64. for (; !siter.Done(); siter.Next()) {
  65. has_states = true;
  66. }
  67. CHECK_EQ(fst.Start(), has_states ? 0 : kNoStateId);
  68. for (siter.Reset(); !siter.Done(); siter.Next()) {
  69. StateId s = siter.Value();
  70. matcher.SetState(s);
  71. CHECK_EQ(fst.Final(s), NthWeight(s));
  72. size_t na = 0;
  73. ArcIterator<G> aiter(fst, s);
  74. for (; !aiter.Done(); aiter.Next()) {
  75. }
  76. for (aiter.Reset(); !aiter.Done(); aiter.Next()) {
  77. ++na;
  78. const Arc &arc = aiter.Value();
  79. CHECK_EQ(arc.ilabel, na);
  80. CHECK_EQ(arc.olabel, 0);
  81. CHECK_EQ(arc.weight, NthWeight(na));
  82. if (na == ns + 1) {
  83. CHECK_EQ(arc.nextstate, s == num_states_ - 1 ? 0 : s + 1);
  84. } else {
  85. CHECK_EQ(arc.nextstate, s);
  86. }
  87. if (match_type == MATCH_INPUT) {
  88. CHECK(matcher.Find(arc.ilabel));
  89. CHECK_EQ(matcher.Value().ilabel, arc.ilabel);
  90. }
  91. }
  92. CHECK_EQ(na, s + 1);
  93. CHECK_EQ(na, aiter.Position());
  94. CHECK_EQ(fst.NumArcs(s), s + 1);
  95. CHECK_EQ(fst.NumInputEpsilons(s), 0);
  96. CHECK_EQ(fst.NumOutputEpsilons(s), s + 1);
  97. CHECK(!matcher.Find(s + 2)); // out-of-range
  98. CHECK(!matcher.Find(kNoLabel)); // no explicit input epsilons
  99. CHECK(matcher.Find(0));
  100. CHECK_EQ(matcher.Value().ilabel, kNoLabel); // implicit epsilon loop
  101. ++ns;
  102. }
  103. CHECK_EQ(num_states_, ns);
  104. CHECK(Verify(fst));
  105. CHECK(fst.Properties(ns > 0 ? kNotAcceptor : kAcceptor, true));
  106. CHECK(fst.Properties(ns > 0 ? kOEpsilons : kNoOEpsilons, true));
  107. }
  108. void TestBase() const { TestBase(*testfst_); }
  109. // This verifies methods specfic to an ExpandedFst.
  110. template <class G>
  111. void TestExpanded(const G &fst) const {
  112. CHECK_EQ(fst.NumStates(), num_states_);
  113. StateId ns = 0;
  114. for (StateIterator<G> siter(fst); !siter.Done(); siter.Next()) {
  115. ++ns;
  116. }
  117. CHECK_EQ(fst.NumStates(), ns);
  118. CHECK(fst.Properties(kExpanded, false));
  119. }
  120. void TestExpanded() const { TestExpanded(*testfst_); }
  121. // This verifies methods specific to a MutableFst.
  122. template <class G>
  123. void TestMutable(G *fst) const {
  124. for (StateIterator<G> siter(*fst); !siter.Done(); siter.Next()) {
  125. StateId s = siter.Value();
  126. size_t na = 0;
  127. size_t ni = fst->NumInputEpsilons(s);
  128. MutableArcIterator<G> aiter(fst, s);
  129. for (; !aiter.Done(); aiter.Next()) {
  130. }
  131. for (aiter.Reset(); !aiter.Done(); aiter.Next()) {
  132. ++na;
  133. Arc arc = aiter.Value();
  134. arc.ilabel = 0;
  135. aiter.SetValue(arc);
  136. arc = aiter.Value();
  137. CHECK_EQ(arc.ilabel, 0);
  138. CHECK_EQ(fst->NumInputEpsilons(s), ni + 1);
  139. arc.ilabel = na;
  140. aiter.SetValue(arc);
  141. CHECK_EQ(fst->NumInputEpsilons(s), ni);
  142. }
  143. }
  144. {
  145. std::unique_ptr<G> cfst1(fst->Copy());
  146. cfst1->DeleteStates();
  147. CHECK_EQ(cfst1->NumStates(), 0);
  148. }
  149. std::unique_ptr<G> cfst2(fst->Copy());
  150. for (StateIterator<G> siter(*cfst2); !siter.Done(); siter.Next()) {
  151. StateId s = siter.Value();
  152. cfst2->DeleteArcs(s);
  153. CHECK_EQ(cfst2->NumArcs(s), 0);
  154. CHECK_EQ(cfst2->NumInputEpsilons(s), 0);
  155. CHECK_EQ(cfst2->NumOutputEpsilons(s), 0);
  156. }
  157. }
  158. void TestMutable() { TestMutable(testfst_.get()); }
  159. // This verifies operator=
  160. template <class G>
  161. void TestAssign(const G &fst) const {
  162. // Assignment from G
  163. G afst1;
  164. afst1 = fst;
  165. CHECK(Equal(fst, afst1));
  166. // Assignment from Fst
  167. G afst2;
  168. afst2 = static_cast<const Fst<Arc> &>(fst);
  169. CHECK(Equal(fst, afst2));
  170. // Assignment from self
  171. afst2.operator=(afst2);
  172. CHECK(Equal(fst, afst2));
  173. }
  174. void TestAssign() { TestAssign(*testfst_); }
  175. // This verifies the copy constructor and Copy method.
  176. template <class G>
  177. void TestCopy(const G &fst) const {
  178. // Copy from G
  179. G c1fst(fst);
  180. TestBase(c1fst);
  181. // Copy from Fst
  182. const G c2fst(static_cast<const Fst<Arc> &>(fst));
  183. TestBase(c2fst);
  184. // Copy from self
  185. std::unique_ptr<const G> c3fst(fst.Copy());
  186. TestBase(*c3fst);
  187. }
  188. void TestCopy() const { TestCopy(*testfst_); }
  189. // This verifies the read/write methods.
  190. template <class G>
  191. void TestIO(const G &fst) const {
  192. const std::string filename = FST_FLAGS_tmpdir + "/test.fst";
  193. const std::string aligned =
  194. FST_FLAGS_tmpdir + "/aligned.fst";
  195. {
  196. // write/read
  197. CHECK(fst.Write(filename));
  198. auto ffst = fst::WrapUnique(G::Read(filename));
  199. CHECK(ffst);
  200. TestBase(*ffst);
  201. }
  202. {
  203. // generic read/cast/test
  204. auto gfst = fst::WrapUnique(Fst<Arc>::Read(filename));
  205. CHECK(gfst);
  206. G *dfst = down_cast<G *>(gfst.get());
  207. TestBase(*dfst);
  208. // generic write/read/test
  209. CHECK(gfst->Write(filename));
  210. auto hfst = fst::WrapUnique(Fst<Arc>::Read(filename));
  211. CHECK(hfst);
  212. TestBase(*hfst);
  213. }
  214. {
  215. // check mmaping by first writing the file with the aligned attribute set
  216. {
  217. std::ofstream ostr(aligned);
  218. FstWriteOptions opts;
  219. opts.source = aligned;
  220. opts.align = true;
  221. CHECK(fst.Write(ostr, opts));
  222. }
  223. std::ifstream istr(aligned);
  224. FstReadOptions opts;
  225. opts.mode = FstReadOptions::ReadMode("map");
  226. opts.source = aligned;
  227. auto gfst = fst::WrapUnique(G::Read(istr, opts));
  228. CHECK(gfst);
  229. TestBase(*gfst);
  230. }
  231. // check mmaping of unaligned files to make sure it does not fail.
  232. {
  233. {
  234. std::ofstream ostr(aligned);
  235. FstWriteOptions opts;
  236. opts.source = aligned;
  237. opts.align = false;
  238. CHECK(fst.Write(ostr, opts));
  239. }
  240. std::ifstream istr(aligned);
  241. FstReadOptions opts;
  242. opts.mode = FstReadOptions::ReadMode("map");
  243. opts.source = aligned;
  244. auto gfst = fst::WrapUnique(G::Read(istr, opts));
  245. CHECK(gfst);
  246. TestBase(*gfst);
  247. }
  248. // expanded write/read/test
  249. if (fst.Properties(kExpanded, false)) {
  250. auto efst = fst::WrapUnique(ExpandedFst<Arc>::Read(filename));
  251. CHECK(efst);
  252. TestBase(*efst);
  253. TestExpanded(*efst);
  254. }
  255. // mutable write/read/test
  256. if (fst.Properties(kMutable, false)) {
  257. auto mfst = fst::WrapUnique(MutableFst<Arc>::Read(filename));
  258. CHECK(mfst);
  259. TestBase(*mfst);
  260. TestExpanded(*mfst);
  261. TestMutable(mfst.get());
  262. }
  263. }
  264. void TestIO() const { TestIO(*testfst_); }
  265. private:
  266. // This constructs test FSTs. Given a mutable FST, will leave
  267. // the FST as follows:
  268. // (I) NumStates() = nstates
  269. // (II) Start() = 0
  270. // (III) Final(s) = NthWeight(s)
  271. // (IV) For state s:
  272. // (a) NumArcs(s) == s + 1
  273. // (b) For ith arc (i: 1 to s) of s:
  274. // (1) ilabel = i
  275. // (2) olabel = 0
  276. // (3) weight = NthWeight(i)
  277. // (4) nextstate = s
  278. // (c) s+1st arc of s:
  279. // (1) ilabel = s + 1
  280. // (2) olabel = 0
  281. // (3) weight = NthWeight(s + 1)
  282. // (4) nextstate = s + 1 if s < nstates - 1
  283. // 0 if s == nstates - 1
  284. void InitFst(MutableFst<Arc> *fst, size_t nstates) const {
  285. fst->DeleteStates();
  286. for (StateId s = 0; s < nstates; ++s) {
  287. fst->AddState();
  288. fst->SetFinal(s, NthWeight(s));
  289. for (size_t i = 1; i <= s; ++i) {
  290. Arc arc(i, 0, NthWeight(i), s);
  291. fst->AddArc(s, arc);
  292. }
  293. fst->AddArc(
  294. s, Arc(s + 1, 0, NthWeight(s + 1), s == nstates - 1 ? 0 : s + 1));
  295. }
  296. if (nstates > 0) fst->SetStart(0);
  297. }
  298. // Generates One() + ... + One() (n times) if weighted_,
  299. // otherwise One().
  300. Weight NthWeight(int n) const {
  301. if (!weighted_) return Weight::One();
  302. Weight w = Weight::Zero();
  303. for (int i = 0; i < n; ++i) w = Plus(w, Weight::One());
  304. return w;
  305. }
  306. size_t num_states_ = 0;
  307. bool weighted_ = true;
  308. std::unique_ptr<F> testfst_; // what we're testing
  309. };
  310. } // namespace fst
  311. #endif // FST_TEST_FST_TEST_H_