You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

252 lines
8.0 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "decoder/context_graph.h"
  15. #include <fstream>
  16. #include <queue>
  17. #include <utility>
  18. #include "fst/determinize.h"
  19. #include "utils/wn_string.h"
  20. #include "../utils/wn_utils.h"
  21. namespace wenet {
  22. // Split the UTF-8 string into unit ids according to unit_table
  23. bool SplitContextToUnits(const std::string& context,
  24. const std::shared_ptr<fst::SymbolTable>& unit_table,
  25. std::vector<int>* units) {
  26. std::vector<std::string> chars;
  27. SplitUTF8StringToChars(context, &chars);
  28. bool no_oov = true;
  29. bool beginning = true;
  30. for (size_t start = 0; start < chars.size();) {
  31. for (size_t end = chars.size(); end > start; --end) {
  32. std::string unit;
  33. for (size_t i = start; i < end; i++) {
  34. unit += chars[i];
  35. }
  36. // Add '▁' at the beginning of English word.
  37. // TODO(zhendong.peng): Support bpe model
  38. if (IsAlpha(unit) && beginning) {
  39. unit = kSpaceSymbol + unit;
  40. }
  41. int unit_id = unit_table->Find(unit);
  42. if (unit_id != -1) {
  43. units->emplace_back(unit_id);
  44. start = end;
  45. beginning = false;
  46. continue;
  47. }
  48. if (end == start + 1) {
  49. // Matching using '▁' separately for English
  50. if (unit[0] == kSpaceSymbol[0]) {
  51. units->emplace_back(unit_table->Find(kSpaceSymbol));
  52. beginning = false;
  53. break;
  54. }
  55. ++start;
  56. if (unit == " ") {
  57. beginning = true;
  58. continue;
  59. }
  60. no_oov = false;
  61. LOG(WARNING) << unit << " is oov.";
  62. }
  63. }
  64. }
  65. return no_oov;
  66. }
  67. ContextGraph::ContextGraph(ContextConfig config) : config_(config) {}
  68. int ContextGraph::TraceContext(int cur_state, int unit_id, int* final_state) {
  69. CHECK_GE(cur_state, 0);
  70. int next_state = 0;
  71. Matcher matcher(*graph_, fst::MATCH_INPUT);
  72. matcher.SetState(cur_state);
  73. if (matcher.Find(unit_id)) {
  74. next_state = matcher.Value().nextstate;
  75. if (graph_->Final(next_state) != Weight::Zero()) {
  76. *final_state = next_state;
  77. }
  78. return next_state;
  79. }
  80. LOG(FATAL) << "Trace context failed.";
  81. }
  82. void ContextGraph::BuildContextGraph(
  83. const std::vector<std::string>& contexts,
  84. const std::shared_ptr<fst::SymbolTable>& unit_table) {
  85. // Split context phrase into unit ids according to the `unit_table`
  86. std::unordered_map<std::string, std::vector<int>> context_units;
  87. for (const auto& context : contexts) {
  88. std::vector<int> units;
  89. bool no_oov = SplitContextToUnits(context, unit_table, &units);
  90. if (!no_oov) {
  91. LOG(WARNING) << "Ignore unknown unit found during compilation.";
  92. continue;
  93. }
  94. context_units[context] = units;
  95. }
  96. // Build the context graph
  97. std::unique_ptr<fst::StdVectorFst> ofst(new fst::StdVectorFst());
  98. int start_state = ofst->AddState();
  99. ofst->SetStart(start_state);
  100. for (const auto& context : contexts) {
  101. if (context_units.count(context) == 0) continue;
  102. std::vector<int> units = context_units[context];
  103. int state = start_state;
  104. int next_state = state;
  105. for (size_t i = 0; i < units.size(); ++i) {
  106. next_state = ofst->AddState();
  107. if (i == units.size() - 1) {
  108. ofst->SetFinal(next_state, Weight::One());
  109. }
  110. float score =
  111. i * config_.incremental_context_score + config_.context_score;
  112. ofst->AddArc(state, fst::StdArc(units[i], units[i], score, next_state));
  113. state = next_state;
  114. }
  115. }
  116. graph_ = std::unique_ptr<fst::StdVectorFst>(new fst::StdVectorFst());
  117. // input/output label are sorted after Determinize
  118. fst::Determinize(*ofst, graph_.get());
  119. // Determinize will change the final state id
  120. for (const auto& context : contexts) {
  121. if (context_units.count(context) == 0) continue;
  122. std::vector<int> units = context_units[context];
  123. int final_state = -1;
  124. int cur_state = 0;
  125. for (int unit : units) {
  126. cur_state = TraceContext(cur_state, unit, &final_state);
  127. }
  128. CHECK_GT(final_state, 0);
  129. context_table_[final_state] = context;
  130. }
  131. // Convert context graph to AC automaton
  132. ConvertToAC();
  133. }
  134. void ContextGraph::ConvertToAC() {
  135. CHECK(graph_ != nullptr) << "Context graph should not be nullptr!";
  136. int num_states = graph_->NumStates();
  137. std::vector<int> fail_states(num_states, 0);
  138. std::vector<float> total_weights(num_states, 0);
  139. Matcher matcher(*graph_, fst::MATCH_INPUT);
  140. // start state
  141. fail_states[0] = -1;
  142. total_weights[0] = 0;
  143. // Please see:
  144. // https://web.stanford.edu/group/cslipublications/cslipublications/koskenniemi-festschrift/9-mohri.pdf
  145. std::queue<int> states_queue;
  146. states_queue.push(0);
  147. while (!states_queue.empty()) {
  148. int state = states_queue.front();
  149. states_queue.pop();
  150. for (ArcIterator aiter(*graph_, state); !aiter.Done(); aiter.Next()) {
  151. const fst::StdArc& arc = aiter.Value();
  152. int next_state = arc.nextstate;
  153. total_weights[next_state] = total_weights[state] + arc.weight.Value();
  154. // Backtracking the failure state for next_state
  155. for (int fail_state = fail_states[state]; fail_state != -1;
  156. fail_state = fail_states[fail_state]) {
  157. matcher.SetState(fail_state);
  158. if (matcher.Find(arc.ilabel)) {
  159. fail_states[next_state] = matcher.Value().nextstate;
  160. break;
  161. }
  162. }
  163. states_queue.push(next_state);
  164. }
  165. }
  166. // Compute fail weight, add fail arc
  167. for (int state = 0; state < num_states; state++) {
  168. int fail_state = fail_states[state];
  169. if (fail_state < 0) continue;
  170. if (graph_->Final(fail_state) != Weight::Zero()) {
  171. fallback_finals_[state] = fail_state;
  172. if (graph_->NumArcs(fail_state) == 0) continue;
  173. }
  174. if (graph_->Final(state) != Weight::Zero() && fail_state == 0) continue;
  175. float fail_weight = total_weights[fail_state] - total_weights[state];
  176. if (graph_->Final(state) != Weight::Zero()) {
  177. fail_weight = 0;
  178. }
  179. graph_->AddArc(state, fst::StdArc(0, 0, fail_weight, fail_state));
  180. }
  181. // Sort arcs by ilabel, means move the fallback arc from last to first for the
  182. // matcher
  183. fst::ArcSort(graph_.get(), fst::ILabelCompare<fst::StdArc>());
  184. }
  185. int ContextGraph::GetNextState(int cur_state, int unit_id, float* score,
  186. std::unordered_set<std::string>* contexts) {
  187. CHECK_GE(cur_state, 0);
  188. // Find(0) matches any epsilons on the underlying FST explicitly
  189. CHECK_NE(unit_id, 0);
  190. int next_state = 0;
  191. Matcher matcher(*graph_, fst::MATCH_INPUT);
  192. matcher.SetState(cur_state);
  193. if (matcher.Find(unit_id)) {
  194. const fst::StdArc& arc = matcher.Value();
  195. next_state = arc.nextstate;
  196. *score += arc.weight.Value();
  197. // Collect all contexts in the decode result
  198. if (contexts != nullptr) {
  199. if (graph_->Final(next_state) != Weight::Zero()) {
  200. contexts->insert(context_table_[next_state]);
  201. }
  202. int fallback_final = next_state;
  203. while (fallback_finals_.count(fallback_final) > 0) {
  204. fallback_final = fallback_finals_[fallback_final];
  205. contexts->insert(context_table_[fallback_final]);
  206. }
  207. }
  208. // Leaves go back to the start state
  209. if (graph_->NumArcs(next_state) == 0) {
  210. return 0;
  211. }
  212. return next_state;
  213. }
  214. // Check whether the first arc is fallback arc
  215. ArcIterator aiter(*graph_, cur_state);
  216. const fst::StdArc& arc = aiter.Value();
  217. // The start state has no fallback arc
  218. if (arc.ilabel == 0) {
  219. next_state = arc.nextstate;
  220. *score += arc.weight.Value();
  221. // fallback
  222. return GetNextState(next_state, unit_id, score);
  223. }
  224. return 0;
  225. }
  226. } // namespace wenet