未验证 提交 26465cdb 编写于 作者: Y YangZhou 提交者: GitHub

fix paddle.audio.function.get_window security error (#47453)

上级 f4788442
...@@ -19,129 +19,155 @@ import paddle ...@@ -19,129 +19,155 @@ import paddle
from paddle import Tensor from paddle import Tensor
class WindowFunctionRegister(object):
def __init__(self):
self._functions_dict = dict()
def register(self, func=None):
def add_subfunction(func):
name = func.__name__
self._functions_dict[name] = func
return func
return add_subfunction
def get(self, name):
return self._functions_dict[name]
window_function_register = WindowFunctionRegister()
@window_function_register.register()
def _cat(x: List[Tensor], data_type: str) -> Tensor: def _cat(x: List[Tensor], data_type: str) -> Tensor:
l = [paddle.to_tensor(_, data_type) for _ in x] l = [paddle.to_tensor(_, data_type) for _ in x]
return paddle.concat(l) return paddle.concat(l)
@window_function_register.register()
def _acosh(x: Union[Tensor, float]) -> Tensor: def _acosh(x: Union[Tensor, float]) -> Tensor:
if isinstance(x, float): if isinstance(x, float):
return math.log(x + math.sqrt(x**2 - 1)) return math.log(x + math.sqrt(x**2 - 1))
return paddle.log(x + paddle.sqrt(paddle.square(x) - 1)) return paddle.log(x + paddle.sqrt(paddle.square(x) - 1))
@window_function_register.register()
def _extend(M: int, sym: bool) -> bool: def _extend(M: int, sym: bool) -> bool:
"""Extend window by 1 sample if needed for DFT-even symmetry. """ """Extend window by 1 sample if needed for DFT-even symmetry."""
if not sym: if not sym:
return M + 1, True return M + 1, True
else: else:
return M, False return M, False
@window_function_register.register()
def _len_guards(M: int) -> bool: def _len_guards(M: int) -> bool:
"""Handle small or incorrect window lengths. """ """Handle small or incorrect window lengths."""
if int(M) != M or M < 0: if int(M) != M or M < 0:
raise ValueError('Window length M must be a non-negative integer') raise ValueError('Window length M must be a non-negative integer')
return M <= 1 return M <= 1
@window_function_register.register()
def _truncate(w: Tensor, needed: bool) -> Tensor: def _truncate(w: Tensor, needed: bool) -> Tensor:
"""Truncate window by 1 sample if needed for DFT-even symmetry. """ """Truncate window by 1 sample if needed for DFT-even symmetry."""
if needed: if needed:
return w[:-1] return w[:-1]
else: else:
return w return w
def _general_gaussian(M: int, @window_function_register.register()
p, def _general_gaussian(
sig, M: int, p, sig, sym: bool = True, dtype: str = 'float64'
sym: bool = True, ) -> Tensor:
dtype: str = 'float64') -> Tensor:
"""Compute a window with a generalized Gaussian shape. """Compute a window with a generalized Gaussian shape.
This function is consistent with scipy.signal.windows.general_gaussian(). This function is consistent with scipy.signal.windows.general_gaussian().
""" """
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0 n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
w = paddle.exp(-0.5 * paddle.abs(n / sig)**(2 * p)) w = paddle.exp(-0.5 * paddle.abs(n / sig) ** (2 * p))
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
def _general_cosine(M: int, @window_function_register.register()
a: float, def _general_cosine(
sym: bool = True, M: int, a: float, sym: bool = True, dtype: str = 'float64'
dtype: str = 'float64') -> Tensor: ) -> Tensor:
"""Compute a generic weighted sum of cosine terms window. """Compute a generic weighted sum of cosine terms window.
This function is consistent with scipy.signal.windows.general_cosine(). This function is consistent with scipy.signal.windows.general_cosine().
""" """
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype) fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype)
w = paddle.zeros((M, ), dtype=dtype) w = paddle.zeros((M,), dtype=dtype)
for k in range(len(a)): for k in range(len(a)):
w += a[k] * paddle.cos(k * fac) w += a[k] * paddle.cos(k * fac)
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
def _general_hamming(M: int, @window_function_register.register()
alpha: float, def _general_hamming(
sym: bool = True, M: int, alpha: float, sym: bool = True, dtype: str = 'float64'
dtype: str = 'float64') -> Tensor: ) -> Tensor:
"""Compute a generalized Hamming window. """Compute a generalized Hamming window.
This function is consistent with scipy.signal.windows.general_hamming() This function is consistent with scipy.signal.windows.general_hamming()
""" """
return _general_cosine(M, [alpha, 1. - alpha], sym, dtype=dtype) return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype)
def _taylor(M: int, @window_function_register.register()
nbar=4, def _taylor(
sll=30, M: int, nbar=4, sll=30, norm=True, sym: bool = True, dtype: str = 'float64'
norm=True, ) -> Tensor:
sym: bool = True,
dtype: str = 'float64') -> Tensor:
"""Compute a Taylor window. """Compute a Taylor window.
The Taylor window taper function approximates the Dolph-Chebyshev window's The Taylor window taper function approximates the Dolph-Chebyshev window's
constant sidelobe level for a parameterized number of near-in sidelobes. constant sidelobe level for a parameterized number of near-in sidelobes.
""" """
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
# Original text uses a negative sidelobe level parameter and then negates # Original text uses a negative sidelobe level parameter and then negates
# it in the calculation of B. To keep consistent with other methods we # it in the calculation of B. To keep consistent with other methods we
# assume the sidelobe level parameter to be positive. # assume the sidelobe level parameter to be positive.
B = 10**(sll / 20) B = 10 ** (sll / 20)
A = _acosh(B) / math.pi A = _acosh(B) / math.pi
s2 = nbar**2 / (A**2 + (nbar - 0.5)**2) s2 = nbar**2 / (A**2 + (nbar - 0.5) ** 2)
ma = paddle.arange(1, nbar, dtype=dtype) ma = paddle.arange(1, nbar, dtype=dtype)
Fm = paddle.empty((nbar - 1, ), dtype=dtype) Fm = paddle.empty((nbar - 1,), dtype=dtype)
signs = paddle.empty_like(ma) signs = paddle.empty_like(ma)
signs[::2] = 1 signs[::2] = 1
signs[1::2] = -1 signs[1::2] = -1
m2 = ma * ma m2 = ma * ma
for mi in range(len(ma)): for mi in range(len(ma)):
numer = signs[mi] * paddle.prod(1 - m2[mi] / s2 / (A**2 + numer = signs[mi] * paddle.prod(
(ma - 0.5)**2)) 1 - m2[mi] / s2 / (A**2 + (ma - 0.5) ** 2)
)
if mi == 0: if mi == 0:
denom = 2 * paddle.prod(1 - m2[mi] / m2[mi + 1:]) denom = 2 * paddle.prod(1 - m2[mi] / m2[mi + 1 :])
elif mi == len(ma) - 1: elif mi == len(ma) - 1:
denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi]) denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi])
else: else:
denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi]) * paddle.prod( denom = (
1 - m2[mi] / m2[mi + 1:]) 2
* paddle.prod(1 - m2[mi] / m2[:mi])
* paddle.prod(1 - m2[mi] / m2[mi + 1 :])
)
Fm[mi] = numer / denom Fm[mi] = numer / denom
def W(n): def W(n):
return 1 + 2 * paddle.matmul( return 1 + 2 * paddle.matmul(
Fm.unsqueeze(0), Fm.unsqueeze(0),
paddle.cos(2 * math.pi * ma.unsqueeze(1) * (n - M / 2. + 0.5) / M)) paddle.cos(2 * math.pi * ma.unsqueeze(1) * (n - M / 2.0 + 0.5) / M),
)
w = W(paddle.arange(0, M, dtype=dtype)) w = W(paddle.arange(0, M, dtype=dtype))
...@@ -153,6 +179,7 @@ def _taylor(M: int, ...@@ -153,6 +179,7 @@ def _taylor(M: int,
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hamming window. """Compute a Hamming window.
The Hamming window is a taper formed by using a raised cosine with The Hamming window is a taper formed by using a raised cosine with
...@@ -161,6 +188,7 @@ def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -161,6 +188,7 @@ def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.54, sym, dtype=dtype) return _general_hamming(M, 0.54, sym, dtype=dtype)
@window_function_register.register()
def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hann window. """Compute a Hann window.
The Hann window is a taper formed by using a raised cosine or sine-squared The Hann window is a taper formed by using a raised cosine or sine-squared
...@@ -169,18 +197,18 @@ def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -169,18 +197,18 @@ def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.5, sym, dtype=dtype) return _general_hamming(M, 0.5, sym, dtype=dtype)
def _tukey(M: int, @window_function_register.register()
alpha=0.5, def _tukey(
sym: bool = True, M: int, alpha=0.5, sym: bool = True, dtype: str = 'float64'
dtype: str = 'float64') -> Tensor: ) -> Tensor:
"""Compute a Tukey window. """Compute a Tukey window.
The Tukey window is also known as a tapered cosine window. The Tukey window is also known as a tapered cosine window.
""" """
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
if alpha <= 0: if alpha <= 0:
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
elif alpha >= 1.0: elif alpha >= 1.0:
return hann(M, sym=sym) return hann(M, sym=sym)
...@@ -188,57 +216,58 @@ def _tukey(M: int, ...@@ -188,57 +216,58 @@ def _tukey(M: int,
n = paddle.arange(0, M, dtype=dtype) n = paddle.arange(0, M, dtype=dtype)
width = int(alpha * (M - 1) / 2.0) width = int(alpha * (M - 1) / 2.0)
n1 = n[0:width + 1] n1 = n[0 : width + 1]
n2 = n[width + 1:M - width - 1] n2 = n[width + 1 : M - width - 1]
n3 = n[M - width - 1:] n3 = n[M - width - 1 :]
w1 = 0.5 * (1 + paddle.cos(math.pi * (-1 + 2.0 * n1 / alpha / (M - 1)))) w1 = 0.5 * (1 + paddle.cos(math.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
w2 = paddle.ones(n2.shape, dtype=dtype) w2 = paddle.ones(n2.shape, dtype=dtype)
w3 = 0.5 * (1 + paddle.cos(math.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / w3 = 0.5 * (
(M - 1)))) 1
+ paddle.cos(math.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / (M - 1)))
)
w = paddle.concat([w1, w2, w3]) w = paddle.concat([w1, w2, w3])
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
def _kaiser(M: int, @window_function_register.register()
beta: float, def _kaiser(
sym: bool = True, M: int, beta: float, sym: bool = True, dtype: str = 'float64'
dtype: str = 'float64') -> Tensor: ) -> Tensor:
"""Compute a Kaiser window. """Compute a Kaiser window.
The Kaiser window is a taper formed by using a Bessel function. The Kaiser window is a taper formed by using a Bessel function.
""" """
raise NotImplementedError() raise NotImplementedError()
def _gaussian(M: int, @window_function_register.register()
std: float, def _gaussian(
sym: bool = True, M: int, std: float, sym: bool = True, dtype: str = 'float64'
dtype: str = 'float64') -> Tensor: ) -> Tensor:
"""Compute a Gaussian window. """Compute a Gaussian window.
The Gaussian widows has a Gaussian shape defined by the standard deviation(std). The Gaussian widows has a Gaussian shape defined by the standard deviation(std).
""" """
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0 n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
sig2 = 2 * std * std sig2 = 2 * std * std
w = paddle.exp(-n**2 / sig2) w = paddle.exp(-(n**2) / sig2)
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
def _exponential(M: int, @window_function_register.register()
center=None, def _exponential(
tau=1., M: int, center=None, tau=1.0, sym: bool = True, dtype: str = 'float64'
sym: bool = True, ) -> Tensor:
dtype: str = 'float64') -> Tensor: """Compute an exponential (or Poisson) window."""
"""Compute an exponential (or Poisson) window. """
if sym and center is not None: if sym and center is not None:
raise ValueError("If sym==True, center must be None.") raise ValueError("If sym==True, center must be None.")
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
if center is None: if center is None:
...@@ -250,11 +279,11 @@ def _exponential(M: int, ...@@ -250,11 +279,11 @@ def _exponential(M: int,
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a triangular window. """Compute a triangular window."""
"""
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
n = paddle.arange(1, (M + 1) // 2 + 1, dtype=dtype) n = paddle.arange(1, (M + 1) // 2 + 1, dtype=dtype)
...@@ -268,22 +297,25 @@ def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -268,22 +297,25 @@ def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Bohman window. """Compute a Bohman window.
The Bohman window is the autocorrelation of a cosine window. The Bohman window is the autocorrelation of a cosine window.
""" """
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
fac = paddle.abs(paddle.linspace(-1, 1, M, dtype=dtype)[1:-1]) fac = paddle.abs(paddle.linspace(-1, 1, M, dtype=dtype)[1:-1])
w = (1 - fac) * paddle.cos(math.pi * fac) + 1.0 / math.pi * paddle.sin( w = (1 - fac) * paddle.cos(math.pi * fac) + 1.0 / math.pi * paddle.sin(
math.pi * fac) math.pi * fac
)
w = _cat([0, w, 0], dtype) w = _cat([0, w, 0], dtype)
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Blackman window. """Compute a Blackman window.
The Blackman window is a taper formed by using the first three terms of The Blackman window is a taper formed by using the first three terms of
...@@ -294,25 +326,27 @@ def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -294,25 +326,27 @@ def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype)
@window_function_register.register()
def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a window with a simple cosine shape. """Compute a window with a simple cosine shape."""
"""
if _len_guards(M): if _len_guards(M):
return paddle.ones((M, ), dtype=dtype) return paddle.ones((M,), dtype=dtype)
M, needs_trunc = _extend(M, sym) M, needs_trunc = _extend(M, sym)
w = paddle.sin(math.pi / M * (paddle.arange(0, M, dtype=dtype) + .5)) w = paddle.sin(math.pi / M * (paddle.arange(0, M, dtype=dtype) + 0.5))
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
def get_window(window: Union[str, Tuple[str, float]], def get_window(
win_length: int, window: Union[str, Tuple[str, float]],
fftbins: bool = True, win_length: int,
dtype: str = 'float64') -> Tensor: fftbins: bool = True,
dtype: str = 'float64',
) -> Tensor:
"""Return a window of a given length and type. """Return a window of a given length and type.
Args: Args:
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
win_length (int): Number of samples. win_length (int): Number of samples.
fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True. fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
dtype (str, optional): The data type of the return window. Defaults to 'float64'. dtype (str, optional): The data type of the return window. Defaults to 'float64'.
...@@ -340,19 +374,22 @@ def get_window(window: Union[str, Tuple[str, float]], ...@@ -340,19 +374,22 @@ def get_window(window: Union[str, Tuple[str, float]],
args = window[1:] args = window[1:]
elif isinstance(window, str): elif isinstance(window, str):
if window in ['gaussian', 'exponential']: if window in ['gaussian', 'exponential']:
raise ValueError("The '" + window + "' window needs one or " raise ValueError(
"more parameters -- pass a tuple.") "The '" + window + "' window needs one or "
"more parameters -- pass a tuple."
)
else: else:
winstr = window winstr = window
else: else:
raise ValueError("%s as window type is not supported." % raise ValueError(
str(type(window))) "%s as window type is not supported." % str(type(window))
)
try: try:
winfunc = eval('_' + winstr) winfunc = window_function_register.get('_' + winstr)
except NameError as e: except KeyError as e:
raise ValueError("Unknown window type.") from e raise ValueError("Unknown window type.") from e
params = (win_length, ) + args params = (win_length,) + args
kwargs = {'sym': sym} kwargs = {'sym': sym}
return winfunc(*params, dtype=dtype, **kwargs) return winfunc(*params, dtype=dtype, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册