// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Functions and classes to relabel an FST (either on input or output).
|
|
|
|
#ifndef FST_RELABEL_H_
|
|
#define FST_RELABEL_H_
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <fst/log.h>
|
|
#include <fst/arc.h>
|
|
#include <fst/cache.h>
|
|
#include <fst/float-weight.h>
|
|
#include <fst/fst.h>
|
|
#include <fst/impl-to-fst.h>
|
|
#include <fst/mutable-fst.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/symbol-table.h>
|
|
#include <fst/util.h>
|
|
#include <unordered_map>
|
|
|
|
namespace fst {
|
|
|
|
// Relabels either the input labels or output labels. The old to
|
|
// new labels are specified using a vector of std::pair<Label, Label>.
|
|
// Any label associations not specified are assumed to be identity
|
|
// mapping. The destination labels must be valid labels (e.g., not kNoLabel).
|
|
template <class Arc>
|
|
void Relabel(
|
|
MutableFst<Arc> *fst,
|
|
const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
|
|
&ipairs,
|
|
const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
|
|
&opairs) {
|
|
using Label = typename Arc::Label;
|
|
const auto props = fst->Properties(kFstProperties, false);
|
|
// Constructs label-to-label maps.
|
|
const std::unordered_map<Label, Label> input_map(
|
|
ipairs.begin(), ipairs.end());
|
|
const std::unordered_map<Label, Label> output_map(
|
|
opairs.begin(), opairs.end());
|
|
for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
|
|
siter.Next()) {
|
|
for (MutableArcIterator<MutableFst<Arc>> aiter(fst, siter.Value());
|
|
!aiter.Done(); aiter.Next()) {
|
|
auto arc = aiter.Value();
|
|
// dense_hash_map does not support find on the empty_key_val.
|
|
// These labels should never be in an FST anyway.
|
|
DCHECK_NE(arc.ilabel, kNoLabel);
|
|
DCHECK_NE(arc.olabel, kNoLabel);
|
|
// Relabels input.
|
|
if (auto it = input_map.find(arc.ilabel); it != input_map.end()) {
|
|
if (it->second == kNoLabel) {
|
|
FSTERROR() << "Input symbol ID " << arc.ilabel
|
|
<< " missing from target vocabulary";
|
|
fst->SetProperties(kError, kError);
|
|
return;
|
|
}
|
|
arc.ilabel = it->second;
|
|
}
|
|
// Relabels output.
|
|
if (auto it = output_map.find(arc.olabel); it != output_map.end()) {
|
|
if (it->second == kNoLabel) {
|
|
FSTERROR() << "Output symbol id " << arc.olabel
|
|
<< " missing from target vocabulary";
|
|
fst->SetProperties(kError, kError);
|
|
return;
|
|
}
|
|
arc.olabel = it->second;
|
|
}
|
|
aiter.SetValue(arc);
|
|
}
|
|
}
|
|
fst->SetProperties(RelabelProperties(props), kFstProperties);
|
|
}
|
|
|
|
// Relabels either the input labels or output labels. The old to
|
|
// new labels are specified using pairs of old and new symbol tables.
|
|
// The tables must contain (at least) all labels on the appropriate side of the
|
|
// FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any
|
|
// missing symbol in new_i(o)symbols table.
|
|
template <class Arc>
|
|
void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
|
|
const SymbolTable *new_isymbols,
|
|
const std::string &unknown_isymbol, bool attach_new_isymbols,
|
|
const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
|
|
const std::string &unknown_osymbol, bool attach_new_osymbols) {
|
|
using Label = typename Arc::Label;
|
|
// Constructs vectors of input-side label pairs.
|
|
std::vector<std::pair<Label, Label>> ipairs;
|
|
if (old_isymbols && new_isymbols) {
|
|
size_t num_missing_syms = 0;
|
|
Label unknown_ilabel = kNoLabel;
|
|
if (!unknown_isymbol.empty()) {
|
|
unknown_ilabel = new_isymbols->Find(unknown_isymbol);
|
|
if (unknown_ilabel == kNoLabel) {
|
|
VLOG(1) << "Input symbol '" << unknown_isymbol
|
|
<< "' missing from target symbol table";
|
|
++num_missing_syms;
|
|
}
|
|
}
|
|
|
|
for (const auto &sitem : *old_isymbols) {
|
|
const auto old_index = sitem.Label();
|
|
const auto symbol = sitem.Symbol();
|
|
auto new_index = new_isymbols->Find(symbol);
|
|
if (new_index == kNoLabel) {
|
|
if (unknown_ilabel != kNoLabel) {
|
|
new_index = unknown_ilabel;
|
|
} else {
|
|
VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol
|
|
<< "' missing from target symbol table";
|
|
++num_missing_syms;
|
|
}
|
|
}
|
|
ipairs.emplace_back(old_index, new_index);
|
|
}
|
|
if (num_missing_syms > 0) {
|
|
LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
|
|
<< " input symbols";
|
|
}
|
|
if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols);
|
|
}
|
|
// Constructs vectors of output-side label pairs.
|
|
std::vector<std::pair<Label, Label>> opairs;
|
|
if (old_osymbols && new_osymbols) {
|
|
size_t num_missing_syms = 0;
|
|
Label unknown_olabel = kNoLabel;
|
|
if (!unknown_osymbol.empty()) {
|
|
unknown_olabel = new_osymbols->Find(unknown_osymbol);
|
|
if (unknown_olabel == kNoLabel) {
|
|
VLOG(1) << "Output symbol '" << unknown_osymbol
|
|
<< "' missing from target symbol table";
|
|
++num_missing_syms;
|
|
}
|
|
}
|
|
for (const auto &sitem : *old_osymbols) {
|
|
const auto old_index = sitem.Label();
|
|
const auto symbol = sitem.Symbol();
|
|
auto new_index = new_osymbols->Find(symbol);
|
|
if (new_index == kNoLabel) {
|
|
if (unknown_olabel != kNoLabel) {
|
|
new_index = unknown_olabel;
|
|
} else {
|
|
VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol
|
|
<< "' missing from target symbol table";
|
|
++num_missing_syms;
|
|
}
|
|
}
|
|
opairs.emplace_back(old_index, new_index);
|
|
}
|
|
if (num_missing_syms > 0) {
|
|
LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
|
|
<< " output symbols";
|
|
}
|
|
if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols);
|
|
}
|
|
// Calls relabel using vector of relabel pairs.
|
|
Relabel(fst, ipairs, opairs);
|
|
}
|
|
|
|
// Same as previous but no special allowance for unknown symbols. Kept
|
|
// for backward compat.
|
|
template <class Arc>
|
|
void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
|
|
const SymbolTable *new_isymbols, bool attach_new_isymbols,
|
|
const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
|
|
bool attach_new_osymbols) {
|
|
Relabel(fst, old_isymbols, new_isymbols, "" /* no unknown isymbol */,
|
|
attach_new_isymbols, old_osymbols, new_osymbols,
|
|
"" /* no unknown osymbol */, attach_new_osymbols);
|
|
}
|
|
|
|
// Relabels either the input labels or output labels. The old to
|
|
// new labels are specified using symbol tables. Any label associations not
|
|
// specified are assumed to be identity mapping.
|
|
template <class Arc>
|
|
void Relabel(MutableFst<Arc> *fst, const SymbolTable *new_isymbols,
|
|
const SymbolTable *new_osymbols) {
|
|
Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(),
|
|
new_osymbols, true);
|
|
}
|
|
|
|
using RelabelFstOptions = CacheOptions;
|
|
|
|
template <class Arc>
|
|
class RelabelFst;
|
|
|
|
namespace internal {
|
|
|
|
// Relabels an FST from one symbol set to another. Relabeling can either be on
|
|
// input or output space. RelabelFst implements a delayed version of the
|
|
// relabel. Arcs are relabeled on the fly and not cached; i.e., each request is
|
|
// recomputed.
|
|
template <class Arc>
|
|
class RelabelFstImpl : public CacheImpl<Arc> {
|
|
public:
|
|
using Label = typename Arc::Label;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
using Store = DefaultCacheStore<Arc>;
|
|
using State = typename Store::State;
|
|
|
|
using FstImpl<Arc>::SetType;
|
|
using FstImpl<Arc>::SetProperties;
|
|
using FstImpl<Arc>::WriteHeader;
|
|
using FstImpl<Arc>::SetInputSymbols;
|
|
using FstImpl<Arc>::SetOutputSymbols;
|
|
|
|
using CacheImpl<Arc>::PushArc;
|
|
using CacheImpl<Arc>::HasArcs;
|
|
using CacheImpl<Arc>::HasFinal;
|
|
using CacheImpl<Arc>::HasStart;
|
|
using CacheImpl<Arc>::SetArcs;
|
|
using CacheImpl<Arc>::SetFinal;
|
|
using CacheImpl<Arc>::SetStart;
|
|
|
|
friend class StateIterator<RelabelFst<Arc>>;
|
|
|
|
RelabelFstImpl(const Fst<Arc> &fst,
|
|
const std::vector<std::pair<Label, Label>> &ipairs,
|
|
const std::vector<std::pair<Label, Label>> &opairs,
|
|
const RelabelFstOptions &opts)
|
|
: CacheImpl<Arc>(opts),
|
|
fst_(fst.Copy()),
|
|
input_map_(ipairs.begin(), ipairs.end()),
|
|
output_map_(opairs.begin(), opairs.end()),
|
|
relabel_input_(!ipairs.empty()),
|
|
relabel_output_(!opairs.empty()) {
|
|
SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false)));
|
|
SetType("relabel");
|
|
}
|
|
|
|
RelabelFstImpl(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
|
|
const SymbolTable *new_isymbols,
|
|
const SymbolTable *old_osymbols,
|
|
const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
|
|
: CacheImpl<Arc>(opts),
|
|
fst_(fst.Copy()),
|
|
relabel_input_(false),
|
|
relabel_output_(false) {
|
|
SetType("relabel");
|
|
SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false)));
|
|
SetInputSymbols(old_isymbols);
|
|
SetOutputSymbols(old_osymbols);
|
|
if (old_isymbols && new_isymbols &&
|
|
old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
|
|
for (const auto &sitem : *old_isymbols) {
|
|
input_map_[sitem.Label()] = new_isymbols->Find(sitem.Symbol());
|
|
}
|
|
SetInputSymbols(new_isymbols);
|
|
relabel_input_ = true;
|
|
}
|
|
if (old_osymbols && new_osymbols &&
|
|
old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
|
|
for (const auto &sitem : *old_osymbols) {
|
|
output_map_[sitem.Label()] = new_osymbols->Find(sitem.Symbol());
|
|
}
|
|
SetOutputSymbols(new_osymbols);
|
|
relabel_output_ = true;
|
|
}
|
|
}
|
|
|
|
RelabelFstImpl(const RelabelFstImpl<Arc> &impl)
|
|
: CacheImpl<Arc>(impl),
|
|
fst_(impl.fst_->Copy(true)),
|
|
input_map_(impl.input_map_),
|
|
output_map_(impl.output_map_),
|
|
relabel_input_(impl.relabel_input_),
|
|
relabel_output_(impl.relabel_output_) {
|
|
SetType("relabel");
|
|
SetProperties(impl.Properties(), kCopyProperties);
|
|
SetInputSymbols(impl.InputSymbols());
|
|
SetOutputSymbols(impl.OutputSymbols());
|
|
}
|
|
|
|
StateId Start() {
|
|
if (!HasStart()) SetStart(fst_->Start());
|
|
return CacheImpl<Arc>::Start();
|
|
}
|
|
|
|
Weight Final(StateId s) {
|
|
if (!HasFinal(s)) SetFinal(s, fst_->Final(s));
|
|
return CacheImpl<Arc>::Final(s);
|
|
}
|
|
|
|
size_t NumArcs(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<Arc>::NumArcs(s);
|
|
}
|
|
|
|
size_t NumInputEpsilons(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<Arc>::NumInputEpsilons(s);
|
|
}
|
|
|
|
size_t NumOutputEpsilons(StateId s) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
return CacheImpl<Arc>::NumOutputEpsilons(s);
|
|
}
|
|
|
|
uint64_t Properties() const override { return Properties(kFstProperties); }
|
|
|
|
// Sets error if found, and returns other FST impl properties.
|
|
uint64_t Properties(uint64_t mask) const override {
|
|
if ((mask & kError) && fst_->Properties(kError, false)) {
|
|
SetProperties(kError, kError);
|
|
}
|
|
return FstImpl<Arc>::Properties(mask);
|
|
}
|
|
|
|
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
|
|
if (!HasArcs(s)) Expand(s);
|
|
CacheImpl<Arc>::InitArcIterator(s, data);
|
|
}
|
|
|
|
void Expand(StateId s) {
|
|
for (ArcIterator<Fst<Arc>> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
|
|
auto arc = aiter.Value();
|
|
if (relabel_input_) {
|
|
if (auto it = input_map_.find(arc.ilabel); it != input_map_.end()) {
|
|
arc.ilabel = it->second;
|
|
}
|
|
}
|
|
if (relabel_output_) {
|
|
if (auto it = output_map_.find(arc.olabel); it != output_map_.end()) {
|
|
arc.olabel = it->second;
|
|
}
|
|
}
|
|
PushArc(s, std::move(arc));
|
|
}
|
|
SetArcs(s);
|
|
}
|
|
|
|
private:
|
|
std::unique_ptr<const Fst<Arc>> fst_;
|
|
|
|
std::unordered_map<Label, Label> input_map_;
|
|
std::unordered_map<Label, Label> output_map_;
|
|
bool relabel_input_;
|
|
bool relabel_output_;
|
|
};
|
|
|
|
} // namespace internal
|
|
|
|
// This class attaches interface to implementation and handles
|
|
// reference counting, delegating most methods to ImplToFst.
|
|
template <class A>
|
|
class RelabelFst : public ImplToFst<internal::RelabelFstImpl<A>> {
|
|
public:
|
|
using Arc = A;
|
|
using Label = typename Arc::Label;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
using Store = DefaultCacheStore<Arc>;
|
|
using State = typename Store::State;
|
|
using Impl = internal::RelabelFstImpl<Arc>;
|
|
|
|
friend class ArcIterator<RelabelFst<A>>;
|
|
friend class StateIterator<RelabelFst<A>>;
|
|
|
|
RelabelFst(const Fst<Arc> &fst,
|
|
const std::vector<std::pair<Label, Label>> &ipairs,
|
|
const std::vector<std::pair<Label, Label>> &opairs,
|
|
const RelabelFstOptions &opts = RelabelFstOptions())
|
|
: ImplToFst<Impl>(std::make_shared<Impl>(fst, ipairs, opairs, opts)) {}
|
|
|
|
RelabelFst(const Fst<Arc> &fst, const SymbolTable *new_isymbols,
|
|
const SymbolTable *new_osymbols,
|
|
const RelabelFstOptions &opts = RelabelFstOptions())
|
|
: ImplToFst<Impl>(
|
|
std::make_shared<Impl>(fst, fst.InputSymbols(), new_isymbols,
|
|
fst.OutputSymbols(), new_osymbols, opts)) {}
|
|
|
|
RelabelFst(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
|
|
const SymbolTable *new_isymbols, const SymbolTable *old_osymbols,
|
|
const SymbolTable *new_osymbols,
|
|
const RelabelFstOptions &opts = RelabelFstOptions())
|
|
: ImplToFst<Impl>(std::make_shared<Impl>(fst, old_isymbols, new_isymbols,
|
|
old_osymbols, new_osymbols,
|
|
opts)) {}
|
|
|
|
// See Fst<>::Copy() for doc.
|
|
RelabelFst(const RelabelFst &fst, bool safe = false)
|
|
: ImplToFst<Impl>(fst, safe) {}
|
|
|
|
// Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc.
|
|
RelabelFst *Copy(bool safe = false) const override {
|
|
return new RelabelFst(*this, safe);
|
|
}
|
|
|
|
void InitStateIterator(StateIteratorData<Arc> *data) const override;
|
|
|
|
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
|
|
return GetMutableImpl()->InitArcIterator(s, data);
|
|
}
|
|
|
|
private:
|
|
using ImplToFst<Impl>::GetImpl;
|
|
using ImplToFst<Impl>::GetMutableImpl;
|
|
|
|
RelabelFst &operator=(const RelabelFst &) = delete;
|
|
};
|
|
|
|
// Specialization for RelabelFst.
|
|
template <class Arc>
|
|
class StateIterator<RelabelFst<Arc>> : public StateIteratorBase<Arc> {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
|
|
explicit StateIterator(const RelabelFst<Arc> &fst)
|
|
: impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
|
|
|
|
bool Done() const final { return siter_.Done(); }
|
|
|
|
StateId Value() const final { return s_; }
|
|
|
|
void Next() final {
|
|
if (!siter_.Done()) {
|
|
++s_;
|
|
siter_.Next();
|
|
}
|
|
}
|
|
|
|
void Reset() final {
|
|
s_ = 0;
|
|
siter_.Reset();
|
|
}
|
|
|
|
private:
|
|
const internal::RelabelFstImpl<Arc> *impl_;
|
|
StateIterator<Fst<Arc>> siter_;
|
|
StateId s_;
|
|
|
|
StateIterator(const StateIterator &) = delete;
|
|
StateIterator &operator=(const StateIterator &) = delete;
|
|
};
|
|
|
|
// Specialization for RelabelFst.
|
|
template <class Arc>
|
|
class ArcIterator<RelabelFst<Arc>> : public CacheArcIterator<RelabelFst<Arc>> {
|
|
public:
|
|
using StateId = typename Arc::StateId;
|
|
|
|
ArcIterator(const RelabelFst<Arc> &fst, StateId s)
|
|
: CacheArcIterator<RelabelFst<Arc>>(fst.GetMutableImpl(), s) {
|
|
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
|
|
}
|
|
};
|
|
|
|
template <class Arc>
|
|
inline void RelabelFst<Arc>::InitStateIterator(
|
|
StateIteratorData<Arc> *data) const {
|
|
data->base = std::make_unique<StateIterator<RelabelFst<Arc>>>(*this);
|
|
}
|
|
|
|
// Useful alias when using StdArc.
|
|
using StdRelabelFst = RelabelFst<StdArc>;
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_RELABEL_H_
|