You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

361 lines
12 KiB

  1. // Copyright (c) 2017 Personal (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef FRONTEND_FBANK_H_
  15. #define FRONTEND_FBANK_H_
  16. #include <cstring>
  17. #include <limits>
  18. #include <random>
  19. #include <utility>
  20. #include <vector>
  21. #include "frontend/fft.h"
  22. #ifndef FST_LOG_H_
  23. #include "fst/log.h"
  24. #endif
  25. namespace wenet {
  26. using namespace fst;
  27. // This code is based on kaldi Fbank implementation, please see
  28. // https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
  29. static const int kS16AbsMax = 1 << 15;
  30. enum class WindowType {
  31. kPovey = 0,
  32. kHanning,
  33. };
  34. enum class MelType {
  35. kHTK = 0,
  36. kSlaney,
  37. };
  38. enum class NormalizationType {
  39. kKaldi = 0,
  40. kWhisper,
  41. };
  42. enum class LogBase {
  43. kBaseE = 0,
  44. kBase10,
  45. };
  46. class Fbank {
  47. public:
  48. Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift,
  49. float low_freq = 20, bool pre_emphasis = true,
  50. bool scale_input_to_unit = false,
  51. float log_floor = std::numeric_limits<float>::epsilon(),
  52. LogBase log_base = LogBase::kBaseE,
  53. WindowType window_type = WindowType::kPovey,
  54. MelType mel_type = MelType::kHTK,
  55. NormalizationType norm_type = NormalizationType::kKaldi)
  56. : num_bins_(num_bins),
  57. sample_rate_(sample_rate),
  58. frame_length_(frame_length),
  59. frame_shift_(frame_shift),
  60. use_log_(true),
  61. remove_dc_offset_(true),
  62. generator_(0),
  63. distribution_(0, 1.0),
  64. dither_(0.0),
  65. low_freq_(low_freq),
  66. high_freq_(sample_rate / 2),
  67. pre_emphasis_(pre_emphasis),
  68. scale_input_to_unit_(scale_input_to_unit),
  69. log_floor_(log_floor),
  70. log_base_(log_base),
  71. norm_type_(norm_type) {
  72. fft_points_ = UpperPowerOfTwo(frame_length_);
  73. // generate bit reversal table and trigonometric function table
  74. const int fft_points_4 = fft_points_ / 4;
  75. bitrev_.resize(fft_points_);
  76. sintbl_.resize(fft_points_ + fft_points_4);
  77. make_sintbl(fft_points_, sintbl_.data());
  78. make_bitrev(fft_points_, bitrev_.data());
  79. InitMelFilters(mel_type);
  80. InitWindow(window_type);
  81. }
  82. void InitMelFilters(MelType mel_type) {
  83. int num_fft_bins = fft_points_ / 2;
  84. float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
  85. float mel_low_freq = MelScale(low_freq_, mel_type);
  86. float mel_high_freq = MelScale(high_freq_, mel_type);
  87. float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins_ + 1);
  88. bins_.resize(num_bins_);
  89. center_freqs_.resize(num_bins_);
  90. for (int bin = 0; bin < num_bins_; ++bin) {
  91. float left_mel = mel_low_freq + bin * mel_freq_delta,
  92. center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
  93. right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
  94. center_freqs_[bin] = InverseMelScale(center_mel, mel_type);
  95. std::vector<float> this_bin(num_fft_bins);
  96. int first_index = -1, last_index = -1;
  97. for (int i = 0; i < num_fft_bins; ++i) {
  98. float freq = (fft_bin_width * i); // Center frequency of this fft
  99. // bin.
  100. float mel = MelScale(freq, mel_type);
  101. if (mel > left_mel && mel < right_mel) {
  102. float weight;
  103. if (mel_type == MelType::kHTK) {
  104. if (mel <= center_mel)
  105. weight = (mel - left_mel) / (center_mel - left_mel);
  106. else if (mel > center_mel)
  107. weight = (right_mel - mel) / (right_mel - center_mel);
  108. } else if (mel_type == MelType::kSlaney) {
  109. if (mel <= center_mel) {
  110. weight = (InverseMelScale(mel, mel_type) -
  111. InverseMelScale(left_mel, mel_type)) /
  112. (InverseMelScale(center_mel, mel_type) -
  113. InverseMelScale(left_mel, mel_type));
  114. weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
  115. InverseMelScale(left_mel, mel_type));
  116. } else if (mel > center_mel) {
  117. weight = (InverseMelScale(right_mel, mel_type) -
  118. InverseMelScale(mel, mel_type)) /
  119. (InverseMelScale(right_mel, mel_type) -
  120. InverseMelScale(center_mel, mel_type));
  121. weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
  122. InverseMelScale(left_mel, mel_type));
  123. }
  124. }
  125. this_bin[i] = weight;
  126. if (first_index == -1) first_index = i;
  127. last_index = i;
  128. }
  129. }
  130. CHECK(first_index != -1 && last_index >= first_index);
  131. bins_[bin].first = first_index;
  132. int size = last_index + 1 - first_index;
  133. bins_[bin].second.resize(size);
  134. for (int i = 0; i < size; ++i) {
  135. bins_[bin].second[i] = this_bin[first_index + i];
  136. }
  137. }
  138. }
  139. void InitWindow(WindowType window_type) {
  140. window_.resize(frame_length_);
  141. if (window_type == WindowType::kPovey) {
  142. // povey window
  143. double a = M_2PI / (frame_length_ - 1);
  144. for (int i = 0; i < frame_length_; ++i)
  145. window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
  146. } else if (window_type == WindowType::kHanning) {
  147. // periodic hanning window
  148. double a = M_2PI / (frame_length_);
  149. for (int i = 0; i < frame_length_; ++i)
  150. window_[i] = 0.5 * (1.0 - cos(i * a));
  151. }
  152. }
  153. void set_use_log(bool use_log) { use_log_ = use_log; }
  154. void set_remove_dc_offset(bool remove_dc_offset) {
  155. remove_dc_offset_ = remove_dc_offset;
  156. }
  157. void set_dither(float dither) { dither_ = dither; }
  158. int num_bins() const { return num_bins_; }
  159. static inline float InverseMelScale(float mel_freq,
  160. MelType mel_type = MelType::kHTK) {
  161. if (mel_type == MelType::kHTK) {
  162. return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
  163. } else if (mel_type == MelType::kSlaney) {
  164. float f_min = 0.0;
  165. float f_sp = 200.0f / 3.0f;
  166. float min_log_hz = 1000.0;
  167. float freq = f_min + f_sp * mel_freq;
  168. float min_log_mel = (min_log_hz - f_min) / f_sp;
  169. float logstep = logf(6.4) / 27.0f;
  170. if (mel_freq >= min_log_mel) {
  171. return min_log_hz * expf(logstep * (mel_freq - min_log_mel));
  172. } else {
  173. return freq;
  174. }
  175. } else {
  176. throw std::invalid_argument("Unsupported mel type!");
  177. }
  178. }
  179. static inline float MelScale(float freq, MelType mel_type = MelType::kHTK) {
  180. if (mel_type == MelType::kHTK) {
  181. return 1127.0f * logf(1.0f + freq / 700.0f);
  182. } else if (mel_type == MelType::kSlaney) {
  183. float f_min = 0.0;
  184. float f_sp = 200.0f / 3.0f;
  185. float min_log_hz = 1000.0;
  186. float mel = (freq - f_min) / f_sp;
  187. float min_log_mel = (min_log_hz - f_min) / f_sp;
  188. float logstep = logf(6.4) / 27.0f;
  189. if (freq >= min_log_hz) {
  190. return min_log_mel + logf(freq / min_log_hz) / logstep;
  191. } else {
  192. return mel;
  193. }
  194. } else {
  195. throw std::invalid_argument("Unsupported mel type!");
  196. }
  197. }
  198. static int UpperPowerOfTwo(int n) {
  199. return static_cast<int>(pow(2, ceil(log(n) / log(2))));
  200. }
  201. // pre emphasis
  202. void PreEmphasis(float coeff, std::vector<float>* data) const {
  203. if (coeff == 0.0) return;
  204. for (int i = data->size() - 1; i > 0; i--)
  205. (*data)[i] -= coeff * (*data)[i - 1];
  206. (*data)[0] -= coeff * (*data)[0];
  207. }
  208. // Apply window on data in place
  209. void ApplyWindow(std::vector<float>* data) const {
  210. CHECK_GE(data->size(), window_.size());
  211. for (size_t i = 0; i < window_.size(); ++i) {
  212. (*data)[i] *= window_[i];
  213. }
  214. }
  215. void WhisperNorm(std::vector<std::vector<float>>* feat,
  216. float max_mel_engery) {
  217. int num_frames = feat->size();
  218. for (int i = 0; i < num_frames; ++i) {
  219. for (int j = 0; j < num_bins_; ++j) {
  220. float energy = (*feat)[i][j];
  221. if (energy < max_mel_engery - 8) energy = max_mel_engery - 8;
  222. energy = (energy + 4.0) / 4.0;
  223. (*feat)[i][j] = energy;
  224. }
  225. }
  226. }
  227. // Compute fbank feat, return num frames
  228. int Compute(const std::vector<float>& wave,
  229. std::vector<std::vector<float>>* feat) {
  230. int num_samples = wave.size();
  231. if (num_samples < frame_length_) return 0;
  232. int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
  233. feat->resize(num_frames);
  234. std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
  235. std::vector<float> power(fft_points_ / 2);
  236. float max_mel_engery = std::numeric_limits<float>::min();
  237. for (int i = 0; i < num_frames; ++i) {
  238. std::vector<float> data(wave.data() + i * frame_shift_,
  239. wave.data() + i * frame_shift_ + frame_length_);
  240. if (scale_input_to_unit_) {
  241. for (int j = 0; j < frame_length_; ++j) {
  242. data[j] = data[j] / kS16AbsMax;
  243. }
  244. }
  245. // optional add noise
  246. if (dither_ != 0.0) {
  247. for (size_t j = 0; j < data.size(); ++j)
  248. data[j] += dither_ * distribution_(generator_);
  249. }
  250. // optinal remove dc offset
  251. if (remove_dc_offset_) {
  252. float mean = 0.0;
  253. for (size_t j = 0; j < data.size(); ++j) mean += data[j];
  254. mean /= data.size();
  255. for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
  256. }
  257. if (pre_emphasis_) {
  258. PreEmphasis(0.97, &data);
  259. }
  260. ApplyWindow(&data);
  261. // copy data to fft_real
  262. memset(fft_img.data(), 0, sizeof(float) * fft_points_);
  263. memset(fft_real.data() + frame_length_, 0,
  264. sizeof(float) * (fft_points_ - frame_length_));
  265. memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_);
  266. fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(),
  267. fft_points_);
  268. // power
  269. for (int j = 0; j < fft_points_ / 2; ++j) {
  270. power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
  271. }
  272. (*feat)[i].resize(num_bins_);
  273. // cepstral coefficients, triangle filter array
  274. for (int j = 0; j < num_bins_; ++j) {
  275. float mel_energy = 0.0;
  276. int s = bins_[j].first;
  277. for (size_t k = 0; k < bins_[j].second.size(); ++k) {
  278. mel_energy += bins_[j].second[k] * power[s + k];
  279. }
  280. // optional use log
  281. if (use_log_) {
  282. if (mel_energy < log_floor_) mel_energy = log_floor_;
  283. if (log_base_ == LogBase::kBaseE)
  284. mel_energy = logf(mel_energy);
  285. else if (log_base_ == LogBase::kBase10)
  286. mel_energy = log10(mel_energy);
  287. }
  288. if (max_mel_engery < mel_energy) max_mel_engery = mel_energy;
  289. (*feat)[i][j] = mel_energy;
  290. }
  291. }
  292. if (norm_type_ == NormalizationType::kWhisper)
  293. WhisperNorm(feat, max_mel_engery);
  294. return num_frames;
  295. }
  296. private:
  297. int num_bins_;
  298. int sample_rate_;
  299. int frame_length_, frame_shift_;
  300. int fft_points_;
  301. bool use_log_;
  302. bool remove_dc_offset_;
  303. bool pre_emphasis_;
  304. bool scale_input_to_unit_;
  305. float low_freq_;
  306. float log_floor_;
  307. float high_freq_;
  308. LogBase log_base_;
  309. NormalizationType norm_type_;
  310. std::vector<float> center_freqs_;
  311. std::vector<std::pair<int, std::vector<float>>> bins_;
  312. std::vector<float> window_;
  313. std::default_random_engine generator_;
  314. std::normal_distribution<float> distribution_;
  315. float dither_;
  316. // bit reversal table
  317. std::vector<int> bitrev_;
  318. // trigonometric function table
  319. std::vector<float> sintbl_;
  320. };
  321. } // namespace wenet
  322. #endif // FRONTEND_FBANK_H_