// Copyright (c) 2017 Personal (Binbin Zhang)
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef FRONTEND_FBANK_H_
|
|
#define FRONTEND_FBANK_H_
|
|
|
|
#include <cstring>
|
|
#include <limits>
|
|
#include <random>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "frontend/fft.h"
|
|
#ifndef FST_LOG_H_
|
|
#include "fst/log.h"
|
|
#endif
|
|
|
|
namespace wenet {
|
|
using namespace fst;
|
|
// This code is based on kaldi Fbank implementation, please see
|
|
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
|
|
|
|
static const int kS16AbsMax = 1 << 15;
|
|
|
|
enum class WindowType {
|
|
kPovey = 0,
|
|
kHanning,
|
|
};
|
|
|
|
enum class MelType {
|
|
kHTK = 0,
|
|
kSlaney,
|
|
};
|
|
|
|
enum class NormalizationType {
|
|
kKaldi = 0,
|
|
kWhisper,
|
|
};
|
|
|
|
enum class LogBase {
|
|
kBaseE = 0,
|
|
kBase10,
|
|
};
|
|
|
|
class Fbank {
|
|
public:
|
|
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift,
|
|
float low_freq = 20, bool pre_emphasis = true,
|
|
bool scale_input_to_unit = false,
|
|
float log_floor = std::numeric_limits<float>::epsilon(),
|
|
LogBase log_base = LogBase::kBaseE,
|
|
WindowType window_type = WindowType::kPovey,
|
|
MelType mel_type = MelType::kHTK,
|
|
NormalizationType norm_type = NormalizationType::kKaldi)
|
|
: num_bins_(num_bins),
|
|
sample_rate_(sample_rate),
|
|
frame_length_(frame_length),
|
|
frame_shift_(frame_shift),
|
|
use_log_(true),
|
|
remove_dc_offset_(true),
|
|
generator_(0),
|
|
distribution_(0, 1.0),
|
|
dither_(0.0),
|
|
low_freq_(low_freq),
|
|
high_freq_(sample_rate / 2),
|
|
pre_emphasis_(pre_emphasis),
|
|
scale_input_to_unit_(scale_input_to_unit),
|
|
log_floor_(log_floor),
|
|
log_base_(log_base),
|
|
norm_type_(norm_type) {
|
|
fft_points_ = UpperPowerOfTwo(frame_length_);
|
|
// generate bit reversal table and trigonometric function table
|
|
const int fft_points_4 = fft_points_ / 4;
|
|
bitrev_.resize(fft_points_);
|
|
sintbl_.resize(fft_points_ + fft_points_4);
|
|
make_sintbl(fft_points_, sintbl_.data());
|
|
make_bitrev(fft_points_, bitrev_.data());
|
|
InitMelFilters(mel_type);
|
|
InitWindow(window_type);
|
|
}
|
|
|
|
void InitMelFilters(MelType mel_type) {
|
|
int num_fft_bins = fft_points_ / 2;
|
|
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
|
|
float mel_low_freq = MelScale(low_freq_, mel_type);
|
|
float mel_high_freq = MelScale(high_freq_, mel_type);
|
|
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins_ + 1);
|
|
bins_.resize(num_bins_);
|
|
center_freqs_.resize(num_bins_);
|
|
|
|
for (int bin = 0; bin < num_bins_; ++bin) {
|
|
float left_mel = mel_low_freq + bin * mel_freq_delta,
|
|
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
|
|
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
|
|
center_freqs_[bin] = InverseMelScale(center_mel, mel_type);
|
|
std::vector<float> this_bin(num_fft_bins);
|
|
int first_index = -1, last_index = -1;
|
|
for (int i = 0; i < num_fft_bins; ++i) {
|
|
float freq = (fft_bin_width * i); // Center frequency of this fft
|
|
// bin.
|
|
float mel = MelScale(freq, mel_type);
|
|
if (mel > left_mel && mel < right_mel) {
|
|
float weight;
|
|
if (mel_type == MelType::kHTK) {
|
|
if (mel <= center_mel)
|
|
weight = (mel - left_mel) / (center_mel - left_mel);
|
|
else if (mel > center_mel)
|
|
weight = (right_mel - mel) / (right_mel - center_mel);
|
|
} else if (mel_type == MelType::kSlaney) {
|
|
if (mel <= center_mel) {
|
|
weight = (InverseMelScale(mel, mel_type) -
|
|
InverseMelScale(left_mel, mel_type)) /
|
|
(InverseMelScale(center_mel, mel_type) -
|
|
InverseMelScale(left_mel, mel_type));
|
|
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
|
|
InverseMelScale(left_mel, mel_type));
|
|
} else if (mel > center_mel) {
|
|
weight = (InverseMelScale(right_mel, mel_type) -
|
|
InverseMelScale(mel, mel_type)) /
|
|
(InverseMelScale(right_mel, mel_type) -
|
|
InverseMelScale(center_mel, mel_type));
|
|
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
|
|
InverseMelScale(left_mel, mel_type));
|
|
}
|
|
}
|
|
this_bin[i] = weight;
|
|
if (first_index == -1) first_index = i;
|
|
last_index = i;
|
|
}
|
|
}
|
|
CHECK(first_index != -1 && last_index >= first_index);
|
|
bins_[bin].first = first_index;
|
|
int size = last_index + 1 - first_index;
|
|
bins_[bin].second.resize(size);
|
|
for (int i = 0; i < size; ++i) {
|
|
bins_[bin].second[i] = this_bin[first_index + i];
|
|
}
|
|
}
|
|
}
|
|
|
|
void InitWindow(WindowType window_type) {
|
|
window_.resize(frame_length_);
|
|
if (window_type == WindowType::kPovey) {
|
|
// povey window
|
|
double a = M_2PI / (frame_length_ - 1);
|
|
for (int i = 0; i < frame_length_; ++i)
|
|
window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
|
|
} else if (window_type == WindowType::kHanning) {
|
|
// periodic hanning window
|
|
double a = M_2PI / (frame_length_);
|
|
for (int i = 0; i < frame_length_; ++i)
|
|
window_[i] = 0.5 * (1.0 - cos(i * a));
|
|
}
|
|
}
|
|
|
|
void set_use_log(bool use_log) { use_log_ = use_log; }
|
|
|
|
void set_remove_dc_offset(bool remove_dc_offset) {
|
|
remove_dc_offset_ = remove_dc_offset;
|
|
}
|
|
|
|
void set_dither(float dither) { dither_ = dither; }
|
|
|
|
int num_bins() const { return num_bins_; }
|
|
|
|
static inline float InverseMelScale(float mel_freq,
|
|
MelType mel_type = MelType::kHTK) {
|
|
if (mel_type == MelType::kHTK) {
|
|
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
|
|
} else if (mel_type == MelType::kSlaney) {
|
|
float f_min = 0.0;
|
|
float f_sp = 200.0f / 3.0f;
|
|
float min_log_hz = 1000.0;
|
|
float freq = f_min + f_sp * mel_freq;
|
|
float min_log_mel = (min_log_hz - f_min) / f_sp;
|
|
float logstep = logf(6.4) / 27.0f;
|
|
if (mel_freq >= min_log_mel) {
|
|
return min_log_hz * expf(logstep * (mel_freq - min_log_mel));
|
|
} else {
|
|
return freq;
|
|
}
|
|
} else {
|
|
throw std::invalid_argument("Unsupported mel type!");
|
|
}
|
|
}
|
|
|
|
static inline float MelScale(float freq, MelType mel_type = MelType::kHTK) {
|
|
if (mel_type == MelType::kHTK) {
|
|
return 1127.0f * logf(1.0f + freq / 700.0f);
|
|
} else if (mel_type == MelType::kSlaney) {
|
|
float f_min = 0.0;
|
|
float f_sp = 200.0f / 3.0f;
|
|
float min_log_hz = 1000.0;
|
|
float mel = (freq - f_min) / f_sp;
|
|
float min_log_mel = (min_log_hz - f_min) / f_sp;
|
|
float logstep = logf(6.4) / 27.0f;
|
|
if (freq >= min_log_hz) {
|
|
return min_log_mel + logf(freq / min_log_hz) / logstep;
|
|
} else {
|
|
return mel;
|
|
}
|
|
} else {
|
|
throw std::invalid_argument("Unsupported mel type!");
|
|
}
|
|
}
|
|
|
|
static int UpperPowerOfTwo(int n) {
|
|
return static_cast<int>(pow(2, ceil(log(n) / log(2))));
|
|
}
|
|
|
|
// pre emphasis
|
|
void PreEmphasis(float coeff, std::vector<float>* data) const {
|
|
if (coeff == 0.0) return;
|
|
for (int i = data->size() - 1; i > 0; i--)
|
|
(*data)[i] -= coeff * (*data)[i - 1];
|
|
(*data)[0] -= coeff * (*data)[0];
|
|
}
|
|
|
|
// Apply window on data in place
|
|
void ApplyWindow(std::vector<float>* data) const {
|
|
CHECK_GE(data->size(), window_.size());
|
|
for (size_t i = 0; i < window_.size(); ++i) {
|
|
(*data)[i] *= window_[i];
|
|
}
|
|
}
|
|
|
|
void WhisperNorm(std::vector<std::vector<float>>* feat,
|
|
float max_mel_engery) {
|
|
int num_frames = feat->size();
|
|
for (int i = 0; i < num_frames; ++i) {
|
|
for (int j = 0; j < num_bins_; ++j) {
|
|
float energy = (*feat)[i][j];
|
|
if (energy < max_mel_engery - 8) energy = max_mel_engery - 8;
|
|
energy = (energy + 4.0) / 4.0;
|
|
(*feat)[i][j] = energy;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Compute fbank feat, return num frames
|
|
int Compute(const std::vector<float>& wave,
|
|
std::vector<std::vector<float>>* feat) {
|
|
int num_samples = wave.size();
|
|
|
|
if (num_samples < frame_length_) return 0;
|
|
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
|
|
feat->resize(num_frames);
|
|
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
|
|
std::vector<float> power(fft_points_ / 2);
|
|
|
|
float max_mel_engery = std::numeric_limits<float>::min();
|
|
|
|
for (int i = 0; i < num_frames; ++i) {
|
|
std::vector<float> data(wave.data() + i * frame_shift_,
|
|
wave.data() + i * frame_shift_ + frame_length_);
|
|
|
|
if (scale_input_to_unit_) {
|
|
for (int j = 0; j < frame_length_; ++j) {
|
|
data[j] = data[j] / kS16AbsMax;
|
|
}
|
|
}
|
|
|
|
// optional add noise
|
|
if (dither_ != 0.0) {
|
|
for (size_t j = 0; j < data.size(); ++j)
|
|
data[j] += dither_ * distribution_(generator_);
|
|
}
|
|
// optinal remove dc offset
|
|
if (remove_dc_offset_) {
|
|
float mean = 0.0;
|
|
for (size_t j = 0; j < data.size(); ++j) mean += data[j];
|
|
mean /= data.size();
|
|
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
|
|
}
|
|
|
|
if (pre_emphasis_) {
|
|
PreEmphasis(0.97, &data);
|
|
}
|
|
ApplyWindow(&data);
|
|
// copy data to fft_real
|
|
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
|
|
memset(fft_real.data() + frame_length_, 0,
|
|
sizeof(float) * (fft_points_ - frame_length_));
|
|
memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_);
|
|
fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(),
|
|
fft_points_);
|
|
// power
|
|
for (int j = 0; j < fft_points_ / 2; ++j) {
|
|
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
|
|
}
|
|
|
|
(*feat)[i].resize(num_bins_);
|
|
// cepstral coefficients, triangle filter array
|
|
|
|
for (int j = 0; j < num_bins_; ++j) {
|
|
float mel_energy = 0.0;
|
|
int s = bins_[j].first;
|
|
for (size_t k = 0; k < bins_[j].second.size(); ++k) {
|
|
mel_energy += bins_[j].second[k] * power[s + k];
|
|
}
|
|
// optional use log
|
|
if (use_log_) {
|
|
if (mel_energy < log_floor_) mel_energy = log_floor_;
|
|
|
|
if (log_base_ == LogBase::kBaseE)
|
|
mel_energy = logf(mel_energy);
|
|
else if (log_base_ == LogBase::kBase10)
|
|
mel_energy = log10(mel_energy);
|
|
}
|
|
if (max_mel_engery < mel_energy) max_mel_engery = mel_energy;
|
|
(*feat)[i][j] = mel_energy;
|
|
}
|
|
}
|
|
if (norm_type_ == NormalizationType::kWhisper)
|
|
WhisperNorm(feat, max_mel_engery);
|
|
|
|
return num_frames;
|
|
}
|
|
|
|
private:
|
|
int num_bins_;
|
|
int sample_rate_;
|
|
int frame_length_, frame_shift_;
|
|
int fft_points_;
|
|
bool use_log_;
|
|
bool remove_dc_offset_;
|
|
bool pre_emphasis_;
|
|
bool scale_input_to_unit_;
|
|
float low_freq_;
|
|
float log_floor_;
|
|
float high_freq_;
|
|
LogBase log_base_;
|
|
NormalizationType norm_type_;
|
|
|
|
std::vector<float> center_freqs_;
|
|
std::vector<std::pair<int, std::vector<float>>> bins_;
|
|
std::vector<float> window_;
|
|
std::default_random_engine generator_;
|
|
std::normal_distribution<float> distribution_;
|
|
float dither_;
|
|
|
|
// bit reversal table
|
|
std::vector<int> bitrev_;
|
|
// trigonometric function table
|
|
std::vector<float> sintbl_;
|
|
};
|
|
|
|
} // namespace wenet
|
|
|
|
#endif // FRONTEND_FBANK_H_
|