未验证 提交 26c419ca 编写于 作者: Y YangZhou 提交者: GitHub

[audio]fix audio get_window security error (#47386)

* fix window security error

* format
上级 0f649b32
...@@ -19,17 +19,39 @@ import paddle ...@@ -19,17 +19,39 @@ import paddle
from paddle import Tensor from paddle import Tensor
class WindowFunctionRegister(object):
def __init__(self):
self._functions_dict = dict()
def register(self, func=None):
def add_subfunction(func):
name = func.__name__
self._functions_dict[name] = func
return func
return add_subfunction
def get(self, name):
return self._functions_dict[name]
window_function_register = WindowFunctionRegister()
@window_function_register.register()
def _cat(x: List[Tensor], data_type: str) -> Tensor: def _cat(x: List[Tensor], data_type: str) -> Tensor:
l = [paddle.to_tensor(_, data_type) for _ in x] l = [paddle.to_tensor(_, data_type) for _ in x]
return paddle.concat(l) return paddle.concat(l)
@window_function_register.register()
def _acosh(x: Union[Tensor, float]) -> Tensor: def _acosh(x: Union[Tensor, float]) -> Tensor:
if isinstance(x, float): if isinstance(x, float):
return math.log(x + math.sqrt(x**2 - 1)) return math.log(x + math.sqrt(x**2 - 1))
return paddle.log(x + paddle.sqrt(paddle.square(x) - 1)) return paddle.log(x + paddle.sqrt(paddle.square(x) - 1))
@window_function_register.register()
def _extend(M: int, sym: bool) -> bool: def _extend(M: int, sym: bool) -> bool:
"""Extend window by 1 sample if needed for DFT-even symmetry.""" """Extend window by 1 sample if needed for DFT-even symmetry."""
if not sym: if not sym:
...@@ -38,6 +60,7 @@ def _extend(M: int, sym: bool) -> bool: ...@@ -38,6 +60,7 @@ def _extend(M: int, sym: bool) -> bool:
return M, False return M, False
@window_function_register.register()
def _len_guards(M: int) -> bool: def _len_guards(M: int) -> bool:
"""Handle small or incorrect window lengths.""" """Handle small or incorrect window lengths."""
if int(M) != M or M < 0: if int(M) != M or M < 0:
...@@ -46,6 +69,7 @@ def _len_guards(M: int) -> bool: ...@@ -46,6 +69,7 @@ def _len_guards(M: int) -> bool:
return M <= 1 return M <= 1
@window_function_register.register()
def _truncate(w: Tensor, needed: bool) -> Tensor: def _truncate(w: Tensor, needed: bool) -> Tensor:
"""Truncate window by 1 sample if needed for DFT-even symmetry.""" """Truncate window by 1 sample if needed for DFT-even symmetry."""
if needed: if needed:
...@@ -54,6 +78,7 @@ def _truncate(w: Tensor, needed: bool) -> Tensor: ...@@ -54,6 +78,7 @@ def _truncate(w: Tensor, needed: bool) -> Tensor:
return w return w
@window_function_register.register()
def _general_gaussian( def _general_gaussian(
M: int, p, sig, sym: bool = True, dtype: str = 'float64' M: int, p, sig, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -70,6 +95,7 @@ def _general_gaussian( ...@@ -70,6 +95,7 @@ def _general_gaussian(
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _general_cosine( def _general_cosine(
M: int, a: float, sym: bool = True, dtype: str = 'float64' M: int, a: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -86,6 +112,7 @@ def _general_cosine( ...@@ -86,6 +112,7 @@ def _general_cosine(
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _general_hamming( def _general_hamming(
M: int, alpha: float, sym: bool = True, dtype: str = 'float64' M: int, alpha: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -95,6 +122,7 @@ def _general_hamming( ...@@ -95,6 +122,7 @@ def _general_hamming(
return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype) return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype)
@window_function_register.register()
def _taylor( def _taylor(
M: int, nbar=4, sll=30, norm=True, sym: bool = True, dtype: str = 'float64' M: int, nbar=4, sll=30, norm=True, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -151,6 +179,7 @@ def _taylor( ...@@ -151,6 +179,7 @@ def _taylor(
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hamming window. """Compute a Hamming window.
The Hamming window is a taper formed by using a raised cosine with The Hamming window is a taper formed by using a raised cosine with
...@@ -159,6 +188,7 @@ def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -159,6 +188,7 @@ def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.54, sym, dtype=dtype) return _general_hamming(M, 0.54, sym, dtype=dtype)
@window_function_register.register()
def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Hann window. """Compute a Hann window.
The Hann window is a taper formed by using a raised cosine or sine-squared The Hann window is a taper formed by using a raised cosine or sine-squared
...@@ -167,6 +197,7 @@ def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -167,6 +197,7 @@ def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_hamming(M, 0.5, sym, dtype=dtype) return _general_hamming(M, 0.5, sym, dtype=dtype)
@window_function_register.register()
def _tukey( def _tukey(
M: int, alpha=0.5, sym: bool = True, dtype: str = 'float64' M: int, alpha=0.5, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -200,6 +231,7 @@ def _tukey( ...@@ -200,6 +231,7 @@ def _tukey(
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _kaiser( def _kaiser(
M: int, beta: float, sym: bool = True, dtype: str = 'float64' M: int, beta: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -209,6 +241,7 @@ def _kaiser( ...@@ -209,6 +241,7 @@ def _kaiser(
raise NotImplementedError() raise NotImplementedError()
@window_function_register.register()
def _gaussian( def _gaussian(
M: int, std: float, sym: bool = True, dtype: str = 'float64' M: int, std: float, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -226,6 +259,7 @@ def _gaussian( ...@@ -226,6 +259,7 @@ def _gaussian(
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _exponential( def _exponential(
M: int, center=None, tau=1.0, sym: bool = True, dtype: str = 'float64' M: int, center=None, tau=1.0, sym: bool = True, dtype: str = 'float64'
) -> Tensor: ) -> Tensor:
...@@ -245,6 +279,7 @@ def _exponential( ...@@ -245,6 +279,7 @@ def _exponential(
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a triangular window.""" """Compute a triangular window."""
if _len_guards(M): if _len_guards(M):
...@@ -262,6 +297,7 @@ def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -262,6 +297,7 @@ def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Bohman window. """Compute a Bohman window.
The Bohman window is the autocorrelation of a cosine window. The Bohman window is the autocorrelation of a cosine window.
...@@ -279,6 +315,7 @@ def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -279,6 +315,7 @@ def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _truncate(w, needs_trunc) return _truncate(w, needs_trunc)
@window_function_register.register()
def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a Blackman window. """Compute a Blackman window.
The Blackman window is a taper formed by using the first three terms of The Blackman window is a taper formed by using the first three terms of
...@@ -289,6 +326,7 @@ def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: ...@@ -289,6 +326,7 @@ def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype)
@window_function_register.register()
def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor:
"""Compute a window with a simple cosine shape.""" """Compute a window with a simple cosine shape."""
if _len_guards(M): if _len_guards(M):
...@@ -308,7 +346,7 @@ def get_window( ...@@ -308,7 +346,7 @@ def get_window(
"""Return a window of a given length and type. """Return a window of a given length and type.
Args: Args:
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
win_length (int): Number of samples. win_length (int): Number of samples.
fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True. fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
dtype (str, optional): The data type of the return window. Defaults to 'float64'. dtype (str, optional): The data type of the return window. Defaults to 'float64'.
...@@ -348,8 +386,8 @@ def get_window( ...@@ -348,8 +386,8 @@ def get_window(
) )
try: try:
winfunc = eval('_' + winstr) winfunc = window_function_register.get('_' + winstr)
except NameError as e: except KeyError as e:
raise ValueError("Unknown window type.") from e raise ValueError("Unknown window type.") from e
params = (win_length,) + args params = (win_length,) + args
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册