|
|
// fstext/determinize-star-inl.h
// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky
// 2015 Hainan Xu
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
#define KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
// Do not include this file directly. It is included by determinize-star.h
#include <algorithm>
#include <climits>
#include <deque>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
using std::unordered_map;
#include "base/kaldi-error.h"
namespace fst {
// This class maps back and forth from/to integer id's to sequences of strings.
// used in determinization algorithm.
template <class Label, class StringId> class StringRepository { // Label and StringId are both integer types, possibly the same.
// This is a utility that maps back and forth between a vector<Label> and
// StringId representation of sequences of Labels. It is to save memory, and
// to save compute. We treat sequences of length zero and one separately, for
// efficiency.
public: class VectorKey { // Hash function object.
public: size_t operator()(const std::vector<Label>* vec) const { assert(vec != NULL); size_t hash = 0, factor = 1; for (typename std::vector<Label>::const_iterator it = vec->begin(); it != vec->end(); it++) { hash += factor * (*it); factor *= 103333; // just an arbitrary prime number.
} return hash; } }; class VectorEqual { // Equality-operator function object.
public: size_t operator()(const std::vector<Label>* vec1, const std::vector<Label>* vec2) const { return (*vec1 == *vec2); } };
typedef unordered_map<const std::vector<Label>*, StringId, VectorKey, VectorEqual> MapType;
StringId IdOfEmpty() { return no_symbol; }
StringId IdOfLabel(Label l) { if (l >= 0 && l <= (Label)single_symbol_range) { return l + single_symbol_start; } else { // l is out of the allowed range so we have to treat it as a sequence of
// length one. Should be v. rare.
std::vector<Label> v; v.push_back(l); return IdOfSeqInternal(v); } }
StringId IdOfSeq( const std::vector<Label>& v) { // also works for sizes 0 and 1.
size_t sz = v.size(); if (sz == 0) return no_symbol; else if (v.size() == 1) return IdOfLabel(v[0]); else return IdOfSeqInternal(v); }
inline bool IsEmptyString(StringId id) { return id == no_symbol; } void SeqOfId(StringId id, std::vector<Label>* v) { if (id == no_symbol) { v->clear(); } else if (id >= single_symbol_start) { v->resize(1); (*v)[0] = id - single_symbol_start; } else { assert(static_cast<size_t>(id) < vec_.size()); *v = *(vec_[id]); } } StringId RemovePrefix(StringId id, size_t prefix_len) { if (prefix_len == 0) { return id; } else { std::vector<Label> v; SeqOfId(id, &v); size_t sz = v.size(); assert(sz >= prefix_len); std::vector<Label> v_noprefix(sz - prefix_len); for (size_t i = 0; i < sz - prefix_len; i++) v_noprefix[i] = v[i + prefix_len]; return IdOfSeq(v_noprefix); } }
StringRepository() { // The following are really just constants but don't want to complicate
// compilation so make them class variables. Due to the brokenness of
// <limits>, they can't be accessed as constants.
string_end = (std::numeric_limits<StringId>::max() / 2) - 1; // all hash values must be <= this.
no_symbol = (std::numeric_limits<StringId>::max() / 2); // reserved for empty sequence.
single_symbol_start = (std::numeric_limits<StringId>::max() / 2) + 1; single_symbol_range = std::numeric_limits<StringId>::max() - single_symbol_start; } void Destroy() { for (typename std::vector<std::vector<Label>*>::iterator iter = vec_.begin(); iter != vec_.end(); ++iter) delete *iter; std::vector<std::vector<Label>*> tmp_vec; tmp_vec.swap(vec_); MapType tmp_map; tmp_map.swap(map_); } ~StringRepository() { Destroy(); }
private: KALDI_DISALLOW_COPY_AND_ASSIGN(StringRepository);
StringId IdOfSeqInternal(const std::vector<Label>& v) { typename MapType::iterator iter = map_.find(&v); if (iter != map_.end()) { return iter->second; } else { // must add it to map.
StringId this_id = (StringId)vec_.size(); std::vector<Label>* v_new = new std::vector<Label>(v); vec_.push_back(v_new); map_[v_new] = this_id; assert(this_id < string_end); // or we used up the labels.
return this_id; } }
std::vector<std::vector<Label>*> vec_; MapType map_;
static const StringId string_start = (StringId)0; // This must not change. It's assumed.
StringId string_end; // = (numeric_limits<StringId>::max() / 2) - 1; // all
// hash values must be <= this.
StringId no_symbol; // = (numeric_limits<StringId>::max() / 2); // reserved
// for empty sequence.
StringId single_symbol_start; // = (numeric_limits<StringId>::max() / 2) + 1;
StringId single_symbol_range; // = numeric_limits<StringId>::max() -
// single_symbol_start;
};
template <class F> class DeterminizerStar { typedef typename F::Arc Arc;
public: // Output to Gallic acceptor (so the strings go on weights, and there is a 1-1
// correspondence between our states and the states in ofst. If destroy ==
// true, release memory as we go (but we cannot output again).
void Output(MutableFst<GallicArc<Arc> >* ofst, bool destroy = true);
// Output to standard FST. We will create extra states to handle sequences of
// symbols on the output. If destroy == true, release memory as we go (but we
// cannot output again).
void Output(MutableFst<Arc>* ofst, bool destroy = true);
// Initializer. After initializing the object you will typically call
// Determinize() and then one of the Output functions.
DeterminizerStar(const Fst<Arc>& ifst, float delta = kDelta, int max_states = -1, bool allow_partial = false) : ifst_(ifst.Copy()), delta_(delta), max_states_(max_states), determinized_(false), allow_partial_(allow_partial), is_partial_(false), equal_(delta), hash_(ifst.Properties(kExpanded, false) ? down_cast<const ExpandedFst<Arc>*, const Fst<Arc> >(&ifst) ->NumStates() / 2 + 3 : 20, hasher_, equal_), epsilon_closure_(ifst_, max_states, &repository_, delta) {}
void Determinize(bool* debug_ptr) { assert(!determinized_); // This determinizes the input fst but leaves it in the "special format"
// in "output_arcs_".
InputStateId start_id = ifst_->Start(); if (start_id == kNoStateId) { determinized_ = true; return; // Nothing to do.
} else { // Insert start state into hash and queue.
Element elem; elem.state = start_id; elem.weight = Weight::One(); elem.string = repository_.IdOfEmpty(); // Id of empty sequence.
std::vector<Element> vec; vec.push_back(elem); OutputStateId cur_id = SubsetToStateId(vec); assert(cur_id == 0 && "Do not call Determinize twice."); } while (!Q_.empty()) { std::pair<std::vector<Element>*, OutputStateId> cur_pair = Q_.front(); Q_.pop_front(); ProcessSubset(cur_pair); if (debug_ptr && *debug_ptr) Debug(); // will exit.
if (max_states_ > 0 && output_arcs_.size() > max_states_) { if (allow_partial_ == false) { KALDI_ERR << "Determinization aborted since passed " << max_states_ << " states"; } else { KALDI_WARN << "Determinization terminated since passed " << max_states_ << " states, partial results will be generated"; is_partial_ = true; break; } } } determinized_ = true; }
bool IsPartial() { return is_partial_; }
// frees all except output_arcs_, which contains the important info
// we need to output.
void FreeMostMemory() { if (ifst_) { delete ifst_; ifst_ = NULL; } for (typename SubsetHash::iterator iter = hash_.begin(); iter != hash_.end(); ++iter) delete iter->first; SubsetHash tmp; tmp.swap(hash_); }
~DeterminizerStar() { FreeMostMemory(); }
private: typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; typedef typename Arc::StateId InputStateId; typedef typename Arc::StateId OutputStateId; // same as above but distinguish states in output Fst.
typedef typename Arc::Label StringId; // Id type used in the StringRepository
typedef StringRepository<Label, StringId> StringRepositoryType;
// Element of a subset [of original states]
struct Element { InputStateId state; StringId string; Weight weight; bool operator!=(const Element& other) const { return (state != other.state || string != other.string || weight != other.weight); } };
// Arcs in the format we temporarily create in this class (a representation,
// essentially of a Gallic Fst).
struct TempArc { Label ilabel; StringId ostring; // Look it up in the StringRepository, it's a sequence of
// Labels.
OutputStateId nextstate; // or kNoState for final weights.
Weight weight; };
// Hashing function used in hash of subsets.
// A subset is a pointer to vector<Element>.
// The Elements are in sorted order on state id, and without repeated states.
// Because the order of Elements is fixed, we can use a hashing function that
// is order-dependent. However the weights are not included in the hashing
// function-- we hash subsets that differ only in weight to the same key. This
// is not optimal in terms of the O(N) performance but typically if we have a
// lot of determinized states that differ only in weight then the input
// probably was pathological in some way, or even non-determinizable.
// We don't quantize the weights, in order to avoid inexactness in simple
// cases.
// Instead we apply the delta when comparing subsets for equality, and allow a
// small difference.
class SubsetKey { public: size_t operator()(const std::vector<Element>* subset) const { // hashes only the state and string.
size_t hash = 0, factor = 1; for (typename std::vector<Element>::const_iterator iter = subset->begin(); iter != subset->end(); ++iter) { hash *= factor; hash += iter->state + 103333 * iter->string; factor *= 23531; // these numbers are primes.
} return hash; } };
// This is the equality operator on subsets. It checks for exact match on
// state-id and string, and approximate match on weights.
class SubsetEqual { public: bool operator()(const std::vector<Element>* s1, const std::vector<Element>* s2) const { size_t sz = s1->size(); assert(sz >= 0); if (sz != s2->size()) return false; typename std::vector<Element>::const_iterator iter1 = s1->begin(), iter1_end = s1->end(), iter2 = s2->begin(); for (; iter1 < iter1_end; ++iter1, ++iter2) { if (iter1->state != iter2->state || iter1->string != iter2->string || !ApproxEqual(iter1->weight, iter2->weight, delta_)) return false; } return true; } float delta_; explicit SubsetEqual(float delta) : delta_(delta) {} SubsetEqual() : delta_(kDelta) {} };
// Operator that says whether two Elements have the same states.
// Used only for debug.
class SubsetEqualStates { public: bool operator()(const std::vector<Element>* s1, const std::vector<Element>* s2) const { size_t sz = s1->size(); assert(sz >= 0); if (sz != s2->size()) return false; typename std::vector<Element>::const_iterator iter1 = s1->begin(), iter1_end = s1->end(), iter2 = s2->begin(); for (; iter1 < iter1_end; ++iter1, ++iter2) { if (iter1->state != iter2->state) return false; } return true; } };
// Define the hash type we use to store subsets.
typedef unordered_map<const std::vector<Element>*, OutputStateId, SubsetKey, SubsetEqual> SubsetHash;
class EpsilonClosure { public: EpsilonClosure(const Fst<Arc>* ifst, int max_states, StringRepository<Label, StringId>* repository, float delta) : ifst_(ifst), max_states_(max_states), repository_(repository), delta_(delta) {}
// This function computes epsilon closure of subset of states by following
// epsilon links. Called by ProcessSubset. Has no side effects except on the
// repository.
void GetEpsilonClosure(const std::vector<Element>& input_subset, std::vector<Element>* output_subset);
private: struct EpsilonClosureInfo { EpsilonClosureInfo() {} EpsilonClosureInfo(const Element& e, const Weight& w, bool i) : element(e), weight_to_process(w), in_queue(i) {} // the weight in the Element struct is the total current weight
// that has been processed already
Element element; // this stores the weight that we haven't processed (propagated)
Weight weight_to_process; // whether "this" struct is in the queue
// we store the info here so that we don't have to look it up every time
bool in_queue; bool operator<(const EpsilonClosureInfo& other) const { return this->element.state < other.element.state; } };
// to further speed up EpsilonClosure() computation, we have 2 queues
// the 2nd queue is used when we first iterate over the input set -
// if queue_2_.empty() then we directly set output_set equal to input_set
// and return immediately
// Since Epsilon arcs are relatively rare, this way we could efficiently
// detect the epsilon-free case, without having to waste our computation
// e.g. allocating the EpsilonClosureInfo structure; this also lets us do a
// level-by-level traversal, which could avoid some (unfortunately not all)
// duplicate computation if epsilons form a DAG that is not a tree
//
// We put the queues here for better efficiency for memory allocation
std::deque<typename Arc::StateId> queue_; std::vector<Element> queue_2_;
// the following 2 structures together form our *virtual "map"*
// basically we need a map from state_id to EpsilonClosureInfo that operates
// in O(1) time, while still takes relatively small mem, and this does it
// well for efficiency we don't clear id_to_index_ of its outdated
// information As a result each time we do a look-up, we need to check if
// (ecinfo_[id_to_index_[id]].element.state == id) Yet this is still faster
// than using a std::map<StateId, EpsilonClosureInfo>
std::vector<int> id_to_index_; // unlike id_to_index_, we clear the content of ecinfo_ each time we call
// EpsilonClosure(). This needed because we need an efficient way to
// traverse the virtual map - it is just too costly to traverse the
// id_to_index_ vector.
std::vector<EpsilonClosureInfo> ecinfo_;
// Add one element (elem) into cur_subset
// it also adds the necessary stuff to queue_, set the correct weight
void AddOneElement(const Element& elem, const Weight& unprocessed_weight);
// Sub-routine that we call in EpsilonClosure()
// It takes the current "unprocessed_weight" and propagate it to the
// states accessible from elem.state by an epsilon arc
// and add the results to cur_subset.
// save_to_queue_2 is set true when we iterate over the initial subset
// - then we save it to queue_2 s.t. if it's empty, we directly return
// the input set
void ExpandOneElement(const Element& elem, bool sorted, const Weight& unprocessed_weight, bool save_to_queue_2 = false);
// no pointers below would take the ownership
const Fst<Arc>* ifst_; int max_states_; StringRepository<Label, StringId>* repository_; float delta_; };
// This function works out the final-weight of the determinized state.
// called by ProcessSubset.
// Has no side effects except on the variable repository_, and output_arcs_.
void ProcessFinal(const std::vector<Element>& closed_subset, OutputStateId state) { // processes final-weights for this subset.
bool is_final = false; StringId final_string = 0; // = 0 to keep compiler happy.
Weight final_weight = Weight::One(); // This value will never be accessed, and
// we just set it to avoid spurious compiler warnings. We avoid setting it
// to Zero() because floating-point infinities can sometimes generate
// interrupts and slow things down.
typename std::vector<Element>::const_iterator iter = closed_subset.begin(), end = closed_subset.end(); for (; iter != end; ++iter) { const Element& elem = *iter; Weight this_final_weight = ifst_->Final(elem.state); if (this_final_weight != Weight::Zero()) { if (!is_final) { // first final-weight
final_string = elem.string; final_weight = Times(elem.weight, this_final_weight); is_final = true; } else { // already have one.
if (final_string != elem.string) { KALDI_ERR << "FST was not functional -> not determinizable"; } final_weight = Plus(final_weight, Times(elem.weight, this_final_weight)); } } } if (is_final) { // store final weights in TempArc structure, just like a transition.
TempArc temp_arc; temp_arc.ilabel = 0; temp_arc.nextstate = kNoStateId; // special marker meaning "final weight".
temp_arc.ostring = final_string; temp_arc.weight = final_weight; output_arcs_[state].push_back(temp_arc); } }
// ProcessTransition is called from "ProcessTransitions". Broken out for
// clarity. Has side effects on output_arcs_, and (via SubsetToStateId), Q_
// and hash_.
void ProcessTransition(OutputStateId state, Label ilabel, std::vector<Element>* subset);
// "less than" operator for pair<Label, Element>. Used in
// ProcessTransitions. Lexicographical order, with comparing the state only
// for "Element".
class PairComparator { public: inline bool operator()(const std::pair<Label, Element>& p1, const std::pair<Label, Element>& p2) { if (p1.first < p2.first) { return true; } else if (p1.first > p2.first) { return false; } else { return p1.second.state < p2.second.state; } } };
// ProcessTransitions handles transitions out of this subset of states.
// Ignores epsilon transitions (epsilon closure already handled that).
// Does not consider final states. Breaks the transitions up by ilabel,
// and creates a new transition in determinized FST, for each ilabel.
// Does this by creating a big vector of pairs <Label, Element> and then
// sorting them using a lexicographical ordering, and calling
// ProcessTransition for each range with the same ilabel. Side effects on
// repository, and (via ProcessTransition) on Q_, hash_, and output_arcs_.
void ProcessTransitions(const std::vector<Element>& closed_subset, OutputStateId state) { std::vector<std::pair<Label, Element> > all_elems; { // Push back into "all_elems", elements corresponding to all
// non-epsilon-input transitions
// out of all states in "closed_subset".
typename std::vector<Element>::const_iterator iter = closed_subset.begin(), end = closed_subset.end(); for (; iter != end; ++iter) { const Element& elem = *iter; for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done(); aiter.Next()) { const Arc& arc = aiter.Value(); if (arc.ilabel != 0) { // Non-epsilon transition -- ignore epsilons here.
std::pair<Label, Element> this_pr; this_pr.first = arc.ilabel; Element& next_elem(this_pr.second); next_elem.state = arc.nextstate; next_elem.weight = Times(elem.weight, arc.weight); if (arc.olabel == 0) { // output epsilon-- this is simple case so
// handle separately for efficiency
next_elem.string = elem.string; } else { std::vector<Label> seq; repository_.SeqOfId(elem.string, &seq); seq.push_back(arc.olabel); next_elem.string = repository_.IdOfSeq(seq); } all_elems.push_back(this_pr); } } } } PairComparator pc; std::sort(all_elems.begin(), all_elems.end(), pc); // now sorted first on input label, then on state.
typedef typename std::vector<std::pair<Label, Element> >::const_iterator PairIter; PairIter cur = all_elems.begin(), end = all_elems.end(); std::vector<Element> this_subset; while (cur != end) { // Process ranges that share the same input symbol.
Label ilabel = cur->first; this_subset.clear(); while (cur != end && cur->first == ilabel) { this_subset.push_back(cur->second); cur++; } // We now have a subset for this ilabel.
ProcessTransition(state, ilabel, &this_subset); } }
// SubsetToStateId converts a subset (vector of Elements) to a StateId in the
// output fst. This is a hash lookup; if no such state exists, it adds a new
// state to the hash and adds a new pair to the queue. Side effects on hash_
// and Q_, and on output_arcs_ [just affects the size].
OutputStateId SubsetToStateId( const std::vector<Element>& subset) { // may add the subset to the queue.
typedef typename SubsetHash::iterator IterType; IterType iter = hash_.find(&subset); if (iter == hash_.end()) { // was not there.
std::vector<Element>* new_subset = new std::vector<Element>(subset); OutputStateId new_state_id = (OutputStateId)output_arcs_.size(); bool ans = hash_ .insert(std::pair<const std::vector<Element>*, OutputStateId>( new_subset, new_state_id)) .second; assert(ans); output_arcs_.push_back(std::vector<TempArc>()); if (allow_partial_ == false) { // If --allow-partial is not requested, we do the old way.
Q_.push_front(std::pair<std::vector<Element>*, OutputStateId>( new_subset, new_state_id)); } else { // If --allow-partial is requested, we do breadth first search. This
// ensures that when we return partial results, we return the states
// that are reachable by the fewest steps from the start state.
Q_.push_back(std::pair<std::vector<Element>*, OutputStateId>( new_subset, new_state_id)); } return new_state_id; } else { return iter->second; // the OutputStateId.
} }
// ProcessSubset does the processing of a determinized state, i.e. it creates
// transitions out of it and adds new determinized states to the queue if
// necessary. The first stage is "EpsilonClosure" (follow epsilons to get a
// possibly larger set of (states, weights)). After that we ignore epsilons.
// We process the final-weight of the state, and then handle transitions out
// (this may add more determinized states to the queue).
void ProcessSubset( const std::pair<std::vector<Element>*, OutputStateId>& pair) { const std::vector<Element>* subset = pair.first; OutputStateId state = pair.second;
std::vector<Element> closed_subset; // subset after epsilon closure.
epsilon_closure_.GetEpsilonClosure(*subset, &closed_subset);
// Now follow non-epsilon arcs [and also process final states]
ProcessFinal(closed_subset, state);
// Now handle transitions out of these states.
ProcessTransitions(closed_subset, state); }
void Debug();
KALDI_DISALLOW_COPY_AND_ASSIGN(DeterminizerStar); std::deque<std::pair<std::vector<Element>*, OutputStateId> > Q_; // queue of subsets to be processed.
std::vector<std::vector<TempArc> > output_arcs_; // essentially an FST in our format.
const Fst<Arc>* ifst_; float delta_; int max_states_; bool determinized_; // used to check usage.
bool allow_partial_; // output paritial results or not
bool is_partial_; // if we get partial results or not
SubsetKey hasher_; // object that computes keys-- has no data members.
SubsetEqual equal_; // object that compares subsets-- only data member is delta_.
SubsetHash hash_; // hash from Subset to StateId in final Fst.
StringRepository<Label, StringId> repository_; // associate integer id's with sequences of labels.
EpsilonClosure epsilon_closure_; };
template <class F> bool DeterminizeStar(F& ifst, // NOLINT
MutableFst<typename F::Arc>* ofst, float delta, bool* debug_ptr, int max_states, bool allow_partial) { ofst->SetOutputSymbols(ifst.OutputSymbols()); ofst->SetInputSymbols(ifst.InputSymbols()); DeterminizerStar<F> det(ifst, delta, max_states, allow_partial); det.Determinize(debug_ptr); det.Output(ofst); return det.IsPartial(); }
template <class F> bool DeterminizeStar(F& ifst, // NOLINT
MutableFst<GallicArc<typename F::Arc> >* ofst, float delta, bool* debug_ptr, int max_states, bool allow_partial) { ofst->SetOutputSymbols(ifst.InputSymbols()); ofst->SetInputSymbols(ifst.InputSymbols()); DeterminizerStar<F> det(ifst, delta, max_states, allow_partial); det.Determinize(debug_ptr); det.Output(ofst); return det.IsPartial(); }
template <class F> void DeterminizerStar<F>::EpsilonClosure::GetEpsilonClosure( const std::vector<Element>& input_subset, std::vector<Element>* output_subset) { ecinfo_.resize(0); size_t size = input_subset.size(); // find whether input fst is known to be sorted in input label.
bool sorted = ((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
// size is still the input_subset.size()
for (size_t i = 0; i < size; i++) { ExpandOneElement(input_subset[i], sorted, input_subset[i].weight, true); }
size_t s = queue_2_.size(); if (s == 0) { *output_subset = input_subset; return; } else { // queue_2 not empty. Need to create the vector<info>
for (size_t i = 0; i < size; i++) { // the weight has not been processed yet,
// so put all of them in the "weight_to_process"
ecinfo_.push_back( EpsilonClosureInfo(input_subset[i], input_subset[i].weight, false)); ecinfo_.back().element.weight = Weight::Zero(); // clear the weight
if (id_to_index_.size() < input_subset[i].state + 1) { id_to_index_.resize(2 * input_subset[i].state + 1, -1); } id_to_index_[input_subset[i].state] = ecinfo_.size() - 1; } }
{ Element elem; elem.weight = Weight::Zero(); for (size_t i = 0; i < s; i++) { elem.state = queue_2_[i].state; elem.string = queue_2_[i].string; AddOneElement(elem, queue_2_[i].weight); } queue_2_.resize(0); }
int counter = 0; // relates to max-states option, used for test.
while (!queue_.empty()) { InputStateId id = queue_.front();
// no need to check validity of the index
// since anything in the queue we are sure they're in the "virtual set"
int index = id_to_index_[id]; EpsilonClosureInfo& info = ecinfo_[index]; Element& elem = info.element; Weight unprocessed_weight = info.weight_to_process;
elem.weight = Plus(elem.weight, unprocessed_weight); info.weight_to_process = Weight::Zero();
info.in_queue = false; queue_.pop_front();
if (max_states_ > 0 && counter++ > max_states_) { KALDI_ERR << "Determinization aborted since looped more than " << max_states_ << " times during epsilon closure"; }
// generally we need to be careful about iterator-invalidation problem
// here we pass a reference (elem), which could be an issue.
// In the beginning of ExpandOneElement, we make a copy of elem.string
// to avoid that issue
ExpandOneElement(elem, sorted, unprocessed_weight); }
{ // this sorting is based on StateId
sort(ecinfo_.begin(), ecinfo_.end());
output_subset->clear();
size = ecinfo_.size(); output_subset->reserve(size); for (size_t i = 0; i < size; i++) { EpsilonClosureInfo& info = ecinfo_[i]; if (info.weight_to_process != Weight::Zero()) { info.element.weight = Plus(info.element.weight, info.weight_to_process); } output_subset->push_back(info.element); } } }
template <class F> void DeterminizerStar<F>::EpsilonClosure::AddOneElement( const Element& elem, const Weight& unprocessed_weight) { // first we try to find the element info in the ecinfo_ vector
int index = -1; if (elem.state < id_to_index_.size()) { index = id_to_index_[elem.state]; } if (index != -1) { if (index >= ecinfo_.size()) { index = -1; } else if (ecinfo_[index].element.state != elem.state) { // since ecinfo_ might store outdated information, we need to check
index = -1; } }
if (index == -1) { // was no such StateId: insert and add to queue.
ecinfo_.push_back(EpsilonClosureInfo(elem, unprocessed_weight, true)); size_t size = id_to_index_.size(); if (size < elem.state + 1) { // double the size to reduce memory operations
id_to_index_.resize(2 * elem.state + 1, -1); } id_to_index_[elem.state] = ecinfo_.size() - 1; queue_.push_back(elem.state);
} else { // one is already there. Add weights.
EpsilonClosureInfo& info = ecinfo_[index]; if (info.element.string != elem.string) { // Non-functional FST.
std::ostringstream ss; ss << "FST was not functional -> not determinizable."; { // Print some debugging information. Can be helpful to debug
// the inputs when FSTs are mysteriously non-functional.
std::vector<Label> tmp_seq; repository_->SeqOfId(info.element.string, &tmp_seq); ss << "\nFirst string:"; for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i]; ss << "\nSecond string:"; repository_->SeqOfId(elem.string, &tmp_seq); for (size_t i = 0; i < tmp_seq.size(); i++) ss << ' ' << tmp_seq[i]; } KALDI_ERR << ss.str(); }
info.weight_to_process = Plus(info.weight_to_process, unprocessed_weight);
if (!info.in_queue) { // this is because the code in "else" below: the
// iter->second.weight_to_process might not be Zero()
Weight weight = Plus(info.element.weight, info.weight_to_process);
// What is done below is, we propagate the weight (by adding them
// to the queue only when the change is big enough;
// otherwise we just store the weight, until before returning
// we add the element.weight and weight_to_process together
if (!ApproxEqual(weight, info.element.weight, delta_)) { // add extra part of weight to queue.
info.in_queue = true; queue_.push_back(elem.state); } } } }
template <class F> void DeterminizerStar<F>::EpsilonClosure::ExpandOneElement( const Element& elem, bool sorted, const Weight& unprocessed_weight, bool save_to_queue_2) { StringId str = elem.string; // copy it here because there is an iterator-
// - invalidation problem (it really happens for some FSTs)
// now we are going to propagate the "unprocessed_weight"
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done(); aiter.Next()) { const Arc& arc = aiter.Value(); if (sorted && arc.ilabel > 0) { break; // Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
} if (arc.ilabel != 0) { continue; // we only process epsilons here
} Element next_elem; next_elem.state = arc.nextstate; next_elem.weight = Weight::Zero(); Weight next_unprocessed_weight = Times(unprocessed_weight, arc.weight);
// now must append strings
if (arc.olabel == 0) { next_elem.string = str; } else { std::vector<Label> seq; repository_->SeqOfId(str, &seq); if (arc.olabel != 0) seq.push_back(arc.olabel); next_elem.string = repository_->IdOfSeq(seq); } if (save_to_queue_2) { next_elem.weight = next_unprocessed_weight; queue_2_.push_back(next_elem); } else { AddOneElement(next_elem, next_unprocessed_weight); } } }
template <class F> void DeterminizerStar<F>::Output(MutableFst<GallicArc<Arc> >* ofst, bool destroy) { assert(determinized_); if (destroy) determinized_ = false; typedef GallicWeight<Label, Weight> ThisGallicWeight; typedef typename Arc::StateId StateId; if (destroy) FreeMostMemory(); StateId nStates = static_cast<StateId>(output_arcs_.size()); ofst->DeleteStates(); ofst->SetStart(kNoStateId); if (nStates == 0) { return; } for (StateId s = 0; s < nStates; s++) { OutputStateId news = ofst->AddState(); assert(news == s); } ofst->SetStart(0); // now process transitions.
for (StateId this_state = 0; this_state < nStates; this_state++) { std::vector<TempArc>& this_vec(output_arcs_[this_state]); typename std::vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end(); for (; iter != end; ++iter) { const TempArc& temp_arc(*iter); GallicArc<Arc> new_arc; std::vector<Label> seq; repository_.SeqOfId(temp_arc.ostring, &seq); StringWeight<Label, STRING_LEFT> string_weight; for (size_t i = 0; i < seq.size(); i++) string_weight.PushBack(seq[i]); ThisGallicWeight gallic_weight(string_weight, temp_arc.weight);
if (temp_arc.nextstate == kNoStateId) { // is really final weight.
ofst->SetFinal(this_state, gallic_weight); } else { // is really an arc.
new_arc.nextstate = temp_arc.nextstate; new_arc.ilabel = temp_arc.ilabel; new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
new_arc.weight = gallic_weight; // includes string and weight.
ofst->AddArc(this_state, new_arc); } } // Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) { std::vector<TempArc> temp; temp.swap(this_vec); } } if (destroy) { std::vector<std::vector<TempArc> > temp; temp.swap(output_arcs_); } }
template <class F> void DeterminizerStar<F>::Output(MutableFst<Arc>* ofst, bool destroy) { assert(determinized_); if (destroy) determinized_ = false; // Outputs to standard fst.
OutputStateId num_states = static_cast<OutputStateId>(output_arcs_.size()); if (destroy) FreeMostMemory(); ofst->DeleteStates(); if (num_states == 0) { ofst->SetStart(kNoStateId); return; } // Add basic states-- but will add extra ones to account for strings on
// output.
for (OutputStateId s = 0; s < num_states; s++) { OutputStateId news = ofst->AddState(); assert(news == s); } ofst->SetStart(0); for (OutputStateId this_state = 0; this_state < num_states; this_state++) { std::vector<TempArc>& this_vec(output_arcs_[this_state]);
typename std::vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end(); for (; iter != end; ++iter) { const TempArc& temp_arc(*iter); std::vector<Label> seq; repository_.SeqOfId(temp_arc.ostring, &seq); if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
// Make a sequence of states going to a final state, with the strings as
// labels. Put the weight on the first arc.
OutputStateId cur_state = this_state; for (size_t i = 0; i < seq.size(); i++) { OutputStateId next_state = ofst->AddState(); Arc arc; arc.nextstate = next_state; arc.weight = (i == 0 ? temp_arc.weight : Weight::One()); arc.ilabel = 0; // epsilon.
arc.olabel = seq[i]; ofst->AddArc(cur_state, arc); cur_state = next_state; } ofst->SetFinal(cur_state, (seq.size() == 0 ? temp_arc.weight : Weight::One())); } else { // Really an arc.
OutputStateId cur_state = this_state; // Have to be careful with this integer comparison (i+1 < seq.size())
// because unsigned. i < seq.size()-1 could fail for zero-length
// sequences.
for (size_t i = 0; i + 1 < seq.size(); i++) { // for all but the last element of seq, create new state.
OutputStateId next_state = ofst->AddState(); Arc arc; arc.nextstate = next_state; arc.weight = (i == 0 ? temp_arc.weight : Weight::One()); arc.ilabel = (i == 0 ? temp_arc.ilabel : 0); // put ilabel on first element of seq.
arc.olabel = seq[i]; ofst->AddArc(cur_state, arc); cur_state = next_state; } // Add the final arc in the sequence.
Arc arc; arc.nextstate = temp_arc.nextstate; arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One()); arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0); arc.olabel = (seq.size() > 0 ? seq.back() : 0); ofst->AddArc(cur_state, arc); } } // Free up memory. Do this inside the loop as ofst is also allocating
// memory
if (destroy) { std::vector<TempArc> temp; temp.swap(this_vec); } } if (destroy) { std::vector<std::vector<TempArc> > temp; temp.swap(output_arcs_); repository_.Destroy(); } }
template <class F> void DeterminizerStar<F>::ProcessTransition(OutputStateId state, Label ilabel, std::vector<Element>* subset) { // At input, "subset" may contain duplicates for a given dest state (but in
// sorted order). This function removes duplicates from "subset", normalizes
// it, and adds a transition to the dest. state (possibly affecting Q_ and
// hash_, if state did not exist).
typedef typename std::vector<Element>::iterator IterType; { // This block makes the subset have one unique Element per state, adding
// the weights.
IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end(); size_t num_out = 0; // Merge elements with same state-id
while (cur_in != end) { // while we have more elements to process.
// At this point, cur_out points to location of next place we want to put
// an element, cur_in points to location of next element we want to
// process.
if (cur_in != cur_out) *cur_out = *cur_in; cur_in++; while (cur_in != end && cur_in->state == cur_out->state) { // merge elements.
if (cur_in->string != cur_out->string) { KALDI_ERR << "FST was not functional -> not determinizable"; } cur_out->weight = Plus(cur_out->weight, cur_in->weight); cur_in++; } cur_out++; num_out++; } subset->resize(num_out); }
StringId common_str; Weight tot_weight; { // This block computes common_str and tot_weight (essentially: the common
// divisor)
// and removes them from the elements.
std::vector<Label> seq;
IterType begin = subset->begin(), iter, end = subset->end(); { // This block computes "seq", which is the common prefix, and
// "common_str",
// which is the StringId version of "seq".
std::vector<Label> tmp_seq; for (iter = begin; iter != end; ++iter) { if (iter == begin) { repository_.SeqOfId(iter->string, &seq); } else { repository_.SeqOfId(iter->string, &tmp_seq); if (tmp_seq.size() < seq.size()) seq.resize(tmp_seq.size()); // size of shortest one.
for (size_t i = 0; i < seq.size(); i++) // seq.size() is the shorter one at this point.
if (tmp_seq[i] != seq[i]) seq.resize(i); } if (seq.size() == 0) break; // will not get any prefix.
} common_str = repository_.IdOfSeq(seq); }
{ // This block computes "tot_weight".
iter = begin; tot_weight = iter->weight; for (++iter; iter != end; ++iter) tot_weight = Plus(tot_weight, iter->weight); }
// Now divide out common stuff from elements.
size_t prefix_len = seq.size(); for (iter = begin; iter != end; ++iter) { iter->weight = Divide(iter->weight, tot_weight); iter->string = repository_.RemovePrefix(iter->string, prefix_len); } }
// Now add an arc to the state that the subset represents.
// We may create a new state id for this (in SubsetToStateId).
TempArc temp_arc; temp_arc.ilabel = ilabel; temp_arc.nextstate = SubsetToStateId(*subset); // may or may not really add the subset.
temp_arc.ostring = common_str; temp_arc.weight = tot_weight; output_arcs_[state].push_back(temp_arc); // record the arc.
}
template <class F> void DeterminizerStar<F>::Debug() { // this function called if you send a signal
// SIGUSR1 to the process (and it's caught by the handler in
// fstdeterminizestar). It prints out some traceback
// info and exits.
KALDI_WARN << "Debug function called (probably SIGUSR1 caught)"; // free up memory from the hash as we need a little memory
{ SubsetHash hash_tmp; std::swap(hash_tmp, hash_); }
if (output_arcs_.size() <= 2) { KALDI_ERR << "Nothing to trace back"; } size_t max_state = output_arcs_.size() - 2; // don't take the last
// one as we might be halfway into constructing it.
std::vector<OutputStateId> predecessor(max_state + 1, kNoStateId); for (size_t i = 0; i < max_state; i++) { for (size_t j = 0; j < output_arcs_[i].size(); j++) { OutputStateId nextstate = output_arcs_[i][j].nextstate; // Always find an earlier-numbered predecessor; this
// is always possible because of the way the algorithm
// works.
if (nextstate <= max_state && nextstate > i) predecessor[nextstate] = i; } } std::vector<std::pair<Label, StringId> > traceback; // 'traceback' is a pair of (ilabel, olabel-seq).
OutputStateId cur_state = max_state; // A recently constructed state.
while (cur_state != 0 && cur_state != kNoStateId) { OutputStateId last_state = predecessor[cur_state]; std::pair<Label, StringId> p; size_t i; for (i = 0; i < output_arcs_[last_state].size(); i++) { if (output_arcs_[last_state][i].nextstate == cur_state) { p.first = output_arcs_[last_state][i].ilabel; p.second = output_arcs_[last_state][i].ostring; traceback.push_back(p); break; } } KALDI_ASSERT(i != output_arcs_[last_state].size()); // Or fell off loop.
cur_state = last_state; } if (cur_state == kNoStateId) KALDI_WARN << "Traceback did not reach start state " << "(possibly debug-code error)";
std::stringstream ss; ss << "Traceback follows in format " << "ilabel (olabel olabel) ilabel (olabel) ... :"; for (ssize_t i = traceback.size() - 1; i >= 0; i--) { ss << ' ' << traceback[i].first << " ( "; std::vector<Label> seq; repository_.SeqOfId(traceback[i].second, &seq); for (size_t j = 0; j < seq.size(); j++) ss << seq[j] << ' '; ss << ')'; } KALDI_ERR << ss.str(); }
} // namespace fst
#endif // KALDI_FSTEXT_DETERMINIZE_STAR_INL_H_
|