diff --git a/demo/components/audio_test.py b/demo/components/audio_test.py index 542d9514c9fd94341cf2e7150cbd078a2dfa70c5..e56e36388ffaea0839c9b6de8867cf4c0591d999 100644 --- a/demo/components/audio_test.py +++ b/demo/components/audio_test.py @@ -14,35 +14,13 @@ # ======================================================================= # coding=utf-8 from visualdl import LogWriter -import numpy as np -import wave - - -def read_audio_data(audio_path): - """ - Get audio data. - """ - CHUNK = 4096 - f = wave.open(audio_path, "rb") - rate = f.getframerate() - width = f.getsampwidth() - channel = f.getnchannels() - wavdata = [] - chunk = f.readframes(CHUNK) - - while chunk: - data = np.frombuffer(chunk, dtype='uint8') - wavdata.extend(data) - chunk = f.readframes(CHUNK) - shape = [rate, width, channel] - return shape, wavdata +from scipy.io import wavfile if __name__ == '__main__': - with LogWriter(logdir="vdl_audio_0713") as writer: - audio_shape, audio_data = read_audio_data("./test.wav") - audio_data = np.array(audio_data) + with LogWriter(logdir="./log/audio_test/train") as writer: + sample_rate, audio_data = wavfile.read('./test.wav') writer.add_audio(tag="audio_tag", audio_array=audio_data, step=0, - sample_rate=audio_shape[0]) + sample_rate=sample_rate) diff --git a/demo/components/test.wav b/demo/components/test.wav index a1170d8057d76ec1073a13089903e85332fe47ea..9c862b8df8af6ee0bf3387fb6f9a84ebe64735db 100644 Binary files a/demo/components/test.wav and b/demo/components/test.wav differ diff --git a/visualdl/component/base_component.py b/visualdl/component/base_component.py index ffc9c6c8300218c04675940b436cd2e0cd6b471b..1265aceac0e69ded3deeb785ad61ea74b31f5fab 100644 --- a/visualdl/component/base_component.py +++ b/visualdl/component/base_component.py @@ -131,15 +131,26 @@ def audio(tag, audio_array, sample_rate, step, walltime): Return: Package with format of record_pb2.Record """ + audio_array = audio_array.squeeze() + if abs(audio_array).max() > 1: + print('warning: audio amplitude out of range, auto clipped.') + audio_array = audio_array.clip(-1, 1) + assert (audio_array.ndim == 1), 'input tensor should be 1 dimensional.' + + audio_array = [int(32767.0 * x) for x in audio_array] + import io import wave + import struct fio = io.BytesIO() wave_writer = wave.open(fio, 'wb') wave_writer.setnchannels(1) wave_writer.setsampwidth(2) wave_writer.setframerate(sample_rate) - wave_writer.writeframes(audio_array) + audio_enc = b'' + audio_enc += struct.pack("<" + "h" * len(audio_array), *audio_array) + wave_writer.writeframes(audio_enc) wave_writer.close() audio_string = fio.getvalue() fio.close()