// Copyright 2005-2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the 'License');
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an 'AS IS' BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//
|
|
/// See www.openfst.org for extensive documentation on this weighted
|
|
// finite-state transducer library.
|
|
//
|
|
// Utilities to convert strings into FSTs.
|
|
|
|
#ifndef FST_STRING_H_
|
|
#define FST_STRING_H_
|
|
|
|
#include <cstdint>
|
|
#include <memory>
|
|
#include <optional>
|
|
#include <ostream>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <fst/flags.h>
|
|
#include <fst/log.h>
|
|
#include <fst/arc.h>
|
|
#include <fst/compact-fst.h>
|
|
#include <fst/fst.h>
|
|
#include <fst/icu.h>
|
|
#include <fst/mutable-fst.h>
|
|
#include <fst/properties.h>
|
|
#include <fst/symbol-table.h>
|
|
#include <fst/util.h>
|
|
#include <fst/compat.h>
|
|
#include <string_view>
|
|
|
|
DECLARE_string(fst_field_separator);
|
|
|
|
namespace fst {
|
|
|
|
enum class TokenType : uint8_t { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
|
|
|
|
inline std::ostream &operator<<(std::ostream &strm,
|
|
const TokenType &token_type) {
|
|
switch (token_type) {
|
|
case TokenType::BYTE:
|
|
return strm << "byte";
|
|
case TokenType::UTF8:
|
|
return strm << "utf8";
|
|
case TokenType::SYMBOL:
|
|
return strm << "symbol";
|
|
}
|
|
return strm; // unreachable
|
|
}
|
|
|
|
namespace internal {
|
|
|
|
template <class Label>
|
|
bool ConvertSymbolToLabel(std::string_view str, const SymbolTable *syms,
|
|
Label unknown_label, Label *output) {
|
|
int64_t n;
|
|
if (syms) {
|
|
n = syms->Find(str);
|
|
if ((n == kNoSymbol) && (unknown_label != kNoLabel)) n = unknown_label;
|
|
if (n == kNoSymbol) {
|
|
LOG(ERROR) << "ConvertSymbolToLabel: Symbol \"" << str
|
|
<< "\" is not mapped to any integer label, symbol table = "
|
|
<< syms->Name();
|
|
return false;
|
|
}
|
|
} else {
|
|
const auto maybe_n = ParseInt64(str);
|
|
if (!maybe_n.has_value()) {
|
|
LOG(ERROR) << "ConvertSymbolToLabel: Bad label integer "
|
|
<< "= \"" << str << "\"";
|
|
return false;
|
|
}
|
|
n = *maybe_n;
|
|
}
|
|
*output = n;
|
|
return true;
|
|
}
|
|
|
|
template <class Label>
|
|
bool ConvertStringToLabels(
|
|
std::string_view str, TokenType token_type, const SymbolTable *syms,
|
|
Label unknown_label, std::vector<Label> *labels,
|
|
std::string_view sep = FST_FLAGS_fst_field_separator) {
|
|
labels->clear();
|
|
switch (token_type) {
|
|
case TokenType::BYTE: {
|
|
labels->reserve(str.size());
|
|
return ByteStringToLabels(str, labels);
|
|
}
|
|
case TokenType::UTF8: {
|
|
return UTF8StringToLabels(str, labels);
|
|
}
|
|
case TokenType::SYMBOL: {
|
|
const std::string separator = fst::StrCat("\n", sep);
|
|
for (std::string_view c :
|
|
StrSplit(str, ByAnyChar(separator), SkipEmpty())) {
|
|
Label label;
|
|
if (!ConvertSymbolToLabel(c, syms, unknown_label, &label)) return false;
|
|
labels->push_back(label);
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
return false; // Unreachable.
|
|
}
|
|
|
|
// The last character of 'sep' is used as a separator between symbols.
|
|
// Additionally, epsilon symbols will be printed only if omit_epsilon
|
|
// is false.
|
|
template <class Label>
|
|
bool LabelsToSymbolString(const std::vector<Label> &labels, std::string *str,
|
|
const SymbolTable &syms, std::string_view sep,
|
|
bool omit_epsilon) {
|
|
std::stringstream ostrm;
|
|
sep.remove_prefix(sep.size() - 1); // We only respect the final char of sep.
|
|
std::string_view delim = "";
|
|
for (auto label : labels) {
|
|
if (omit_epsilon && !label) continue;
|
|
ostrm << delim;
|
|
const std::string &symbol = syms.Find(label);
|
|
if (symbol.empty()) {
|
|
LOG(ERROR) << "LabelsToSymbolString: Label " << label
|
|
<< " is not mapped onto any textual symbol in symbol table "
|
|
<< syms.Name();
|
|
return false;
|
|
}
|
|
ostrm << symbol;
|
|
delim = sep;
|
|
}
|
|
*str = ostrm.str();
|
|
return !!ostrm;
|
|
}
|
|
|
|
// The last character of 'sep' is used as a separator between symbols.
|
|
// Additionally, epsilon symbols will be printed only if omit_epsilon
|
|
// is false.
|
|
template <class Label>
|
|
bool LabelsToNumericString(const std::vector<Label> &labels, std::string *str,
|
|
std::string_view sep, bool omit_epsilon) {
|
|
std::stringstream ostrm;
|
|
sep.remove_prefix(sep.size() - 1); // We only respect the final char of sep.
|
|
std::string_view delim = "";
|
|
for (auto label : labels) {
|
|
if (omit_epsilon && !label) continue;
|
|
ostrm << delim;
|
|
ostrm << label;
|
|
delim = sep;
|
|
}
|
|
*str = ostrm.str();
|
|
return !!ostrm;
|
|
}
|
|
|
|
} // namespace internal
|
|
|
|
// Functor for compiling a string in an FST.
|
|
template <class Arc>
|
|
class OPENFST_DEPRECATED("allow_negative is no-op") StringCompiler {
|
|
public:
|
|
using Label = typename Arc::Label;
|
|
using StateId = typename Arc::StateId;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
explicit StringCompiler(TokenType token_type = TokenType::BYTE,
|
|
const SymbolTable *syms = nullptr,
|
|
Label unknown_label = kNoLabel)
|
|
: token_type_(token_type), syms_(syms), unknown_label_(unknown_label) {}
|
|
|
|
// Compiles string into an FST. With SYMBOL token type, sep is used to
|
|
// specify the set of char separators between symbols, in addition
|
|
// of '\n' which is always treated as a separator.
|
|
// Returns true on success.
|
|
template <class FST>
|
|
bool operator()(
|
|
std::string_view str, FST *fst,
|
|
std::string_view sep = FST_FLAGS_fst_field_separator) const {
|
|
std::vector<Label> labels;
|
|
if (!internal::ConvertStringToLabels(str, token_type_, syms_,
|
|
unknown_label_, &labels, sep)) {
|
|
return false;
|
|
}
|
|
Compile(labels, fst);
|
|
return true;
|
|
}
|
|
|
|
// Same as above but allows to specify a weight for the string.
|
|
template <class FST>
|
|
bool operator()(
|
|
std::string_view str, FST *fst, Weight weight,
|
|
std::string_view sep = FST_FLAGS_fst_field_separator) const {
|
|
std::vector<Label> labels;
|
|
if (!internal::ConvertStringToLabels(str, token_type_, syms_,
|
|
unknown_label_, &labels, sep)) {
|
|
return false;
|
|
}
|
|
Compile(labels, fst, std::move(weight));
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst,
|
|
Weight weight = Weight::One()) const {
|
|
fst->DeleteStates();
|
|
auto state = fst->AddState();
|
|
fst->SetStart(state);
|
|
fst->AddStates(labels.size());
|
|
for (auto label : labels) {
|
|
fst->AddArc(state, Arc(label, label, state + 1));
|
|
++state;
|
|
}
|
|
fst->SetFinal(state, std::move(weight));
|
|
fst->SetProperties(kCompiledStringProperties, kCompiledStringProperties);
|
|
}
|
|
|
|
template <class Unsigned>
|
|
void Compile(const std::vector<Label> &labels,
|
|
CompactStringFst<Arc, Unsigned> *fst) const {
|
|
using Compactor = typename CompactStringFst<Arc, Unsigned>::Compactor;
|
|
fst->SetCompactor(
|
|
std::make_shared<Compactor>(labels.begin(), labels.end()));
|
|
}
|
|
|
|
template <class Unsigned>
|
|
void Compile(const std::vector<Label> &labels,
|
|
CompactWeightedStringFst<Arc, Unsigned> *fst,
|
|
Weight weight = Weight::One()) const {
|
|
std::vector<std::pair<Label, Weight>> compacts;
|
|
compacts.reserve(labels.size() + 1);
|
|
for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) {
|
|
compacts.emplace_back(labels[i], Weight::One());
|
|
}
|
|
compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight);
|
|
using Compactor =
|
|
typename CompactWeightedStringFst<Arc, Unsigned>::Compactor;
|
|
fst->SetCompactor(
|
|
std::make_shared<Compactor>(compacts.begin(), compacts.end()));
|
|
}
|
|
|
|
const TokenType token_type_;
|
|
const SymbolTable *syms_; // Symbol table (used when token type is symbol).
|
|
const Label unknown_label_; // Label for token missing from symbol table.
|
|
|
|
StringCompiler(const StringCompiler &) = delete;
|
|
StringCompiler &operator=(const StringCompiler &) = delete;
|
|
};
|
|
|
|
// A useful alias when using StdArc.
|
|
using StdStringCompiler = StringCompiler<StdArc>;
|
|
|
|
// Helpers for StringPrinter.
|
|
|
|
// Converts an FST to a vector of output labels. To get input labels, use
|
|
// Project or Invert. Returns true on success. Use only with string FSTs; may
|
|
// loop for non-string FSTs.
|
|
template <class Arc>
|
|
bool StringFstToOutputLabels(const Fst<Arc> &fst,
|
|
std::vector<typename Arc::Label> *labels) {
|
|
labels->clear();
|
|
auto s = fst.Start();
|
|
if (s == kNoStateId) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: Invalid start state";
|
|
return false;
|
|
}
|
|
while (fst.Final(s) == Arc::Weight::Zero()) {
|
|
ArcIterator<Fst<Arc>> aiter(fst, s);
|
|
if (aiter.Done()) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: Does not reach final state";
|
|
return false;
|
|
}
|
|
const auto &arc = aiter.Value();
|
|
labels->push_back(arc.olabel);
|
|
s = arc.nextstate;
|
|
aiter.Next();
|
|
if (!aiter.Done()) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: State " << s
|
|
<< " has multiple outgoing arcs";
|
|
return false;
|
|
}
|
|
}
|
|
if (fst.NumArcs(s) != 0) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: Final state " << s
|
|
<< " has outgoing arc(s)";
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Same as above but also computes the path weight. The output weight parameter
|
|
// is only set if labels extraction is successful.
|
|
template <class Arc>
|
|
bool StringFstToOutputLabels(const Fst<Arc> &fst,
|
|
std::vector<typename Arc::Label> *labels,
|
|
typename Arc::Weight *weight) {
|
|
labels->clear();
|
|
auto path_weight = Arc::Weight::One();
|
|
auto s = fst.Start();
|
|
if (s == kNoStateId) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: Invalid start state";
|
|
return false;
|
|
}
|
|
auto final_weight = fst.Final(s);
|
|
while (final_weight == Arc::Weight::Zero()) {
|
|
ArcIterator<Fst<Arc>> aiter(fst, s);
|
|
if (aiter.Done()) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: Does not reach final state";
|
|
return false;
|
|
}
|
|
const auto &arc = aiter.Value();
|
|
labels->push_back(arc.olabel);
|
|
path_weight = Times(path_weight, arc.weight);
|
|
s = arc.nextstate;
|
|
aiter.Next();
|
|
if (!aiter.Done()) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: State " << s
|
|
<< " has multiple outgoing arcs";
|
|
return false;
|
|
}
|
|
final_weight = fst.Final(s);
|
|
}
|
|
if (fst.NumArcs(s) != 0) {
|
|
LOG(ERROR) << "StringFstToOutputLabels: Final state " << s
|
|
<< " has outgoing arc(s)";
|
|
return false;
|
|
}
|
|
*weight = Times(path_weight, final_weight);
|
|
return true;
|
|
}
|
|
|
|
// Converts a list of symbols to a string. If the token type is SYMBOL, the last
|
|
// character of sep is used to separate textual symbols. Additionally, if the
|
|
// token type is SYMBOL, epsilon symbols will be printed only if omit_epsilon
|
|
// is false. Returns true on success.
|
|
template <class Label>
|
|
bool LabelsToString(
|
|
const std::vector<Label> &labels, std::string *str,
|
|
TokenType ttype = TokenType::BYTE, const SymbolTable *syms = nullptr,
|
|
std::string_view sep = FST_FLAGS_fst_field_separator,
|
|
bool omit_epsilon = true) {
|
|
switch (ttype) {
|
|
case TokenType::BYTE: {
|
|
return LabelsToByteString(labels, str);
|
|
}
|
|
case TokenType::UTF8: {
|
|
return LabelsToUTF8String(labels, str);
|
|
}
|
|
case TokenType::SYMBOL: {
|
|
return syms ? internal::LabelsToSymbolString(labels, str, *syms, sep,
|
|
omit_epsilon)
|
|
: internal::LabelsToNumericString(labels, str, sep,
|
|
omit_epsilon);
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Functor for printing a string FST as a string.
|
|
template <class Arc>
|
|
class StringPrinter {
|
|
public:
|
|
using Label = typename Arc::Label;
|
|
using Weight = typename Arc::Weight;
|
|
|
|
explicit StringPrinter(TokenType token_type = TokenType::BYTE,
|
|
const SymbolTable *syms = nullptr,
|
|
bool omit_epsilon = true)
|
|
: token_type_(token_type), syms_(syms), omit_epsilon_(omit_epsilon) {}
|
|
|
|
// Converts the FST into a string. With SYMBOL token type, the last character
|
|
// of sep is used as a separator between symbols. Returns true on success.
|
|
bool operator()(
|
|
const Fst<Arc> &fst, std::string *str,
|
|
std::string_view sep = FST_FLAGS_fst_field_separator) const {
|
|
std::vector<Label> labels;
|
|
return StringFstToOutputLabels(fst, &labels) &&
|
|
LabelsToString(labels, str, token_type_, syms_, sep, omit_epsilon_);
|
|
}
|
|
|
|
// Same as above but also computes the path weight. The output weight
|
|
// parameter is only set if labels extraction is successful.
|
|
bool operator()(
|
|
const Fst<Arc> &fst, std::string *str, Weight *weight,
|
|
std::string_view sep = FST_FLAGS_fst_field_separator) const {
|
|
std::vector<Label> labels;
|
|
return StringFstToOutputLabels(fst, &labels, weight) &&
|
|
LabelsToString(labels, str, token_type_, syms_, sep, omit_epsilon_);
|
|
}
|
|
|
|
private:
|
|
const TokenType token_type_;
|
|
const SymbolTable *syms_;
|
|
const bool omit_epsilon_;
|
|
|
|
StringPrinter(const StringPrinter &) = delete;
|
|
StringPrinter &operator=(const StringPrinter &) = delete;
|
|
};
|
|
|
|
// A useful alias when using StdArc.
|
|
using StdStringPrinter = StringPrinter<StdArc>;
|
|
|
|
} // namespace fst
|
|
|
|
#endif // FST_STRING_H_
|