// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Expanded FST augmented with mutators; interface class definition and
|
|
// mutable arc iterator interface.
|
|
|
|
#ifndef FST_MUTABLE_FST_H_
|
|
#define FST_MUTABLE_FST_H_
|
|
|
|
#include <sys/types.h>
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <ios>
|
|
#include <iostream>
|
|
#include <istream>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <fst/log.h>
|
|
#include <fst/arc.h>
|
|
#include <fst/expanded-fst.h>
|
|
#include <fstream>
|
|
#include <fst/fst.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/register.h>
|
|
#include <fst/symbol-table.h>
|
|
#include <string_view>
|
|
|
|
namespace fst {
|
|
|
|
template <class Arc>
|
|
struct MutableArcIteratorData;
|
|
|
|
// Abstract interface for an expanded FST which also supports mutation
|
|
// operations. To modify arcs, use MutableArcIterator.
|
|
template <class A>
|
|
class MutableFst : public ExpandedFst<A> {
|
|
public:
|
|
using Arc = A;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
virtual MutableFst<Arc> &operator=(const Fst<Arc> &fst) = 0;
|
|
|
|
MutableFst &operator=(const MutableFst &fst) {
|
|
return operator=(static_cast<const Fst<Arc> &>(fst));
|
|
}
|
|
|
|
// Sets the initial state.
|
|
virtual void SetStart(StateId) = 0;
|
|
|
|
// Sets a state's final weight.
|
|
virtual void SetFinal(StateId s, Weight weight = Weight::One()) = 0;
|
|
|
|
// Sets property bits w.r.t. mask.
|
|
virtual void SetProperties(uint64_t props, uint64_t mask) = 0;
|
|
|
|
// Adds a state and returns its ID.
|
|
virtual StateId AddState() = 0;
|
|
|
|
// Adds multiple states.
|
|
virtual void AddStates(size_t) = 0;
|
|
|
|
// Adds an arc to state.
|
|
virtual void AddArc(StateId, const Arc &) = 0;
|
|
|
|
// Adds an arc (passed by rvalue reference) to state. Allows subclasses
|
|
// to optionally implement move semantics. Defaults to lvalue overload.
|
|
virtual void AddArc(StateId state, Arc &&arc) { AddArc(state, arc); }
|
|
|
|
// Deletes some states, preserving original StateId ordering.
|
|
virtual void DeleteStates(const std::vector<StateId> &) = 0;
|
|
|
|
// Delete all states.
|
|
virtual void DeleteStates() = 0;
|
|
|
|
// Delete some arcs at a given state.
|
|
virtual void DeleteArcs(StateId, size_t) = 0;
|
|
|
|
// Delete all arcs at a given state.
|
|
virtual void DeleteArcs(StateId) = 0;
|
|
|
|
// Optional, best effort only.
|
|
virtual void ReserveStates(size_t) {}
|
|
|
|
// Optional, best effort only.
|
|
virtual void ReserveArcs(StateId, size_t) {}
|
|
|
|
// Returns input label symbol table or nullptr if not specified.
|
|
const SymbolTable *InputSymbols() const override = 0;
|
|
|
|
// Returns output label symbol table or nullptr if not specified.
|
|
const SymbolTable *OutputSymbols() const override = 0;
|
|
|
|
// Returns input label symbol table or nullptr if not specified.
|
|
virtual SymbolTable *MutableInputSymbols() = 0;
|
|
|
|
// Returns output label symbol table or nullptr if not specified.
|
|
virtual SymbolTable *MutableOutputSymbols() = 0;
|
|
|
|
// Sets input label symbol table; pass nullptr to delete table.
|
|
virtual void SetInputSymbols(const SymbolTable *isyms) = 0;
|
|
|
|
// Sets output label symbol table; pass nullptr to delete table.
|
|
virtual void SetOutputSymbols(const SymbolTable *osyms) = 0;
|
|
|
|
// Gets a copy of this MutableFst. See Fst<>::Copy() for further doc.
|
|
MutableFst *Copy(bool safe = false) const override = 0;
|
|
|
|
// Reads a MutableFst from an input stream, returning nullptr on error.
|
|
static MutableFst *Read(std::istream &strm, const FstReadOptions &opts) {
|
|
FstReadOptions ropts(opts);
|
|
FstHeader hdr;
|
|
if (ropts.header) {
|
|
hdr = *opts.header;
|
|
} else {
|
|
if (!hdr.Read(strm, opts.source)) return nullptr;
|
|
ropts.header = &hdr;
|
|
}
|
|
if (!(hdr.Properties() & kMutable)) {
|
|
LOG(ERROR) << "MutableFst::Read: Not a MutableFst: " << ropts.source;
|
|
return nullptr;
|
|
}
|
|
const auto &fst_type = hdr.FstType();
|
|
const auto reader = FstRegister<Arc>::GetRegister()->GetReader(fst_type);
|
|
if (!reader) {
|
|
LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << fst_type
|
|
<< "\" (arc type = \"" << A::Type() << "\"): " << ropts.source;
|
|
return nullptr;
|
|
}
|
|
auto *fst = reader(strm, ropts);
|
|
if (!fst) return nullptr;
|
|
return down_cast<MutableFst *>(fst);
|
|
}
|
|
|
|
// Reads a MutableFst from a file; returns nullptr on error. An empty
|
|
// source results in reading from standard input. If convert is true,
|
|
// convert to a mutable FST subclass (given by convert_type) in the case
|
|
// that the input FST is non-mutable.
|
|
static MutableFst *Read(const std::string &source, bool convert = false,
|
|
std::string_view convert_type = "vector") {
|
|
if (convert == false) {
|
|
if (!source.empty()) {
|
|
std::ifstream strm(source,
|
|
std::ios_base::in | std::ios_base::binary);
|
|
if (!strm) {
|
|
LOG(ERROR) << "MutableFst::Read: Can't open file: " << source;
|
|
return nullptr;
|
|
}
|
|
return Read(strm, FstReadOptions(source));
|
|
} else {
|
|
return Read(std::cin, FstReadOptions("standard input"));
|
|
}
|
|
} else { // Converts to 'convert_type' if not mutable.
|
|
std::unique_ptr<Fst<Arc>> ifst(Fst<Arc>::Read(source));
|
|
if (!ifst) return nullptr;
|
|
if (ifst->Properties(kMutable, false)) {
|
|
return down_cast<MutableFst *>(ifst.release());
|
|
} else {
|
|
std::unique_ptr<Fst<Arc>> ofst(Convert(*ifst, convert_type));
|
|
ifst.reset();
|
|
if (!ofst) return nullptr;
|
|
if (!ofst->Properties(kMutable, false)) {
|
|
LOG(ERROR) << "MutableFst: Bad convert type: " << convert_type;
|
|
}
|
|
return down_cast<MutableFst *>(ofst.release());
|
|
}
|
|
}
|
|
}
|
|
|
|
// For generic mutuble arc iterator construction; not normally called
|
|
// directly by users.
|
|
virtual void InitMutableArcIterator(StateId s,
|
|
MutableArcIteratorData<Arc> *data) = 0;
|
|
};
|
|
|
|
// Mutable arc iterator interface, templated on the Arc definition. This is
|
|
// used by mutable arc iterator specializations that are returned by the
|
|
// InitMutableArcIterator MutableFst method.
|
|
template <class Arc>
|
|
class MutableArcIteratorBase : public ArcIteratorBase<Arc> {
|
|
public:
|
|
// Sets current arc.
|
|
virtual void SetValue(const Arc &) = 0;
|
|
};
|
|
|
|
template <class Arc>
|
|
struct MutableArcIteratorData {
|
|
std::unique_ptr<MutableArcIteratorBase<Arc>> base; // Specific iterator.
|
|
};
|
|
|
|
// Generic mutable arc iterator, templated on the FST definition; a wrapper
|
|
// around a pointer to a more specific one.
|
|
//
|
|
// Here is a typical use:
|
|
//
|
|
// for (MutableArcIterator<StdFst> aiter(&fst, s);
|
|
// !aiter.Done();
|
|
// aiter.Next()) {
|
|
// StdArc arc = aiter.Value();
|
|
// arc.ilabel = 7;
|
|
// aiter.SetValue(arc);
|
|
// ...
|
|
// }
|
|
//
|
|
// This version requires function calls.
|
|
template <class FST>
|
|
class MutableArcIterator {
|
|
public:
|
|
using Arc = typename FST::Arc;
|
|
using StateId = typename Arc::StateId;
|
|
|
|
MutableArcIterator(FST *fst, StateId s) {
|
|
fst->InitMutableArcIterator(s, &data_);
|
|
}
|
|
|
|
bool Done() const { return data_.base->Done(); }
|
|
|
|
const Arc &Value() const { return data_.base->Value(); }
|
|
|
|
void Next() { data_.base->Next(); }
|
|
|
|
size_t Position() const { return data_.base->Position(); }
|
|
|
|
void Reset() { data_.base->Reset(); }
|
|
|
|
void Seek(size_t a) { data_.base->Seek(a); }
|
|
|
|
void SetValue(const Arc &arc) { data_.base->SetValue(arc); }
|
|
|
|
uint8_t Flags() const { return data_.base->Flags(); }
|
|
|
|
void SetFlags(uint8_t flags, uint8_t mask) {
|
|
return data_.base->SetFlags(flags, mask);
|
|
}
|
|
|
|
private:
|
|
MutableArcIteratorData<Arc> data_;
|
|
|
|
MutableArcIterator(const MutableArcIterator &) = delete;
|
|
MutableArcIterator &operator=(const MutableArcIterator &) = delete;
|
|
};
|
|
|
|
namespace internal {
|
|
|
|
// MutableFst<A> case: abstract methods.
|
|
template <class Arc>
|
|
inline typename Arc::Weight Final(const MutableFst<Arc> &fst,
|
|
typename Arc::StateId s) {
|
|
return fst.Final(s);
|
|
}
|
|
|
|
template <class Arc>
|
|
inline ssize_t NumArcs(const MutableFst<Arc> &fst, typename Arc::StateId s) {
|
|
return fst.NumArcs(s);
|
|
}
|
|
|
|
template <class Arc>
|
|
inline ssize_t NumInputEpsilons(const MutableFst<Arc> &fst,
|
|
typename Arc::StateId s) {
|
|
return fst.NumInputEpsilons(s);
|
|
}
|
|
|
|
template <class Arc>
|
|
inline ssize_t NumOutputEpsilons(const MutableFst<Arc> &fst,
|
|
typename Arc::StateId s) {
|
|
return fst.NumOutputEpsilons(s);
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
// A useful alias when using StdArc.
|
|
using StdMutableFst = MutableFst<StdArc>;
|
|
|
|
// This is a helper class template useful for attaching a MutableFst interface
|
|
// to its implementation, handling reference counting and COW semantics.
|
|
template <class Impl, class FST = MutableFst<typename Impl::Arc>>
|
|
class ImplToMutableFst : public ImplToExpandedFst<Impl, FST> {
|
|
public:
|
|
using Arc = typename Impl::Arc;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
using ImplToExpandedFst<Impl, FST>::operator=;
|
|
|
|
void SetStart(StateId s) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->SetStart(s);
|
|
}
|
|
|
|
void SetFinal(StateId s, Weight weight = Weight::One()) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->SetFinal(s, std::move(weight));
|
|
}
|
|
|
|
void SetProperties(uint64_t props, uint64_t mask) override {
|
|
// Can skip mutate check if extrinsic properties don't change,
|
|
// since it is then safe to update all (shallow) copies
|
|
const auto exprops = kExtrinsicProperties & mask;
|
|
if (GetImpl()->Properties(exprops) != (props & exprops)) MutateCheck();
|
|
GetMutableImpl()->SetProperties(props, mask);
|
|
}
|
|
|
|
StateId AddState() override {
|
|
MutateCheck();
|
|
return GetMutableImpl()->AddState();
|
|
}
|
|
|
|
void AddStates(size_t n) override {
|
|
MutateCheck();
|
|
return GetMutableImpl()->AddStates(n);
|
|
}
|
|
|
|
void AddArc(StateId s, const Arc &arc) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->AddArc(s, arc);
|
|
}
|
|
|
|
void AddArc(StateId s, Arc &&arc) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->AddArc(s, std::forward<Arc>(arc));
|
|
}
|
|
|
|
void DeleteStates(const std::vector<StateId> &dstates) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->DeleteStates(dstates);
|
|
}
|
|
|
|
void DeleteStates() override {
|
|
if (!Unique()) {
|
|
const auto *isymbols = GetImpl()->InputSymbols();
|
|
const auto *osymbols = GetImpl()->OutputSymbols();
|
|
SetImpl(std::make_shared<Impl>());
|
|
GetMutableImpl()->SetInputSymbols(isymbols);
|
|
GetMutableImpl()->SetOutputSymbols(osymbols);
|
|
} else {
|
|
GetMutableImpl()->DeleteStates();
|
|
}
|
|
}
|
|
|
|
void DeleteArcs(StateId s, size_t n) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->DeleteArcs(s, n);
|
|
}
|
|
|
|
void DeleteArcs(StateId s) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->DeleteArcs(s);
|
|
}
|
|
|
|
void ReserveStates(size_t n) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->ReserveStates(n);
|
|
}
|
|
|
|
void ReserveArcs(StateId s, size_t n) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->ReserveArcs(s, n);
|
|
}
|
|
|
|
const SymbolTable *InputSymbols() const override {
|
|
return GetImpl()->InputSymbols();
|
|
}
|
|
|
|
const SymbolTable *OutputSymbols() const override {
|
|
return GetImpl()->OutputSymbols();
|
|
}
|
|
|
|
SymbolTable *MutableInputSymbols() override {
|
|
MutateCheck();
|
|
return GetMutableImpl()->InputSymbols();
|
|
}
|
|
|
|
SymbolTable *MutableOutputSymbols() override {
|
|
MutateCheck();
|
|
return GetMutableImpl()->OutputSymbols();
|
|
}
|
|
|
|
void SetInputSymbols(const SymbolTable *isyms) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->SetInputSymbols(isyms);
|
|
}
|
|
|
|
void SetOutputSymbols(const SymbolTable *osyms) override {
|
|
MutateCheck();
|
|
GetMutableImpl()->SetOutputSymbols(osyms);
|
|
}
|
|
|
|
protected:
|
|
using ImplToExpandedFst<Impl, FST>::GetImpl;
|
|
using ImplToExpandedFst<Impl, FST>::GetMutableImpl;
|
|
using ImplToExpandedFst<Impl, FST>::Unique;
|
|
using ImplToExpandedFst<Impl, FST>::SetImpl;
|
|
using ImplToExpandedFst<Impl, FST>::InputSymbols;
|
|
|
|
explicit ImplToMutableFst(std::shared_ptr<Impl> impl)
|
|
: ImplToExpandedFst<Impl, FST>(impl) {}
|
|
|
|
ImplToMutableFst(const ImplToMutableFst &fst, bool safe)
|
|
: ImplToExpandedFst<Impl, FST>(fst, safe) {}
|
|
|
|
void MutateCheck() {
|
|
if (!Unique()) SetImpl(std::make_shared<Impl>(*this));
|
|
}
|
|
};
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_MUTABLE_FST_H_
|