You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

650 lines
23 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Utility classes for the recursive replacement of FSTs (RTNs).
#ifndef FST_REPLACE_UTIL_H_
#define FST_REPLACE_UTIL_H_
#include <cstddef>
#include <cstdint>
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/cc-visitors.h>
#include <fst/connect.h>
#include <fst/dfs-visit.h>
#include <fst/fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/topsort.h>
#include <fst/util.h>
#include <fst/vector-fst.h>
#include <unordered_map>
#include <unordered_set>
namespace fst {
// This specifies what labels to output on the call or return arc. Note that
// REPLACE_LABEL_INPUT and REPLACE_LABEL_OUTPUT will produce transducers when
// applied to acceptors.
enum ReplaceLabelType {
// Epsilon labels on both input and output.
REPLACE_LABEL_NEITHER = 1,
// Non-epsilon labels on input and epsilon on output.
REPLACE_LABEL_INPUT = 2,
// Epsilon on input and non-epsilon on output.
REPLACE_LABEL_OUTPUT = 3,
// Non-epsilon labels on both input and output.
REPLACE_LABEL_BOTH = 4
};
// By default ReplaceUtil will copy the input label of the replace arc.
// The call_label_type and return_label_type options specify how to manage
// the labels of the call arc and the return arc of the replace FST
struct ReplaceUtilOptions {
int64_t root; // Root rule for expansion.
ReplaceLabelType call_label_type; // How to label call arc.
ReplaceLabelType return_label_type; // How to label return arc.
int64_t return_label; // Label to put on return arc.
explicit ReplaceUtilOptions(
int64_t root = kNoLabel,
ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT,
ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER,
int64_t return_label = 0)
: root(root),
call_label_type(call_label_type),
return_label_type(return_label_type),
return_label(return_label) {}
// For backwards compatibility.
ReplaceUtilOptions(int64_t root, bool epsilon_replace_arc)
: ReplaceUtilOptions(root, epsilon_replace_arc ? REPLACE_LABEL_NEITHER
: REPLACE_LABEL_INPUT) {}
};
// Every non-terminal on a path appears as the first label on that path in every
// FST associated with a given SCC of the replace dependency graph. This would
// be true if the SCC were formed from left-linear grammar rules.
inline constexpr uint8_t kReplaceSCCLeftLinear = 0x01;
// Every non-terminal on a path appears as the final label on that path in every
// FST associated with a given SCC of the replace dependency graph. This would
// be true if the SCC were formed from right-linear grammar rules.
inline constexpr uint8_t kReplaceSCCRightLinear = 0x02;
// The SCC in the replace dependency graph has more than one state or a
// self-loop.
inline constexpr uint8_t kReplaceSCCNonTrivial = 0x04;
// Defined in replace.h.
template <class Arc>
void Replace(
const std::vector<std::pair<typename Arc::Label, const Fst<Arc> *>> &,
MutableFst<Arc> *, const ReplaceUtilOptions &);
// Utility class for the recursive replacement of FSTs (RTNs). The user provides
// a set of label/FST pairs at construction. These are used by methods for
// testing cyclic dependencies and connectedness and doing RTN connection and
// specific FST replacement by label or for various optimization properties. The
// modified results can be obtained with the GetFstPairs() or
// GetMutableFstPairs() methods.
template <class Arc>
class ReplaceUtil {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using FstPair = std::pair<Label, const Fst<Arc> *>;
using MutableFstPair = std::pair<Label, MutableFst<Arc> *>;
using NonTerminalHash = std::unordered_map<Label, Label>;
// Constructs from mutable FSTs; FST ownership is given to ReplaceUtil.
ReplaceUtil(const std::vector<MutableFstPair> &fst_pairs,
const ReplaceUtilOptions &opts);
// Constructs from FSTs; FST ownership is retained by caller.
ReplaceUtil(const std::vector<FstPair> &fst_pairs,
const ReplaceUtilOptions &opts);
// Constructs from ReplaceFst internals; FST ownership is retained by caller.
ReplaceUtil(const std::vector<std::unique_ptr<const Fst<Arc>>> &fst_array,
const NonTerminalHash &nonterminal_hash,
const ReplaceUtilOptions &opts);
~ReplaceUtil() {
for (Label i = 0; i < fst_array_.size(); ++i) delete fst_array_[i];
}
// True if the non-terminal dependencies are cyclic. Cyclic dependencies will
// result in an unexpandable FST.
bool CyclicDependencies() const {
GetDependencies(false);
return depprops_ & kCyclic;
}
// Returns the strongly-connected component ID in the dependency graph of the
// replace FSTS.
StateId SCC(Label label) const {
GetDependencies(false);
if (const auto it = nonterminal_hash_.find(label);
it != nonterminal_hash_.end()) {
return depscc_[it->second];
}
return kNoStateId;
}
// Returns properties for the strongly-connected component in the dependency
// graph of the replace FSTs. If the SCC is kReplaceSCCLeftLinear or
// kReplaceSCCRightLinear, that SCC can be represented as finite-state despite
// any cyclic dependencies, but not by the usual replacement operation (see
// fst/extensions/pdt/replace.h).
uint8_t SCCProperties(StateId scc_id) {
GetSCCProperties();
return depsccprops_[scc_id];
}
// Returns true if no useless FSTs, states or transitions are present in the
// RTN.
bool Connected() const {
GetDependencies(false);
uint64_t props = kAccessible | kCoAccessible;
for (Label i = 0; i < fst_array_.size(); ++i) {
if (!fst_array_[i]) continue;
if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) {
return false;
}
}
return true;
}
// Removes useless FSTs, states and transitions from the RTN.
void Connect();
// Replaces FSTs specified by labels, unless there are cyclic dependencies.
void ReplaceLabels(const std::vector<Label> &labels);
// Replaces FSTs that have at most nstates states, narcs arcs and nnonterm
// non-terminals (updating in reverse dependency order), unless there are
// cyclic dependencies.
void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
// Replaces singleton FSTS, unless there are cyclic dependencies.
void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
// Replaces non-terminals that have at most ninstances instances (updating in
// dependency order), unless there are cyclic dependencies.
void ReplaceByInstances(size_t ninstances);
// Replaces non-terminals that have only one instance, unless there are cyclic
// dependencies.
void ReplaceUnique() { ReplaceByInstances(1); }
// Returns label/FST pairs, retaining FST ownership.
void GetFstPairs(std::vector<FstPair> *fst_pairs);
// Returns label/mutable FST pairs, giving FST ownership over to the caller.
void GetMutableFstPairs(std::vector<MutableFstPair> *mutable_fst_pairs);
private:
// FST statistics.
struct ReplaceStats {
StateId nstates; // Number of states.
StateId nfinal; // Number of final states.
size_t narcs; // Number of arcs.
Label nnonterms; // Number of non-terminals in FST.
size_t nref; // Number of non-terminal instances referring to this FST.
// Number of times that ith FST references this FST
std::map<Label, size_t> inref;
// Number of times that this FST references the ith FST
std::map<Label, size_t> outref;
ReplaceStats() : nstates(0), nfinal(0), narcs(0), nnonterms(0), nref(0) {}
};
// Checks that Mutable FSTs exists, creating them if necessary.
void CheckMutableFsts();
// Computes the dependency graph for the RTN, computing dependency statistics
// if stats is true.
void GetDependencies(bool stats) const;
void ClearDependencies() const {
depfst_.DeleteStates();
stats_.clear();
depprops_ = 0;
depsccprops_.clear();
have_stats_ = false;
}
// Gets topological order of dependencies, returning false with cyclic input.
bool GetTopOrder(const Fst<Arc> &fst, std::vector<Label> *toporder) const;
// Updates statistics to reflect the replacement of the jth FST.
void UpdateStats(Label j);
// Computes the properties for the strongly-connected component in the
// dependency graph of the replace FSTs.
void GetSCCProperties() const;
Label root_label_; // Root non-terminal.
Label root_fst_; // Root FST ID.
ReplaceLabelType call_label_type_; // See Replace().
ReplaceLabelType return_label_type_; // See Replace().
int64_t return_label_; // See Replace().
std::vector<const Fst<Arc> *> fst_array_; // FST per ID.
std::vector<MutableFst<Arc> *> mutable_fst_array_; // Mutable FST per ID.
std::vector<Label> nonterminal_array_; // FST ID to non-terminal.
NonTerminalHash nonterminal_hash_; // Non-terminal to FST ID.
mutable VectorFst<Arc> depfst_; // FST ID dependencies.
mutable std::vector<StateId> depscc_; // FST SCC ID.
mutable std::vector<bool> depaccess_; // FST ID accessibility.
mutable uint64_t depprops_; // Dependency FST props.
mutable bool have_stats_; // Have dependency statistics?
mutable std::vector<ReplaceStats> stats_; // Per-FST statistics.
mutable std::vector<uint8_t> depsccprops_; // SCC properties.
ReplaceUtil(const ReplaceUtil &) = delete;
ReplaceUtil &operator=(const ReplaceUtil &) = delete;
};
template <class Arc>
ReplaceUtil<Arc>::ReplaceUtil(const std::vector<MutableFstPair> &fst_pairs,
const ReplaceUtilOptions &opts)
: root_label_(opts.root),
call_label_type_(opts.call_label_type),
return_label_type_(opts.return_label_type),
return_label_(opts.return_label),
depprops_(0),
have_stats_(false) {
fst_array_.push_back(nullptr);
mutable_fst_array_.push_back(nullptr);
nonterminal_array_.push_back(kNoLabel);
for (const auto &fst_pair : fst_pairs) {
const auto label = fst_pair.first;
auto *fst = fst_pair.second;
nonterminal_hash_[label] = fst_array_.size();
nonterminal_array_.push_back(label);
fst_array_.push_back(fst);
mutable_fst_array_.push_back(fst);
}
root_fst_ = nonterminal_hash_[root_label_];
if (!root_fst_) {
FSTERROR() << "ReplaceUtil: No root FST for label: " << root_label_;
}
}
template <class Arc>
ReplaceUtil<Arc>::ReplaceUtil(const std::vector<FstPair> &fst_pairs,
const ReplaceUtilOptions &opts)
: root_label_(opts.root),
call_label_type_(opts.call_label_type),
return_label_type_(opts.return_label_type),
return_label_(opts.return_label),
depprops_(0),
have_stats_(false) {
fst_array_.push_back(nullptr);
nonterminal_array_.push_back(kNoLabel);
for (const auto &fst_pair : fst_pairs) {
const auto label = fst_pair.first;
const auto *fst = fst_pair.second;
nonterminal_hash_[label] = fst_array_.size();
nonterminal_array_.push_back(label);
fst_array_.push_back(fst->Copy());
}
root_fst_ = nonterminal_hash_[root_label_];
if (!root_fst_) {
FSTERROR() << "ReplaceUtil: No root FST for label: " << root_label_;
}
}
template <class Arc>
ReplaceUtil<Arc>::ReplaceUtil(
const std::vector<std::unique_ptr<const Fst<Arc>>> &fst_array,
const NonTerminalHash &nonterminal_hash, const ReplaceUtilOptions &opts)
: root_fst_(opts.root),
call_label_type_(opts.call_label_type),
return_label_type_(opts.return_label_type),
return_label_(opts.return_label),
nonterminal_array_(fst_array.size()),
nonterminal_hash_(nonterminal_hash),
depprops_(0),
have_stats_(false) {
fst_array_.push_back(nullptr);
for (size_t i = 1; i < fst_array.size(); ++i) {
fst_array_.push_back(fst_array[i]->Copy());
}
for (auto it = nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) {
nonterminal_array_[it->second] = it->first;
}
root_label_ = nonterminal_array_[root_fst_];
}
template <class Arc>
void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
if (depfst_.NumStates() > 0) {
if (stats && !have_stats_) {
ClearDependencies();
} else {
return;
}
}
have_stats_ = stats;
if (have_stats_) stats_.reserve(fst_array_.size());
for (Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) {
depfst_.AddState();
depfst_.SetFinal(ilabel);
if (have_stats_) stats_.push_back(ReplaceStats());
}
depfst_.SetStart(root_fst_);
// An arc from each state (representing the FST) to the state representing the
// FST being replaced
for (Label ilabel = 0; ilabel < fst_array_.size(); ++ilabel) {
const auto *ifst = fst_array_[ilabel];
if (!ifst) continue;
for (StateIterator<Fst<Arc>> siter(*ifst); !siter.Done(); siter.Next()) {
const auto s = siter.Value();
if (have_stats_) {
++stats_[ilabel].nstates;
if (ifst->Final(s) != Weight::Zero()) ++stats_[ilabel].nfinal;
}
for (ArcIterator<Fst<Arc>> aiter(*ifst, s); !aiter.Done(); aiter.Next()) {
if (have_stats_) ++stats_[ilabel].narcs;
const auto &arc = aiter.Value();
if (auto it = nonterminal_hash_.find(arc.olabel);
it != nonterminal_hash_.end()) {
const auto nextstate = it->second;
depfst_.EmplaceArc(ilabel, arc.olabel, arc.olabel, nextstate);
if (have_stats_) {
++stats_[ilabel].nnonterms;
++stats_[nextstate].nref;
++stats_[nextstate].inref[ilabel];
++stats_[ilabel].outref[nextstate];
}
}
}
}
}
// Computes accessibility info.
SccVisitor<Arc> scc_visitor(&depscc_, &depaccess_, nullptr, &depprops_);
DfsVisit(depfst_, &scc_visitor);
}
template <class Arc>
void ReplaceUtil<Arc>::UpdateStats(Label j) {
if (!have_stats_) {
FSTERROR() << "ReplaceUtil::UpdateStats: Stats not available";
return;
}
if (j == root_fst_) return; // Can't replace root.
for (auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) {
const auto i = in->first;
const auto ni = in->second;
stats_[i].nstates += stats_[j].nstates * ni;
stats_[i].narcs += (stats_[j].narcs + 1) * ni;
stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
stats_[i].outref.erase(j);
for (auto out = stats_[j].outref.begin(); out != stats_[j].outref.end();
++out) {
const auto k = out->first;
const auto nk = out->second;
stats_[i].outref[k] += ni * nk;
}
}
for (auto out = stats_[j].outref.begin(); out != stats_[j].outref.end();
++out) {
const auto k = out->first;
const auto nk = out->second;
stats_[k].nref -= nk;
stats_[k].inref.erase(j);
for (auto in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) {
const auto i = in->first;
const auto ni = in->second;
stats_[k].inref[i] += ni * nk;
stats_[k].nref += ni * nk;
}
}
}
template <class Arc>
void ReplaceUtil<Arc>::CheckMutableFsts() {
if (mutable_fst_array_.empty()) {
for (Label i = 0; i < fst_array_.size(); ++i) {
if (!fst_array_[i]) {
mutable_fst_array_.push_back(nullptr);
} else {
mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
delete fst_array_[i];
fst_array_[i] = mutable_fst_array_[i];
}
}
}
}
template <class Arc>
void ReplaceUtil<Arc>::Connect() {
CheckMutableFsts();
static constexpr auto props = kAccessible | kCoAccessible;
for (auto *mutable_fst : mutable_fst_array_) {
if (!mutable_fst) continue;
if (mutable_fst->Properties(props, false) != props) {
fst::Connect(mutable_fst);
}
}
GetDependencies(false);
for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
auto *fst = mutable_fst_array_[i];
if (fst && !depaccess_[i]) {
delete fst;
fst_array_[i] = nullptr;
mutable_fst_array_[i] = nullptr;
}
}
ClearDependencies();
}
template <class Arc>
bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
std::vector<Label> *toporder) const {
// Finds topological order of dependencies.
std::vector<StateId> order;
bool acyclic = false;
TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
DfsVisit(fst, &top_order_visitor);
if (!acyclic) {
LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
return false;
}
toporder->resize(order.size());
for (Label i = 0; i < order.size(); ++i) (*toporder)[order[i]] = i;
return true;
}
template <class Arc>
void ReplaceUtil<Arc>::ReplaceLabels(const std::vector<Label> &labels) {
CheckMutableFsts();
std::unordered_set<Label> label_set;
for (const auto label : labels) {
// Can't replace root.
if (label != root_label_) label_set.insert(label);
}
// Finds FST dependencies restricted to the labels requested.
GetDependencies(false);
VectorFst<Arc> pfst(depfst_);
for (StateId i = 0; i < pfst.NumStates(); ++i) {
std::vector<Arc> arcs;
for (ArcIterator<VectorFst<Arc>> aiter(pfst, i); !aiter.Done();
aiter.Next()) {
const auto &arc = aiter.Value();
const auto label = nonterminal_array_[arc.nextstate];
if (label_set.count(label) > 0) arcs.push_back(arc);
}
pfst.DeleteArcs(i);
for (auto &arc : arcs) pfst.AddArc(i, std::move(arc));
}
std::vector<Label> toporder;
if (!GetTopOrder(pfst, &toporder)) {
ClearDependencies();
return;
}
// Visits FSTs in reverse topological order of dependencies and performs
// replacements.
for (Label o = toporder.size() - 1; o >= 0; --o) {
std::vector<FstPair> fst_pairs;
auto s = toporder[o];
for (ArcIterator<VectorFst<Arc>> aiter(pfst, s); !aiter.Done();
aiter.Next()) {
const auto &arc = aiter.Value();
const auto label = nonterminal_array_[arc.nextstate];
const auto *fst = fst_array_[arc.nextstate];
fst_pairs.emplace_back(label, fst);
}
if (fst_pairs.empty()) continue;
const auto label = nonterminal_array_[s];
const auto *fst = fst_array_[s];
fst_pairs.emplace_back(label, fst);
const ReplaceUtilOptions opts(label, call_label_type_, return_label_type_,
return_label_);
Replace(fst_pairs, mutable_fst_array_[s], opts);
}
ClearDependencies();
}
template <class Arc>
void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
size_t nnonterms) {
std::vector<Label> labels;
GetDependencies(true);
std::vector<Label> toporder;
if (!GetTopOrder(depfst_, &toporder)) {
ClearDependencies();
return;
}
for (Label o = toporder.size() - 1; o >= 0; --o) {
const auto j = toporder[o];
if (stats_[j].nstates <= nstates && stats_[j].narcs <= narcs &&
stats_[j].nnonterms <= nnonterms) {
labels.push_back(nonterminal_array_[j]);
UpdateStats(j);
}
}
ReplaceLabels(labels);
}
template <class Arc>
void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
std::vector<Label> labels;
GetDependencies(true);
std::vector<Label> toporder;
if (!GetTopOrder(depfst_, &toporder)) {
ClearDependencies();
return;
}
for (Label o = 0; o < toporder.size(); ++o) {
const auto j = toporder[o];
if (stats_[j].nref <= ninstances) {
labels.push_back(nonterminal_array_[j]);
UpdateStats(j);
}
}
ReplaceLabels(labels);
}
template <class Arc>
void ReplaceUtil<Arc>::GetFstPairs(std::vector<FstPair> *fst_pairs) {
CheckMutableFsts();
fst_pairs->clear();
for (Label i = 0; i < fst_array_.size(); ++i) {
const auto label = nonterminal_array_[i];
const auto *fst = fst_array_[i];
if (!fst) continue;
fst_pairs->emplace_back(label, fst);
}
}
template <class Arc>
void ReplaceUtil<Arc>::GetMutableFstPairs(
std::vector<MutableFstPair> *mutable_fst_pairs) {
CheckMutableFsts();
mutable_fst_pairs->clear();
for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
const auto label = nonterminal_array_[i];
const auto *fst = mutable_fst_array_[i];
if (!fst) continue;
mutable_fst_pairs->emplace_back(label, fst->Copy());
}
}
template <class Arc>
void ReplaceUtil<Arc>::GetSCCProperties() const {
if (!depsccprops_.empty()) return;
GetDependencies(false);
if (depscc_.empty()) return;
for (StateId scc = 0; scc < depscc_.size(); ++scc) {
depsccprops_.push_back(kReplaceSCCLeftLinear | kReplaceSCCRightLinear);
}
if (!(depprops_ & kCyclic)) return; // No cyclic dependencies.
// Checks for self-loops in the dependency graph.
for (StateId scc = 0; scc < depscc_.size(); ++scc) {
for (ArcIterator<Fst<Arc>> aiter(depfst_, scc); !aiter.Done();
aiter.Next()) {
const auto &arc = aiter.Value();
if (arc.nextstate == scc) { // SCC has a self loop.
depsccprops_[scc] |= kReplaceSCCNonTrivial;
}
}
}
std::vector<bool> depscc_visited(depscc_.size(), false);
for (Label i = 0; i < fst_array_.size(); ++i) {
const auto *fst = fst_array_[i];
if (!fst) continue;
const auto depscc = depscc_[i];
if (depscc_visited[depscc]) { // SCC has more than one state.
depsccprops_[depscc] |= kReplaceSCCNonTrivial;
}
depscc_visited[depscc] = true;
std::vector<StateId> fstscc; // SCCs of the current FST.
uint64_t fstprops;
SccVisitor<Arc> scc_visitor(&fstscc, nullptr, nullptr, &fstprops);
DfsVisit(*fst, &scc_visitor);
for (StateIterator<Fst<Arc>> siter(*fst); !siter.Done(); siter.Next()) {
const auto s = siter.Value();
for (ArcIterator<Fst<Arc>> aiter(*fst, s); !aiter.Done(); aiter.Next()) {
const auto &arc = aiter.Value();
auto it = nonterminal_hash_.find(arc.olabel);
if (it == nonterminal_hash_.end() || depscc_[it->second] != depscc) {
continue; // Skips if a terminal or a non-terminal not in SCC.
}
const bool arc_in_cycle = fstscc[s] == fstscc[arc.nextstate];
// Left linear iff all non-terminals are initial.
if (s != fst->Start() || arc_in_cycle) {
depsccprops_[depscc] &= ~kReplaceSCCLeftLinear;
}
// Right linear iff all non-terminals are final.
if (fst->Final(arc.nextstate) == Weight::Zero() || arc_in_cycle) {
depsccprops_[depscc] &= ~kReplaceSCCRightLinear;
}
}
}
}
}
} // namespace fst
#endif // FST_REPLACE_UTIL_H_