|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Utility classes for the recursive replacement of FSTs (RTNs).
#ifndef FST_REPLACE_UTIL_H_
#define FST_REPLACE_UTIL_H_
#include <cstddef>
#include <cstdint>
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/cc-visitors.h>
#include <fst/connect.h>
#include <fst/dfs-visit.h>
#include <fst/fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/topsort.h>
#include <fst/util.h>
#include <fst/vector-fst.h>
#include <unordered_map>
#include <unordered_set>
namespace fst {
// This specifies what labels to output on the call or return arc. Note that
// REPLACE_LABEL_INPUT and REPLACE_LABEL_OUTPUT will produce transducers when
// applied to acceptors.
enum ReplaceLabelType { // Epsilon labels on both input and output.
REPLACE_LABEL_NEITHER = 1, // Non-epsilon labels on input and epsilon on output.
REPLACE_LABEL_INPUT = 2, // Epsilon on input and non-epsilon on output.
REPLACE_LABEL_OUTPUT = 3, // Non-epsilon labels on both input and output.
REPLACE_LABEL_BOTH = 4 };
// By default ReplaceUtil will copy the input label of the replace arc.
// The call_label_type and return_label_type options specify how to manage
// the labels of the call arc and the return arc of the replace FST
struct ReplaceUtilOptions { int64_t root; // Root rule for expansion.
ReplaceLabelType call_label_type; // How to label call arc.
ReplaceLabelType return_label_type; // How to label return arc.
int64_t return_label; // Label to put on return arc.
explicit ReplaceUtilOptions( int64_t root = kNoLabel, ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT, ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER, int64_t return_label = 0) : root(root), call_label_type(call_label_type), return_label_type(return_label_type), return_label(return_label) {}
// For backwards compatibility.
ReplaceUtilOptions(int64_t root, bool epsilon_replace_arc) : ReplaceUtilOptions(root, epsilon_replace_arc ? REPLACE_LABEL_NEITHER : REPLACE_LABEL_INPUT) {} };
// Every non-terminal on a path appears as the first label on that path in every
// FST associated with a given SCC of the replace dependency graph. This would
// be true if the SCC were formed from left-linear grammar rules.
inline constexpr uint8_t kReplaceSCCLeftLinear = 0x01; // Every non-terminal on a path appears as the final label on that path in every
// FST associated with a given SCC of the replace dependency graph. This would
// be true if the SCC were formed from right-linear grammar rules.
inline constexpr uint8_t kReplaceSCCRightLinear = 0x02; // The SCC in the replace dependency graph has more than one state or a
// self-loop.
inline constexpr uint8_t kReplaceSCCNonTrivial = 0x04;
// Defined in replace.h.
template <class Arc> void Replace( const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>> &, MutableFst<Arc> *, const ReplaceUtilOptions &);
// Utility class for the recursive replacement of FSTs (RTNs). The user provides
// a set of label/FST pairs at construction. These are used by methods for
// testing cyclic dependencies and connectedness and doing RTN connection and
// specific FST replacement by label or for various optimization properties. The
// modified results can be obtained with the GetFstPairs() or
// GetMutableFstPairs() methods.
template <class Arc> class ReplaceUtil { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
using FstPair = std::pair<Label, const Fst<Arc> *>; using MutableFstPair = std::pair<Label, MutableFst<Arc> *>; using NonTerminalHash = std::unordered_map<Label, Label>;
// Constructs from mutable FSTs; FST ownership is given to ReplaceUtil.
ReplaceUtil(const std::vector<MutableFstPair> &fst_pairs, const ReplaceUtilOptions &opts);
// Constructs from FSTs; FST ownership is retained by caller.
ReplaceUtil(const std::vector<FstPair> &fst_pairs, const ReplaceUtilOptions &opts);
// Constructs from ReplaceFst internals; FST ownership is retained by caller.
ReplaceUtil(const std::vector<std::unique_ptr<const Fst<Arc>>> &fst_array, const NonTerminalHash &nonterminal_hash, const ReplaceUtilOptions &opts);
~ReplaceUtil() { for (Label i = 0; i < fst_array_.size(); ++i) delete fst_array_[i]; }
// True if the non-terminal dependencies are cyclic. Cyclic dependencies will
// result in an unexpandable FST.
bool CyclicDependencies() const { GetDependencies(false); return depprops_ & kCyclic; }
// Returns the strongly-connected component ID in the dependency graph of the
// replace FSTS.
StateId SCC(Label label) const { GetDependencies(false); if (const auto it = nonterminal_hash_.find(label); it != nonterminal_hash_.end()) { return depscc_[it->second]; } return kNoStateId; }
// Returns properties for the strongly-connected component in the dependency
// graph of the replace FSTs. If the SCC is kReplaceSCCLeftLinear or
// kReplaceSCCRightLinear, that SCC can be represented as finite-state despite
// any cyclic dependencies, but not by the usual replacement operation (see
// fst/extensions/pdt/replace.h).
uint8_t SCCProperties(StateId scc_id) { GetSCCProperties(); return depsccprops_[scc_id]; }
// Returns true if no useless FSTs, states or transitions are present in the
// RTN.
bool Connected() const { GetDependencies(false); uint64_t props = kAccessible | kCoAccessible; for (Label i = 0; i < fst_array_.size(); ++i) { if (!fst_array_[i]) continue; if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) { return false; } } return true; }
// Removes useless FSTs, states and transitions from the RTN.
void Connect();
// Replaces FSTs specified by labels, unless there are cyclic dependencies.
void ReplaceLabels(const std::vector<Label> &labels);
// Replaces FSTs that have at most nstates states, narcs arcs and nnonterm
// non-terminals (updating in reverse dependency order), unless there are
// cyclic dependencies.
void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
// Replaces singleton FSTS, unless there are cyclic dependencies.
void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
// Replaces non-terminals that have at most ninstances instances (updating in
// dependency order), unless there are cyclic dependencies.
void ReplaceByInstances(size_t ninstances);
// Replaces non-terminals that have only one instance, unless there are cyclic
// dependencies.
void ReplaceUnique() { ReplaceByInstances(1); }
// Returns label/FST pairs, retaining FST ownership.
void GetFstPairs(std::vector<FstPair> *fst_pairs);
// Returns label/mutable FST pairs, giving FST ownership over to the caller.
void GetMutableFstPairs(std::vector<MutableFstPair> *mutable_fst_pairs);
private: // FST statistics.
struct ReplaceStats { StateId nstates; // Number of states.
StateId nfinal; // Number of final states.
size_t narcs; // Number of arcs.
Label nnonterms; // Number of non-terminals in FST.
size_t nref; // Number of non-terminal instances referring to this FST.
// Number of times that ith FST references this FST
std::map<Label, size_t> inref; // Number of times that this FST references the ith FST
std::map<Label, size_t> outref;
ReplaceStats() : nstates(0), nfinal(0), narcs(0), nnonterms(0), nref(0) {} };
// Checks that Mutable FSTs exists, creating them if necessary.
void CheckMutableFsts();
// Computes the dependency graph for the RTN, computing dependency statistics
// if stats is true.
void GetDependencies(bool stats) const;
void ClearDependencies() const { depfst_.DeleteStates(); stats_.clear(); depprops_ = 0; depsccprops_.clear(); have_stats_ = false; }
// Gets topological order of dependencies, returning false with cyclic input.
bool GetTopOrder(const Fst<Arc> &fst, std::vector<Label> *toporder) const;
// Updates statistics to reflect the replacement of the jth FST.
void UpdateStats(Label j);
// Computes the properties for the strongly-connected component in the
// dependency graph of the replace FSTs.
void GetSCCProperties() const;
Label root_label_; // Root non-terminal.
Label root_fst_; // Root FST ID.
ReplaceLabelType call_label_type_; // See Replace().
ReplaceLabelType return_label_type_; // See Replace().
int64_t return_label_; // See Replace().
std::vector<const Fst<Arc> *> fst_array_; // FST per ID.
std::vector<MutableFst<Arc> *> mutable_fst_array_; // Mutable FST per ID.
std::vector<Label> nonterminal_array_; // FST ID to non-terminal.
NonTerminalHash nonterminal_hash_; // Non-terminal to FST ID.
mutable VectorFst<Arc> depfst_; // FST ID dependencies.
mutable std::vector<StateId> depscc_; // FST SCC ID.
mutable std::vector<bool> depaccess_; // FST ID accessibility.
mutable uint64_t depprops_; // Dependency FST props.
mutable bool have_stats_; // Have dependency statistics?
mutable std::vector<ReplaceStats> stats_; // Per-FST statistics.
mutable std::vector<uint8_t> depsccprops_; // SCC properties.
ReplaceUtil(const ReplaceUtil &) = delete; ReplaceUtil &operator=(const ReplaceUtil &) = delete; };
template <class Arc> ReplaceUtil<Arc>::ReplaceUtil(const std::vector<MutableFstPair> &fst_pairs, const ReplaceUtilOptions &opts) : root_label_(opts.root), call_label_type_(opts.call_label_type), return_label_type_(opts.return_label_type), return_label_(opts.return_label), depprops_(0), have_stats_(false) { fst_array_.push_back(nullptr); mutable_fst_array_.push_back(nullptr); nonterminal_array_.push_back(kNoLabel); for (const auto &fst_pair : fst_pairs) { const auto label = fst_pair.first; auto *fst = fst_pair.second; nonterminal_hash_[label] = fst_array_.size(); nonterminal_array_.push_back(label); fst_array_.push_back(fst); mutable_fst_array_.push_back(fst); } root_fst_ = nonterminal_hash_[root_label_]; if (!root_fst_) { FSTERROR() << "ReplaceUtil: No root FST for label: " << root_label_; } }
template <class Arc> ReplaceUtil<Arc>::ReplaceUtil(const std::vector<FstPair> &fst_pairs, const ReplaceUtilOptions &opts) : root_label_(opts.root), call_label_type_(opts.call_label_type), return_label_type_(opts.return_label_type), return_label_(opts.return_label), depprops_(0), have_stats_(false) { fst_array_.push_back(nullptr); nonterminal_array_.push_back(kNoLabel); for (const auto &fst_pair : fst_pairs) { const auto label = fst_pair.first; const auto *fst = fst_pair.second; nonterminal_hash_[label] = fst_array_.size(); nonterminal_array_.push_back(label); fst_array_.push_back(fst->Copy()); } root_fst_ = nonterminal_hash_[root_label_]; if (!root_fst_) { FSTERROR() << "ReplaceUtil: No root FST for label: " << root_label_; } }
template <class Arc> ReplaceUtil<Arc>::ReplaceUtil( const std::vector<std::unique_ptr<const Fst<Arc>>> &fst_array, const NonTerminalHash &nonterminal_hash, const ReplaceUtilOptions &opts) : root_fst_(opts.root), call_label_type_(opts.call_label_type), return_label_type_(opts.return_label_type), return_label_(opts.return_label), nonterminal_array_(fst_array.size()), nonterminal_hash_(nonterminal_hash), depprops_(0), have_stats_(false) { fst_array_.push_back(nullptr); for (size_t i = 1; i < fst_array.size(); ++i) { fst_array_.push_back(fst_array[i]->Copy()); } for (auto it = nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) { nonterminal_array_[it->second] = it->first; } root_label_ = nonterminal_array_[root_fst_]; }
template <class Arc> void ReplaceUtil<Arc>::GetDependencies(bool stats) const { if (depfst_.NumStates() > 0) { if (stats && !have_stats_) { ClearDependencies(); } else { return; } } have_stats_ = stats; if (have_stats_) stats_.reserve(fst_array_.size()); for (Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) { depfst_.AddState(); depfst_.SetFinal(ilabel); if (have_stats_) stats_.push_back(ReplaceStats()); } depfst_.SetStart(root_fst_); // An arc from each state (representing the FST) to the state representing the
// FST being replaced
for (Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) { const auto *ifst = fst_array_[ilabel]; if (!ifst) continue; for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) { const auto s = siter.Value(); if (have_stats_) { ++stats_[ilabel].nstates; if (ifst->Final(s) != Weight::Zero()) ++stats_[ilabel].nfinal; } for (ArcIterator<Fst<Arc>> aiter(*ifst, s); !aiter.Done(); aiter.Next()) { if (have_stats_) ++stats_[ilabel].narcs; const auto &arc = aiter.Value(); if (auto it = nonterminal_hash_.find(arc.olabel); it != nonterminal_hash_.end()) { const auto nextstate = it->second; depfst_.EmplaceArc(ilabel, arc.olabel, arc.olabel, nextstate); if (have_stats_) { ++stats_[ilabel].nnonterms; ++stats_[nextstate].nref; ++stats_[nextstate].inref[ilabel]; ++stats_[ilabel].outref[nextstate]; } } } } } // Computes accessibility info.
SccVisitor<Arc> scc_visitor(&depscc_, &depaccess_, nullptr, &depprops_); DfsVisit(depfst_, &scc_visitor); }
template <class Arc> void ReplaceUtil<Arc>::UpdateStats(Label j) { if (!have_stats_) { FSTERROR() << "ReplaceUtil::UpdateStats: Stats not available"; return; } if (j == root_fst_) return; // Can't replace root.
for (auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) { const auto i = in->first; const auto ni = in->second; stats_[i].nstates += stats_[j].nstates * ni; stats_[i].narcs += (stats_[j].narcs + 1) * ni; stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni; stats_[i].outref.erase(j); for (auto out = stats_[j].outref.begin(); out != stats_[j].outref.end(); ++out) { const auto k = out->first; const auto nk = out->second; stats_[i].outref[k] += ni * nk; } } for (auto out = stats_[j].outref.begin(); out != stats_[j].outref.end(); ++out) { const auto k = out->first; const auto nk = out->second; stats_[k].nref -= nk; stats_[k].inref.erase(j); for (auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) { const auto i = in->first; const auto ni = in->second; stats_[k].inref[i] += ni * nk; stats_[k].nref += ni * nk; } } }
template <class Arc> void ReplaceUtil<Arc>::CheckMutableFsts() { if (mutable_fst_array_.empty()) { for (Label i = 0; i < fst_array_.size(); ++i) { if (!fst_array_[i]) { mutable_fst_array_.push_back(nullptr); } else { mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i])); delete fst_array_[i]; fst_array_[i] = mutable_fst_array_[i]; } } } }
template <class Arc> void ReplaceUtil<Arc>::Connect() { CheckMutableFsts(); static constexpr auto props = kAccessible | kCoAccessible; for (auto *mutable_fst : mutable_fst_array_) { if (!mutable_fst) continue; if (mutable_fst->Properties(props, false) != props) { fst::Connect(mutable_fst); } } GetDependencies(false); for (Label i = 0; i < mutable_fst_array_.size(); ++i) { auto *fst = mutable_fst_array_[i]; if (fst && !depaccess_[i]) { delete fst; fst_array_[i] = nullptr; mutable_fst_array_[i] = nullptr; } } ClearDependencies(); }
template <class Arc> bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst, std::vector<Label> *toporder) const { // Finds topological order of dependencies.
std::vector<StateId> order; bool acyclic = false; TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); DfsVisit(fst, &top_order_visitor); if (!acyclic) { LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies"; return false; } toporder->resize(order.size()); for (Label i = 0; i < order.size(); ++i) (*toporder)[order[i]] = i; return true; }
template <class Arc> void ReplaceUtil<Arc>::ReplaceLabels(const std::vector<Label> &labels) { CheckMutableFsts(); std::unordered_set<Label> label_set; for (const auto label : labels) { // Can't replace root.
if (label != root_label_) label_set.insert(label); } // Finds FST dependencies restricted to the labels requested.
GetDependencies(false); VectorFst<Arc> pfst(depfst_); for (StateId i = 0; i < pfst.NumStates(); ++i) { std::vector<Arc> arcs; for (ArcIterator<VectorFst<Arc>> aiter(pfst, i); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); const auto label = nonterminal_array_[arc.nextstate]; if (label_set.count(label) > 0) arcs.push_back(arc); } pfst.DeleteArcs(i); for (auto &arc : arcs) pfst.AddArc(i, std::move(arc)); } std::vector<Label> toporder; if (!GetTopOrder(pfst, &toporder)) { ClearDependencies(); return; } // Visits FSTs in reverse topological order of dependencies and performs
// replacements.
for (Label o = toporder.size() - 1; o >= 0; --o) { std::vector<FstPair> fst_pairs; auto s = toporder[o]; for (ArcIterator<VectorFst<Arc>> aiter(pfst, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); const auto label = nonterminal_array_[arc.nextstate]; const auto *fst = fst_array_[arc.nextstate]; fst_pairs.emplace_back(label, fst); } if (fst_pairs.empty()) continue; const auto label = nonterminal_array_[s]; const auto *fst = fst_array_[s]; fst_pairs.emplace_back(label, fst); const ReplaceUtilOptions opts(label, call_label_type_, return_label_type_, return_label_); Replace(fst_pairs, mutable_fst_array_[s], opts); } ClearDependencies(); }
template <class Arc> void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms) { std::vector<Label> labels; GetDependencies(true); std::vector<Label> toporder; if (!GetTopOrder(depfst_, &toporder)) { ClearDependencies(); return; } for (Label o = toporder.size() - 1; o >= 0; --o) { const auto j = toporder[o]; if (stats_[j].nstates <= nstates && stats_[j].narcs <= narcs && stats_[j].nnonterms <= nnonterms) { labels.push_back(nonterminal_array_[j]); UpdateStats(j); } } ReplaceLabels(labels); }
template <class Arc> void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) { std::vector<Label> labels; GetDependencies(true); std::vector<Label> toporder; if (!GetTopOrder(depfst_, &toporder)) { ClearDependencies(); return; } for (Label o = 0; o < toporder.size(); ++o) { const auto j = toporder[o]; if (stats_[j].nref <= ninstances) { labels.push_back(nonterminal_array_[j]); UpdateStats(j); } } ReplaceLabels(labels); }
template <class Arc> void ReplaceUtil<Arc>::GetFstPairs(std::vector<FstPair> *fst_pairs) { CheckMutableFsts(); fst_pairs->clear(); for (Label i = 0; i < fst_array_.size(); ++i) { const auto label = nonterminal_array_[i]; const auto *fst = fst_array_[i]; if (!fst) continue; fst_pairs->emplace_back(label, fst); } }
template <class Arc> void ReplaceUtil<Arc>::GetMutableFstPairs( std::vector<MutableFstPair> *mutable_fst_pairs) { CheckMutableFsts(); mutable_fst_pairs->clear(); for (Label i = 0; i < mutable_fst_array_.size(); ++i) { const auto label = nonterminal_array_[i]; const auto *fst = mutable_fst_array_[i]; if (!fst) continue; mutable_fst_pairs->emplace_back(label, fst->Copy()); } }
template <class Arc> void ReplaceUtil<Arc>::GetSCCProperties() const { if (!depsccprops_.empty()) return; GetDependencies(false); if (depscc_.empty()) return; for (StateId scc = 0; scc < depscc_.size(); ++scc) { depsccprops_.push_back(kReplaceSCCLeftLinear | kReplaceSCCRightLinear); } if (!(depprops_ & kCyclic)) return; // No cyclic dependencies.
// Checks for self-loops in the dependency graph.
for (StateId scc = 0; scc < depscc_.size(); ++scc) { for (ArcIterator<Fst<Arc>> aiter(depfst_, scc); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); if (arc.nextstate == scc) { // SCC has a self loop.
depsccprops_[scc] |= kReplaceSCCNonTrivial; } } } std::vector<bool> depscc_visited(depscc_.size(), false); for (Label i = 0; i < fst_array_.size(); ++i) { const auto *fst = fst_array_[i]; if (!fst) continue; const auto depscc = depscc_[i]; if (depscc_visited[depscc]) { // SCC has more than one state.
depsccprops_[depscc] |= kReplaceSCCNonTrivial; } depscc_visited[depscc] = true; std::vector<StateId> fstscc; // SCCs of the current FST.
uint64_t fstprops; SccVisitor<Arc> scc_visitor(&fstscc, nullptr, nullptr, &fstprops); DfsVisit(*fst, &scc_visitor); for (StateIterator<Fst<Arc>> siter(*fst); !siter.Done(); siter.Next()) { const auto s = siter.Value(); for (ArcIterator<Fst<Arc>> aiter(*fst, s); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); auto it = nonterminal_hash_.find(arc.olabel); if (it == nonterminal_hash_.end() || depscc_[it->second] != depscc) { continue; // Skips if a terminal or a non-terminal not in SCC.
} const bool arc_in_cycle = fstscc[s] == fstscc[arc.nextstate]; // Left linear iff all non-terminals are initial.
if (s != fst->Start() || arc_in_cycle) { depsccprops_[depscc] &= ~kReplaceSCCLeftLinear; } // Right linear iff all non-terminals are final.
if (fst->Final(arc.nextstate) == Weight::Zero() || arc_in_cycle) { depsccprops_[depscc] &= ~kReplaceSCCRightLinear; } } } } }
} // namespace fst
#endif // FST_REPLACE_UTIL_H_
|