You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
8.7 KiB

  1. // Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <memory>
  15. #include <sstream>
  16. #include <unordered_map>
  17. #include <vector>
  18. #include "decoder/params.h"
  19. #include "frontend/wav.h"
  20. #include "utils/flags.h"
  21. #include "utils/wn_string.h"
  22. #ifdef GFLAGS_GFLAGS_H_
  23. std::cout << "gflags已经进行了include" << std::endl;
  24. #else
  25. #include "gflags/gflags.h"
  26. std::cout << "gflags没有进行include" << std::endl;
  27. #endif
  28. DEFINE_string(text, "", "kaldi style text input file");
  29. DEFINE_string(wav_scp, "", "kaldi style wav scp");
  30. DEFINE_double(is_penalty, 1.0,
  31. "insertion/substitution penalty for align insertion");
  32. DEFINE_double(del_penalty, 1.0, "deletion penalty for align insertion");
  33. DEFINE_string(result, "", "result output file");
  34. DEFINE_string(timestamp, "", "timestamp output file");
  35. namespace wenet {
  36. const char* kDeletion = "<del>";
  37. // Is: Insertion and substitution
  38. const char* kIsStart = "<is>";
  39. const char* kIsEnd = "</is>";
  40. bool MapToLabel(const std::string& text,
  41. std::shared_ptr<fst::SymbolTable> symbol_table,
  42. std::vector<int>* labels) {
  43. labels->clear();
  44. // Split label to char sequence
  45. std::vector<std::string> chars;
  46. SplitUTF8StringToChars(text, &chars);
  47. for (size_t i = 0; i < chars.size(); i++) {
  48. // ▁ is special symbol for white space
  49. std::string label = chars[i] != " " ? chars[i] : "";
  50. int id = symbol_table->Find(label);
  51. if (id != -1) { // fst::kNoSymbol
  52. // LOG(INFO) << label << " " << id;
  53. labels->push_back(id);
  54. }
  55. }
  56. return true;
  57. }
  58. std::shared_ptr<fst::SymbolTable> MakeSymbolTableForFst(
  59. std::shared_ptr<fst::SymbolTable> isymbol_table) {
  60. LOG(INFO) << isymbol_table;
  61. CHECK(isymbol_table != nullptr);
  62. auto osymbol_table = std::make_shared<fst::SymbolTable>();
  63. osymbol_table->AddSymbol("<eps>", 0);
  64. CHECK_EQ(isymbol_table->Find("<blank>"), 0);
  65. osymbol_table->AddSymbol("<blank>", 1);
  66. for (int i = 1; i < isymbol_table->NumSymbols(); i++) {
  67. std::string symbol = isymbol_table->Find(i);
  68. osymbol_table->AddSymbol(symbol, i + 1);
  69. }
  70. osymbol_table->AddSymbol(kDeletion, isymbol_table->NumSymbols() + 1);
  71. osymbol_table->AddSymbol(kIsStart, isymbol_table->NumSymbols() + 2);
  72. osymbol_table->AddSymbol(kIsEnd, isymbol_table->NumSymbols() + 3);
  73. return osymbol_table;
  74. }
  75. void CompileCtcFst(std::shared_ptr<fst::SymbolTable> symbol_table,
  76. fst::StdVectorFst* ofst) {
  77. ofst->DeleteStates();
  78. int start = ofst->AddState();
  79. ofst->SetStart(start);
  80. CHECK_EQ(symbol_table->Find("<eps>"), 0);
  81. CHECK_EQ(symbol_table->Find("<blank>"), 1);
  82. ofst->AddArc(start, fst::StdArc(1, 0, 0.0, start));
  83. // Exclude kDeletion and kInsertion
  84. for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
  85. int s = ofst->AddState();
  86. ofst->AddArc(start, fst::StdArc(i, i, 0.0, s));
  87. ofst->AddArc(s, fst::StdArc(i, 0, 0.0, s));
  88. ofst->AddArc(s, fst::StdArc(0, 0, 0.0, start));
  89. }
  90. ofst->SetFinal(start, fst::StdArc::Weight::One());
  91. fst::ArcSort(ofst, fst::StdOLabelCompare());
  92. }
  93. void CompileAlignFst(std::vector<int> labels,
  94. std::shared_ptr<fst::SymbolTable> symbol_table,
  95. fst::StdVectorFst* ofst) {
  96. ofst->DeleteStates();
  97. int deletion = symbol_table->Find(kDeletion);
  98. int insertion_start = symbol_table->Find(kIsStart);
  99. int insertion_end = symbol_table->Find(kIsEnd);
  100. int start = ofst->AddState();
  101. ofst->SetStart(start);
  102. // Filler State
  103. int filler_start = ofst->AddState();
  104. int filler_end = ofst->AddState();
  105. for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
  106. ofst->AddArc(filler_start, fst::StdArc(i, i, FLAGS_is_penalty, filler_end));
  107. }
  108. ofst->AddArc(filler_end, fst::StdArc(0, 0, 0.0, filler_start));
  109. int prev = start;
  110. // Alignment path and optional filler
  111. for (size_t i = 0; i < labels.size(); i++) {
  112. int cur = ofst->AddState();
  113. // 1. Insertion or Substitution
  114. ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
  115. ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
  116. // 2. Correct
  117. ofst->AddArc(prev, fst::StdArc(labels[i], labels[i], 0.0, cur));
  118. // 3. Deletion
  119. ofst->AddArc(prev, fst::StdArc(0, deletion, FLAGS_del_penalty, cur));
  120. prev = cur;
  121. }
  122. // Optional add endding filler
  123. ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
  124. ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
  125. ofst->SetFinal(prev, fst::StdArc::Weight::One());
  126. fst::ArcSort(ofst, fst::StdILabelCompare());
  127. }
  128. } // namespace wenet
  129. int main(int argc, char* argv[]) {
  130. gflags::ParseCommandLineFlags(&argc, &argv, false);
  131. google::InitGoogleLogging(argv[0]);
  132. auto decode_config = wenet::InitDecodeOptionsFromFlags();
  133. auto feature_config = wenet::InitFeaturePipelineConfigFromFlags();
  134. auto decode_resource = wenet::InitDecodeResourceFromFlags();
  135. CHECK(decode_resource->unit_table != nullptr);
  136. auto wfst_symbol_table =
  137. wenet::MakeSymbolTableForFst(decode_resource->unit_table);
  138. // wfst_symbol_table->WriteText("fst.txt");
  139. // Reset symbol_table to on-the-fly generated wfst_symbol_table
  140. decode_resource->symbol_table = wfst_symbol_table;
  141. // Compile ctc FST
  142. fst::StdVectorFst ctc_fst;
  143. wenet::CompileCtcFst(wfst_symbol_table, &ctc_fst);
  144. // ctc_fst.Write("ctc.fst");
  145. std::unordered_map<std::string, std::string> wav_table;
  146. std::ifstream wav_is(FLAGS_wav_scp);
  147. std::string line;
  148. while (std::getline(wav_is, line)) {
  149. std::vector<std::string> strs;
  150. wenet::SplitString(line, &strs);
  151. CHECK_EQ(strs.size(), 2);
  152. wav_table[strs[0]] = strs[1];
  153. }
  154. std::ifstream text_is(FLAGS_text);
  155. std::ofstream result_os(FLAGS_result, std::ios::out);
  156. std::ofstream timestamp_out;
  157. if (!FLAGS_timestamp.empty()) {
  158. timestamp_out.open(FLAGS_timestamp, std::ios::out);
  159. }
  160. std::ostream& timestamp_os =
  161. FLAGS_timestamp.empty() ? std::cout : timestamp_out;
  162. while (std::getline(text_is, line)) {
  163. std::vector<std::string> strs;
  164. wenet::SplitString(line, &strs);
  165. if (strs.size() < 2) continue;
  166. std::string key = strs[0];
  167. LOG(INFO) << "Processing " << key;
  168. if (wav_table.find(key) != wav_table.end()) {
  169. strs.erase(strs.begin());
  170. std::string text = wenet::JoinString(" ", strs);
  171. std::vector<int> labels;
  172. wenet::MapToLabel(text, wfst_symbol_table, &labels);
  173. // Prepare FST for alignment decoding
  174. fst::StdVectorFst align_fst;
  175. wenet::CompileAlignFst(labels, wfst_symbol_table, &align_fst);
  176. // align_fst.Write("align.fst");
  177. auto decoding_fst = std::make_shared<fst::StdVectorFst>();
  178. fst::Compose(ctc_fst, align_fst, decoding_fst.get());
  179. // decoding_fst->Write("decoding.fst");
  180. // Preapre feature pipeline
  181. wenet::WavReader wav_reader;
  182. if (!wav_reader.Open(wav_table[key])) {
  183. LOG(WARNING) << "Error in reading " << wav_table[key];
  184. continue;
  185. }
  186. int num_samples = wav_reader.num_samples();
  187. CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
  188. auto feature_pipeline =
  189. std::make_shared<wenet::FeaturePipeline>(*feature_config);
  190. feature_pipeline->AcceptWaveform(wav_reader.data(), num_samples);
  191. feature_pipeline->set_input_finished();
  192. decode_resource->fst = decoding_fst;
  193. LOG(INFO) << "num frames " << feature_pipeline->num_frames();
  194. wenet::AsrDecoder decoder(feature_pipeline, decode_resource,
  195. *decode_config);
  196. while (true) {
  197. wenet::DecodeState state = decoder.Decode();
  198. if (state == wenet::DecodeState::kEndFeats) {
  199. decoder.Rescoring();
  200. break;
  201. }
  202. }
  203. std::string final_result;
  204. std::string timestamp_str;
  205. if (decoder.DecodedSomething()) {
  206. const wenet::DecodeResult& result = decoder.result()[0];
  207. final_result = result.sentence;
  208. std::stringstream ss;
  209. for (const auto& w : result.word_pieces) {
  210. ss << " " << w.word << " " << w.start << " " << w.end;
  211. }
  212. timestamp_str = ss.str();
  213. }
  214. result_os << key << " " << final_result << std::endl;
  215. timestamp_os << key << " " << timestamp_str << std::endl;
  216. LOG(INFO) << key << " " << final_result;
  217. } else {
  218. LOG(WARNING) << "No wav file for " << key;
  219. }
  220. }
  221. return 0;
  222. }