提交 220fe203 编写于 作者: H Hui Zhang

test with dither, remove dc offset, preermphs

上级 42f93b2c
......@@ -28,8 +28,8 @@ def read(wavpath:str, sr:int = None, start=0, stop=None, dtype='int16', always_2
def write(wavpath:str, wav:np.ndarray, sr:int, dtype='PCM_16'):
sf.write(wavpath, wav, sr, subtype=dtype)
def frames(x: Tensor,
num_samples: Tensor,
sr: int,
......@@ -51,7 +51,7 @@ def frames(x: Tensor,
stride_length : float
Stride length in ms.
clip : bool, optional
Whether to clip audio that does not fit into the last frame, by
Whether to clip audio that does not fit into the last frame, by
default True
Returns
......@@ -64,7 +64,7 @@ def frames(x: Tensor,
assert stride_length <= win_length
stride_length = int(stride_length * sr)
win_length = int(win_length * sr)
num_frames = (num_samples - win_length) // stride_length
padding = (0, 0)
if not clip:
......@@ -92,10 +92,11 @@ def dither(signal:Tensor, dither_value=1.0)->Tensor:
Returns:
Tensor: [B, T, D]
"""
signal += paddle.normal(shape=[1, 1, signal.shape[-1]]) * dither_value
D = paddle.shape(signal)[-1]
signal += paddle.normal(shape=[1, 1, D]) * dither_value
return signal
def remove_dc_offset(signal:Tensor)->Tensor:
"""remove dc.
......@@ -105,7 +106,7 @@ def remove_dc_offset(signal:Tensor)->Tensor:
Returns:
Tensor: [B, T, D]
"""
signal -= paddle.mean(signal, axis=-1)
signal -= paddle.mean(signal, axis=-1, keepdim=True)
return signal
def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
......@@ -125,21 +126,21 @@ def preemphasis(signal:Tensor, coeff=0.97)->Tensor:
class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
"""A module for computing stft transformation in a differentiable way.
http://practicalcryptography.com/miscellaneous/machine-learning/intuitive-guide-discrete-fourier-transform/
Parameters
------------
------------
n_fft : int
Number of samples in a frame.
sr: int
Number of Samplilng rate.
stride_length : float
Number of samples shifted between adjacent frames.
win_length : float
Length of the window.
......@@ -151,7 +152,7 @@ class STFT(nn.Layer):
sr: int,
win_length: float,
stride_length: float,
dither:float=1.0,
dither:float=0.0,
preemph_coeff:float=0.97,
remove_dc_offset:bool=True,
window_type: str = 'povey',
......@@ -165,17 +166,17 @@ class STFT(nn.Layer):
self.remove_dc_offset = remove_dc_offset
self.window_type = window_type
self.clip = clip
self.n_fft = n_fft
self.n_bin = 1 + n_fft // 2
w_real, w_imag, kernel_size = dft_matrix(
self.n_fft, int(self.win_length * self.sr), self.n_bin
)
# calculate window
window = get_window(window_type, kernel_size)
# (2 * n_bins, kernel_size)
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
......@@ -203,7 +204,7 @@ class STFT(nn.Layer):
batch_size = paddle.shape(num_samples)
F, nframe = frames(x, num_samples, self.sr, self.win_length, self.stride_length, clip=self.clip)
if self.dither:
F = dither(F, dither)
F = dither(F, self.dither)
if self.remove_dc_offset:
F = remove_dc_offset(F)
if self.preemph_coeff:
......@@ -215,7 +216,7 @@ class STFT(nn.Layer):
def powspec(C:Tensor) -> Tensor:
"""Compute the power spectrum.
"""Compute the power spectrum.
Args:
C (Tensor): [B, T, C, 2]
......@@ -225,10 +226,10 @@ def powspec(C:Tensor) -> Tensor:
"""
real, imag = paddle.chunk(C, 2, axis=-1)
return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1))
def magspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute the magnitude spectrum.
"""Compute the magnitude spectrum.
Args:
C (Tensor): [B, T, C, 2]
......
......@@ -397,20 +397,18 @@ class TestKaldiFE(unittest.TestCase):
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertTrue(np.allclose(t_fs.numpy(), fs))
def test_stft(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
_, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
print('py', stft_c_win.real)
print('py', stft_c_win.imag)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
......@@ -420,33 +418,26 @@ class TestKaldiFE(unittest.TestCase):
t_stft = t_stft.astype(stft_c_win.real.dtype)[0]
t_real = t_stft[:, :, 0]
t_imag = t_stft[:, :, 1]
print('pd', t_real.numpy())
print('pd', t_imag.numpy())
self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0])
self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1)
print(np.sum(t_real.numpy()))
print(np.sum(stft_c_win.real))
self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1))
self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1)
print(np.sum(t_imag.numpy()))
print(np.sum(stft_c_win.imag))
self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1))
def test_magspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
......@@ -455,20 +446,39 @@ class TestKaldiFE(unittest.TestCase):
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.magspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1))
def test_magsepc_winprocess(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
fs, _= framesig(wav, self.winlen*sr, self.winstep*sr,
dither=0.0, preemph=0.97, remove_dc_offset=True, wintype='povey', stride_trick=True)
spec = magspec(fs, self.nfft) # nearly the same until this part
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(
self.nfft, sr, self.winlen, self.winstep,
window_type='povey', dither=0.0, preemph_coeff=0.97, remove_dc_offset=True, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(spec.dtype)
t_spec = kaldi.magspec(t_stft)[0]
self.assertEqual(t_nframe.item(), fs.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(spec), 1)
self.assertTrue(np.allclose(t_spec.numpy(), spec, atol=1e-1))
def test_powspec(self):
sr, wav = kaldi.read(self.wavpath)
wav = wav[:, 0]
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
......@@ -476,7 +486,6 @@ class TestKaldiFE(unittest.TestCase):
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
stft_win = np.square(stft_win)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
......@@ -485,13 +494,10 @@ class TestKaldiFE(unittest.TestCase):
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.powspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 5e4)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册