提交 598fe0d4 编写于 作者: H Hui Zhang

stft complex, powspec, magspec

上级 daf9abda
......@@ -132,7 +132,7 @@ class STFT(nn.Layer):
wsin = np.empty((self.n_bin, kernel_size)) #[Cout, kernel_size]
wcos = np.empty((self.n_bin, kernel_size)) #[Cout, kernel_size]
for k in range(self.n_bin): # Only half of the bins contain useful info
wsin[k,:] = np.sin(2*np.pi*k*n/self.n_fft)[:kernel_size]
wsin[k,:] = -np.sin(2*np.pi*k*n/self.n_fft)[:kernel_size]
wcos[k,:] = np.cos(2*np.pi*k*n/self.n_fft)[:kernel_size]
w_real = wcos
w_imag = wsin
......@@ -144,8 +144,7 @@ class STFT(nn.Layer):
# w_imag = weight.imag
# (2 * n_bins, kernel_size)
#w = np.concatenate([w_real, w_imag], axis=0)
w = w_real
w = np.concatenate([w_real, w_imag], axis=0)
w = w * window
# (2 * n_bins, 1, kernel_size) # (C_out, C_in, kernel_size)
......@@ -163,7 +162,7 @@ class STFT(nn.Layer):
Number of samples of each waveform.
Returns
------------
D : Tensor
C : Tensor
Shape(B, T', n_bins, 2) Spectrogram.
num_frames: Tensor
......@@ -178,11 +177,37 @@ class STFT(nn.Layer):
batch_size, _ = paddle.shape(x)
x = x.unsqueeze(-1)
D = F.conv1d(x, self.weight,
C = F.conv1d(x, self.weight,
stride=(self.stride_length, ),
padding=padding,
data_format="NLC")
#D = paddle.reshape(D, [batch_size, -1, self.n_bin, 2])
D = paddle.reshape(D, [batch_size, -1, self.n_bin, 1])
return D, num_frames
C = paddle.reshape(C, [batch_size, -1, 2, self.n_bin])
C = C.transpose([0, 1, 3, 2])
return C, num_frames
def powspec(C:Tensor) -> Tensor:
"""Compute the power spectrum.
Args:
C (Tensor): [B, T, C, 2]
Returns:
Tensor: [B, T, C]
"""
real, imag = paddle.chunk(C, 2, axis=-1)
return paddle.square(real.squeeze(-1)) + paddle.square(imag.squeeze(-1))
def magspec(C: Tensor, eps=1e-10) -> Tensor:
"""Compute the magnitude spectrum.
Args:
C (Tensor): [B, T, C, 2]
eps (float): epsilon.
Returns:
Tensor: [B, T, C]
"""
pspec = powspec(C)
return paddle.sqrt(pspec + eps)
\ No newline at end of file
......@@ -235,33 +235,95 @@ class TestKaldiFE(unittest.TestCase):
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
sftf_win, stft_c_win, _, stft_c = stft_with_window(wav, samplerate=sr,
_, stft_c_win, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
print('py', stft_c_win.real)
#print(stft_c_win.imag)
print('py', stft_c_win.imag)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(sftf_win.dtype)[0]
t_stft = t_stft.astype(stft_c_win.real.dtype)[0]
t_real = t_stft[:, :, 0]
#t_imag = t_stft[:, :, 1]
t_imag = t_stft[:, :, 1]
print('pd', t_real.numpy())
#print(t_imag.numpy())
print('pd', t_imag.numpy())
self.assertEqual(t_nframe.item(), sftf_win.shape[0])
self.assertEqual(t_nframe.item(), stft_c_win.real.shape[0])
self.assertLess(np.sum(t_real.numpy()) - np.sum(stft_c_win.real), 1)
print(np.sum(t_real.numpy()))
print(np.sum(stft_c_win.real))
self.assertTrue(np.allclose(t_real.numpy(), stft_c_win.real, atol=1e-1))
#self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag))
self.assertLess(np.sum(t_imag.numpy()) - np.sum(stft_c_win.imag), 1)
print(np.sum(t_imag.numpy()))
print(np.sum(stft_c_win.imag))
self.assertTrue(np.allclose(t_imag.numpy(), stft_c_win.imag, atol=1e-1))
def test_magspec(self):
sr, wav = kaldi.read(self.wavpath)
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.magspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy()) - np.sum(stft_win), 1)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e-1))
def test_powspec(self):
sr, wav = kaldi.read(self.wavpath)
for wintype in ['', 'hamm', 'hann', 'povey']:
print(wintype)
self.wintype=wintype
stft_win, _, _, _ = stft_with_window(wav, samplerate=sr,
winlen=self.winlen, winstep=self.winstep,
nfilt=self.nfilt, nfft=self.nfft,
lowfreq=self.lowfreq, highfreq=self.highfreq,
wintype=self.wintype)
stft_win = np.square(stft_win)
print('py', stft_win)
t_wav = paddle.to_tensor([wav], dtype='float32')
t_wavlen = paddle.to_tensor([len(wav)])
stft_class = kaldi.STFT(self.nfft, sr, self.winlen, self.winstep, window_type=self.wintype, clip=False)
t_stft, t_nframe = stft_class(t_wav, t_wavlen)
t_stft = t_stft.astype(stft_win.dtype)
t_spec = kaldi.powspec(t_stft)[0]
print('pd', t_spec.numpy())
self.assertEqual(t_nframe.item(), stft_win.shape[0])
self.assertLess(np.sum(t_spec.numpy() - stft_win), 2e4)
print(np.sum(t_spec.numpy()))
print(np.sum(stft_win))
self.assertTrue(np.allclose(t_spec.numpy(), stft_win, atol=1e2))
# from python_speech_features import mfcc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册