|
|
// Copyright 2005-2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the 'License');
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an 'AS IS' BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
/// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
// Utilities to convert strings into FSTs.
#ifndef FST_STRING_H_
#define FST_STRING_H_
#include <cstdint>
#include <memory>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include <fst/flags.h>
#include <fst/log.h>
#include <fst/arc.h>
#include <fst/compact-fst.h>
#include <fst/fst.h>
#include <fst/icu.h>
#include <fst/mutable-fst.h>
#include <fst/properties.h>
#include <fst/symbol-table.h>
#include <fst/util.h>
#include <fst/compat.h>
#include <string_view>
DECLARE_string(fst_field_separator);
namespace fst {
enum class TokenType : uint8_t { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
inline std::ostream &operator<<(std::ostream &strm, const TokenType &token_type) { switch (token_type) { case TokenType::BYTE: return strm << "byte"; case TokenType::UTF8: return strm << "utf8"; case TokenType::SYMBOL: return strm << "symbol"; } return strm; // unreachable
}
namespace internal {
template <class Label> bool ConvertSymbolToLabel(std::string_view str, const SymbolTable *syms, Label unknown_label, Label *output) { int64_t n; if (syms) { n = syms->Find(str); if ((n == kNoSymbol) && (unknown_label != kNoLabel)) n = unknown_label; if (n == kNoSymbol) { LOG(ERROR) << "ConvertSymbolToLabel: Symbol \"" << str << "\" is not mapped to any integer label, symbol table = " << syms->Name(); return false; } } else { const auto maybe_n = ParseInt64(str); if (!maybe_n.has_value()) { LOG(ERROR) << "ConvertSymbolToLabel: Bad label integer " << "= \"" << str << "\""; return false; } n = *maybe_n; } *output = n; return true; }
template <class Label> bool ConvertStringToLabels( std::string_view str, TokenType token_type, const SymbolTable *syms, Label unknown_label, std::vector<Label> *labels, std::string_view sep = FST_FLAGS_fst_field_separator) { labels->clear(); switch (token_type) { case TokenType::BYTE: { labels->reserve(str.size()); return ByteStringToLabels(str, labels); } case TokenType::UTF8: { return UTF8StringToLabels(str, labels); } case TokenType::SYMBOL: { const std::string separator = fst::StrCat("\n", sep); for (std::string_view c : StrSplit(str, ByAnyChar(separator), SkipEmpty())) { Label label; if (!ConvertSymbolToLabel(c, syms, unknown_label, &label)) return false; labels->push_back(label); } return true; } } return false; // Unreachable.
}
// The last character of 'sep' is used as a separator between symbols.
// Additionally, epsilon symbols will be printed only if omit_epsilon
// is false.
template <class Label> bool LabelsToSymbolString(const std::vector<Label> &labels, std::string *str, const SymbolTable &syms, std::string_view sep, bool omit_epsilon) { std::stringstream ostrm; sep.remove_prefix(sep.size() - 1); // We only respect the final char of sep.
std::string_view delim = ""; for (auto label : labels) { if (omit_epsilon && !label) continue; ostrm << delim; const std::string &symbol = syms.Find(label); if (symbol.empty()) { LOG(ERROR) << "LabelsToSymbolString: Label " << label << " is not mapped onto any textual symbol in symbol table " << syms.Name(); return false; } ostrm << symbol; delim = sep; } *str = ostrm.str(); return !!ostrm; }
// The last character of 'sep' is used as a separator between symbols.
// Additionally, epsilon symbols will be printed only if omit_epsilon
// is false.
template <class Label> bool LabelsToNumericString(const std::vector<Label> &labels, std::string *str, std::string_view sep, bool omit_epsilon) { std::stringstream ostrm; sep.remove_prefix(sep.size() - 1); // We only respect the final char of sep.
std::string_view delim = ""; for (auto label : labels) { if (omit_epsilon && !label) continue; ostrm << delim; ostrm << label; delim = sep; } *str = ostrm.str(); return !!ostrm; }
} // namespace internal
// Functor for compiling a string in an FST.
template <class Arc> class OPENFST_DEPRECATED("allow_negative is no-op") StringCompiler { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight;
explicit StringCompiler(TokenType token_type = TokenType::BYTE, const SymbolTable *syms = nullptr, Label unknown_label = kNoLabel) : token_type_(token_type), syms_(syms), unknown_label_(unknown_label) {}
// Compiles string into an FST. With SYMBOL token type, sep is used to
// specify the set of char separators between symbols, in addition
// of '\n' which is always treated as a separator.
// Returns true on success.
template <class FST> bool operator()( std::string_view str, FST *fst, std::string_view sep = FST_FLAGS_fst_field_separator) const { std::vector<Label> labels; if (!internal::ConvertStringToLabels(str, token_type_, syms_, unknown_label_, &labels, sep)) { return false; } Compile(labels, fst); return true; }
// Same as above but allows to specify a weight for the string.
template <class FST> bool operator()( std::string_view str, FST *fst, Weight weight, std::string_view sep = FST_FLAGS_fst_field_separator) const { std::vector<Label> labels; if (!internal::ConvertStringToLabels(str, token_type_, syms_, unknown_label_, &labels, sep)) { return false; } Compile(labels, fst, std::move(weight)); return true; }
private: void Compile(const std::vector<Label> &labels, MutableFst<Arc> *fst, Weight weight = Weight::One()) const { fst->DeleteStates(); auto state = fst->AddState(); fst->SetStart(state); fst->AddStates(labels.size()); for (auto label : labels) { fst->AddArc(state, Arc(label, label, state + 1)); ++state; } fst->SetFinal(state, std::move(weight)); fst->SetProperties(kCompiledStringProperties, kCompiledStringProperties); }
template <class Unsigned> void Compile(const std::vector<Label> &labels, CompactStringFst<Arc, Unsigned> *fst) const { using Compactor = typename CompactStringFst<Arc, Unsigned>::Compactor; fst->SetCompactor( std::make_shared<Compactor>(labels.begin(), labels.end())); }
template <class Unsigned> void Compile(const std::vector<Label> &labels, CompactWeightedStringFst<Arc, Unsigned> *fst, Weight weight = Weight::One()) const { std::vector<std::pair<Label, Weight>> compacts; compacts.reserve(labels.size() + 1); for (StateId i = 0; i < static_cast<StateId>(labels.size()) - 1; ++i) { compacts.emplace_back(labels[i], Weight::One()); } compacts.emplace_back(!labels.empty() ? labels.back() : kNoLabel, weight); using Compactor = typename CompactWeightedStringFst<Arc, Unsigned>::Compactor; fst->SetCompactor( std::make_shared<Compactor>(compacts.begin(), compacts.end())); }
const TokenType token_type_; const SymbolTable *syms_; // Symbol table (used when token type is symbol).
const Label unknown_label_; // Label for token missing from symbol table.
StringCompiler(const StringCompiler &) = delete; StringCompiler &operator=(const StringCompiler &) = delete; };
// A useful alias when using StdArc.
using StdStringCompiler = StringCompiler<StdArc>;
// Helpers for StringPrinter.
// Converts an FST to a vector of output labels. To get input labels, use
// Project or Invert. Returns true on success. Use only with string FSTs; may
// loop for non-string FSTs.
template <class Arc> bool StringFstToOutputLabels(const Fst<Arc> &fst, std::vector<typename Arc::Label> *labels) { labels->clear(); auto s = fst.Start(); if (s == kNoStateId) { LOG(ERROR) << "StringFstToOutputLabels: Invalid start state"; return false; } while (fst.Final(s) == Arc::Weight::Zero()) { ArcIterator<Fst<Arc>> aiter(fst, s); if (aiter.Done()) { LOG(ERROR) << "StringFstToOutputLabels: Does not reach final state"; return false; } const auto &arc = aiter.Value(); labels->push_back(arc.olabel); s = arc.nextstate; aiter.Next(); if (!aiter.Done()) { LOG(ERROR) << "StringFstToOutputLabels: State " << s << " has multiple outgoing arcs"; return false; } } if (fst.NumArcs(s) != 0) { LOG(ERROR) << "StringFstToOutputLabels: Final state " << s << " has outgoing arc(s)"; return false; } return true; }
// Same as above but also computes the path weight. The output weight parameter
// is only set if labels extraction is successful.
template <class Arc> bool StringFstToOutputLabels(const Fst<Arc> &fst, std::vector<typename Arc::Label> *labels, typename Arc::Weight *weight) { labels->clear(); auto path_weight = Arc::Weight::One(); auto s = fst.Start(); if (s == kNoStateId) { LOG(ERROR) << "StringFstToOutputLabels: Invalid start state"; return false; } auto final_weight = fst.Final(s); while (final_weight == Arc::Weight::Zero()) { ArcIterator<Fst<Arc>> aiter(fst, s); if (aiter.Done()) { LOG(ERROR) << "StringFstToOutputLabels: Does not reach final state"; return false; } const auto &arc = aiter.Value(); labels->push_back(arc.olabel); path_weight = Times(path_weight, arc.weight); s = arc.nextstate; aiter.Next(); if (!aiter.Done()) { LOG(ERROR) << "StringFstToOutputLabels: State " << s << " has multiple outgoing arcs"; return false; } final_weight = fst.Final(s); } if (fst.NumArcs(s) != 0) { LOG(ERROR) << "StringFstToOutputLabels: Final state " << s << " has outgoing arc(s)"; return false; } *weight = Times(path_weight, final_weight); return true; }
// Converts a list of symbols to a string. If the token type is SYMBOL, the last
// character of sep is used to separate textual symbols. Additionally, if the
// token type is SYMBOL, epsilon symbols will be printed only if omit_epsilon
// is false. Returns true on success.
template <class Label> bool LabelsToString( const std::vector<Label> &labels, std::string *str, TokenType ttype = TokenType::BYTE, const SymbolTable *syms = nullptr, std::string_view sep = FST_FLAGS_fst_field_separator, bool omit_epsilon = true) { switch (ttype) { case TokenType::BYTE: { return LabelsToByteString(labels, str); } case TokenType::UTF8: { return LabelsToUTF8String(labels, str); } case TokenType::SYMBOL: { return syms ? internal::LabelsToSymbolString(labels, str, *syms, sep, omit_epsilon) : internal::LabelsToNumericString(labels, str, sep, omit_epsilon); } } return false; }
// Functor for printing a string FST as a string.
template <class Arc> class StringPrinter { public: using Label = typename Arc::Label; using Weight = typename Arc::Weight;
explicit StringPrinter(TokenType token_type = TokenType::BYTE, const SymbolTable *syms = nullptr, bool omit_epsilon = true) : token_type_(token_type), syms_(syms), omit_epsilon_(omit_epsilon) {}
// Converts the FST into a string. With SYMBOL token type, the last character
// of sep is used as a separator between symbols. Returns true on success.
bool operator()( const Fst<Arc> &fst, std::string *str, std::string_view sep = FST_FLAGS_fst_field_separator) const { std::vector<Label> labels; return StringFstToOutputLabels(fst, &labels) && LabelsToString(labels, str, token_type_, syms_, sep, omit_epsilon_); }
// Same as above but also computes the path weight. The output weight
// parameter is only set if labels extraction is successful.
bool operator()( const Fst<Arc> &fst, std::string *str, Weight *weight, std::string_view sep = FST_FLAGS_fst_field_separator) const { std::vector<Label> labels; return StringFstToOutputLabels(fst, &labels, weight) && LabelsToString(labels, str, token_type_, syms_, sep, omit_epsilon_); }
private: const TokenType token_type_; const SymbolTable *syms_; const bool omit_epsilon_;
StringPrinter(const StringPrinter &) = delete; StringPrinter &operator=(const StringPrinter &) = delete; };
// A useful alias when using StdArc.
using StdStringPrinter = StringPrinter<StdArc>;
} // namespace fst
#endif // FST_STRING_H_
|