未验证 提交 26c419ca 编写于 作者: Y YangZhou 提交者: GitHub

[audio]fix audio get_window security error (#47386)

* fix window security error

* format
上级 0f649b32
......@@ -19,17 +19,39 @@ import paddle
from paddle import Tensor
class WindowFunctionRegister(object):
def __init__(self):
self._functions_dict = dict()
def register(self, func=None):
def add_subfunction(func):
name = func.__name__
self._functions_dict[name] = func
return func
return add_subfunction
def get(self, name):
return self._functions_dict[name]
window_function_register = WindowFunctionRegister()
@window_function_register.register()
def _cat(x: List[Tensor], data_type: str) -> Tensor:
l = [paddle.to_tensor(_, data_type) for _ in x]
return paddle.concat(l)
@window_function_register.register()
def _acosh(x: Union[Tensor, float]) -> Tensor:
if isinstance(x, float):
return math.log(x + math.sqrt(x**2 - 1))
return paddle.log(x + paddle.sqrt(paddle.square(x) - 1))
@window_function_register.register()
def _extend(M: int, sym: bool) -> bool:
"""Extend window by 1 sample if needed for DFT-even symmetry."""
if not sym:
......@@ -38,6 +60,7 @@ def _extend(M: int, sym: bool) -> bool:
return M, False
@window_function_register.register()
def _len_guards(M: int) -> bool:
"""Handle small or incorrect window lengths."""
if int(M) != M or M < 0:
......@@ -46,6 +69,7 @@ def _len_guards(M: int) -> bool:
return M <= 1
@window_function_register.register()
def _truncate(w: Tensor, needed: bool) -> Tensor:
"""Truncate window by 1 sample if needed for DFT-even symmetry."""
if needed:
......@@ -54,6 +78,7 @@ def _truncate(w: Tensor, needed: bool) -> Tensor:
return w
@window_function_register.register()
def _general_gaussian(
M: int, p, sig, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -70,6 +95,7 @@ def _general_gaussian(
return _truncate(w, needs_trunc)
@window_function_register.register()
def _general_cosine(
M: int, a: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -86,6 +112,7 @@ def _general_cosine(
return _truncate(w, needs_trunc)
@window_function_register.register()
def _general_hamming(
M: int, alpha: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -95,6 +122,7 @@ def _general_hamming(
return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype)
@window_function_register.register()
def _taylor(
M: int, nbar=4, sll=30, norm=True, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -151,6 +179,7 @@ def _taylor(
return _truncate(w, needs_trunc)
@window_function_register.register()
def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hamming window.
The Hamming window is a taper formed by using a raised cosine with
......@@ -159,6 +188,7 @@ def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.54, sym, dtype=dtype)
@window_function_register.register()
def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hann window.
The Hann window is a taper formed by using a raised cosine or sine-squared
......@@ -167,6 +197,7 @@ def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.5, sym, dtype=dtype)
@window_function_register.register()
def _tukey(
M: int, alpha=0.5, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -200,6 +231,7 @@ def _tukey(
return _truncate(w, needs_trunc)
@window_function_register.register()
def _kaiser(
M: int, beta: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -209,6 +241,7 @@ def _kaiser(
raise NotImplementedError()
@window_function_register.register()
def _gaussian(
M: int, std: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -226,6 +259,7 @@ def _gaussian(
return _truncate(w, needs_trunc)
@window_function_register.register()
def _exponential(
M: int, center=None, tau=1.0, sym: bool = True, dtype: str = 'float64'
) -> Tensor:
......@@ -245,6 +279,7 @@ def _exponential(
return _truncate(w, needs_trunc)
@window_function_register.register()
def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a triangular window."""
if _len_guards(M):
......@@ -262,6 +297,7 @@ def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _truncate(w, needs_trunc)
@window_function_register.register()
def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Bohman window.
The Bohman window is the autocorrelation of a cosine window.
......@@ -279,6 +315,7 @@ def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _truncate(w, needs_trunc)
@window_function_register.register()
def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Blackman window.
The Blackman window is a taper formed by using the first three terms of
......@@ -289,6 +326,7 @@ def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype)
@window_function_register.register()
def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a window with a simple cosine shape."""
if _len_guards(M):
......@@ -308,7 +346,7 @@ def get_window(
"""Return a window of a given length and type.
Args:
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
win_length (int): Number of samples.
fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
dtype (str, optional): The data type of the return window. Defaults to 'float64'.
......@@ -348,8 +386,8 @@ def get_window(
)
try:
winfunc = eval('_' + winstr)
except NameError as e:
winfunc = window_function_register.get('_' + winstr)
except KeyError as e:
raise ValueError("Unknown window type.") from e
params = (win_length,) + args
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册