You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
6.2 KiB

  1. // Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <iomanip>
  15. #include <thread>
  16. #include <utility>
  17. #include "decoder/params.h"
  18. #include "frontend/wav.h"
  19. #include "utils/flags.h"
  20. #include "utils/wn_string.h"
  21. #include "utils/thread_pool.h"
  22. #include "utils/timer.h"
  23. #include "utils/wn_utils.h"
  24. #include <gflags/gflags.h>
  25. DEFINE_bool(simulate_streaming, false, "simulate streaming input");
  26. DEFINE_bool(output_nbest, false, "output n-best of decode result");
  27. DEFINE_string(wav_path, "", "single wave path");
  28. DEFINE_string(wav_scp, "", "input wav scp");
  29. DEFINE_string(result, "", "result output file");
  30. DEFINE_bool(continuous_decoding, false, "continuous decoding mode");
  31. DEFINE_int32(thread_num, 1, "num of decode thread");
  32. DEFINE_int32(warmup, 0, "num of warmup decode, 0 means no warmup");
  33. std::shared_ptr<wenet::DecodeOptions> g_decode_config;
  34. std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config;
  35. std::shared_ptr<wenet::DecodeResource> g_decode_resource;
  36. std::ofstream g_result;
  37. std::mutex g_mutex;
  38. int g_total_waves_dur = 0;
  39. int g_total_decode_time = 0;
  40. void Decode(std::pair<std::string, std::string> wav, bool warmup = false) {
  41. wenet::WavReader wav_reader(wav.second);
  42. int num_samples = wav_reader.num_samples();
  43. CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
  44. auto feature_pipeline =
  45. std::make_shared<wenet::FeaturePipeline>(*g_feature_config);
  46. feature_pipeline->AcceptWaveform(wav_reader.data(), num_samples);
  47. feature_pipeline->set_input_finished();
  48. LOG(INFO) << "num frames " << feature_pipeline->num_frames();
  49. wenet::AsrDecoder decoder(feature_pipeline, g_decode_resource,
  50. *g_decode_config);
  51. int wave_dur = static_cast<int>(static_cast<float>(num_samples) /
  52. wav_reader.sample_rate() * 1000);
  53. int decode_time = 0;
  54. std::string final_result;
  55. while (true) {
  56. wenet::Timer timer;
  57. wenet::DecodeState state = decoder.Decode();
  58. if (state == wenet::DecodeState::kEndFeats) {
  59. decoder.Rescoring();
  60. }
  61. int chunk_decode_time = timer.Elapsed();
  62. decode_time += chunk_decode_time;
  63. if (decoder.DecodedSomething()) {
  64. LOG(INFO) << "Partial result: " << decoder.result()[0].sentence;
  65. }
  66. if (FLAGS_continuous_decoding && state == wenet::DecodeState::kEndpoint) {
  67. if (decoder.DecodedSomething()) {
  68. decoder.Rescoring();
  69. LOG(INFO) << "Final result (continuous decoding): "
  70. << decoder.result()[0].sentence;
  71. final_result.append(decoder.result()[0].sentence);
  72. }
  73. decoder.ResetContinuousDecoding();
  74. }
  75. if (state == wenet::DecodeState::kEndFeats) {
  76. break;
  77. } else if (FLAGS_chunk_size > 0 && FLAGS_simulate_streaming) {
  78. float frame_shift_in_ms =
  79. static_cast<float>(g_feature_config->frame_shift) /
  80. wav_reader.sample_rate() * 1000;
  81. auto wait_time =
  82. decoder.num_frames_in_current_chunk() * frame_shift_in_ms -
  83. chunk_decode_time;
  84. if (wait_time > 0) {
  85. LOG(INFO) << "Simulate streaming, waiting for " << wait_time << "ms";
  86. std::this_thread::sleep_for(
  87. std::chrono::milliseconds(static_cast<int>(wait_time)));
  88. }
  89. }
  90. }
  91. if (decoder.DecodedSomething()) {
  92. final_result.append(decoder.result()[0].sentence);
  93. }
  94. LOG(INFO) << wav.first << " Final result: " << final_result << std::endl;
  95. LOG(INFO) << "Decoded " << wave_dur << "ms audio taken " << decode_time
  96. << "ms.";
  97. if (!warmup) {
  98. g_mutex.lock();
  99. std::ostream& buffer = FLAGS_result.empty() ? std::cout : g_result;
  100. if (!FLAGS_output_nbest) {
  101. buffer << wav.first << " " << final_result << std::endl;
  102. } else {
  103. buffer << "wav " << wav.first << std::endl;
  104. auto& results = decoder.result();
  105. for (auto& r : results) {
  106. if (r.sentence.empty()) continue;
  107. buffer << "candidate " << r.score << " " << r.sentence << std::endl;
  108. }
  109. }
  110. g_total_waves_dur += wave_dur;
  111. g_total_decode_time += decode_time;
  112. g_mutex.unlock();
  113. }
  114. }
  115. int main(int argc, char* argv[]) {
  116. gflags::ParseCommandLineFlags(&argc, &argv, false);
  117. google::InitGoogleLogging(argv[0]);
  118. g_decode_config = wenet::InitDecodeOptionsFromFlags();
  119. g_feature_config = wenet::InitFeaturePipelineConfigFromFlags();
  120. g_decode_resource = wenet::InitDecodeResourceFromFlags();
  121. if (FLAGS_wav_path.empty() && FLAGS_wav_scp.empty()) {
  122. LOG(FATAL) << "Please provide the wave path or the wav scp.";
  123. }
  124. std::vector<std::pair<std::string, std::string>> waves;
  125. if (!FLAGS_wav_path.empty()) {
  126. waves.emplace_back(make_pair("test", FLAGS_wav_path));
  127. } else {
  128. std::ifstream wav_scp(FLAGS_wav_scp);
  129. std::string line;
  130. while (getline(wav_scp, line)) {
  131. std::vector<std::string> strs;
  132. wenet::SplitString(line, &strs);
  133. CHECK_GE(strs.size(), 2);
  134. waves.emplace_back(make_pair(strs[0], strs[1]));
  135. }
  136. if (waves.empty()) {
  137. LOG(FATAL) << "Please provide non-empty wav scp.";
  138. }
  139. }
  140. if (!FLAGS_result.empty()) {
  141. g_result.open(FLAGS_result, std::ios::out);
  142. }
  143. // Warmup
  144. if (FLAGS_warmup > 0) {
  145. LOG(INFO) << "Warming up...";
  146. {
  147. ThreadPool pool(FLAGS_thread_num);
  148. auto wav = waves[0];
  149. for (int i = 0; i < FLAGS_warmup; i++) {
  150. pool.enqueue(Decode, wav, true);
  151. }
  152. }
  153. LOG(INFO) << "Warmup done.";
  154. }
  155. {
  156. ThreadPool pool(FLAGS_thread_num);
  157. for (auto& wav : waves) {
  158. pool.enqueue(Decode, wav, false);
  159. }
  160. }
  161. LOG(INFO) << "Total: decoded " << g_total_waves_dur << "ms audio taken "
  162. << g_total_decode_time << "ms.";
  163. LOG(INFO) << "RTF: " << std::setprecision(4)
  164. << static_cast<float>(g_total_decode_time) / g_total_waves_dur;
  165. return 0;
  166. }