|
|
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/context_graph.h"
#include <fstream>
#include <queue>
#include <utility>
#include "fst/determinize.h"
#include "utils/wn_string.h"
#include "../utils/wn_utils.h"
namespace wenet {
// Split the UTF-8 string into unit ids according to unit_table
bool SplitContextToUnits(const std::string& context, const std::shared_ptr<fst::SymbolTable>& unit_table, std::vector<int>* units) { std::vector<std::string> chars; SplitUTF8StringToChars(context, &chars);
bool no_oov = true; bool beginning = true; for (size_t start = 0; start < chars.size();) { for (size_t end = chars.size(); end > start; --end) { std::string unit; for (size_t i = start; i < end; i++) { unit += chars[i]; } // Add '▁' at the beginning of English word.
// TODO(zhendong.peng): Support bpe model
if (IsAlpha(unit) && beginning) { unit = kSpaceSymbol + unit; }
int unit_id = unit_table->Find(unit); if (unit_id != -1) { units->emplace_back(unit_id); start = end; beginning = false; continue; }
if (end == start + 1) { // Matching using '▁' separately for English
if (unit[0] == kSpaceSymbol[0]) { units->emplace_back(unit_table->Find(kSpaceSymbol)); beginning = false; break; } ++start; if (unit == " ") { beginning = true; continue; } no_oov = false; LOG(WARNING) << unit << " is oov."; } } } return no_oov; }
ContextGraph::ContextGraph(ContextConfig config) : config_(config) {}
int ContextGraph::TraceContext(int cur_state, int unit_id, int* final_state) { CHECK_GE(cur_state, 0); int next_state = 0; Matcher matcher(*graph_, fst::MATCH_INPUT); matcher.SetState(cur_state); if (matcher.Find(unit_id)) { next_state = matcher.Value().nextstate; if (graph_->Final(next_state) != Weight::Zero()) { *final_state = next_state; } return next_state; } LOG(FATAL) << "Trace context failed."; }
void ContextGraph::BuildContextGraph( const std::vector<std::string>& contexts, const std::shared_ptr<fst::SymbolTable>& unit_table) { // Split context phrase into unit ids according to the `unit_table`
std::unordered_map<std::string, std::vector<int>> context_units; for (const auto& context : contexts) { std::vector<int> units; bool no_oov = SplitContextToUnits(context, unit_table, &units); if (!no_oov) { LOG(WARNING) << "Ignore unknown unit found during compilation."; continue; } context_units[context] = units; }
// Build the context graph
std::unique_ptr<fst::StdVectorFst> ofst(new fst::StdVectorFst()); int start_state = ofst->AddState(); ofst->SetStart(start_state); for (const auto& context : contexts) { if (context_units.count(context) == 0) continue; std::vector<int> units = context_units[context]; int state = start_state; int next_state = state; for (size_t i = 0; i < units.size(); ++i) { next_state = ofst->AddState(); if (i == units.size() - 1) { ofst->SetFinal(next_state, Weight::One()); } float score = i * config_.incremental_context_score + config_.context_score; ofst->AddArc(state, fst::StdArc(units[i], units[i], score, next_state)); state = next_state; } } graph_ = std::unique_ptr<fst::StdVectorFst>(new fst::StdVectorFst()); // input/output label are sorted after Determinize
fst::Determinize(*ofst, graph_.get());
// Determinize will change the final state id
for (const auto& context : contexts) { if (context_units.count(context) == 0) continue; std::vector<int> units = context_units[context]; int final_state = -1; int cur_state = 0; for (int unit : units) { cur_state = TraceContext(cur_state, unit, &final_state); } CHECK_GT(final_state, 0); context_table_[final_state] = context; }
// Convert context graph to AC automaton
ConvertToAC(); }
void ContextGraph::ConvertToAC() { CHECK(graph_ != nullptr) << "Context graph should not be nullptr!"; int num_states = graph_->NumStates(); std::vector<int> fail_states(num_states, 0); std::vector<float> total_weights(num_states, 0); Matcher matcher(*graph_, fst::MATCH_INPUT); // start state
fail_states[0] = -1; total_weights[0] = 0;
// Please see:
// https://web.stanford.edu/group/cslipublications/cslipublications/koskenniemi-festschrift/9-mohri.pdf
std::queue<int> states_queue; states_queue.push(0); while (!states_queue.empty()) { int state = states_queue.front(); states_queue.pop();
for (ArcIterator aiter(*graph_, state); !aiter.Done(); aiter.Next()) { const fst::StdArc& arc = aiter.Value(); int next_state = arc.nextstate; total_weights[next_state] = total_weights[state] + arc.weight.Value(); // Backtracking the failure state for next_state
for (int fail_state = fail_states[state]; fail_state != -1; fail_state = fail_states[fail_state]) { matcher.SetState(fail_state); if (matcher.Find(arc.ilabel)) { fail_states[next_state] = matcher.Value().nextstate; break; } } states_queue.push(next_state); } }
// Compute fail weight, add fail arc
for (int state = 0; state < num_states; state++) { int fail_state = fail_states[state]; if (fail_state < 0) continue; if (graph_->Final(fail_state) != Weight::Zero()) { fallback_finals_[state] = fail_state; if (graph_->NumArcs(fail_state) == 0) continue; } if (graph_->Final(state) != Weight::Zero() && fail_state == 0) continue;
float fail_weight = total_weights[fail_state] - total_weights[state]; if (graph_->Final(state) != Weight::Zero()) { fail_weight = 0; } graph_->AddArc(state, fst::StdArc(0, 0, fail_weight, fail_state)); } // Sort arcs by ilabel, means move the fallback arc from last to first for the
// matcher
fst::ArcSort(graph_.get(), fst::ILabelCompare<fst::StdArc>()); }
int ContextGraph::GetNextState(int cur_state, int unit_id, float* score, std::unordered_set<std::string>* contexts) { CHECK_GE(cur_state, 0); // Find(0) matches any epsilons on the underlying FST explicitly
CHECK_NE(unit_id, 0); int next_state = 0;
Matcher matcher(*graph_, fst::MATCH_INPUT); matcher.SetState(cur_state); if (matcher.Find(unit_id)) { const fst::StdArc& arc = matcher.Value(); next_state = arc.nextstate; *score += arc.weight.Value(); // Collect all contexts in the decode result
if (contexts != nullptr) { if (graph_->Final(next_state) != Weight::Zero()) { contexts->insert(context_table_[next_state]); } int fallback_final = next_state; while (fallback_finals_.count(fallback_final) > 0) { fallback_final = fallback_finals_[fallback_final]; contexts->insert(context_table_[fallback_final]); } }
// Leaves go back to the start state
if (graph_->NumArcs(next_state) == 0) { return 0; } return next_state; }
// Check whether the first arc is fallback arc
ArcIterator aiter(*graph_, cur_state); const fst::StdArc& arc = aiter.Value(); // The start state has no fallback arc
if (arc.ilabel == 0) { next_state = arc.nextstate; *score += arc.weight.Value(); // fallback
return GetNextState(next_state, unit_id, score); }
return 0; }
} // namespace wenet
|