You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

236 lines
8.6 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <memory>
  15. #include <sstream>
  16. #include <unordered_map>
  17. #include <vector>
  18. #include "decoder/params.h"
  19. #include "frontend/wav.h"
  20. #include "utils/flags.h"
  21. #include "utils/wn_string.h"
  22. #include <gflags/gflags.h>
  23. DEFINE_string(text, "", "kaldi style text input file");
  24. DEFINE_string(wav_scp, "", "kaldi style wav scp");
  25. DEFINE_double(is_penalty, 1.0,
  26. "insertion/substitution penalty for align insertion");
  27. DEFINE_double(del_penalty, 1.0, "deletion penalty for align insertion");
  28. DEFINE_string(result, "", "result output file");
  29. DEFINE_string(timestamp, "", "timestamp output file");
  30. namespace wenet {
  31. const char* kDeletion = "<del>";
  32. // Is: Insertion and substitution
  33. const char* kIsStart = "<is>";
  34. const char* kIsEnd = "</is>";
  35. bool MapToLabel(const std::string& text,
  36. std::shared_ptr<fst::SymbolTable> symbol_table,
  37. std::vector<int>* labels) {
  38. labels->clear();
  39. // Split label to char sequence
  40. std::vector<std::string> chars;
  41. SplitUTF8StringToChars(text, &chars);
  42. for (size_t i = 0; i < chars.size(); i++) {
  43. // ▁ is special symbol for white space
  44. std::string label = chars[i] != " " ? chars[i] : "";
  45. int id = symbol_table->Find(label);
  46. if (id != -1) { // fst::kNoSymbol
  47. // LOG(INFO) << label << " " << id;
  48. labels->push_back(id);
  49. }
  50. }
  51. return true;
  52. }
  53. std::shared_ptr<fst::SymbolTable> MakeSymbolTableForFst(
  54. std::shared_ptr<fst::SymbolTable> isymbol_table) {
  55. LOG(INFO) << isymbol_table;
  56. CHECK(isymbol_table != nullptr);
  57. auto osymbol_table = std::make_shared<fst::SymbolTable>();
  58. osymbol_table->AddSymbol("<eps>", 0);
  59. CHECK_EQ(isymbol_table->Find("<blank>"), 0);
  60. osymbol_table->AddSymbol("<blank>", 1);
  61. for (int i = 1; i < isymbol_table->NumSymbols(); i++) {
  62. std::string symbol = isymbol_table->Find(i);
  63. osymbol_table->AddSymbol(symbol, i + 1);
  64. }
  65. osymbol_table->AddSymbol(kDeletion, isymbol_table->NumSymbols() + 1);
  66. osymbol_table->AddSymbol(kIsStart, isymbol_table->NumSymbols() + 2);
  67. osymbol_table->AddSymbol(kIsEnd, isymbol_table->NumSymbols() + 3);
  68. return osymbol_table;
  69. }
  70. void CompileCtcFst(std::shared_ptr<fst::SymbolTable> symbol_table,
  71. fst::StdVectorFst* ofst) {
  72. ofst->DeleteStates();
  73. int start = ofst->AddState();
  74. ofst->SetStart(start);
  75. CHECK_EQ(symbol_table->Find("<eps>"), 0);
  76. CHECK_EQ(symbol_table->Find("<blank>"), 1);
  77. ofst->AddArc(start, fst::StdArc(1, 0, 0.0, start));
  78. // Exclude kDeletion and kInsertion
  79. for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
  80. int s = ofst->AddState();
  81. ofst->AddArc(start, fst::StdArc(i, i, 0.0, s));
  82. ofst->AddArc(s, fst::StdArc(i, 0, 0.0, s));
  83. ofst->AddArc(s, fst::StdArc(0, 0, 0.0, start));
  84. }
  85. ofst->SetFinal(start, fst::StdArc::Weight::One());
  86. fst::ArcSort(ofst, fst::StdOLabelCompare());
  87. }
  88. void CompileAlignFst(std::vector<int> labels,
  89. std::shared_ptr<fst::SymbolTable> symbol_table,
  90. fst::StdVectorFst* ofst) {
  91. ofst->DeleteStates();
  92. int deletion = symbol_table->Find(kDeletion);
  93. int insertion_start = symbol_table->Find(kIsStart);
  94. int insertion_end = symbol_table->Find(kIsEnd);
  95. int start = ofst->AddState();
  96. ofst->SetStart(start);
  97. // Filler State
  98. int filler_start = ofst->AddState();
  99. int filler_end = ofst->AddState();
  100. for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
  101. ofst->AddArc(filler_start, fst::StdArc(i, i, FLAGS_is_penalty, filler_end));
  102. }
  103. ofst->AddArc(filler_end, fst::StdArc(0, 0, 0.0, filler_start));
  104. int prev = start;
  105. // Alignment path and optional filler
  106. for (size_t i = 0; i < labels.size(); i++) {
  107. int cur = ofst->AddState();
  108. // 1. Insertion or Substitution
  109. ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
  110. ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
  111. // 2. Correct
  112. ofst->AddArc(prev, fst::StdArc(labels[i], labels[i], 0.0, cur));
  113. // 3. Deletion
  114. ofst->AddArc(prev, fst::StdArc(0, deletion, FLAGS_del_penalty, cur));
  115. prev = cur;
  116. }
  117. // Optional add endding filler
  118. ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
  119. ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
  120. ofst->SetFinal(prev, fst::StdArc::Weight::One());
  121. fst::ArcSort(ofst, fst::StdILabelCompare());
  122. }
  123. } // namespace wenet
  124. int main(int argc, char* argv[]) {
  125. gflags::ParseCommandLineFlags(&argc, &argv, false);
  126. google::InitGoogleLogging(argv[0]);
  127. auto decode_config = wenet::InitDecodeOptionsFromFlags();
  128. auto feature_config = wenet::InitFeaturePipelineConfigFromFlags();
  129. auto decode_resource = wenet::InitDecodeResourceFromFlags();
  130. CHECK(decode_resource->unit_table != nullptr);
  131. auto wfst_symbol_table =
  132. wenet::MakeSymbolTableForFst(decode_resource->unit_table);
  133. // wfst_symbol_table->WriteText("fst.txt");
  134. // Reset symbol_table to on-the-fly generated wfst_symbol_table
  135. decode_resource->symbol_table = wfst_symbol_table;
  136. // Compile ctc FST
  137. fst::StdVectorFst ctc_fst;
  138. wenet::CompileCtcFst(wfst_symbol_table, &ctc_fst);
  139. // ctc_fst.Write("ctc.fst");
  140. std::unordered_map<std::string, std::string> wav_table;
  141. std::ifstream wav_is(FLAGS_wav_scp);
  142. std::string line;
  143. while (std::getline(wav_is, line)) {
  144. std::vector<std::string> strs;
  145. wenet::SplitString(line, &strs);
  146. CHECK_EQ(strs.size(), 2);
  147. wav_table[strs[0]] = strs[1];
  148. }
  149. std::ifstream text_is(FLAGS_text);
  150. std::ofstream result_os(FLAGS_result, std::ios::out);
  151. std::ofstream timestamp_out;
  152. if (!FLAGS_timestamp.empty()) {
  153. timestamp_out.open(FLAGS_timestamp, std::ios::out);
  154. }
  155. std::ostream& timestamp_os =
  156. FLAGS_timestamp.empty() ? std::cout : timestamp_out;
  157. while (std::getline(text_is, line)) {
  158. std::vector<std::string> strs;
  159. wenet::SplitString(line, &strs);
  160. if (strs.size() < 2) continue;
  161. std::string key = strs[0];
  162. LOG(INFO) << "Processing " << key;
  163. if (wav_table.find(key) != wav_table.end()) {
  164. strs.erase(strs.begin());
  165. std::string text = wenet::JoinString(" ", strs);
  166. std::vector<int> labels;
  167. wenet::MapToLabel(text, wfst_symbol_table, &labels);
  168. // Prepare FST for alignment decoding
  169. fst::StdVectorFst align_fst;
  170. wenet::CompileAlignFst(labels, wfst_symbol_table, &align_fst);
  171. // align_fst.Write("align.fst");
  172. auto decoding_fst = std::make_shared<fst::StdVectorFst>();
  173. fst::Compose(ctc_fst, align_fst, decoding_fst.get());
  174. // decoding_fst->Write("decoding.fst");
  175. // Preapre feature pipeline
  176. wenet::WavReader wav_reader;
  177. if (!wav_reader.Open(wav_table[key])) {
  178. LOG(WARNING) << "Error in reading " << wav_table[key];
  179. continue;
  180. }
  181. int num_samples = wav_reader.num_samples();
  182. CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
  183. auto feature_pipeline =
  184. std::make_shared<wenet::FeaturePipeline>(*feature_config);
  185. feature_pipeline->AcceptWaveform(wav_reader.data(), num_samples);
  186. feature_pipeline->set_input_finished();
  187. decode_resource->fst = decoding_fst;
  188. LOG(INFO) << "num frames " << feature_pipeline->num_frames();
  189. wenet::AsrDecoder decoder(feature_pipeline, decode_resource,
  190. *decode_config);
  191. while (true) {
  192. wenet::DecodeState state = decoder.Decode();
  193. if (state == wenet::DecodeState::kEndFeats) {
  194. decoder.Rescoring();
  195. break;
  196. }
  197. }
  198. std::string final_result;
  199. std::string timestamp_str;
  200. if (decoder.DecodedSomething()) {
  201. const wenet::DecodeResult& result = decoder.result()[0];
  202. final_result = result.sentence;
  203. std::stringstream ss;
  204. for (const auto& w : result.word_pieces) {
  205. ss << " " << w.word << " " << w.start << " " << w.end;
  206. }
  207. timestamp_str = ss.str();
  208. }
  209. result_os << key << " " << final_result << std::endl;
  210. timestamp_os << key << " " << timestamp_str << std::endl;
  211. LOG(INFO) << key << " " << final_result;
  212. } else {
  213. LOG(WARNING) << "No wav file for " << key;
  214. }
  215. }
  216. return 0;
  217. }