未验证 提交 d64553b6 编写于 作者: N Nicky Chan 提交者: GitHub

Complete audio data integration: (#366)

* Complete audio data integration:
* Add num channels, sample width, sample rate as params
* Generate audio file with proper params

* clang format
上级 375dd77a
...@@ -112,7 +112,9 @@ with logw.mode("train") as logger: ...@@ -112,7 +112,9 @@ with logw.mode("train") as logger:
for sample in range(10): for sample in range(10):
idx = audio.is_sample_taken() idx = audio.is_sample_taken()
if idx >= 0: if idx >= 0:
audio.set_sample(idx, 8000, wavdata) # 8k sample rate, 16bit frame, 1 channel
shape = [8000, 2, 1]
audio.set_sample(idx, shape, wavdata)
audio.finish_sampling() audio.finish_sampling()
......
...@@ -287,29 +287,34 @@ PYBIND11_MODULE(core, m) { ...@@ -287,29 +287,34 @@ PYBIND11_MODULE(core, m) {
End a sampling period, it will clear all states for reservoir sampling. End a sampling period, it will clear all states for reservoir sampling.
)pbdoc") )pbdoc")
.def("set_sample", &cp::Audio::SetSample, R"pbdoc( .def("set_sample", &cp::Audio::SetSample, R"pbdoc(
Store the flatten audio data with sample rate specified. Store the flatten audio data as vector of uint8 types. Audio params need to
be specified as a tuple of 3 integers as following:
sample_rate: number of samples(frames) per second, e.g. 8000, 16000 or 44100
sample_width: size of each sample(frame) in bytes, 16bit frame will be 2
num_channels: number of channels associated with the audio data, normally 1 or 2
:param index: :param index:
:type index: integer :type index: integer
:param sample_rate: Sample rate of audio :param audio_params: [sample rate, sample width, number of channels]
:type sample_rate: integer :type audio_params: tuple
:param audio_data: Flatten audio data :param audio_data: Flatten audio data
:type audio_data: list :type audio_data: list
)pbdoc") )pbdoc")
.def("add_sample", &cp::Audio::AddSample, R"pbdoc( .def("add_sample", &cp::Audio::AddSample, R"pbdoc(
A combined interface for is_sample_taken and set_sample, simpler but is less efficient. A combined interface for is_sample_taken and set_sample, simpler but is less efficient.
Audio params details see set_sample
:param sample_rate: Sample rate of audio :param audio_params: [sample rate, sample width, number of channels]
:type sample_rate: integer :type audio_params: tuple
:param audio_data: Flatten audio data :param audio_data: Flatten audio data
:type audio_data: list :type audio_data: list of uint8
)pbdoc"); )pbdoc");
py::class_<cp::AudioReader::AudioRecord>(m, "AudioRecord") py::class_<cp::AudioReader::AudioRecord>(m, "AudioRecord")
// TODO(Nicky) make these copyless. // TODO(Nicky) make these copyless.
.def("data", [](cp::AudioReader::AudioRecord& self) { return self.data; }) .def("data", [](cp::AudioReader::AudioRecord& self) { return self.data; })
.def("sample_rate", .def("shape",
[](cp::AudioReader::AudioRecord& self) { return self.sample_rate; }) [](cp::AudioReader::AudioRecord& self) { return self.shape; })
.def("step_id", .def("step_id",
[](cp::AudioReader::AudioRecord& self) { return self.step_id; }); [](cp::AudioReader::AudioRecord& self) { return self.step_id; });
......
...@@ -459,25 +459,40 @@ void Audio::FinishSampling() { ...@@ -459,25 +459,40 @@ void Audio::FinishSampling() {
} }
} }
void Audio::AddSample(int sample_rate, const std::vector<value_t>& data) { void Audio::AddSample(const std::vector<shape_t>& shape,
const std::vector<value_t>& data) {
auto idx = IndexOfSampleTaken(); auto idx = IndexOfSampleTaken();
if (idx >= 0) { if (idx >= 0) {
SetSample(idx, sample_rate, data); SetSample(idx, shape, data);
} }
} }
void Audio::SetSample(int index, void Audio::SetSample(int index,
int sample_rate, const std::vector<shape_t>& shape,
const std::vector<value_t>& data) { const std::vector<value_t>& data) {
CHECK_GT(sample_rate, 0) CHECK_EQ(shape.size(), 3)
<< "sample rate should be something like 6000, 8000 or 44100"; << "shape need to be (sample rate, sample width, num channel)";
shape_t sample_rate = shape[0];
shape_t sample_width = shape[1];
shape_t num_channels = shape[2];
CHECK_GT(sample_rate, 0) << "sample rate is number of frames per second, "
"should be something like 8000, 16000 or 44100";
CHECK_GT(sample_width, 0)
<< "sample width is frame size in bytes, 16bits frame will be 2";
CHECK_GT(num_channels, 0) << "num channel will be something like 1 or 2";
CHECK_LT(index, num_samples_) CHECK_LT(index, num_samples_)
<< "index should be less than number of samples"; << "index should be less than number of samples";
CHECK_LE(index, num_records_) CHECK_LE(index, num_records_)
<< "index should be less than or equal to number of records"; << "index should be less than or equal to number of records";
BinaryRecord brcd(GenBinaryRecordDir(step_.parent()->dir()), // due to prototype limit size, we create a directory to log binary data such
std::string(data.begin(), data.end())); // as audio or image
BinaryRecord brcd(
GenBinaryRecordDir(step_.parent()->dir()),
std::string(data.begin(),
data.end())); // convert vector to binary string
brcd.tofile(); brcd.tofile();
auto entry = step_.MutableData<std::vector<byte_t>>(index); auto entry = step_.MutableData<std::vector<byte_t>>(index);
...@@ -490,6 +505,7 @@ void Audio::SetSample(int index, ...@@ -490,6 +505,7 @@ void Audio::SetSample(int index,
<< old_path << " failed"; << old_path << " failed";
} }
entry.SetRaw(brcd.filename()); entry.SetRaw(brcd.filename());
entry.SetMulti(shape);
} }
std::string AudioReader::caption() { std::string AudioReader::caption() {
...@@ -511,10 +527,13 @@ AudioReader::AudioRecord AudioReader::record(int offset, int index) { ...@@ -511,10 +527,13 @@ AudioReader::AudioRecord AudioReader::record(int offset, int index) {
<< "g_log_dir should be set in LogReader construction"; << "g_log_dir should be set in LogReader construction";
BinaryRecordReader brcd(GenBinaryRecordDir(g_log_dir), filename); BinaryRecordReader brcd(GenBinaryRecordDir(g_log_dir), filename);
// convert binary string back to vector of uint8_t, equivalent of python
// numpy.fromstring(data, dtype='uint8')
std::transform(brcd.data.begin(), std::transform(brcd.data.begin(),
brcd.data.end(), brcd.data.end(),
std::back_inserter(res.data), std::back_inserter(res.data),
[](byte_t i) { return (int8_t)(i); }); [](byte_t i) { return (uint8_t)(i); });
res.shape = entry.GetMulti<shape_t>();
res.step_id = record.id(); res.step_id = record.id();
return res; return res;
} }
......
...@@ -377,7 +377,8 @@ private: ...@@ -377,7 +377,8 @@ private:
* Image component writer. * Image component writer.
*/ */
struct Audio { struct Audio {
using value_t = float; using value_t = uint8_t;
using shape_t = int32_t;
/* /*
* step_cycle: store every `step_cycle` as a record. * step_cycle: store every `step_cycle` as a record.
...@@ -413,7 +414,8 @@ struct Audio { ...@@ -413,7 +414,8 @@ struct Audio {
* might be * might be
* low efficiency. * low efficiency.
*/ */
void AddSample(int sample_rate, const std::vector<value_t>& data); void AddSample(const std::vector<shape_t>& shape,
const std::vector<value_t>& data);
/* /*
* Will this sample be taken, this interface is introduced to reduce the cost * Will this sample be taken, this interface is introduced to reduce the cost
...@@ -425,7 +427,9 @@ struct Audio { ...@@ -425,7 +427,9 @@ struct Audio {
/* /*
* Store audio data with sample rate * Store audio data with sample rate
*/ */
void SetSample(int index, int sample_rate, const std::vector<value_t>& data); void SetSample(int index,
const std::vector<shape_t>& shape,
const std::vector<value_t>& data);
protected: protected:
bool ToSampleThisStep() { return step_id_ % step_cycle_ == 0; } bool ToSampleThisStep() { return step_id_ % step_cycle_ == 0; }
...@@ -444,11 +448,12 @@ private: ...@@ -444,11 +448,12 @@ private:
*/ */
struct AudioReader { struct AudioReader {
using value_t = typename Audio::value_t; using value_t = typename Audio::value_t;
using shape_t = typename Audio::shape_t;
struct AudioRecord { struct AudioRecord {
int step_id; int step_id;
int sample_rate; std::vector<uint8_t> data;
std::vector<int8_t> data; std::vector<shape_t> shape;
}; };
AudioReader(const std::string& mode, TabletReader tablet) AudioReader(const std::string& mode, TabletReader tablet)
...@@ -475,6 +480,12 @@ struct AudioReader { ...@@ -475,6 +480,12 @@ struct AudioReader {
*/ */
std::vector<value_t> data(int offset, int index); std::vector<value_t> data(int offset, int index);
/*
* offset: offset of a step.
* index: index of a sample.
*/
std::vector<shape_t> shape(int offset, int index);
int stepid(int offset, int index); int stepid(int offset, int index);
private: private:
......
...@@ -222,15 +222,26 @@ def get_individual_audio(storage, mode, tag, step_index, max_size=80): ...@@ -222,15 +222,26 @@ def get_individual_audio(storage, mode, tag, step_index, max_size=80):
audio = reader.audio(tag) audio = reader.audio(tag)
record = audio.record(step_index, offset) record = audio.record(step_index, offset)
data = np.array(record.data(), dtype='uint8') shape = record.shape()
sample_rate = shape[0]
sample_width = shape[1]
num_channels = shape[2]
# sending a temp file to front end
tempfile = NamedTemporaryFile(mode='w+b', suffix='.wav') tempfile = NamedTemporaryFile(mode='w+b', suffix='.wav')
# write audio file to that tempfile
wavfile = wave.open(tempfile, 'wb') wavfile = wave.open(tempfile, 'wb')
wavfile.setnchannels(2)
wavfile.setsampwidth(2) wavfile.setframerate(sample_rate)
wavfile.setnchannels(num_channels)
wavfile.setsampwidth(sample_width)
# convert to binary string to write to wav file
data = np.array(record.data(), dtype='uint8')
wavfile.writeframes(data.tostring()) wavfile.writeframes(data.tostring())
# make sure the marker is at the start of file
tempfile.seek(0, 0) tempfile.seek(0, 0)
return tempfile return tempfile
......
...@@ -269,7 +269,6 @@ def audio(): ...@@ -269,7 +269,6 @@ def audio():
@app.route('/data/plugin/audio/individualAudio') @app.route('/data/plugin/audio/individualAudio')
def individual_audio(): def individual_audio():
mode = request.args.get('run') mode = request.args.get('run')
print mode
tag = request.args.get('tag') # include a index tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step step_index = int(request.args.get('index')) # index of step
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册