提交 0d7d8712 编写于 作者: H Hui Zhang

simplify feature pipeline graph

上级 8690a00b
...@@ -357,10 +357,13 @@ def _get_mel_banks(num_bins: int, ...@@ -357,10 +357,13 @@ def _get_mel_banks(num_bins: int,
('Bad values in options: vtln-low {} and vtln-high {}, versus ' ('Bad values in options: vtln-low {} and vtln-high {}, versus '
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq)) 'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
bin = paddle.arange(num_bins).unsqueeze(1) bin = paddle.arange(num_bins, dtype=paddle.float32).unsqueeze(1)
# left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
# center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1)
# right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1)
left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1) left_mel = mel_low_freq + bin * mel_freq_delta # (num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # (num_bins, 1) center_mel = left_mel + mel_freq_delta
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # (num_bins, 1) right_mel = center_mel + mel_freq_delta
if vtln_warp_factor != 1.0: if vtln_warp_factor != 1.0:
left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, left_mel = _vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq,
...@@ -373,7 +376,7 @@ def _get_mel_banks(num_bins: int, ...@@ -373,7 +376,7 @@ def _get_mel_banks(num_bins: int,
center_freqs = _inverse_mel_scale(center_mel) # (num_bins) center_freqs = _inverse_mel_scale(center_mel) # (num_bins)
# (1, num_fft_bins) # (1, num_fft_bins)
mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0) mel = _mel_scale(fft_bin_width * paddle.arange(num_fft_bins, dtype=paddle.float32)).unsqueeze(0)
# (num_bins, num_fft_bins) # (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel) up_slope = (mel - left_mel) / (center_mel - left_mel)
...@@ -472,7 +475,8 @@ def fbank(waveform: Tensor, ...@@ -472,7 +475,8 @@ def fbank(waveform: Tensor,
# (n_mels, padded_window_size // 2) # (n_mels, padded_window_size // 2)
mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq, mel_energies, _ = _get_mel_banks(n_mels, padded_window_size, sr, low_freq,
high_freq, vtln_low, vtln_high, vtln_warp) high_freq, vtln_low, vtln_high, vtln_warp)
mel_energies = mel_energies.astype(dtype) # mel_energies = mel_energies.astype(dtype)
assert mel_energies.dtype == dtype
# (n_mels, padded_window_size // 2 + 1) # (n_mels, padded_window_size // 2 + 1)
mel_energies = paddle.nn.functional.pad( mel_energies = paddle.nn.functional.pad(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册