未验证 提交 a45c354e 编写于 作者: 走神的阿圆's avatar 走神的阿圆 提交者: GitHub

Modify add_audio (#751)

上级 7db24c50
...@@ -14,35 +14,13 @@ ...@@ -14,35 +14,13 @@
# ======================================================================= # =======================================================================
# coding=utf-8 # coding=utf-8
from visualdl import LogWriter from visualdl import LogWriter
import numpy as np from scipy.io import wavfile
import wave
def read_audio_data(audio_path):
"""
Get audio data.
"""
CHUNK = 4096
f = wave.open(audio_path, "rb")
rate = f.getframerate()
width = f.getsampwidth()
channel = f.getnchannels()
wavdata = []
chunk = f.readframes(CHUNK)
while chunk:
data = np.frombuffer(chunk, dtype='uint8')
wavdata.extend(data)
chunk = f.readframes(CHUNK)
shape = [rate, width, channel]
return shape, wavdata
if __name__ == '__main__': if __name__ == '__main__':
with LogWriter(logdir="vdl_audio_0713") as writer: with LogWriter(logdir="./log/audio_test/train") as writer:
audio_shape, audio_data = read_audio_data("./test.wav") sample_rate, audio_data = wavfile.read('./test.wav')
audio_data = np.array(audio_data)
writer.add_audio(tag="audio_tag", writer.add_audio(tag="audio_tag",
audio_array=audio_data, audio_array=audio_data,
step=0, step=0,
sample_rate=audio_shape[0]) sample_rate=sample_rate)
...@@ -131,15 +131,26 @@ def audio(tag, audio_array, sample_rate, step, walltime): ...@@ -131,15 +131,26 @@ def audio(tag, audio_array, sample_rate, step, walltime):
Return: Return:
Package with format of record_pb2.Record Package with format of record_pb2.Record
""" """
audio_array = audio_array.squeeze()
if abs(audio_array).max() > 1:
print('warning: audio amplitude out of range, auto clipped.')
audio_array = audio_array.clip(-1, 1)
assert (audio_array.ndim == 1), 'input tensor should be 1 dimensional.'
audio_array = [int(32767.0 * x) for x in audio_array]
import io import io
import wave import wave
import struct
fio = io.BytesIO() fio = io.BytesIO()
wave_writer = wave.open(fio, 'wb') wave_writer = wave.open(fio, 'wb')
wave_writer.setnchannels(1) wave_writer.setnchannels(1)
wave_writer.setsampwidth(2) wave_writer.setsampwidth(2)
wave_writer.setframerate(sample_rate) wave_writer.setframerate(sample_rate)
wave_writer.writeframes(audio_array) audio_enc = b''
audio_enc += struct.pack("<" + "h" * len(audio_array), *audio_array)
wave_writer.writeframes(audio_enc)
wave_writer.close() wave_writer.close()
audio_string = fio.getvalue() audio_string = fio.getvalue()
fio.close() fio.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册