You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.4 KiB

  1. // Copyright (c) 2017 Personal (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef FRONTEND_FEATURE_PIPELINE_H_
  15. #define FRONTEND_FEATURE_PIPELINE_H_
  16. #include <limits>
  17. #include <mutex>
  18. #include <queue>
  19. #include <string>
  20. #include <vector>
  21. #include "fbank.h"
  22. #include "utils/blocking_queue.h"
  23. namespace wenet {
  24. enum class FeatureType {
  25. kKaldi = 0,
  26. kWhisper,
  27. };
  28. struct FeaturePipelineConfig {
  29. int num_bins;
  30. int sample_rate;
  31. int frame_length;
  32. int frame_shift;
  33. float low_freq;
  34. bool pre_emphasis;
  35. bool scale_input_to_unit;
  36. float log_floor;
  37. LogBase log_base;
  38. WindowType window_type;
  39. MelType mel_type;
  40. NormalizationType norm_type;
  41. FeaturePipelineConfig(int num_bins, int sample_rate,
  42. FeatureType feat_type = FeatureType::kKaldi)
  43. : num_bins(num_bins), // 80 dim fbank
  44. sample_rate(sample_rate) { // 16k sample rate
  45. frame_length = sample_rate / 1000 * 25; // frame length 25ms
  46. frame_shift = sample_rate / 1000 * 10; // frame shift 10ms
  47. if (feat_type == FeatureType::kKaldi) {
  48. low_freq = 20.0;
  49. pre_emphasis = true;
  50. log_floor = std::numeric_limits<float>::epsilon();
  51. log_base = LogBase::kBaseE;
  52. window_type = WindowType::kPovey;
  53. mel_type = MelType::kHTK;
  54. norm_type = NormalizationType::kKaldi;
  55. scale_input_to_unit = false;
  56. } else if (feat_type == FeatureType::kWhisper) {
  57. low_freq = 0.0;
  58. pre_emphasis = false;
  59. log_floor = 1e-10;
  60. log_base = LogBase::kBase10;
  61. window_type = WindowType::kHanning;
  62. mel_type = MelType::kSlaney;
  63. scale_input_to_unit = true;
  64. norm_type = NormalizationType::kWhisper;
  65. }
  66. }
  67. void Info() const {
  68. LOG(INFO) << "feature pipeline config"
  69. << " num_bins " << num_bins << " frame_length " << frame_length
  70. << " frame_shift " << frame_shift << " low_freq " << low_freq
  71. << " preemphasis " << pre_emphasis << " log_floor " << log_floor
  72. << " log_base " << int(log_base) << " window_type "
  73. << int(window_type) << " mel_type " << int(mel_type)
  74. << " norm_type " << int(norm_type);
  75. }
  76. };
  77. // Typically, FeaturePipeline is used in two threads: one thread A calls
  78. // AcceptWaveform() to add raw wav data and set_input_finished() to notice
  79. // the end of input wav, another thread B (decoder thread) calls Read() to
  80. // consume features.So a BlockingQueue is used to make this class thread safe.
  81. // The Read() is designed as a blocking method when there is no feature
  82. // in feature_queue_ and the input is not finished.
  83. // See bin/decoder_main.cc, websocket/websocket_server.cc and
  84. // decoder/torch_asr_decoder.cc for usage
  85. class FeaturePipeline {
  86. public:
  87. explicit FeaturePipeline(const FeaturePipelineConfig& config);
  88. // The feature extraction is done in AcceptWaveform().
  89. void AcceptWaveform(const float* pcm, const int size);
  90. void AcceptWaveform(const int16_t* pcm, const int size);
  91. // Current extracted frames number.
  92. int num_frames() const { return num_frames_; }
  93. int feature_dim() const { return feature_dim_; }
  94. const FeaturePipelineConfig& config() const { return config_; }
  95. // The caller should call this method when speech input is end.
  96. // Never call AcceptWaveform() after calling set_input_finished() !
  97. void set_input_finished();
  98. bool input_finished() const { return input_finished_; }
  99. // Return False if input is finished and no feature could be read.
  100. // Return True if a feature is read.
  101. // This function is a blocking method. It will block the thread when
  102. // there is no feature in feature_queue_ and the input is not finished.
  103. bool ReadOne(std::vector<float>* feat);
  104. // Read #num_frames frame features.
  105. // Return False if less than #num_frames features are read and the
  106. // input is finished.
  107. // Return True if #num_frames features are read.
  108. // This function is a blocking method when there is no feature
  109. // in feature_queue_ and the input is not finished.
  110. bool Read(int num_frames, std::vector<std::vector<float>>* feats);
  111. void Reset();
  112. bool IsLastFrame(int frame) const {
  113. return input_finished_ && (frame == num_frames_ - 1);
  114. }
  115. int NumQueuedFrames() const { return feature_queue_.Size(); }
  116. private:
  117. const FeaturePipelineConfig& config_;
  118. int feature_dim_;
  119. Fbank fbank_;
  120. BlockingQueue<std::vector<float>> feature_queue_;
  121. int num_frames_;
  122. bool input_finished_;
  123. // The feature extraction is done in AcceptWaveform().
  124. // This waveform sample points are consumed by frame size.
  125. // The residual waveform sample points after framing are
  126. // kept to be used in next AcceptWaveform() calling.
  127. std::vector<float> remained_wav_;
  128. // Used to block the Read when there is no feature in feature_queue_
  129. // and the input is not finished.
  130. mutable std::mutex mutex_;
  131. std::condition_variable finish_condition_;
  132. };
  133. } // namespace wenet
  134. #endif // FRONTEND_FEATURE_PIPELINE_H_