// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Class to to compile a binary FST from textual input.
|
|
|
|
#ifndef FST_SCRIPT_COMPILE_IMPL_H_
|
|
#define FST_SCRIPT_COMPILE_IMPL_H_
|
|
|
|
#include <cstddef>
|
|
#include <iostream>
|
|
#include <istream>
|
|
#include <memory>
|
|
#include <optional>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include <fst/log.h>
|
|
#include <fst/fst.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/symbol-table.h>
|
|
#include <fst/util.h>
|
|
#include <fst/vector-fst.h>
|
|
#include <unordered_map>
|
|
#include <string_view>
|
|
|
|
DECLARE_string(fst_field_separator);
|
|
|
|
namespace fst {
|
|
|
|
// Compile a binary FST from textual input, helper class for fstcompile.cc.
|
|
// WARNING: Stand-alone use of this class not recommended, most code should
|
|
// read/write using the binary format which is much more efficient.
|
|
template <class Arc>
|
|
class FstCompiler {
|
|
public:
|
|
using Label = typename Arc::Label;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
// If add_symbols_ is true, then the symbols will be dynamically added to the
|
|
// symbol tables. This is only useful if you set the (i/o)keep flag to attach
|
|
// the final symbol table, or use the accessors. (The input symbol tables are
|
|
// const and therefore not changed.)
|
|
FstCompiler(std::istream &istrm, std::string_view source,
|
|
const SymbolTable *isyms, const SymbolTable *osyms,
|
|
const SymbolTable *ssyms, bool accep, bool ikeep, bool okeep,
|
|
bool nkeep) {
|
|
std::unique_ptr<SymbolTable> misyms(isyms ? isyms->Copy() : nullptr);
|
|
std::unique_ptr<SymbolTable> mosyms(osyms ? osyms->Copy() : nullptr);
|
|
std::unique_ptr<SymbolTable> mssyms(ssyms ? ssyms->Copy() : nullptr);
|
|
Init(istrm, source, misyms.get(), mosyms.get(), mssyms.get(), accep, ikeep,
|
|
okeep, nkeep, false);
|
|
}
|
|
|
|
FstCompiler(std::istream &istrm, std::string_view source, SymbolTable *isyms,
|
|
SymbolTable *osyms, SymbolTable *ssyms, bool accep, bool ikeep,
|
|
bool okeep, bool nkeep, bool add_symbols) {
|
|
Init(istrm, source, isyms, osyms, ssyms, accep, ikeep, okeep, nkeep,
|
|
add_symbols);
|
|
}
|
|
|
|
void Init(std::istream &istrm, std::string_view source, SymbolTable *isyms,
|
|
SymbolTable *osyms, SymbolTable *ssyms, bool accep, bool ikeep,
|
|
bool okeep, bool nkeep, bool add_symbols) {
|
|
nline_ = 0;
|
|
source_ = std::string(source);
|
|
isyms_ = isyms;
|
|
osyms_ = osyms;
|
|
ssyms_ = ssyms;
|
|
nstates_ = 0;
|
|
keep_state_numbering_ = nkeep;
|
|
add_symbols_ = add_symbols;
|
|
bool start_state_populated = false;
|
|
char line[kLineLen];
|
|
const std::string separator =
|
|
FST_FLAGS_fst_field_separator + "\n";
|
|
while (istrm.getline(line, kLineLen)) {
|
|
++nline_;
|
|
const std::vector<std::string_view> col =
|
|
StrSplit(line, ByAnyChar(separator), SkipEmpty());
|
|
if (col.empty() || col[0].empty()) continue;
|
|
if (col.size() > 5 || (col.size() > 4 && accep) ||
|
|
(col.size() == 3 && !accep)) {
|
|
FSTERROR() << "FstCompiler: Bad number of columns, source = " << source_
|
|
<< ", line = " << nline_;
|
|
fst_.SetProperties(kError, kError);
|
|
return;
|
|
}
|
|
StateId s = StrToStateId(col[0]);
|
|
while (s >= fst_.NumStates()) fst_.AddState();
|
|
if (!start_state_populated) {
|
|
fst_.SetStart(s);
|
|
start_state_populated = true;
|
|
}
|
|
Arc arc;
|
|
StateId d = s;
|
|
switch (col.size()) {
|
|
case 1:
|
|
fst_.SetFinal(s, Weight::One());
|
|
break;
|
|
case 2:
|
|
fst_.SetFinal(s, StrToWeight(col[1], true));
|
|
break;
|
|
case 3:
|
|
arc.nextstate = d = StrToStateId(col[1]);
|
|
arc.ilabel = StrToILabel(col[2]);
|
|
arc.olabel = arc.ilabel;
|
|
arc.weight = Weight::One();
|
|
fst_.AddArc(s, arc);
|
|
break;
|
|
case 4:
|
|
arc.nextstate = d = StrToStateId(col[1]);
|
|
arc.ilabel = StrToILabel(col[2]);
|
|
if (accep) {
|
|
arc.olabel = arc.ilabel;
|
|
arc.weight = StrToWeight(col[3], true);
|
|
} else {
|
|
arc.olabel = StrToOLabel(col[3]);
|
|
arc.weight = Weight::One();
|
|
}
|
|
fst_.AddArc(s, arc);
|
|
break;
|
|
case 5:
|
|
arc.nextstate = d = StrToStateId(col[1]);
|
|
arc.ilabel = StrToILabel(col[2]);
|
|
arc.olabel = StrToOLabel(col[3]);
|
|
arc.weight = StrToWeight(col[4], true);
|
|
fst_.AddArc(s, arc);
|
|
}
|
|
while (d >= fst_.NumStates()) fst_.AddState();
|
|
}
|
|
if (ikeep) fst_.SetInputSymbols(isyms);
|
|
if (okeep) fst_.SetOutputSymbols(osyms);
|
|
}
|
|
|
|
const VectorFst<Arc> &Fst() const { return fst_; }
|
|
|
|
private:
|
|
// Maximum line length in text file.
|
|
static constexpr int kLineLen = 8096;
|
|
|
|
StateId StrToId(std::string_view s, SymbolTable *syms,
|
|
std::string_view name) const {
|
|
StateId n = 0;
|
|
if (syms) {
|
|
// n = (add_symbols_) ? syms->AddSymbol(s) : syms->Find(s);
|
|
if (add_symbols_) {
|
|
n = syms->AddSymbol(s);
|
|
} else {
|
|
n = syms->Find(s);
|
|
}
|
|
if (n == kNoSymbol) {
|
|
FSTERROR() << "FstCompiler: Symbol \"" << s
|
|
<< "\" is not mapped to any integer " << name
|
|
<< ", symbol table = " << syms->Name()
|
|
<< ", source = " << source_ << ", line = " << nline_;
|
|
fst_.SetProperties(kError, kError);
|
|
}
|
|
} else {
|
|
auto maybe_n = ParseInt64(s);
|
|
if (!maybe_n.has_value()) {
|
|
FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
|
|
<< "\", source = " << source_ << ", line = " << nline_;
|
|
fst_.SetProperties(kError, kError);
|
|
}
|
|
n = *maybe_n;
|
|
}
|
|
return n;
|
|
}
|
|
|
|
StateId StrToStateId(std::string_view s) {
|
|
StateId n = StrToId(s, ssyms_, "state ID");
|
|
if (keep_state_numbering_) return n;
|
|
// Remaps state IDs to make dense set.
|
|
const auto it = states_.find(n);
|
|
if (it == states_.end()) {
|
|
states_[n] = nstates_;
|
|
return nstates_++;
|
|
} else {
|
|
return it->second;
|
|
}
|
|
}
|
|
|
|
StateId StrToILabel(std::string_view s) const {
|
|
return StrToId(s, isyms_, "arc ilabel");
|
|
}
|
|
|
|
StateId StrToOLabel(std::string_view s) const {
|
|
return StrToId(s, osyms_, "arc olabel");
|
|
}
|
|
|
|
Weight StrToWeight(std::string_view s, bool allow_zero) const {
|
|
Weight w;
|
|
std::istringstream strm(std::string{s});
|
|
strm >> w;
|
|
if (!strm || (!allow_zero && w == Weight::Zero())) {
|
|
FSTERROR() << "FstCompiler: Bad weight = \"" << s
|
|
<< "\", source = " << source_ << ", line = " << nline_;
|
|
fst_.SetProperties(kError, kError);
|
|
w = Weight::NoWeight();
|
|
}
|
|
return w;
|
|
}
|
|
|
|
mutable VectorFst<Arc> fst_;
|
|
size_t nline_;
|
|
std::string source_; // Text FST source name.
|
|
SymbolTable *isyms_; // ilabel symbol table (not owned).
|
|
SymbolTable *osyms_; // olabel symbol table (not owned).
|
|
SymbolTable *ssyms_; // slabel symbol table (not owned).
|
|
std::unordered_map<StateId, StateId> states_; // State ID map.
|
|
StateId nstates_; // Number of seen states.
|
|
bool keep_state_numbering_;
|
|
bool add_symbols_; // Add to symbol tables on-the fly.
|
|
|
|
FstCompiler(const FstCompiler &) = delete;
|
|
FstCompiler &operator=(const FstCompiler &) = delete;
|
|
};
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_SCRIPT_COMPILE_IMPL_H_
|