You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

237 lines
7.9 KiB

// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Class to to compile a binary FST from textual input.
#ifndef FST_SCRIPT_COMPILE_IMPL_H_
#define FST_SCRIPT_COMPILE_IMPL_H_
#include <cstddef>
#include <iostream>
#include <istream>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <vector>
#include <fst/log.h>
#include <fst/fst.h>
#include <fst/properties.h>
#include <fst/symbol-table.h>
#include <fst/util.h>
#include <fst/vector-fst.h>
#include <unordered_map>
#include <string_view>
DECLARE_string(fst_field_separator);
namespace fst {
// Compile a binary FST from textual input, helper class for fstcompile.cc.
// WARNING: Stand-alone use of this class not recommended, most code should
// read/write using the binary format which is much more efficient.
template <class Arc>
class FstCompiler {
public:
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
// If add_symbols_ is true, then the symbols will be dynamically added to the
// symbol tables. This is only useful if you set the (i/o)keep flag to attach
// the final symbol table, or use the accessors. (The input symbol tables are
// const and therefore not changed.)
FstCompiler(std::istream &istrm, std::string_view source,
const SymbolTable *isyms, const SymbolTable *osyms,
const SymbolTable *ssyms, bool accep, bool ikeep, bool okeep,
bool nkeep) {
std::unique_ptr<SymbolTable> misyms(isyms ? isyms->Copy() : nullptr);
std::unique_ptr<SymbolTable> mosyms(osyms ? osyms->Copy() : nullptr);
std::unique_ptr<SymbolTable> mssyms(ssyms ? ssyms->Copy() : nullptr);
Init(istrm, source, misyms.get(), mosyms.get(), mssyms.get(), accep, ikeep,
okeep, nkeep, false);
}
FstCompiler(std::istream &istrm, std::string_view source, SymbolTable *isyms,
SymbolTable *osyms, SymbolTable *ssyms, bool accep, bool ikeep,
bool okeep, bool nkeep, bool add_symbols) {
Init(istrm, source, isyms, osyms, ssyms, accep, ikeep, okeep, nkeep,
add_symbols);
}
void Init(std::istream &istrm, std::string_view source, SymbolTable *isyms,
SymbolTable *osyms, SymbolTable *ssyms, bool accep, bool ikeep,
bool okeep, bool nkeep, bool add_symbols) {
nline_ = 0;
source_ = std::string(source);
isyms_ = isyms;
osyms_ = osyms;
ssyms_ = ssyms;
nstates_ = 0;
keep_state_numbering_ = nkeep;
add_symbols_ = add_symbols;
bool start_state_populated = false;
char line[kLineLen];
const std::string separator =
FST_FLAGS_fst_field_separator + "\n";
while (istrm.getline(line, kLineLen)) {
++nline_;
const std::vector<std::string_view> col =
StrSplit(line, ByAnyChar(separator), SkipEmpty());
if (col.empty() || col[0].empty()) continue;
if (col.size() > 5 || (col.size() > 4 && accep) ||
(col.size() == 3 && !accep)) {
FSTERROR() << "FstCompiler: Bad number of columns, source = " << source_
<< ", line = " << nline_;
fst_.SetProperties(kError, kError);
return;
}
StateId s = StrToStateId(col[0]);
while (s >= fst_.NumStates()) fst_.AddState();
if (!start_state_populated) {
fst_.SetStart(s);
start_state_populated = true;
}
Arc arc;
StateId d = s;
switch (col.size()) {
case 1:
fst_.SetFinal(s, Weight::One());
break;
case 2:
fst_.SetFinal(s, StrToWeight(col[1], true));
break;
case 3:
arc.nextstate = d = StrToStateId(col[1]);
arc.ilabel = StrToILabel(col[2]);
arc.olabel = arc.ilabel;
arc.weight = Weight::One();
fst_.AddArc(s, arc);
break;
case 4:
arc.nextstate = d = StrToStateId(col[1]);
arc.ilabel = StrToILabel(col[2]);
if (accep) {
arc.olabel = arc.ilabel;
arc.weight = StrToWeight(col[3], true);
} else {
arc.olabel = StrToOLabel(col[3]);
arc.weight = Weight::One();
}
fst_.AddArc(s, arc);
break;
case 5:
arc.nextstate = d = StrToStateId(col[1]);
arc.ilabel = StrToILabel(col[2]);
arc.olabel = StrToOLabel(col[3]);
arc.weight = StrToWeight(col[4], true);
fst_.AddArc(s, arc);
}
while (d >= fst_.NumStates()) fst_.AddState();
}
if (ikeep) fst_.SetInputSymbols(isyms);
if (okeep) fst_.SetOutputSymbols(osyms);
}
const VectorFst<Arc> &Fst() const { return fst_; }
private:
// Maximum line length in text file.
static constexpr int kLineLen = 8096;
StateId StrToId(std::string_view s, SymbolTable *syms,
std::string_view name) const {
StateId n = 0;
if (syms) {
// n = (add_symbols_) ? syms->AddSymbol(s) : syms->Find(s);
if (add_symbols_) {
n = syms->AddSymbol(s);
} else {
n = syms->Find(s);
}
if (n == kNoSymbol) {
FSTERROR() << "FstCompiler: Symbol \"" << s
<< "\" is not mapped to any integer " << name
<< ", symbol table = " << syms->Name()
<< ", source = " << source_ << ", line = " << nline_;
fst_.SetProperties(kError, kError);
}
} else {
auto maybe_n = ParseInt64(s);
if (!maybe_n.has_value()) {
FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
<< "\", source = " << source_ << ", line = " << nline_;
fst_.SetProperties(kError, kError);
}
n = *maybe_n;
}
return n;
}
StateId StrToStateId(std::string_view s) {
StateId n = StrToId(s, ssyms_, "state ID");
if (keep_state_numbering_) return n;
// Remaps state IDs to make dense set.
const auto it = states_.find(n);
if (it == states_.end()) {
states_[n] = nstates_;
return nstates_++;
} else {
return it->second;
}
}
StateId StrToILabel(std::string_view s) const {
return StrToId(s, isyms_, "arc ilabel");
}
StateId StrToOLabel(std::string_view s) const {
return StrToId(s, osyms_, "arc olabel");
}
Weight StrToWeight(std::string_view s, bool allow_zero) const {
Weight w;
std::istringstream strm(std::string{s});
strm >> w;
if (!strm || (!allow_zero && w == Weight::Zero())) {
FSTERROR() << "FstCompiler: Bad weight = \"" << s
<< "\", source = " << source_ << ", line = " << nline_;
fst_.SetProperties(kError, kError);
w = Weight::NoWeight();
}
return w;
}
mutable VectorFst<Arc> fst_;
size_t nline_;
std::string source_; // Text FST source name.
SymbolTable *isyms_; // ilabel symbol table (not owned).
SymbolTable *osyms_; // olabel symbol table (not owned).
SymbolTable *ssyms_; // slabel symbol table (not owned).
std::unordered_map<StateId, StateId> states_; // State ID map.
StateId nstates_; // Number of seen states.
bool keep_state_numbering_;
bool add_symbols_; // Add to symbol tables on-the fly.
FstCompiler(const FstCompiler &) = delete;
FstCompiler &operator=(const FstCompiler &) = delete;
};
} // namespace fst
#endif // FST_SCRIPT_COMPILE_IMPL_H_