diff --git a/python/paddle/audio/functional/window.py b/python/paddle/audio/functional/window.py index 844e2fc26335f98b75ff5d24dfd5b446fe13ed39..315d5a50a323f7f8f34b80954d884302cd77bf40 100644 --- a/python/paddle/audio/functional/window.py +++ b/python/paddle/audio/functional/window.py @@ -19,17 +19,39 @@ import paddle from paddle import Tensor +class WindowFunctionRegister(object): + def __init__(self): + self._functions_dict = dict() + + def register(self, func=None): + def add_subfunction(func): + name = func.__name__ + self._functions_dict[name] = func + return func + + return add_subfunction + + def get(self, name): + return self._functions_dict[name] + + +window_function_register = WindowFunctionRegister() + + +@window_function_register.register() def _cat(x: List[Tensor], data_type: str) -> Tensor: l = [paddle.to_tensor(_, data_type) for _ in x] return paddle.concat(l) +@window_function_register.register() def _acosh(x: Union[Tensor, float]) -> Tensor: if isinstance(x, float): return math.log(x + math.sqrt(x**2 - 1)) return paddle.log(x + paddle.sqrt(paddle.square(x) - 1)) +@window_function_register.register() def _extend(M: int, sym: bool) -> bool: """Extend window by 1 sample if needed for DFT-even symmetry.""" if not sym: @@ -38,6 +60,7 @@ def _extend(M: int, sym: bool) -> bool: return M, False +@window_function_register.register() def _len_guards(M: int) -> bool: """Handle small or incorrect window lengths.""" if int(M) != M or M < 0: @@ -46,6 +69,7 @@ def _len_guards(M: int) -> bool: return M <= 1 +@window_function_register.register() def _truncate(w: Tensor, needed: bool) -> Tensor: """Truncate window by 1 sample if needed for DFT-even symmetry.""" if needed: @@ -54,6 +78,7 @@ def _truncate(w: Tensor, needed: bool) -> Tensor: return w +@window_function_register.register() def _general_gaussian( M: int, p, sig, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -70,6 +95,7 @@ def _general_gaussian( return _truncate(w, needs_trunc) +@window_function_register.register() def _general_cosine( M: int, a: float, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -86,6 +112,7 @@ def _general_cosine( return _truncate(w, needs_trunc) +@window_function_register.register() def _general_hamming( M: int, alpha: float, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -95,6 +122,7 @@ def _general_hamming( return _general_cosine(M, [alpha, 1.0 - alpha], sym, dtype=dtype) +@window_function_register.register() def _taylor( M: int, nbar=4, sll=30, norm=True, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -151,6 +179,7 @@ def _taylor( return _truncate(w, needs_trunc) +@window_function_register.register() def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: """Compute a Hamming window. The Hamming window is a taper formed by using a raised cosine with @@ -159,6 +188,7 @@ def _hamming(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: return _general_hamming(M, 0.54, sym, dtype=dtype) +@window_function_register.register() def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: """Compute a Hann window. The Hann window is a taper formed by using a raised cosine or sine-squared @@ -167,6 +197,7 @@ def _hann(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: return _general_hamming(M, 0.5, sym, dtype=dtype) +@window_function_register.register() def _tukey( M: int, alpha=0.5, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -200,6 +231,7 @@ def _tukey( return _truncate(w, needs_trunc) +@window_function_register.register() def _kaiser( M: int, beta: float, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -209,6 +241,7 @@ def _kaiser( raise NotImplementedError() +@window_function_register.register() def _gaussian( M: int, std: float, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -226,6 +259,7 @@ def _gaussian( return _truncate(w, needs_trunc) +@window_function_register.register() def _exponential( M: int, center=None, tau=1.0, sym: bool = True, dtype: str = 'float64' ) -> Tensor: @@ -245,6 +279,7 @@ def _exponential( return _truncate(w, needs_trunc) +@window_function_register.register() def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: """Compute a triangular window.""" if _len_guards(M): @@ -262,6 +297,7 @@ def _triang(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: return _truncate(w, needs_trunc) +@window_function_register.register() def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: """Compute a Bohman window. The Bohman window is the autocorrelation of a cosine window. @@ -279,6 +315,7 @@ def _bohman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: return _truncate(w, needs_trunc) +@window_function_register.register() def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: """Compute a Blackman window. The Blackman window is a taper formed by using the first three terms of @@ -289,6 +326,7 @@ def _blackman(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype) +@window_function_register.register() def _cosine(M: int, sym: bool = True, dtype: str = 'float64') -> Tensor: """Compute a window with a simple cosine shape.""" if _len_guards(M): @@ -308,7 +346,7 @@ def get_window( """Return a window of a given length and type. Args: - window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. + window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'general_gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. win_length (int): Number of samples. fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True. dtype (str, optional): The data type of the return window. Defaults to 'float64'. @@ -348,8 +386,8 @@ def get_window( ) try: - winfunc = eval('_' + winstr) - except NameError as e: + winfunc = window_function_register.get('_' + winstr) + except KeyError as e: raise ValueError("Unknown window type.") from e params = (win_length,) + args