|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Regression test for various FST algorithms.
#ifndef FST_TEST_ALGO_TEST_H_
#define FST_TEST_ALGO_TEST_H_
#include <sys/types.h>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arc-map.h>
#include <fst/arc.h>
#include <fst/arcfilter.h>
#include <fst/arcsort.h>
#include <fst/cache.h>
#include <fst/closure.h>
#include <fst/compose-filter.h>
#include <fst/compose.h>
#include <fst/concat.h>
#include <fst/connect.h>
#include <fst/determinize.h>
#include <fst/dfs-visit.h>
#include <fst/difference.h>
#include <fst/disambiguate.h>
#include <fst/encode.h>
#include <fst/equivalent.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/fstlib.h>
#include <fst/intersect.h>
#include <fst/invert.h>
#include <fst/lookahead-matcher.h>
#include <fst/matcher-fst.h>
#include <fst/matcher.h>
#include <fst/minimize.h>
#include <fst/mutable-fst.h>
#include <fst/pair-weight.h>
#include <fst/project.h>
#include <fst/properties.h>
#include <fst/prune.h>
#include <fst/push.h>
#include <fst/randequivalent.h>
#include <fst/randgen.h>
#include <fst/rational.h>
#include <fst/relabel.h>
#include <fst/reverse.h>
#include <fst/reweight.h>
#include <fst/rmepsilon.h>
#include <fst/shortest-distance.h>
#include <fst/shortest-path.h>
#include <fst/string-weight.h>
#include <fst/synchronize.h>
#include <fst/topsort.h>
#include <fst/union-weight.h>
#include <fst/union.h>
#include <fst/vector-fst.h>
#include <fst/verify.h>
#include <fst/weight.h>
#include <fst/test/rand-fst.h>
DECLARE_int32(repeat); // defined in ./algo_test.cc
namespace fst {
// Mapper to change input and output label of every transition into
// epsilons.
template <class A> class EpsMapper { public: EpsMapper() = default;
A operator()(const A &arc) const { return A(0, 0, arc.weight, arc.nextstate); }
uint64_t Properties(uint64_t props) const { props &= ~kNotAcceptor; props |= kAcceptor; props &= ~kNoIEpsilons & ~kNoOEpsilons & ~kNoEpsilons; props |= kIEpsilons | kOEpsilons | kEpsilons; props &= ~kNotILabelSorted & ~kNotOLabelSorted; props |= kILabelSorted | kOLabelSorted; return props; }
MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } };
// Generic - no lookahead.
template <class Arc> void LookAheadCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2, MutableFst<Arc> *ofst) { Compose(ifst1, ifst2, ofst); }
// Specialized and epsilon olabel acyclic - lookahead.
inline void LookAheadCompose(const Fst<StdArc> &ifst1, const Fst<StdArc> &ifst2, MutableFst<StdArc> *ofst) { std::vector<StdArc::StateId> order; bool acyclic; TopOrderVisitor<StdArc> visitor(&order, &acyclic); DfsVisit(ifst1, &visitor, OutputEpsilonArcFilter<StdArc>()); if (acyclic) { // no ifst1 output epsilon cycles?
StdOLabelLookAheadFst lfst1(ifst1); StdVectorFst lfst2(ifst2); LabelLookAheadRelabeler<StdArc>::Relabel(&lfst2, lfst1, true); Compose(lfst1, lfst2, ofst); } else { Compose(ifst1, ifst2, ofst); } }
// This class tests a variety of identities and properties that must
// hold for various algorithms on weighted FSTs.
template <class Arc> class WeightedTester { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using WeightGenerator = WeightGenerate<Weight>;
WeightedTester(uint64_t seed, const Fst<Arc> &zero_fst, const Fst<Arc> &one_fst, const Fst<Arc> &univ_fst, WeightGenerator weight_generator) : seed_(seed), rand_(seed), zero_fst_(zero_fst), one_fst_(one_fst), univ_fst_(univ_fst), generate_(std::move(weight_generator)) {}
void Test(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) { TestRational(T1, T2, T3); TestMap(T1); TestCompose(T1, T2, T3); TestSort(T1); TestOptimize(T1); TestSearch(T1); }
private: // Tests rational operations with identities
void TestRational(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) { { VLOG(1) << "Check destructive and delayed union are equivalent."; VectorFst<Arc> U1(T1); Union(&U1, T2); UnionFst<Arc> U2(T1, T2); CHECK(Equiv(U1, U2)); }
{ VLOG(1) << "Check destructive and delayed concatenation are equivalent."; VectorFst<Arc> C1(T1); Concat(&C1, T2); ConcatFst<Arc> C2(T1, T2); CHECK(Equiv(C1, C2)); VectorFst<Arc> C3(T2); Concat(T1, &C3); CHECK(Equiv(C3, C2)); }
{ VLOG(1) << "Check destructive and delayed closure* are equivalent."; VectorFst<Arc> C1(T1); Closure(&C1, CLOSURE_STAR); ClosureFst<Arc> C2(T1, CLOSURE_STAR); CHECK(Equiv(C1, C2)); }
{ VLOG(1) << "Check destructive and delayed closure+ are equivalent."; VectorFst<Arc> C1(T1); Closure(&C1, CLOSURE_PLUS); ClosureFst<Arc> C2(T1, CLOSURE_PLUS); CHECK(Equiv(C1, C2)); }
{ VLOG(1) << "Check union is associative (destructive)."; VectorFst<Arc> U1(T1); Union(&U1, T2); Union(&U1, T3);
VectorFst<Arc> U3(T2); Union(&U3, T3); VectorFst<Arc> U4(T1); Union(&U4, U3);
CHECK(Equiv(U1, U4)); }
{ VLOG(1) << "Check union is associative (delayed)."; UnionFst<Arc> U1(T1, T2); UnionFst<Arc> U2(U1, T3);
UnionFst<Arc> U3(T2, T3); UnionFst<Arc> U4(T1, U3);
CHECK(Equiv(U2, U4)); }
{ VLOG(1) << "Check union is associative (destructive delayed)."; UnionFst<Arc> U1(T1, T2); Union(&U1, T3);
UnionFst<Arc> U3(T2, T3); UnionFst<Arc> U4(T1, U3);
CHECK(Equiv(U1, U4)); }
{ VLOG(1) << "Check concatenation is associative (destructive)."; VectorFst<Arc> C1(T1); Concat(&C1, T2); Concat(&C1, T3);
VectorFst<Arc> C3(T2); Concat(&C3, T3); VectorFst<Arc> C4(T1); Concat(&C4, C3);
CHECK(Equiv(C1, C4)); }
{ VLOG(1) << "Check concatenation is associative (delayed)."; ConcatFst<Arc> C1(T1, T2); ConcatFst<Arc> C2(C1, T3);
ConcatFst<Arc> C3(T2, T3); ConcatFst<Arc> C4(T1, C3);
CHECK(Equiv(C2, C4)); }
{ VLOG(1) << "Check concatenation is associative (destructive delayed)."; ConcatFst<Arc> C1(T1, T2); Concat(&C1, T3);
ConcatFst<Arc> C3(T2, T3); ConcatFst<Arc> C4(T1, C3);
CHECK(Equiv(C1, C4)); }
if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check concatenation left distributes" << " over union (destructive).";
VectorFst<Arc> U1(T1); Union(&U1, T2); VectorFst<Arc> C1(T3); Concat(&C1, U1);
VectorFst<Arc> C2(T3); Concat(&C2, T1); VectorFst<Arc> C3(T3); Concat(&C3, T2); VectorFst<Arc> U2(C2); Union(&U2, C3);
CHECK(Equiv(C1, U2)); }
if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check concatenation right distributes" << " over union (destructive)."; VectorFst<Arc> U1(T1); Union(&U1, T2); VectorFst<Arc> C1(U1); Concat(&C1, T3);
VectorFst<Arc> C2(T1); Concat(&C2, T3); VectorFst<Arc> C3(T2); Concat(&C3, T3); VectorFst<Arc> U2(C2); Union(&U2, C3);
CHECK(Equiv(C1, U2)); }
if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check concatenation left distributes over union (delayed)."; UnionFst<Arc> U1(T1, T2); ConcatFst<Arc> C1(T3, U1);
ConcatFst<Arc> C2(T3, T1); ConcatFst<Arc> C3(T3, T2); UnionFst<Arc> U2(C2, C3);
CHECK(Equiv(C1, U2)); }
if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check concatenation right distributes over union (delayed)."; UnionFst<Arc> U1(T1, T2); ConcatFst<Arc> C1(U1, T3);
ConcatFst<Arc> C2(T1, T3); ConcatFst<Arc> C3(T2, T3); UnionFst<Arc> U2(C2, C3);
CHECK(Equiv(C1, U2)); }
if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check T T* == T+ (destructive)."; VectorFst<Arc> S(T1); Closure(&S, CLOSURE_STAR); VectorFst<Arc> C(T1); Concat(&C, S);
VectorFst<Arc> P(T1); Closure(&P, CLOSURE_PLUS);
CHECK(Equiv(C, P)); }
if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check T* T == T+ (destructive)."; VectorFst<Arc> S(T1); Closure(&S, CLOSURE_STAR); VectorFst<Arc> C(S); Concat(&C, T1);
VectorFst<Arc> P(T1); Closure(&P, CLOSURE_PLUS);
CHECK(Equiv(C, P)); }
if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check T T* == T+ (delayed)."; ClosureFst<Arc> S(T1, CLOSURE_STAR); ConcatFst<Arc> C(T1, S);
ClosureFst<Arc> P(T1, CLOSURE_PLUS);
CHECK(Equiv(C, P)); }
if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check T* T == T+ (delayed)."; ClosureFst<Arc> S(T1, CLOSURE_STAR); ConcatFst<Arc> C(S, T1);
ClosureFst<Arc> P(T1, CLOSURE_PLUS);
CHECK(Equiv(C, P)); } }
// Tests map-based operations.
void TestMap(const Fst<Arc> &T) { { VLOG(1) << "Check destructive and delayed projection are equivalent."; VectorFst<Arc> P1(T); Project(&P1, ProjectType::INPUT); ProjectFst<Arc> P2(T, ProjectType::INPUT); CHECK(Equiv(P1, P2)); }
{ VLOG(1) << "Check destructive and delayed inversion are equivalent."; VectorFst<Arc> I1(T); Invert(&I1); InvertFst<Arc> I2(T); CHECK(Equiv(I1, I2)); }
{ VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive)."; VectorFst<Arc> P1(T); VectorFst<Arc> I1(T); Project(&P1, ProjectType::INPUT); Invert(&I1); Project(&I1, ProjectType::OUTPUT); CHECK(Equiv(P1, I1)); }
{ VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive)."; VectorFst<Arc> P1(T); VectorFst<Arc> I1(T); Project(&P1, ProjectType::OUTPUT); Invert(&I1); Project(&I1, ProjectType::INPUT); CHECK(Equiv(P1, I1)); }
{ VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed)."; ProjectFst<Arc> P1(T, ProjectType::INPUT); InvertFst<Arc> I1(T); ProjectFst<Arc> P2(I1, ProjectType::OUTPUT); CHECK(Equiv(P1, P2)); }
{ VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed)."; ProjectFst<Arc> P1(T, ProjectType::OUTPUT); InvertFst<Arc> I1(T); ProjectFst<Arc> P2(I1, ProjectType::INPUT); CHECK(Equiv(P1, P2)); }
{ VLOG(1) << "Check destructive relabeling"; static const int kNumLabels = 10; // set up relabeling pairs
std::vector<Label> labelset(kNumLabels); for (size_t i = 0; i < kNumLabels; ++i) labelset[i] = i; for (size_t i = 0; i < kNumLabels; ++i) { using std::swap; const auto index = std::uniform_int_distribution<>(0, kNumLabels - 1)(rand_); swap(labelset[i], labelset[index]); }
std::vector<std::pair<Label, Label>> ipairs1(kNumLabels); std::vector<std::pair<Label, Label>> opairs1(kNumLabels); for (size_t i = 0; i < kNumLabels; ++i) { ipairs1[i] = std::make_pair(i, labelset[i]); opairs1[i] = std::make_pair(labelset[i], i); } VectorFst<Arc> R(T); Relabel(&R, ipairs1, opairs1);
std::vector<std::pair<Label, Label>> ipairs2(kNumLabels); std::vector<std::pair<Label, Label>> opairs2(kNumLabels); for (size_t i = 0; i < kNumLabels; ++i) { ipairs2[i] = std::make_pair(labelset[i], i); opairs2[i] = std::make_pair(i, labelset[i]); } Relabel(&R, ipairs2, opairs2); CHECK(Equiv(R, T));
VLOG(1) << "Check on-the-fly relabeling"; RelabelFst<Arc> Rdelay(T, ipairs1, opairs1);
RelabelFst<Arc> RRdelay(Rdelay, ipairs2, opairs2); CHECK(Equiv(RRdelay, T)); }
{ VLOG(1) << "Check encoding/decoding (destructive)."; VectorFst<Arc> D(T); uint8_t encode_props = 0; if (std::bernoulli_distribution(.5)(rand_)) { encode_props |= kEncodeLabels; } if (std::bernoulli_distribution(.5)(rand_)) { encode_props |= kEncodeWeights; } EncodeMapper<Arc> encoder(encode_props, ENCODE); Encode(&D, &encoder); Decode(&D, encoder); CHECK(Equiv(D, T)); }
{ VLOG(1) << "Check encoding/decoding (delayed)."; uint8_t encode_props = 0; if (std::bernoulli_distribution(.5)(rand_)) { encode_props |= kEncodeLabels; } if (std::bernoulli_distribution(.5)(rand_)) { encode_props |= kEncodeWeights; } EncodeMapper<Arc> encoder(encode_props, ENCODE); EncodeFst<Arc> E(T, &encoder); VectorFst<Arc> Encoded(E); DecodeFst<Arc> D(Encoded, encoder); CHECK(Equiv(D, T)); }
{ VLOG(1) << "Check gallic mappers (constructive)."; ToGallicMapper<Arc> to_mapper; FromGallicMapper<Arc> from_mapper; VectorFst<GallicArc<Arc>> G; VectorFst<Arc> F; ArcMap(T, &G, to_mapper); ArcMap(G, &F, from_mapper); CHECK(Equiv(T, F)); }
{ VLOG(1) << "Check gallic mappers (delayed)."; ArcMapFst G(T, ToGallicMapper<Arc>()); ArcMapFst F(G, FromGallicMapper<Arc>()); CHECK(Equiv(T, F)); } }
// Tests compose-based operations.
void TestCompose(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) { if (!(Weight::Properties() & kCommutative)) return;
VectorFst<Arc> S1(T1); VectorFst<Arc> S2(T2); VectorFst<Arc> S3(T3);
ILabelCompare<Arc> icomp; OLabelCompare<Arc> ocomp;
ArcSort(&S1, ocomp); ArcSort(&S2, ocomp); ArcSort(&S3, icomp);
{ VLOG(1) << "Check composition is associative."; ComposeFst<Arc> C1(S1, S2); ComposeFst<Arc> C2(C1, S3); ComposeFst<Arc> C3(S2, S3); ComposeFst<Arc> C4(S1, C3);
CHECK(Equiv(C2, C4)); }
{ VLOG(1) << "Check composition left distributes over union."; UnionFst<Arc> U1(S2, S3); ComposeFst<Arc> C1(S1, U1);
ComposeFst<Arc> C2(S1, S2); ComposeFst<Arc> C3(S1, S3); UnionFst<Arc> U2(C2, C3);
CHECK(Equiv(C1, U2)); }
{ VLOG(1) << "Check composition right distributes over union."; UnionFst<Arc> U1(S1, S2); ComposeFst<Arc> C1(U1, S3);
ComposeFst<Arc> C2(S1, S3); ComposeFst<Arc> C3(S2, S3); UnionFst<Arc> U2(C2, C3);
CHECK(Equiv(C1, U2)); }
VectorFst<Arc> A1(S1); VectorFst<Arc> A2(S2); VectorFst<Arc> A3(S3); Project(&A1, ProjectType::OUTPUT); Project(&A2, ProjectType::INPUT); Project(&A3, ProjectType::INPUT);
{ VLOG(1) << "Check intersection is commutative."; IntersectFst<Arc> I1(A1, A2); IntersectFst<Arc> I2(A2, A1); CHECK(Equiv(I1, I2)); }
{ VLOG(1) << "Check all epsilon filters leads to equivalent results."; using M = Matcher<Fst<Arc>>; ComposeFst<Arc> C1(S1, S2); ComposeFst<Arc> C2( S1, S2, ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>>()); ComposeFst<Arc> C3(S1, S2, ComposeFstOptions<Arc, M, MatchComposeFilter<M>>());
CHECK(Equiv(C1, C2)); CHECK(Equiv(C1, C3));
if ((Weight::Properties() & kIdempotent) || S1.Properties(kNoOEpsilons, false) || S2.Properties(kNoIEpsilons, false)) { ComposeFst<Arc> C4( S1, S2, ComposeFstOptions<Arc, M, TrivialComposeFilter<M>>()); CHECK(Equiv(C1, C4)); ComposeFst<Arc> C5( S1, S2, ComposeFstOptions<Arc, M, NoMatchComposeFilter<M>>()); CHECK(Equiv(C1, C5)); }
if (S1.Properties(kNoOEpsilons, false) && S2.Properties(kNoIEpsilons, false)) { ComposeFst<Arc> C6(S1, S2, ComposeFstOptions<Arc, M, NullComposeFilter<M>>()); CHECK(Equiv(C1, C6)); } }
{ VLOG(1) << "Check look-ahead filters lead to equivalent results."; VectorFst<Arc> C1, C2; Compose(S1, S2, &C1); LookAheadCompose(S1, S2, &C2); CHECK(Equiv(C1, C2)); } }
// Tests sorting operations
void TestSort(const Fst<Arc> &T) { ILabelCompare<Arc> icomp; OLabelCompare<Arc> ocomp;
{ VLOG(1) << "Check arc sorted Fst is equivalent to its input."; VectorFst<Arc> S1(T); ArcSort(&S1, icomp); CHECK(Equiv(T, S1)); }
{ VLOG(1) << "Check destructive and delayed arcsort are equivalent."; VectorFst<Arc> S1(T); ArcSort(&S1, icomp); ArcSortFst<Arc, ILabelCompare<Arc>> S2(T, icomp); CHECK(Equiv(S1, S2)); }
{ VLOG(1) << "Check ilabel sorting vs. olabel sorting with inversions."; VectorFst<Arc> S1(T); VectorFst<Arc> S2(T); ArcSort(&S1, icomp); Invert(&S2); ArcSort(&S2, ocomp); Invert(&S2); CHECK(Equiv(S1, S2)); }
{ VLOG(1) << "Check topologically sorted Fst is equivalent to its input."; VectorFst<Arc> S1(T); TopSort(&S1); CHECK(Equiv(T, S1)); }
{ VLOG(1) << "Check reverse(reverse(T)) = T"; for (int i = 0; i < 2; ++i) { VectorFst<ReverseArc<Arc>> R1; VectorFst<Arc> R2; bool require_superinitial = i == 1; Reverse(T, &R1, require_superinitial); Reverse(R1, &R2, require_superinitial); CHECK(Equiv(T, R2)); } } }
// Tests optimization operations
void TestOptimize(const Fst<Arc> &T) { uint64_t tprops = T.Properties(kFstProperties, true); uint64_t wprops = Weight::Properties();
VectorFst<Arc> A(T); Project(&A, ProjectType::INPUT); { VLOG(1) << "Check connected FST is equivalent to its input."; VectorFst<Arc> C1(T); Connect(&C1); CHECK(Equiv(T, C1)); }
if ((wprops & kSemiring) == kSemiring && (tprops & kAcyclic || wprops & kIdempotent)) { VLOG(1) << "Check epsilon-removed FST is equivalent to its input."; VectorFst<Arc> R1(T); RmEpsilon(&R1); CHECK(Equiv(T, R1));
VLOG(1) << "Check destructive and delayed epsilon removal" << "are equivalent."; RmEpsilonFst<Arc> R2(T); CHECK(Equiv(R1, R2));
VLOG(1) << "Check an FST with a large proportion" << " of epsilon transitions:"; // Maps all transitions of T to epsilon-transitions and append
// a non-epsilon transition.
VectorFst<Arc> U; ArcMap(T, &U, EpsMapper<Arc>()); VectorFst<Arc> V; V.SetStart(V.AddState()); Arc arc(1, 1, Weight::One(), V.AddState()); V.AddArc(V.Start(), arc); V.SetFinal(arc.nextstate, Weight::One()); Concat(&U, V); // Check that epsilon-removal preserves the shortest-distance
// from the initial state to the final states.
std::vector<Weight> d; ShortestDistance(U, &d, true); Weight w = U.Start() < d.size() ? d[U.Start()] : Weight::Zero(); VectorFst<Arc> U1(U); RmEpsilon(&U1); ShortestDistance(U1, &d, true); Weight w1 = U1.Start() < d.size() ? d[U1.Start()] : Weight::Zero(); CHECK(ApproxEqual(w, w1, kTestDelta)); RmEpsilonFst<Arc> U2(U); ShortestDistance(U2, &d, true); Weight w2 = U2.Start() < d.size() ? d[U2.Start()] : Weight::Zero(); CHECK(ApproxEqual(w, w2, kTestDelta)); }
if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) { VLOG(1) << "Check determinized FSA is equivalent to its input."; DeterminizeFst<Arc> D(A); CHECK(Equiv(A, D));
{ VLOG(1) << "Check determinized FST is equivalent to its input."; DeterminizeFstOptions<Arc> opts; opts.type = DETERMINIZE_NONFUNCTIONAL; DeterminizeFst<Arc> DT(T, opts); CHECK(Equiv(T, DT)); }
if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) { VLOG(1) << "Check pruning in determinization"; VectorFst<Arc> P; const Weight threshold = generate_(); DeterminizeOptions<Arc> opts; opts.weight_threshold = threshold; Determinize(A, &P, opts); CHECK(P.Properties(kIDeterministic, true)); CHECK(PruneEquiv(A, P, threshold)); }
if ((wprops & kPath) == kPath) { VLOG(1) << "Check min-determinization";
// Ensures no input epsilons
VectorFst<Arc> R(T); std::vector<std::pair<Label, Label>> ipairs, opairs; ipairs.push_back(std::pair<Label, Label>(0, 1)); Relabel(&R, ipairs, opairs);
VectorFst<Arc> M; DeterminizeOptions<Arc> opts; opts.type = DETERMINIZE_DISAMBIGUATE; Determinize(R, &M, opts); CHECK(M.Properties(kIDeterministic, true)); CHECK(MinRelated(M, R)); }
int n; { VLOG(1) << "Check size(min(det(A))) <= size(det(A))" << " and min(det(A)) equiv det(A)"; VectorFst<Arc> M(D); n = M.NumStates(); Minimize(&M, static_cast<MutableFst<Arc> *>(nullptr), kDelta); CHECK(Equiv(D, M)); CHECK(M.NumStates() <= n); n = M.NumStates(); }
if (n && (wprops & kIdempotent) == kIdempotent && A.Properties(kNoEpsilons, true)) { VLOG(1) << "Check that Revuz's algorithm leads to the" << " same number of states as Brozozowski's algorithm";
// Skip test if A is the empty machine or contains epsilons or
// if the semiring is not idempotent (to avoid floating point
// errors)
VectorFst<ReverseArc<Arc>> R; Reverse(A, &R); RmEpsilon(&R); DeterminizeFst<ReverseArc<Arc>> DR(R); VectorFst<Arc> RD; Reverse(DR, &RD); DeterminizeFst<Arc> DRD(RD); VectorFst<Arc> M(DRD); CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
// to the initial state
} }
if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) { VLOG(1) << "Check disambiguated FSA is equivalent to its input."; VectorFst<Arc> R(A), D; RmEpsilon(&R); Disambiguate(R, &D); CHECK(Equiv(R, D)); VLOG(1) << "Check disambiguated FSA is unambiguous"; CHECK(Unambiguous(D));
/* TODO(riley): find out why this fails
if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) { VLOG(1) << "Check pruning in disambiguation"; VectorFst<Arc> P; const Weight threshold = generate_(); DisambiguateOptions<Arc> opts; opts.weight_threshold = threshold; Disambiguate(R, &P, opts); CHECK(Unambiguous(P)); CHECK(PruneEquiv(A, P, threshold)); } */ }
if (Arc::Type() == LogArc::Type() || Arc::Type() == StdArc::Type()) { VLOG(1) << "Check reweight(T) equiv T"; std::vector<Weight> potential; VectorFst<Arc> RI(T); VectorFst<Arc> RF(T); while (potential.size() < RI.NumStates()) { potential.push_back(generate_()); }
Reweight(&RI, potential, REWEIGHT_TO_INITIAL); CHECK(Equiv(T, RI));
Reweight(&RF, potential, REWEIGHT_TO_FINAL); CHECK(Equiv(T, RF)); }
if ((wprops & kIdempotent) || (tprops & kAcyclic)) { VLOG(1) << "Check pushed FST is equivalent to input FST."; // Pushing towards the final state.
if (wprops & kRightSemiring) { VectorFst<Arc> P1; Push<Arc, REWEIGHT_TO_FINAL>(T, &P1, kPushLabels); CHECK(Equiv(T, P1));
VectorFst<Arc> P2; Push<Arc, REWEIGHT_TO_FINAL>(T, &P2, kPushWeights); CHECK(Equiv(T, P2));
VectorFst<Arc> P3; Push<Arc, REWEIGHT_TO_FINAL>(T, &P3, kPushLabels | kPushWeights); CHECK(Equiv(T, P3)); }
// Pushing towards the initial state.
if (wprops & kLeftSemiring) { VectorFst<Arc> P1; Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1, kPushLabels); CHECK(Equiv(T, P1));
VectorFst<Arc> P2; Push<Arc, REWEIGHT_TO_INITIAL>(T, &P2, kPushWeights); CHECK(Equiv(T, P2)); VectorFst<Arc> P3; Push<Arc, REWEIGHT_TO_INITIAL>(T, &P3, kPushLabels | kPushWeights); CHECK(Equiv(T, P3)); } }
if constexpr (IsPath<Weight>::value) { if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) { VLOG(1) << "Check pruning algorithm"; { VLOG(1) << "Check equiv. of constructive and destructive algorithms"; const Weight threshold = generate_(); VectorFst<Arc> P1(T); Prune(&P1, threshold); VectorFst<Arc> P2; Prune(T, &P2, threshold); CHECK(Equiv(P1, P2)); }
{ VLOG(1) << "Check prune(reverse) equiv reverse(prune)"; const Weight threshold = generate_(); VectorFst<ReverseArc<Arc>> R; VectorFst<Arc> P1(T); VectorFst<Arc> P2; Prune(&P1, threshold); Reverse(T, &R); Prune(&R, threshold.Reverse()); Reverse(R, &P2); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check: ShortestDistance(A - prune(A))" << " > ShortestDistance(A) times Threshold"; const Weight threshold = generate_(); VectorFst<Arc> P; Prune(A, &P, threshold); CHECK(PruneEquiv(A, P, threshold)); } } } if (tprops & kAcyclic) { VLOG(1) << "Check synchronize(T) equiv T"; SynchronizeFst<Arc> S(T); CHECK(Equiv(T, S)); } }
// Tests search operations
void TestSearch(const Fst<Arc> &T) { if constexpr (IsPath<Weight>::value) { uint64_t wprops = Weight::Properties();
VectorFst<Arc> A(T); Project(&A, ProjectType::INPUT);
if ((wprops & (kPath | kRightSemiring)) == (kPath | kRightSemiring)) { VLOG(1) << "Check 1-best weight."; VectorFst<Arc> path; ShortestPath(T, &path); Weight tsum = ShortestDistance(T); Weight psum = ShortestDistance(path); CHECK(ApproxEqual(tsum, psum, kTestDelta)); }
if ((wprops & (kPath | kSemiring)) == (kPath | kSemiring)) { VLOG(1) << "Check n-best weights"; VectorFst<Arc> R(A); RmEpsilon(&R, /*connect=*/true, Arc::Weight::Zero(), kNoStateId, kDelta); const int nshortest = std::uniform_int_distribution<>( 0, kNumRandomShortestPaths + 1)(rand_); VectorFst<Arc> paths; ShortestPath(R, &paths, nshortest, /*unique=*/true, /*first_path=*/false, Weight::Zero(), kNumShortestStates, kDelta); std::vector<Weight> distance; ShortestDistance(paths, &distance, true, kDelta); StateId pstart = paths.Start(); if (pstart != kNoStateId) { ArcIterator<Fst<Arc>> piter(paths, pstart); for (; !piter.Done(); piter.Next()) { StateId s = piter.Value().nextstate; Weight nsum = s < distance.size() ? Times(piter.Value().weight, distance[s]) : Weight::Zero(); VectorFst<Arc> path; ShortestPath(R, &path, 1, false, false, Weight::Zero(), kNoStateId, kDelta); Weight dsum = ShortestDistance(path, kDelta); CHECK(ApproxEqual(nsum, dsum, kTestDelta)); ArcMap(&path, RmWeightMapper<Arc>()); VectorFst<Arc> S; Difference(R, path, &S); R = S; } } } } }
// Tests if two FSTS are equivalent by checking if random
// strings from one FST are transduced the same by both FSTs.
template <class A> bool Equiv(const Fst<A> &fst1, const Fst<A> &fst2) { VLOG(1) << "Check FSTs for sanity (including property bits)."; CHECK(Verify(fst1)); CHECK(Verify(fst2));
// Ensures seed used once per instantiation.
static const UniformArcSelector<A> uniform_selector(seed_); const RandGenOptions<UniformArcSelector<A>> opts(uniform_selector, kRandomPathLength); return RandEquivalent(fst1, fst2, kNumRandomPaths, opts, kTestDelta, seed_); }
// Tests FSA is unambiguous.
bool Unambiguous(const Fst<Arc> &fst) { VectorFst<StdArc> sfst, dfst; VectorFst<LogArc> lfst1, lfst2; ArcMap(fst, &sfst, RmWeightMapper<Arc, StdArc>()); Determinize(sfst, &dfst); ArcMap(fst, &lfst1, RmWeightMapper<Arc, LogArc>()); ArcMap(dfst, &lfst2, RmWeightMapper<StdArc, LogArc>()); return Equiv(lfst1, lfst2); }
// Ensures input-epsilon free transducers fst1 and fst2 have the
// same domain and that for each string pair '(is, os)' in fst1,
// '(is, os)' is the minimum weight match to 'is' in fst2.
template <class A> bool MinRelated(const Fst<A> &fst1, const Fst<A> &fst2) { // Same domain
VectorFst<Arc> P1(fst1), P2(fst2); Project(&P1, ProjectType::INPUT); Project(&P2, ProjectType::INPUT); if (!Equiv(P1, P2)) { LOG(ERROR) << "Inputs not equivalent"; return false; }
// Ensures seed used once per instantiation.
static const UniformArcSelector<A> uniform_selector(seed_); const RandGenOptions<UniformArcSelector<A>> opts(uniform_selector, kRandomPathLength);
VectorFst<Arc> path, paths1, paths2; for (ssize_t n = 0; n < kNumRandomPaths; ++n) { RandGen(fst1, &path, opts); Invert(&path); ArcMap(&path, RmWeightMapper<Arc>()); Compose(path, fst2, &paths1); Weight sum1 = ShortestDistance(paths1); Compose(paths1, path, &paths2); Weight sum2 = ShortestDistance(paths2); if (!ApproxEqual(Plus(sum1, sum2), sum2, kTestDelta)) { LOG(ERROR) << "Sums not equivalent: " << sum1 << " " << sum2; return false; } } return true; }
// Tests ShortestDistance(A - P) >= ShortestDistance(A) times Threshold.
template <class A> bool PruneEquiv(const Fst<A> &fst, const Fst<A> &pfst, Weight threshold) { VLOG(1) << "Check FSTs for sanity (including property bits)."; CHECK(Verify(fst)); CHECK(Verify(pfst));
DifferenceFst<Arc> D(fst, DeterminizeFst<Arc>(RmEpsilonFst<Arc>( ArcMapFst(pfst, RmWeightMapper<Arc>())))); const Weight sum1 = Times(ShortestDistance(fst), threshold); const Weight sum2 = ShortestDistance(D); return ApproxEqual(Plus(sum1, sum2), sum1, kTestDelta); }
// Random seed.
uint64_t seed_; // Random state (for randomness in this class).
std::mt19937_64 rand_; // FST with no states
VectorFst<Arc> zero_fst_; // FST with one state that accepts epsilon.
VectorFst<Arc> one_fst_; // FST with one state that accepts all strings.
VectorFst<Arc> univ_fst_; // Generates weights used in testing.
WeightGenerator generate_; // Maximum random path length.
static constexpr int kRandomPathLength = 25; // Number of random paths to explore.
static constexpr int kNumRandomPaths = 100; // Maximum number of nshortest paths.
static constexpr int kNumRandomShortestPaths = 100; // Maximum number of nshortest states.
static constexpr int kNumShortestStates = 10000; // Delta for equivalence tests.
static constexpr float kTestDelta = .05;
WeightedTester(const WeightedTester &) = delete; WeightedTester &operator=(const WeightedTester &) = delete; };
// This class tests a variety of identities and properties that must
// hold for various algorithms on unweighted FSAs and that are not tested
// by WeightedTester. Only the specialization does anything interesting.
template <class Arc> class UnweightedTester { public: UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa, const Fst<Arc> &univ_fsa, uint64_t seed) {}
void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {} };
// Specialization for StdArc. This should work for any commutative,
// idempotent semiring when restricted to the unweighted case
// (being isomorphic to the boolean semiring).
template <> class UnweightedTester<StdArc> { public: using Arc = StdArc; using Label = Arc::Label; using StateId = Arc::StateId; using Weight = Arc::Weight;
UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa, const Fst<Arc> &univ_fsa, uint64_t seed) : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa), rand_(seed) {}
void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) { TestRational(A1, A2, A3); TestIntersect(A1, A2, A3); TestOptimize(A1); }
private: // Tests rational operations with identities.
void TestRational(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) { { VLOG(1) << "Check the union contains its arguments (destructive)."; VectorFst<Arc> U(A1); Union(&U, A2);
CHECK(Subset(A1, U)); CHECK(Subset(A2, U)); }
{ VLOG(1) << "Check the union contains its arguments (delayed)."; UnionFst<Arc> U(A1, A2);
CHECK(Subset(A1, U)); CHECK(Subset(A2, U)); }
{ VLOG(1) << "Check if A^n c A* (destructive)."; VectorFst<Arc> C(one_fsa_); const int n = std::uniform_int_distribution<>(0, 4)(rand_); for (int i = 0; i < n; ++i) Concat(&C, A1);
VectorFst<Arc> S(A1); Closure(&S, CLOSURE_STAR); CHECK(Subset(C, S)); }
{ VLOG(1) << "Check if A^n c A* (delayed)."; const int n = std::uniform_int_distribution<>(0, 4)(rand_); std::unique_ptr<Fst<Arc>> C = std::make_unique<VectorFst<Arc>>(one_fsa_); for (int i = 0; i < n; ++i) { C = std::make_unique<ConcatFst<Arc>>(*C, A1); } ClosureFst<Arc> S(A1, CLOSURE_STAR); CHECK(Subset(*C, S)); } }
// Tests intersect-based operations.
void TestIntersect(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) { VectorFst<Arc> S1(A1); VectorFst<Arc> S2(A2); VectorFst<Arc> S3(A3);
ILabelCompare<Arc> comp;
ArcSort(&S1, comp); ArcSort(&S2, comp); ArcSort(&S3, comp);
{ VLOG(1) << "Check the intersection is contained in its arguments."; IntersectFst<Arc> I1(S1, S2); CHECK(Subset(I1, S1)); CHECK(Subset(I1, S2)); }
{ VLOG(1) << "Check union distributes over intersection."; IntersectFst<Arc> I1(S1, S2); UnionFst<Arc> U1(I1, S3);
UnionFst<Arc> U2(S1, S3); UnionFst<Arc> U3(S2, S3); ArcSortFst<Arc, ILabelCompare<Arc>> S4(U3, comp); IntersectFst<Arc> I2(U2, S4);
CHECK(Equiv(U1, I2)); }
VectorFst<Arc> C1; VectorFst<Arc> C2; Complement(S1, &C1); Complement(S2, &C2); ArcSort(&C1, comp); ArcSort(&C2, comp);
{ VLOG(1) << "Check S U S' = Sigma*"; UnionFst<Arc> U(S1, C1); CHECK(Equiv(U, univ_fsa_)); }
{ VLOG(1) << "Check S n S' = {}"; IntersectFst<Arc> I(S1, C1); CHECK(Equiv(I, zero_fsa_)); }
{ VLOG(1) << "Check (S1' U S2') == (S1 n S2)'"; UnionFst<Arc> U(C1, C2);
IntersectFst<Arc> I(S1, S2); VectorFst<Arc> C3; Complement(I, &C3); CHECK(Equiv(U, C3)); }
{ VLOG(1) << "Check (S1' n S2') == (S1 U S2)'"; IntersectFst<Arc> I(C1, C2);
UnionFst<Arc> U(S1, S2); VectorFst<Arc> C3; Complement(U, &C3); CHECK(Equiv(I, C3)); } }
// Tests optimization operations.
void TestOptimize(const Fst<Arc> &A) { { VLOG(1) << "Check determinized FSA is equivalent to its input."; DeterminizeFst<Arc> D(A); CHECK(Equiv(A, D)); }
{ VLOG(1) << "Check disambiguated FSA is equivalent to its input."; VectorFst<Arc> R(A), D; RmEpsilon(&R);
Disambiguate(R, &D); CHECK(Equiv(R, D)); }
{ VLOG(1) << "Check minimized FSA is equivalent to its input."; int n; { RmEpsilonFst<Arc> R(A); DeterminizeFst<Arc> D(R); VectorFst<Arc> M(D); Minimize(&M, static_cast<MutableFst<Arc> *>(nullptr), kDelta); CHECK(Equiv(A, M)); n = M.NumStates(); }
if (n) { // Skips test if A is the empty machine.
VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the" << " same number of states as Brozozowski's algorithm"; VectorFst<Arc> R; Reverse(A, &R); RmEpsilon(&R); DeterminizeFst<Arc> DR(R); VectorFst<Arc> RD; Reverse(DR, &RD); DeterminizeFst<Arc> DRD(RD); VectorFst<Arc> M(DRD); CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
// to the initial state.
} } }
// Tests if two FSAS are equivalent.
bool Equiv(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) { VLOG(1) << "Check FSAs for sanity (including property bits)."; CHECK(Verify(fsa1)); CHECK(Verify(fsa2));
VectorFst<Arc> vfsa1(fsa1); VectorFst<Arc> vfsa2(fsa2); RmEpsilon(&vfsa1); RmEpsilon(&vfsa2); DeterminizeFst<Arc> dfa1(vfsa1); DeterminizeFst<Arc> dfa2(vfsa2);
// Test equivalence using union-find algorithm
bool equiv1 = Equivalent(dfa1, dfa2);
// Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty
ILabelCompare<Arc> comp; VectorFst<Arc> sdfa1(dfa1); ArcSort(&sdfa1, comp); VectorFst<Arc> sdfa2(dfa2); ArcSort(&sdfa2, comp);
DifferenceFst<Arc> dfsa1(sdfa1, sdfa2); DifferenceFst<Arc> dfsa2(sdfa2, sdfa1);
VectorFst<Arc> ufsa(dfsa1); Union(&ufsa, dfsa2); Connect(&ufsa); bool equiv2 = ufsa.NumStates() == 0;
// Checks both equivalence tests match.
CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
return equiv1; }
// Tests if FSA1 is a subset of FSA2 (disregarding weights).
bool Subset(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) { VLOG(1) << "Check FSAs (incl. property bits) for sanity"; CHECK(Verify(fsa1)); CHECK(Verify(fsa2));
VectorFst<StdArc> vfsa1; VectorFst<StdArc> vfsa2; RmEpsilon(&vfsa1); RmEpsilon(&vfsa2); ILabelCompare<StdArc> comp; ArcSort(&vfsa1, comp); ArcSort(&vfsa2, comp); IntersectFst<StdArc> ifsa(vfsa1, vfsa2); DeterminizeFst<StdArc> dfa1(vfsa1); DeterminizeFst<StdArc> dfa2(ifsa); return Equivalent(dfa1, dfa2); }
// Returns complement FSA.
void Complement(const Fst<Arc> &ifsa, MutableFst<Arc> *ofsa) { RmEpsilonFst<Arc> rfsa(ifsa); DeterminizeFst<Arc> dfa(rfsa); DifferenceFst<Arc> cfsa(univ_fsa_, dfa); *ofsa = cfsa; }
// FSA with no states.
VectorFst<Arc> zero_fsa_; // FSA with one state that accepts epsilon.
VectorFst<Arc> one_fsa_; // FSA with one state that accepts all strings.
VectorFst<Arc> univ_fsa_; // Random state.
std::mt19937_64 rand_; };
// This class tests a variety of identities and properties that must
// hold for various FST algorithms. It randomly generates FSTs, using
// function object 'weight_generator' to select weights. 'WeightTester'
// and 'UnweightedTester' are then called.
template <class Arc> class AlgoTester { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using WeightGenerator = WeightGenerate<Weight>;
AlgoTester(WeightGenerator generator, uint64_t seed) : generate_(std::move(generator)), rand_(seed) { one_fst_.AddState(); one_fst_.SetStart(0); one_fst_.SetFinal(0);
univ_fst_.AddState(); univ_fst_.SetStart(0); univ_fst_.SetFinal(0); for (int i = 0; i < kNumRandomLabels; ++i) univ_fst_.EmplaceArc(0, i, i, 0);
weighted_tester_.reset(new WeightedTester<Arc>(seed, zero_fst_, one_fst_, univ_fst_, generate_));
unweighted_tester_.reset( new UnweightedTester<Arc>(zero_fst_, one_fst_, univ_fst_, seed)); }
void MakeRandFst(MutableFst<Arc> *fst) { RandFst<Arc, WeightGenerator>(kNumRandomStates, kNumRandomArcs, kNumRandomLabels, kAcyclicProb, generate_, rand_(), fst); }
void Test() { VLOG(1) << "weight type = " << Weight::Type();
for (int i = 0; i < FST_FLAGS_repeat; ++i) { // Random transducers
VectorFst<Arc> T1; VectorFst<Arc> T2; VectorFst<Arc> T3; MakeRandFst(&T1); MakeRandFst(&T2); MakeRandFst(&T3); weighted_tester_->Test(T1, T2, T3);
VectorFst<Arc> A1(T1); VectorFst<Arc> A2(T2); VectorFst<Arc> A3(T3); Project(&A1, ProjectType::OUTPUT); Project(&A2, ProjectType::INPUT); Project(&A3, ProjectType::INPUT); ArcMap(&A1, rm_weight_mapper_); ArcMap(&A2, rm_weight_mapper_); ArcMap(&A3, rm_weight_mapper_); unweighted_tester_->Test(A1, A2, A3); } }
private: // Generates weights used in testing.
WeightGenerator generate_; // Random state used to seed RandFst.
std::mt19937_64 rand_; // FST with no states
VectorFst<Arc> zero_fst_; // FST with one state that accepts epsilon.
VectorFst<Arc> one_fst_; // FST with one state that accepts all strings.
VectorFst<Arc> univ_fst_; // Tests weighted FSTs
std::unique_ptr<WeightedTester<Arc>> weighted_tester_; // Tests unweighted FSTs
std::unique_ptr<UnweightedTester<Arc>> unweighted_tester_; // Mapper to remove weights from an Fst
RmWeightMapper<Arc> rm_weight_mapper_; // Maximum number of states in random test Fst.
static constexpr int kNumRandomStates = 10; // Maximum number of arcs in random test Fst.
static constexpr int kNumRandomArcs = 25; // Number of alternative random labels.
static constexpr int kNumRandomLabels = 5; // Probability to force an acyclic Fst
static constexpr float kAcyclicProb = .25; // Maximum random path length.
static constexpr int kRandomPathLength = 25; // Number of random paths to explore.
static constexpr int kNumRandomPaths = 100;
AlgoTester(const AlgoTester &) = delete; AlgoTester &operator=(const AlgoTester &) = delete; }; } // namespace fst
#endif // FST_TEST_ALGO_TEST_H_
|