You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
7.0 KiB

  1. // Copyright (c) 2016 Personal (Binbin Zhang)
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #ifndef FRONTEND_WAV_H_
  15. #define FRONTEND_WAV_H_
  16. #include <assert.h>
  17. #include <stdint.h>
  18. #include <stdio.h>
  19. #include <stdlib.h>
  20. #include <string.h>
  21. #include "glog/logging.h"
  22. #include <string>
  23. namespace wenet {
  24. struct WavHeader {
  25. char riff[4] = {'R', 'I', 'F', 'F'};
  26. unsigned int size = 0;
  27. char wav[4] = {'W', 'A', 'V', 'E'};
  28. char fmt[4] = {'f', 'm', 't', ' '};
  29. unsigned int fmt_size = 16;
  30. uint16_t format = 1;
  31. uint16_t channels = 0;
  32. unsigned int sample_rate = 0;
  33. unsigned int bytes_per_second = 0;
  34. uint16_t block_size = 0;
  35. uint16_t bit = 0;
  36. char data[4] = {'d', 'a', 't', 'a'};
  37. unsigned int data_size = 0;
  38. WavHeader() {}
  39. WavHeader(int num_samples, int num_channel, int sample_rate,
  40. int bits_per_sample) {
  41. data_size = num_samples * num_channel * (bits_per_sample / 8);
  42. size = sizeof(WavHeader) - 8 + data_size;
  43. channels = num_channel;
  44. this->sample_rate = sample_rate;
  45. bytes_per_second = sample_rate * num_channel * (bits_per_sample / 8);
  46. block_size = num_channel * (bits_per_sample / 8);
  47. bit = bits_per_sample;
  48. }
  49. };
  50. class WavReader {
  51. public:
  52. WavReader() : data_(nullptr) {}
  53. explicit WavReader(const std::string& filename) { Open(filename); }
  54. bool Open(const std::string& filename) {
  55. FILE* fp = fopen(filename.c_str(), "rb");
  56. if (NULL == fp) {
  57. LOG(WARNING) << "Error in read " << filename;
  58. return false;
  59. }
  60. WavHeader header;
  61. fread(&header, 1, sizeof(header), fp);
  62. if ((0 != strncmp(header.riff, "RIFF", 4)) ||
  63. (0 != strncmp(header.wav, "WAVE", 4)) ||
  64. (0 != strncmp(header.fmt, "fmt", 3))) {
  65. fprintf(stderr, "WaveData: expect audio format data.\n");
  66. return false;
  67. }
  68. if (header.fmt_size < 16) {
  69. fprintf(stderr,
  70. "WaveData: expect PCM format data "
  71. "to have fmt chunk of at least size 16.\n");
  72. fclose(fp);
  73. return false;
  74. } else if (header.fmt_size > 16) {
  75. int offset = 44 - 8 + header.fmt_size - 16;
  76. fseek(fp, offset, SEEK_SET);
  77. fread(header.data, 8, sizeof(char), fp);
  78. }
  79. // check "RIFF" "WAVE" "fmt " "data"
  80. // Skip any sub-chunks between "fmt" and "data". Usually there will
  81. // be a single "fact" sub chunk, but on Windows there can also be a
  82. // "list" sub chunk.
  83. while (0 != strncmp(header.data, "data", 4)) {
  84. // We will just ignore the data in these chunks.
  85. fseek(fp, header.data_size, SEEK_CUR);
  86. // read next sub chunk
  87. fread(header.data, 8, sizeof(char), fp);
  88. }
  89. num_channel_ = header.channels;
  90. sample_rate_ = header.sample_rate;
  91. bits_per_sample_ = header.bit;
  92. int num_data = header.data_size / (bits_per_sample_ / 8);
  93. data_ = new float[num_data];
  94. num_samples_ = num_data / num_channel_;
  95. for (int i = 0; i < num_data; ++i) {
  96. switch (bits_per_sample_) {
  97. case 8: {
  98. char sample;
  99. fread(&sample, 1, sizeof(char), fp);
  100. data_[i] = static_cast<float>(sample);
  101. break;
  102. }
  103. case 16: {
  104. int16_t sample;
  105. fread(&sample, 1, sizeof(int16_t), fp);
  106. data_[i] = static_cast<float>(sample);
  107. break;
  108. }
  109. case 32: {
  110. int sample;
  111. fread(&sample, 1, sizeof(int), fp);
  112. data_[i] = static_cast<float>(sample);
  113. break;
  114. }
  115. default:
  116. fprintf(stderr, "unsupported quantization bits");
  117. exit(1);
  118. }
  119. }
  120. fclose(fp);
  121. return true;
  122. }
  123. int num_channel() const { return num_channel_; }
  124. int sample_rate() const { return sample_rate_; }
  125. int bits_per_sample() const { return bits_per_sample_; }
  126. int num_samples() const { return num_samples_; }
  127. ~WavReader() { delete[] data_; }
  128. const float* data() const { return data_; }
  129. private:
  130. int num_channel_;
  131. int sample_rate_;
  132. int bits_per_sample_;
  133. int num_samples_; // sample points per channel
  134. float* data_;
  135. };
  136. class WavWriter {
  137. public:
  138. WavWriter(const float* data, int num_samples, int num_channel,
  139. int sample_rate, int bits_per_sample)
  140. : data_(data),
  141. num_samples_(num_samples),
  142. num_channel_(num_channel),
  143. sample_rate_(sample_rate),
  144. bits_per_sample_(bits_per_sample) {}
  145. void Write(const std::string& filename) {
  146. FILE* fp = fopen(filename.c_str(), "wb");
  147. WavHeader header(num_samples_, num_channel_, sample_rate_,
  148. bits_per_sample_);
  149. fwrite(&header, 1, sizeof(header), fp);
  150. for (int i = 0; i < num_samples_; ++i) {
  151. for (int j = 0; j < num_channel_; ++j) {
  152. switch (bits_per_sample_) {
  153. case 8: {
  154. char sample = static_cast<char>(data_[i * num_channel_ + j]);
  155. fwrite(&sample, 1, sizeof(sample), fp);
  156. break;
  157. }
  158. case 16: {
  159. int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
  160. fwrite(&sample, 1, sizeof(sample), fp);
  161. break;
  162. }
  163. case 32: {
  164. int sample = static_cast<int>(data_[i * num_channel_ + j]);
  165. fwrite(&sample, 1, sizeof(sample), fp);
  166. break;
  167. }
  168. }
  169. }
  170. }
  171. fclose(fp);
  172. }
  173. private:
  174. const float* data_;
  175. int num_samples_; // total float points in data_
  176. int num_channel_;
  177. int sample_rate_;
  178. int bits_per_sample_;
  179. };
  180. class StreamWavWriter {
  181. public:
  182. StreamWavWriter(int num_channel, int sample_rate, int bits_per_sample)
  183. : num_channel_(num_channel),
  184. sample_rate_(sample_rate),
  185. bits_per_sample_(bits_per_sample),
  186. total_num_samples_(0) {}
  187. StreamWavWriter(const std::string& filename, int num_channel, int sample_rate,
  188. int bits_per_sample)
  189. : StreamWavWriter(num_channel, sample_rate, bits_per_sample) {
  190. Open(filename);
  191. }
  192. void Open(const std::string& filename) {
  193. fp_ = fopen(filename.c_str(), "wb");
  194. fseek(fp_, sizeof(WavHeader), SEEK_SET);
  195. }
  196. void Write(const int16_t* sample_data, size_t num_samples) {
  197. fwrite(sample_data, sizeof(int16_t), num_samples, fp_);
  198. total_num_samples_ += num_samples;
  199. }
  200. void Close() {
  201. WavHeader header(total_num_samples_, num_channel_, sample_rate_,
  202. bits_per_sample_);
  203. fseek(fp_, 0L, SEEK_SET);
  204. fwrite(&header, 1, sizeof(header), fp_);
  205. fclose(fp_);
  206. }
  207. private:
  208. FILE* fp_;
  209. int num_channel_;
  210. int sample_rate_;
  211. int bits_per_sample_;
  212. size_t total_num_samples_;
  213. };
  214. } // namespace wenet
  215. #endif // FRONTEND_WAV_H_