|
|
// fstext/lattice-utils-inl.h
// Copyright 2009-2012 Microsoft Corporation Johns Hopkins University (Author:
// Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_LATTICE_UTILS_INL_H_
#define KALDI_FSTEXT_LATTICE_UTILS_INL_H_
// Do not include this file directly. It is included by lattice-utils.h
#include <utility>
#include <vector>
namespace fst {
/* Convert from FST with arc-type Weight, to one with arc-type
CompactLatticeWeight. Uses FactorFst to identify chains of states which can be turned into a single output arc. */
template <class Weight, class Int> void ConvertLattice( const ExpandedFst<ArcTpl<Weight> >& ifst, MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >* ofst, bool invert) { typedef ArcTpl<Weight> Arc; typedef typename Arc::StateId StateId; typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight; typedef ArcTpl<CompactWeight> CompactArc;
VectorFst<ArcTpl<Weight> > ffst; std::vector<std::vector<Int> > labels; if (invert) { // normal case: want the ilabels as sequences on the arcs of
Factor(ifst, &ffst, &labels); // the output... Factor makes seqs of
// ilabels.
} else { VectorFst<ArcTpl<Weight> > invfst(ifst); Invert(&invfst); Factor(invfst, &ffst, &labels); }
TopSort(&ffst); // Put the states in ffst in topological order, which is
// easier on the eye when reading the text-form lattices and corresponds to
// what we get when we generate the lattices in the decoder.
ofst->DeleteStates();
// The states will be numbered exactly the same as the original FST.
// Add the states to the new FST.
StateId num_states = ffst.NumStates(); for (StateId s = 0; s < num_states; s++) { StateId news = ofst->AddState(); assert(news == s); } ofst->SetStart(ffst.Start()); for (StateId s = 0; s < num_states; s++) { Weight final_weight = ffst.Final(s); if (final_weight != Weight::Zero()) { CompactWeight final_compact_weight(final_weight, std::vector<Int>()); ofst->SetFinal(s, final_compact_weight); } for (ArcIterator<ExpandedFst<Arc> > iter(ffst, s); !iter.Done(); iter.Next()) { const Arc& arc = iter.Value(); KALDI_PARANOID_ASSERT(arc.weight != Weight::Zero()); // note: zero-weight arcs not allowed anyway so weight should not be zero,
// but no harm in checking.
CompactArc compact_arc(arc.olabel, arc.olabel, CompactWeight(arc.weight, labels[arc.ilabel]), arc.nextstate); ofst->AddArc(s, compact_arc); } } }
template <class Weight, class Int> void ConvertLattice( const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >& ifst, MutableFst<ArcTpl<Weight> >* ofst, bool invert) { typedef ArcTpl<Weight> Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef CompactLatticeWeightTpl<Weight, Int> CompactWeight; typedef ArcTpl<CompactWeight> CompactArc; ofst->DeleteStates(); // make the states in the new FST have the same numbers as
// the original ones, and add chains of states as necessary
// to encode the string-valued weights.
StateId num_states = ifst.NumStates(); for (StateId s = 0; s < num_states; s++) { StateId news = ofst->AddState(); assert(news == s); } ofst->SetStart(ifst.Start()); for (StateId s = 0; s < num_states; s++) { CompactWeight final_weight = ifst.Final(s); if (final_weight != CompactWeight::Zero()) { StateId cur_state = s; size_t string_length = final_weight.String().size(); for (size_t n = 0; n < string_length; n++) { StateId next_state = ofst->AddState(); Label ilabel = 0; Arc arc(ilabel, final_weight.String()[n], (n == 0 ? final_weight.Weight() : Weight::One()), next_state); if (invert) std::swap(arc.ilabel, arc.olabel); ofst->AddArc(cur_state, arc); cur_state = next_state; } ofst->SetFinal(cur_state, string_length > 0 ? Weight::One() : final_weight.Weight()); } for (ArcIterator<ExpandedFst<CompactArc> > iter(ifst, s); !iter.Done(); iter.Next()) { const CompactArc& arc = iter.Value(); size_t string_length = arc.weight.String().size(); StateId cur_state = s; // for all but the last element in the string--
// add a temporary state.
for (size_t n = 0; n + 1 < string_length; n++) { StateId next_state = ofst->AddState(); Label ilabel = (n == 0 ? arc.ilabel : 0), olabel = static_cast<Label>(arc.weight.String()[n]); Weight weight = (n == 0 ? arc.weight.Weight() : Weight::One()); Arc new_arc(ilabel, olabel, weight, next_state); if (invert) std::swap(new_arc.ilabel, new_arc.olabel); ofst->AddArc(cur_state, new_arc); cur_state = next_state; } Label ilabel = (string_length <= 1 ? arc.ilabel : 0), olabel = (string_length > 0 ? arc.weight.String()[string_length - 1] : 0); Weight weight = (string_length <= 1 ? arc.weight.Weight() : Weight::One()); Arc new_arc(ilabel, olabel, weight, arc.nextstate); if (invert) std::swap(new_arc.ilabel, new_arc.olabel); ofst->AddArc(cur_state, new_arc); } } }
// This function converts lattices between float and double;
// it works for both CompactLatticeWeight and LatticeWeight.
template <class WeightIn, class WeightOut> void ConvertLattice(const ExpandedFst<ArcTpl<WeightIn> >& ifst, MutableFst<ArcTpl<WeightOut> >* ofst) { typedef ArcTpl<WeightIn> ArcIn; typedef ArcTpl<WeightOut> ArcOut; typedef typename ArcIn::StateId StateId; ofst->DeleteStates(); // The states will be numbered exactly the same as the original FST.
// Add the states to the new FST.
StateId num_states = ifst.NumStates(); for (StateId s = 0; s < num_states; s++) { StateId news = ofst->AddState(); assert(news == s); } ofst->SetStart(ifst.Start()); for (StateId s = 0; s < num_states; s++) { WeightIn final_iweight = ifst.Final(s); if (final_iweight != WeightIn::Zero()) { WeightOut final_oweight; ConvertLatticeWeight(final_iweight, &final_oweight); ofst->SetFinal(s, final_oweight); } for (ArcIterator<ExpandedFst<ArcIn> > iter(ifst, s); !iter.Done(); iter.Next()) { ArcIn arc = iter.Value(); KALDI_PARANOID_ASSERT(arc.weight != WeightIn::Zero()); ArcOut oarc; ConvertLatticeWeight(arc.weight, &oarc.weight); oarc.ilabel = arc.ilabel; oarc.olabel = arc.olabel; oarc.nextstate = arc.nextstate; ofst->AddArc(s, oarc); } } }
template <class Weight, class ScaleFloat> void ScaleLattice(const std::vector<std::vector<ScaleFloat> >& scale, MutableFst<ArcTpl<Weight> >* fst) { assert(scale.size() == 2 && scale[0].size() == 2 && scale[1].size() == 2); if (scale == DefaultLatticeScale()) // nothing to do.
return; typedef ArcTpl<Weight> Arc; typedef MutableFst<Arc> Fst; typedef typename Arc::StateId StateId; StateId num_states = fst->NumStates(); for (StateId s = 0; s < num_states; s++) { for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); arc.weight = Weight(ScaleTupleWeight(arc.weight, scale)); aiter.SetValue(arc); } Weight final_weight = fst->Final(s); if (final_weight != Weight::Zero()) fst->SetFinal(s, Weight(ScaleTupleWeight(final_weight, scale))); } }
template <class Weight, class Int> void RemoveAlignmentsFromCompactLattice( MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >* fst) { typedef CompactLatticeWeightTpl<Weight, Int> W; typedef ArcTpl<W> Arc; typedef MutableFst<Arc> Fst; typedef typename Arc::StateId StateId; StateId num_states = fst->NumStates(); for (StateId s = 0; s < num_states; s++) { for (MutableArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); arc.weight = W(arc.weight.Weight(), std::vector<Int>()); aiter.SetValue(arc); } W final_weight = fst->Final(s); if (final_weight != W::Zero()) fst->SetFinal(s, W(final_weight.Weight(), std::vector<Int>())); } }
template <class Weight, class Int> bool CompactLatticeHasAlignment( const ExpandedFst<ArcTpl<CompactLatticeWeightTpl<Weight, Int> > >& fst) { typedef CompactLatticeWeightTpl<Weight, Int> W; typedef ArcTpl<W> Arc; typedef ExpandedFst<Arc> Fst; typedef typename Arc::StateId StateId; StateId num_states = fst.NumStates(); for (StateId s = 0; s < num_states; s++) { for (ArcIterator<Fst> aiter(fst, s); !aiter.Done(); aiter.Next()) { const Arc& arc = aiter.Value(); if (!arc.weight.String().empty()) return true; } W final_weight = fst.Final(s); if (!final_weight.String().empty()) return true; } return false; }
template <class Real> void ConvertFstToLattice(const ExpandedFst<ArcTpl<TropicalWeight> >& ifst, MutableFst<ArcTpl<LatticeWeightTpl<Real> > >* ofst) { int32 num_states_cache = 50000; fst::CacheOptions cache_opts(true, num_states_cache); fst::MapFstOptions mapfst_opts(cache_opts); StdToLatticeMapper<Real> mapper; MapFst<StdArc, ArcTpl<LatticeWeightTpl<Real> >, StdToLatticeMapper<Real> > map_fst(ifst, mapper, mapfst_opts); *ofst = map_fst; }
} // namespace fst
#endif // KALDI_FSTEXT_LATTICE_UTILS_INL_H_
|