You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

236 lines
8.5 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <memory>
  15. #include <sstream>
  16. #include <unordered_map>
  17. #include <vector>
  18. #include "decoder/params.h"
  19. #include "frontend/wav.h"
  20. #include "utils/flags.h"
  21. #include "utils/wn_string.h"
  22. DEFINE_string(text, "", "kaldi style text input file");
  23. DEFINE_string(wav_scp, "", "kaldi style wav scp");
  24. DEFINE_double(is_penalty, 1.0,
  25. "insertion/substitution penalty for align insertion");
  26. DEFINE_double(del_penalty, 1.0, "deletion penalty for align insertion");
  27. DEFINE_string(result, "", "result output file");
  28. DEFINE_string(timestamp, "", "timestamp output file");
  29. namespace wenet {
  30. const char* kDeletion = "<del>";
  31. // Is: Insertion and substitution
  32. const char* kIsStart = "<is>";
  33. const char* kIsEnd = "</is>";
  34. bool MapToLabel(const std::string& text,
  35. std::shared_ptr<fst::SymbolTable> symbol_table,
  36. std::vector<int>* labels) {
  37. labels->clear();
  38. // Split label to char sequence
  39. std::vector<std::string> chars;
  40. SplitUTF8StringToChars(text, &chars);
  41. for (size_t i = 0; i < chars.size(); i++) {
  42. // ▁ is special symbol for white space
  43. std::string label = chars[i] != " " ? chars[i] : "";
  44. int id = symbol_table->Find(label);
  45. if (id != -1) { // fst::kNoSymbol
  46. // LOG(INFO) << label << " " << id;
  47. labels->push_back(id);
  48. }
  49. }
  50. return true;
  51. }
  52. std::shared_ptr<fst::SymbolTable> MakeSymbolTableForFst(
  53. std::shared_ptr<fst::SymbolTable> isymbol_table) {
  54. LOG(INFO) << isymbol_table;
  55. CHECK(isymbol_table != nullptr);
  56. auto osymbol_table = std::make_shared<fst::SymbolTable>();
  57. osymbol_table->AddSymbol("<eps>", 0);
  58. CHECK_EQ(isymbol_table->Find("<blank>"), 0);
  59. osymbol_table->AddSymbol("<blank>", 1);
  60. for (int i = 1; i < isymbol_table->NumSymbols(); i++) {
  61. std::string symbol = isymbol_table->Find(i);
  62. osymbol_table->AddSymbol(symbol, i + 1);
  63. }
  64. osymbol_table->AddSymbol(kDeletion, isymbol_table->NumSymbols() + 1);
  65. osymbol_table->AddSymbol(kIsStart, isymbol_table->NumSymbols() + 2);
  66. osymbol_table->AddSymbol(kIsEnd, isymbol_table->NumSymbols() + 3);
  67. return osymbol_table;
  68. }
  69. void CompileCtcFst(std::shared_ptr<fst::SymbolTable> symbol_table,
  70. fst::StdVectorFst* ofst) {
  71. ofst->DeleteStates();
  72. int start = ofst->AddState();
  73. ofst->SetStart(start);
  74. CHECK_EQ(symbol_table->Find("<eps>"), 0);
  75. CHECK_EQ(symbol_table->Find("<blank>"), 1);
  76. ofst->AddArc(start, fst::StdArc(1, 0, 0.0, start));
  77. // Exclude kDeletion and kInsertion
  78. for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
  79. int s = ofst->AddState();
  80. ofst->AddArc(start, fst::StdArc(i, i, 0.0, s));
  81. ofst->AddArc(s, fst::StdArc(i, 0, 0.0, s));
  82. ofst->AddArc(s, fst::StdArc(0, 0, 0.0, start));
  83. }
  84. ofst->SetFinal(start, fst::StdArc::Weight::One());
  85. fst::ArcSort(ofst, fst::StdOLabelCompare());
  86. }
  87. void CompileAlignFst(std::vector<int> labels,
  88. std::shared_ptr<fst::SymbolTable> symbol_table,
  89. fst::StdVectorFst* ofst) {
  90. ofst->DeleteStates();
  91. int deletion = symbol_table->Find(kDeletion);
  92. int insertion_start = symbol_table->Find(kIsStart);
  93. int insertion_end = symbol_table->Find(kIsEnd);
  94. int start = ofst->AddState();
  95. ofst->SetStart(start);
  96. // Filler State
  97. int filler_start = ofst->AddState();
  98. int filler_end = ofst->AddState();
  99. for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
  100. ofst->AddArc(filler_start, fst::StdArc(i, i, FLAGS_is_penalty, filler_end));
  101. }
  102. ofst->AddArc(filler_end, fst::StdArc(0, 0, 0.0, filler_start));
  103. int prev = start;
  104. // Alignment path and optional filler
  105. for (size_t i = 0; i < labels.size(); i++) {
  106. int cur = ofst->AddState();
  107. // 1. Insertion or Substitution
  108. ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
  109. ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
  110. // 2. Correct
  111. ofst->AddArc(prev, fst::StdArc(labels[i], labels[i], 0.0, cur));
  112. // 3. Deletion
  113. ofst->AddArc(prev, fst::StdArc(0, deletion, FLAGS_del_penalty, cur));
  114. prev = cur;
  115. }
  116. // Optional add endding filler
  117. ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
  118. ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
  119. ofst->SetFinal(prev, fst::StdArc::Weight::One());
  120. fst::ArcSort(ofst, fst::StdILabelCompare());
  121. }
  122. } // namespace wenet
  123. int main(int argc, char* argv[]) {
  124. gflags::ParseCommandLineFlags(&argc, &argv, false);
  125. google::InitGoogleLogging(argv[0]);
  126. auto decode_config = wenet::InitDecodeOptionsFromFlags();
  127. auto feature_config = wenet::InitFeaturePipelineConfigFromFlags();
  128. auto decode_resource = wenet::InitDecodeResourceFromFlags();
  129. CHECK(decode_resource->unit_table != nullptr);
  130. auto wfst_symbol_table =
  131. wenet::MakeSymbolTableForFst(decode_resource->unit_table);
  132. // wfst_symbol_table->WriteText("fst.txt");
  133. // Reset symbol_table to on-the-fly generated wfst_symbol_table
  134. decode_resource->symbol_table = wfst_symbol_table;
  135. // Compile ctc FST
  136. fst::StdVectorFst ctc_fst;
  137. wenet::CompileCtcFst(wfst_symbol_table, &ctc_fst);
  138. // ctc_fst.Write("ctc.fst");
  139. std::unordered_map<std::string, std::string> wav_table;
  140. std::ifstream wav_is(FLAGS_wav_scp);
  141. std::string line;
  142. while (std::getline(wav_is, line)) {
  143. std::vector<std::string> strs;
  144. wenet::SplitString(line, &strs);
  145. CHECK_EQ(strs.size(), 2);
  146. wav_table[strs[0]] = strs[1];
  147. }
  148. std::ifstream text_is(FLAGS_text);
  149. std::ofstream result_os(FLAGS_result, std::ios::out);
  150. std::ofstream timestamp_out;
  151. if (!FLAGS_timestamp.empty()) {
  152. timestamp_out.open(FLAGS_timestamp, std::ios::out);
  153. }
  154. std::ostream& timestamp_os =
  155. FLAGS_timestamp.empty() ? std::cout : timestamp_out;
  156. while (std::getline(text_is, line)) {
  157. std::vector<std::string> strs;
  158. wenet::SplitString(line, &strs);
  159. if (strs.size() < 2) continue;
  160. std::string key = strs[0];
  161. LOG(INFO) << "Processing " << key;
  162. if (wav_table.find(key) != wav_table.end()) {
  163. strs.erase(strs.begin());
  164. std::string text = wenet::JoinString(" ", strs);
  165. std::vector<int> labels;
  166. wenet::MapToLabel(text, wfst_symbol_table, &labels);
  167. // Prepare FST for alignment decoding
  168. fst::StdVectorFst align_fst;
  169. wenet::CompileAlignFst(labels, wfst_symbol_table, &align_fst);
  170. // align_fst.Write("align.fst");
  171. auto decoding_fst = std::make_shared<fst::StdVectorFst>();
  172. fst::Compose(ctc_fst, align_fst, decoding_fst.get());
  173. // decoding_fst->Write("decoding.fst");
  174. // Preapre feature pipeline
  175. wenet::WavReader wav_reader;
  176. if (!wav_reader.Open(wav_table[key])) {
  177. LOG(WARNING) << "Error in reading " << wav_table[key];
  178. continue;
  179. }
  180. int num_samples = wav_reader.num_samples();
  181. CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
  182. auto feature_pipeline =
  183. std::make_shared<wenet::FeaturePipeline>(*feature_config);
  184. feature_pipeline->AcceptWaveform(wav_reader.data(), num_samples);
  185. feature_pipeline->set_input_finished();
  186. decode_resource->fst = decoding_fst;
  187. LOG(INFO) << "num frames " << feature_pipeline->num_frames();
  188. wenet::AsrDecoder decoder(feature_pipeline, decode_resource,
  189. *decode_config);
  190. while (true) {
  191. wenet::DecodeState state = decoder.Decode();
  192. if (state == wenet::DecodeState::kEndFeats) {
  193. decoder.Rescoring();
  194. break;
  195. }
  196. }
  197. std::string final_result;
  198. std::string timestamp_str;
  199. if (decoder.DecodedSomething()) {
  200. const wenet::DecodeResult& result = decoder.result()[0];
  201. final_result = result.sentence;
  202. std::stringstream ss;
  203. for (const auto& w : result.word_pieces) {
  204. ss << " " << w.word << " " << w.start << " " << w.end;
  205. }
  206. timestamp_str = ss.str();
  207. }
  208. result_os << key << " " << final_result << std::endl;
  209. timestamp_os << key << " " << timestamp_str << std::endl;
  210. LOG(INFO) << key << " " << final_result;
  211. } else {
  212. LOG(WARNING) << "No wav file for " << key;
  213. }
  214. }
  215. return 0;
  216. }