You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
5.4 KiB

  1. // Copyright (c) 2017 Personal (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef FRONTEND_FEATURE_PIPELINE_H_
  15. #define FRONTEND_FEATURE_PIPELINE_H_
  16. #include <limits>
  17. #include <mutex>
  18. #include <queue>
  19. #include <string>
  20. #include <vector>
  21. #include "fbank.h"
  22. #include "utils/blocking_queue.h"
  23. #include "utils/log.h"
  24. namespace wenet {
  25. enum class FeatureType {
  26. kKaldi = 0,
  27. kWhisper,
  28. };
  29. struct FeaturePipelineConfig {
  30. int num_bins;
  31. int sample_rate;
  32. int frame_length;
  33. int frame_shift;
  34. float low_freq;
  35. bool pre_emphasis;
  36. bool scale_input_to_unit;
  37. float log_floor;
  38. LogBase log_base;
  39. WindowType window_type;
  40. MelType mel_type;
  41. NormalizationType norm_type;
  42. FeaturePipelineConfig(int num_bins, int sample_rate,
  43. FeatureType feat_type = FeatureType::kKaldi)
  44. : num_bins(num_bins), // 80 dim fbank
  45. sample_rate(sample_rate) { // 16k sample rate
  46. frame_length = sample_rate / 1000 * 25; // frame length 25ms
  47. frame_shift = sample_rate / 1000 * 10; // frame shift 10ms
  48. if (feat_type == FeatureType::kKaldi) {
  49. low_freq = 20.0;
  50. pre_emphasis = true;
  51. log_floor = std::numeric_limits<float>::epsilon();
  52. log_base = LogBase::kBaseE;
  53. window_type = WindowType::kPovey;
  54. mel_type = MelType::kHTK;
  55. norm_type = NormalizationType::kKaldi;
  56. scale_input_to_unit = false;
  57. } else if (feat_type == FeatureType::kWhisper) {
  58. low_freq = 0.0;
  59. pre_emphasis = false;
  60. log_floor = 1e-10;
  61. log_base = LogBase::kBase10;
  62. window_type = WindowType::kHanning;
  63. mel_type = MelType::kSlaney;
  64. scale_input_to_unit = true;
  65. norm_type = NormalizationType::kWhisper;
  66. }
  67. }
  68. void Info() const {
  69. LOG(INFO) << "feature pipeline config"
  70. << " num_bins " << num_bins << " frame_length " << frame_length
  71. << " frame_shift " << frame_shift << " low_freq " << low_freq
  72. << " preemphasis " << pre_emphasis << " log_floor " << log_floor
  73. << " log_base " << int(log_base) << " window_type "
  74. << int(window_type) << " mel_type " << int(mel_type)
  75. << " norm_type " << int(norm_type);
  76. }
  77. };
  78. // Typically, FeaturePipeline is used in two threads: one thread A calls
  79. // AcceptWaveform() to add raw wav data and set_input_finished() to notice
  80. // the end of input wav, another thread B (decoder thread) calls Read() to
  81. // consume features.So a BlockingQueue is used to make this class thread safe.
  82. // The Read() is designed as a blocking method when there is no feature
  83. // in feature_queue_ and the input is not finished.
  84. // See bin/decoder_main.cc, websocket/websocket_server.cc and
  85. // decoder/torch_asr_decoder.cc for usage
  86. class FeaturePipeline {
  87. public:
  88. explicit FeaturePipeline(const FeaturePipelineConfig& config);
  89. // The feature extraction is done in AcceptWaveform().
  90. void AcceptWaveform(const float* pcm, const int size);
  91. void AcceptWaveform(const int16_t* pcm, const int size);
  92. // Current extracted frames number.
  93. int num_frames() const { return num_frames_; }
  94. int feature_dim() const { return feature_dim_; }
  95. const FeaturePipelineConfig& config() const { return config_; }
  96. // The caller should call this method when speech input is end.
  97. // Never call AcceptWaveform() after calling set_input_finished() !
  98. void set_input_finished();
  99. bool input_finished() const { return input_finished_; }
  100. // Return False if input is finished and no feature could be read.
  101. // Return True if a feature is read.
  102. // This function is a blocking method. It will block the thread when
  103. // there is no feature in feature_queue_ and the input is not finished.
  104. bool ReadOne(std::vector<float>* feat);
  105. // Read #num_frames frame features.
  106. // Return False if less than #num_frames features are read and the
  107. // input is finished.
  108. // Return True if #num_frames features are read.
  109. // This function is a blocking method when there is no feature
  110. // in feature_queue_ and the input is not finished.
  111. bool Read(int num_frames, std::vector<std::vector<float>>* feats);
  112. void Reset();
  113. bool IsLastFrame(int frame) const {
  114. return input_finished_ && (frame == num_frames_ - 1);
  115. }
  116. int NumQueuedFrames() const { return feature_queue_.Size(); }
  117. private:
  118. const FeaturePipelineConfig& config_;
  119. int feature_dim_;
  120. Fbank fbank_;
  121. BlockingQueue<std::vector<float>> feature_queue_;
  122. int num_frames_;
  123. bool input_finished_;
  124. // The feature extraction is done in AcceptWaveform().
  125. // This waveform sample points are consumed by frame size.
  126. // The residual waveform sample points after framing are
  127. // kept to be used in next AcceptWaveform() calling.
  128. std::vector<float> remained_wav_;
  129. // Used to block the Read when there is no feature in feature_queue_
  130. // and the input is not finished.
  131. mutable std::mutex mutex_;
  132. std::condition_variable finish_condition_;
  133. };
  134. } // namespace wenet
  135. #endif // FRONTEND_FEATURE_PIPELINE_H_