You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

484 lines
16 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Functions and classes to relabel an FST (either on input or output).
#ifndef FST_RELABEL_H_
#define FST_RELABEL_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <fst/log.h>
#include <fst/arc.h>
#include <fst/cache.h>
#include <fst/float-weight.h>
#include <fst/fst.h>
#include <fst/impl-to-fst.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/symbol-table.h>
#include <fst/util.h>
#include <unordered_map>
namespace fst {
// Relabels either the input labels or output labels. The old to
// new labels are specified using a vector of std::pair<Label, Label>.
// Any label associations not specified are assumed to be identity
// mapping. The destination labels must be valid labels (e.g., not kNoLabel).
template <class Arc>
void Relabel(
MutableFst<Arc> *fst,
const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
&ipairs,
const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
&opairs) {
using Label = typename Arc::Label;
const auto props = fst->Properties(kFstProperties, false);
// Constructs label-to-label maps.
const std::unordered_map<Label, Label> input_map(
ipairs.begin(), ipairs.end());
const std::unordered_map<Label, Label> output_map(
opairs.begin(), opairs.end());
for (StateIterator<MutableFst<Arc>> siter(*fst); !siter.Done();
siter.Next()) {
for (MutableArcIterator<MutableFst<Arc>> aiter(fst, siter.Value());
!aiter.Done(); aiter.Next()) {
auto arc = aiter.Value();
// dense_hash_map does not support find on the empty_key_val.
// These labels should never be in an FST anyway.
DCHECK_NE(arc.ilabel, kNoLabel);
DCHECK_NE(arc.olabel, kNoLabel);
// Relabels input.
if (auto it = input_map.find(arc.ilabel); it != input_map.end()) {
if (it->second == kNoLabel) {
FSTERROR() << "Input symbol ID " << arc.ilabel
<< " missing from target vocabulary";
fst->SetProperties(kError, kError);
return;
}
arc.ilabel = it->second;
}
// Relabels output.
if (auto it = output_map.find(arc.olabel); it != output_map.end()) {
if (it->second == kNoLabel) {
FSTERROR() << "Output symbol id " << arc.olabel
<< " missing from target vocabulary";
fst->SetProperties(kError, kError);
return;
}
arc.olabel = it->second;
}
aiter.SetValue(arc);
}
}
fst->SetProperties(RelabelProperties(props), kFstProperties);
}
// Relabels either the input labels or output labels. The old to
// new labels are specified using pairs of old and new symbol tables.
// The tables must contain (at least) all labels on the appropriate side of the
// FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any
// missing symbol in new_i(o)symbols table.
template <class Arc>
void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
const SymbolTable *new_isymbols,
const std::string &unknown_isymbol, bool attach_new_isymbols,
const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
const std::string &unknown_osymbol, bool attach_new_osymbols) {
using Label = typename Arc::Label;
// Constructs vectors of input-side label pairs.
std::vector<std::pair<Label, Label>> ipairs;
if (old_isymbols && new_isymbols) {
size_t num_missing_syms = 0;
Label unknown_ilabel = kNoLabel;
if (!unknown_isymbol.empty()) {
unknown_ilabel = new_isymbols->Find(unknown_isymbol);
if (unknown_ilabel == kNoLabel) {
VLOG(1) << "Input symbol '" << unknown_isymbol
<< "' missing from target symbol table";
++num_missing_syms;
}
}
for (const auto &sitem : *old_isymbols) {
const auto old_index = sitem.Label();
const auto symbol = sitem.Symbol();
auto new_index = new_isymbols->Find(symbol);
if (new_index == kNoLabel) {
if (unknown_ilabel != kNoLabel) {
new_index = unknown_ilabel;
} else {
VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol
<< "' missing from target symbol table";
++num_missing_syms;
}
}
ipairs.emplace_back(old_index, new_index);
}
if (num_missing_syms > 0) {
LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
<< " input symbols";
}
if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols);
}
// Constructs vectors of output-side label pairs.
std::vector<std::pair<Label, Label>> opairs;
if (old_osymbols && new_osymbols) {
size_t num_missing_syms = 0;
Label unknown_olabel = kNoLabel;
if (!unknown_osymbol.empty()) {
unknown_olabel = new_osymbols->Find(unknown_osymbol);
if (unknown_olabel == kNoLabel) {
VLOG(1) << "Output symbol '" << unknown_osymbol
<< "' missing from target symbol table";
++num_missing_syms;
}
}
for (const auto &sitem : *old_osymbols) {
const auto old_index = sitem.Label();
const auto symbol = sitem.Symbol();
auto new_index = new_osymbols->Find(symbol);
if (new_index == kNoLabel) {
if (unknown_olabel != kNoLabel) {
new_index = unknown_olabel;
} else {
VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol
<< "' missing from target symbol table";
++num_missing_syms;
}
}
opairs.emplace_back(old_index, new_index);
}
if (num_missing_syms > 0) {
LOG(WARNING) << "Target symbol table missing: " << num_missing_syms
<< " output symbols";
}
if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols);
}
// Calls relabel using vector of relabel pairs.
Relabel(fst, ipairs, opairs);
}
// Same as previous but no special allowance for unknown symbols. Kept
// for backward compat.
template <class Arc>
void Relabel(MutableFst<Arc> *fst, const SymbolTable *old_isymbols,
const SymbolTable *new_isymbols, bool attach_new_isymbols,
const SymbolTable *old_osymbols, const SymbolTable *new_osymbols,
bool attach_new_osymbols) {
Relabel(fst, old_isymbols, new_isymbols, "" /* no unknown isymbol */,
attach_new_isymbols, old_osymbols, new_osymbols,
"" /* no unknown osymbol */, attach_new_osymbols);
}
// Relabels either the input labels or output labels. The old to
// new labels are specified using symbol tables. Any label associations not
// specified are assumed to be identity mapping.
template <class Arc>
void Relabel(MutableFst<Arc> *fst, const SymbolTable *new_isymbols,
const SymbolTable *new_osymbols) {
Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(),
new_osymbols, true);
}
using RelabelFstOptions = CacheOptions;
template <class Arc>
class RelabelFst;
namespace internal {
// Relabels an FST from one symbol set to another. Relabeling can either be on
// input or output space. RelabelFst implements a delayed version of the
// relabel. Arcs are relabeled on the fly and not cached; i.e., each request is
// recomputed.
template <class Arc>
class RelabelFstImpl : public CacheImpl<Arc> {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using FstImpl<Arc>::SetType;
using FstImpl<Arc>::SetProperties;
using FstImpl<Arc>::WriteHeader;
using FstImpl<Arc>::SetInputSymbols;
using FstImpl<Arc>::SetOutputSymbols;
using CacheImpl<Arc>::PushArc;
using CacheImpl<Arc>::HasArcs;
using CacheImpl<Arc>::HasFinal;
using CacheImpl<Arc>::HasStart;
using CacheImpl<Arc>::SetArcs;
using CacheImpl<Arc>::SetFinal;
using CacheImpl<Arc>::SetStart;
friend class StateIterator<RelabelFst<Arc>>;
RelabelFstImpl(const Fst<Arc> &fst,
const std::vector<std::pair<Label, Label>> &ipairs,
const std::vector<std::pair<Label, Label>> &opairs,
const RelabelFstOptions &opts)
: CacheImpl<Arc>(opts),
fst_(fst.Copy()),
input_map_(ipairs.begin(), ipairs.end()),
output_map_(opairs.begin(), opairs.end()),
relabel_input_(!ipairs.empty()),
relabel_output_(!opairs.empty()) {
SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false)));
SetType("relabel");
}
RelabelFstImpl(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
const SymbolTable *new_isymbols,
const SymbolTable *old_osymbols,
const SymbolTable *new_osymbols, const RelabelFstOptions &opts)
: CacheImpl<Arc>(opts),
fst_(fst.Copy()),
relabel_input_(false),
relabel_output_(false) {
SetType("relabel");
SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false)));
SetInputSymbols(old_isymbols);
SetOutputSymbols(old_osymbols);
if (old_isymbols && new_isymbols &&
old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
for (const auto &sitem : *old_isymbols) {
input_map_[sitem.Label()] = new_isymbols->Find(sitem.Symbol());
}
SetInputSymbols(new_isymbols);
relabel_input_ = true;
}
if (old_osymbols && new_osymbols &&
old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
for (const auto &sitem : *old_osymbols) {
output_map_[sitem.Label()] = new_osymbols->Find(sitem.Symbol());
}
SetOutputSymbols(new_osymbols);
relabel_output_ = true;
}
}
RelabelFstImpl(const RelabelFstImpl<Arc> &impl)
: CacheImpl<Arc>(impl),
fst_(impl.fst_->Copy(true)),
input_map_(impl.input_map_),
output_map_(impl.output_map_),
relabel_input_(impl.relabel_input_),
relabel_output_(impl.relabel_output_) {
SetType("relabel");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
StateId Start() {
if (!HasStart()) SetStart(fst_->Start());
return CacheImpl<Arc>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) SetFinal(s, fst_->Final(s));
return CacheImpl<Arc>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s)) Expand(s);
return CacheImpl<Arc>::NumOutputEpsilons(s);
}
uint64_t Properties() const override { return Properties(kFstProperties); }
// Sets error if found, and returns other FST impl properties.
uint64_t Properties(uint64_t mask) const override {
if ((mask & kError) && fst_->Properties(kError, false)) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
if (!HasArcs(s)) Expand(s);
CacheImpl<Arc>::InitArcIterator(s, data);
}
void Expand(StateId s) {
for (ArcIterator<Fst<Arc>> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
auto arc = aiter.Value();
if (relabel_input_) {
if (auto it = input_map_.find(arc.ilabel); it != input_map_.end()) {
arc.ilabel = it->second;
}
}
if (relabel_output_) {
if (auto it = output_map_.find(arc.olabel); it != output_map_.end()) {
arc.olabel = it->second;
}
}
PushArc(s, std::move(arc));
}
SetArcs(s);
}
private:
std::unique_ptr<const Fst<Arc>> fst_;
std::unordered_map<Label, Label> input_map_;
std::unordered_map<Label, Label> output_map_;
bool relabel_input_;
bool relabel_output_;
};
} // namespace internal
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A>
class RelabelFst : public ImplToFst<internal::RelabelFstImpl<A>> {
public:
using Arc = A;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using Store = DefaultCacheStore<Arc>;
using State = typename Store::State;
using Impl = internal::RelabelFstImpl<Arc>;
friend class ArcIterator<RelabelFst<A>>;
friend class StateIterator<RelabelFst<A>>;
RelabelFst(const Fst<Arc> &fst,
const std::vector<std::pair<Label, Label>> &ipairs,
const std::vector<std::pair<Label, Label>> &opairs,
const RelabelFstOptions &opts = RelabelFstOptions())
: ImplToFst<Impl>(std::make_shared<Impl>(fst, ipairs, opairs, opts)) {}
RelabelFst(const Fst<Arc> &fst, const SymbolTable *new_isymbols,
const SymbolTable *new_osymbols,
const RelabelFstOptions &opts = RelabelFstOptions())
: ImplToFst<Impl>(
std::make_shared<Impl>(fst, fst.InputSymbols(), new_isymbols,
fst.OutputSymbols(), new_osymbols, opts)) {}
RelabelFst(const Fst<Arc> &fst, const SymbolTable *old_isymbols,
const SymbolTable *new_isymbols, const SymbolTable *old_osymbols,
const SymbolTable *new_osymbols,
const RelabelFstOptions &opts = RelabelFstOptions())
: ImplToFst<Impl>(std::make_shared<Impl>(fst, old_isymbols, new_isymbols,
old_osymbols, new_osymbols,
opts)) {}
// See Fst<>::Copy() for doc.
RelabelFst(const RelabelFst &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc.
RelabelFst *Copy(bool safe = false) const override {
return new RelabelFst(*this, safe);
}
void InitStateIterator(StateIteratorData<Arc> *data) const override;
void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
return GetMutableImpl()->InitArcIterator(s, data);
}
private:
using ImplToFst<Impl>::GetImpl;
using ImplToFst<Impl>::GetMutableImpl;
RelabelFst &operator=(const RelabelFst &) = delete;
};
// Specialization for RelabelFst.
template <class Arc>
class StateIterator<RelabelFst<Arc>> : public StateIteratorBase<Arc> {
public:
using StateId = typename Arc::StateId;
explicit StateIterator(const RelabelFst<Arc> &fst)
: impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
bool Done() const final { return siter_.Done(); }
StateId Value() const final { return s_; }
void Next() final {
if (!siter_.Done()) {
++s_;
siter_.Next();
}
}
void Reset() final {
s_ = 0;
siter_.Reset();
}
private:
const internal::RelabelFstImpl<Arc> *impl_;
StateIterator<Fst<Arc>> siter_;
StateId s_;
StateIterator(const StateIterator &) = delete;
StateIterator &operator=(const StateIterator &) = delete;
};
// Specialization for RelabelFst.
template <class Arc>
class ArcIterator<RelabelFst<Arc>> : public CacheArcIterator<RelabelFst<Arc>> {
public:
using StateId = typename Arc::StateId;
ArcIterator(const RelabelFst<Arc> &fst, StateId s)
: CacheArcIterator<RelabelFst<Arc>>(fst.GetMutableImpl(), s) {
if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
}
};
template <class Arc>
inline void RelabelFst<Arc>::InitStateIterator(
StateIteratorData<Arc> *data) const {
data->base = std::make_unique<StateIterator<RelabelFst<Arc>>>(*this);
}
// Useful alias when using StdArc.
using StdRelabelFst = RelabelFst<StdArc>;
} // namespace fst
#endif // FST_RELABEL_H_