From ad8856aa06cab2874cebe4877017925c8ac08507 Mon Sep 17 00:00:00 2001 From: ranchlai Date: Fri, 16 Jul 2021 11:27:59 +0800 Subject: [PATCH] Update doc string with examples/shapes, controllable backends, and some bug fixed (#5324) * added sound classication * added liscense, clean code, add pre-commit * update req * moved to PaddlePaddle-models * code re-structure * update README.md * update README.md * Update README.md * add audioset training * default resample mode to kaiser_fast * delete some comments * precommit check * sha->rev * add config.ymal * remove SoundClassification from paddlespeech, since it's in PaddleAudio now * add labels * remove old labels * update code * empty * #5300 * add evaluate, etc * remove trace| * import evaluate * path update * precommit check * recover slowfast * restore README.md to paddle:develop * refactor * update readme * update README.md * refactor * refactor * refactor * refactor * precommit fixed * update README.md * Update README.md * Update README.md * Update train.py changed prefixed, removed some comments * add wav file for testing * bug fixed eval,new checkpoint map=0.416 * Update README.md * added dcase task1b example * update README.md * code fixed for last review * fixed level string formating * fixed according to PR reviews * added wav2vec2.0 * restore datatsets * add liscense, remove scipy, move test_audio to cloud * remove 3rd-party dependency:pathos * add testing for wav2vec2 * update README.md * updated README.md, added librispeech results * Revert "updated README.md, added librispeech results" This reverts commit da4012958e8e0bf2d7f4b608f74518583dd7d73b. * code fixed from reviews * add librispeech test * remove pathos imports * updated README.md * update README.md * minor-fix according to code reviews * updated README_LP.MD * fixed according to code review * fixed according to code review * added preprocessing example * removed dcase2021_task1b from examples * remove preprocessing from examples * added amsoftmax to losses * added eer/min_dcf to metrics * updated __init__.py * add stft,spectrogram, melspectrogram, log-melspectrogram * add _internal, transoform, functional to imports * add new module: functional * add new module: window.py to _internel/ * add correspoding new unit-test for the new modules * added ISTFT * clean code and docstring, clean unit test * clean code and docstring * functional * added back preprocessing * add README.md * remove preprocessing for now * clean code, add doc * change _internal to signal * add new transoforms * add new functionals * add eps to amsoftmax, return the prediction * add ffmpeg backend * remove dithering in depth-convert, add ffmpeg to backend * add Mudecode/enccode/RandomCodec * changed variable name, fixed bug * use namedtuple for returning * refactor utils * refactor * add melspectrogram/spectrogram, add doc string * add doc string, clean code * rename window to windowing * updated docstring, minor bug fixed * move losses.py to future examples * remove mu_encode/decode * refactor * move metrics to future examples * remove features/ * naming changes for mu law algorithms * update test, add testing utils * fixed import * fixed import * fixed duplicate output in logging * add code examples, shape info, etc * add doc for public functions * make backend controllable * fixed coding stype in docstring --- PaddleAudio/paddleaudio/backends/audio.py | 102 +++++-- PaddleAudio/paddleaudio/core/windowing.py | 130 ++++++-- PaddleAudio/paddleaudio/functional.py | 353 +++++++++++++++++++--- PaddleAudio/paddleaudio/transforms.py | 165 +++++++++- PaddleAudio/paddleaudio/utils/utils.py | 15 +- 5 files changed, 657 insertions(+), 108 deletions(-) diff --git a/PaddleAudio/paddleaudio/backends/audio.py b/PaddleAudio/paddleaudio/backends/audio.py index e90fade5..82179293 100644 --- a/PaddleAudio/paddleaudio/backends/audio.py +++ b/PaddleAudio/paddleaudio/backends/audio.py @@ -13,6 +13,8 @@ # limitations under the License. __all__ = [ + 'set_backend', + 'get_backends', 'resample', 'to_mono', 'depth_convert', @@ -36,16 +38,39 @@ from ._ffmpeg import DecodingError, FFmpegAudioFile NORMALMIZE_TYPES = ['linear', 'gaussian'] MERGE_TYPES = ['ch0', 'ch1', 'random', 'average'] RESAMPLE_MODES = ['kaiser_best', 'kaiser_fast'] +SUPPORT_BACKENDS = ['ffmpeg', 'soundfile'] + EPS = 1e-8 +BACK_END = None + + +def set_backend(backend: Union[str, None] = 'ffmpeg'): + """Set audio decoding backend. + Parameters: + backend(str|None): The name of the backend to use. If None, paddleaudio will + choose the optimal backend automatically. + + Notes: + Use get_backends() to get available backends. + + """ + global BACK_END + if backend and backend not in SUPPORT_BACKENDS: + raise ParameterError(f'Unsupported backend {backend} ,' + + f'supported backends are {SUPPORT_BACKENDS}') + BACK_END = backend + + +def get_backends(): + return SUPPORT_BACKENDS + def _safe_cast(y: array, dtype: Union[type, str]) -> array: """Data type casting in a safe way, i.e., prevent overflow or underflow. Notes: This function is used internally. """ - import pdb - pdb.set_trace() return np.clip(y, np.iinfo(dtype).min, np.iinfo(dtype).max).astype(dtype) @@ -80,8 +105,8 @@ def _sound_file_load(file: os.PathLike, offset: Optional[float] = None, dtype: str = 'int16', duration: Optional[int] = None) -> Tuple[array, int]: - """Load audio using soundfile library - This function load audio file using libsndfile. + """Load audio using soundfile library. + This function loads audio file using libsndfile. Reference: http://www.mega-nerd.com/libsndfile/#Features @@ -102,8 +127,8 @@ def _sound_file_load(file: os.PathLike, def _sox_file_load(): - """Load audio using sox library - This function load audio file using sox. + """Load audio using sox library. + This function loads audio file using sox. Reference: http://sox.sourceforge.net/ @@ -127,13 +152,13 @@ def depth_convert(y: array, dtype: Union[type, str]) -> array: SUPPORT_DTYPE = ['int16', 'int8', 'float32', 'float64'] if y.dtype not in SUPPORT_DTYPE: raise ParameterError( - f'Unsupported audio dtype, ' - 'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}') + f'Unsupported audio dtype, ' + + f'y.dtype is {y.dtype}, supported dtypes are {SUPPORT_DTYPE}') if dtype not in SUPPORT_DTYPE: raise ParameterError( - f'Unsupported audio dtype, ' - 'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}') + f'Unsupported audio dtype, ' + + f'target dtype is {dtype}, supported dtypes are {SUPPORT_DTYPE}') if dtype == y.dtype: return y @@ -171,21 +196,22 @@ def resample(y: array, src_sr: int, target_sr: int, mode: str = 'kaiser_fast') -> array: - """Apply resampling to the input audio array. + """Apply audio resampling to the input audio array. Notes: 1. This function uses resampy.resample to do the resampling. 2. The default mode is kaiser_fast. For better audio quality, - use mode = 'kaiser_fast' + use mode = 'kaiser_best' """ if mode == 'kaiser_best': warnings.warn( - f'Using resampy in kaiser_best to {src_sr}=>{target_sr}. This function is pretty slow, \ - we recommend the mode kaiser_fast in large scale audio trainning') + f'Using resampy in kaiser_best to {src_sr}=>{target_sr}.' + + f'This function is pretty slow, ' + + f'we recommend the mode kaiser_fast in large scale audio training') if not isinstance(y, np.ndarray): - raise ParameterError( - 'Only support numpy array, but received y in {type(y)}') + raise TypeError( + f'Only support numpy array, but received y in {type(y)}') if mode not in RESAMPLE_MODES: raise ParameterError(f'resample mode must in {RESAMPLE_MODES}') @@ -193,7 +219,7 @@ def resample(y: array, return resampy.resample(y, src_sr, target_sr, filter=mode) -def to_mono(y: array, merge_type: str = 'average') -> array: +def to_mono(y: array, merge_type: str = 'ch0') -> array: """Convert stereo audio to mono audio. Parameters: y(array): the input audio array of shape [2,n], where n is the number of audio samples. @@ -229,17 +255,14 @@ def to_mono(y: array, merge_type: str = 'average') -> array: # need to do averaging according to dtype if y.dtype == 'float32': - y_out = (y[0] + y[1]) * 0.5 + y_out = y.mean(0) elif y.dtype == 'int16': - y_out = y.astype('int32') - y_out = (y_out[0] + y_out[1]) // 2 + y_out = y.mean(0) y_out = np.clip(y_out, np.iinfo(y.dtype).min, np.iinfo(y.dtype).max).astype(y.dtype) - elif y.dtype == 'int8': - y_out = y.astype('int16') - y_out = (y_out[0] + y_out[1]) // 2 + y_out = y.mean(0) y_out = np.clip(y_out, np.iinfo(y.dtype).min, np.iinfo(y.dtype).max).astype(y.dtype) @@ -293,6 +316,11 @@ def save_wav(y: array, sr: int, file: os.PathLike) -> None: Notes: The function only supports raw wav format. """ + if y.ndim == 2 and y.shape[0] > y.shape[1]: + warnings.warn( + f'The audio array tried to saved has {y.shape[0]} channels ' + + f'and the wave length is {y.shape[1]}. It\'s that what you mean?' + + f'If not, try to tranpose the array before saving.') if not file.endswith('.wav'): raise ParameterError( f'only .wav file supported, but dst file name is: {file}') @@ -309,7 +337,7 @@ def save_wav(y: array, sr: int, file: os.PathLike) -> None: else: y_out = y - wavfile.write(file, sr, y_out) + wavfile.write(file, sr, y_out.T) def load( @@ -337,7 +365,7 @@ def load( if it is originally steore. See to_mono() for more details. The default value is True. merge_type(str): the merging algorithm. See to_mono() for more details. - The default value is 'average'. + The default value is 'ch0'. normal(bool): whether to normalize the audio waveform. If True, the audio will be normalized using algorithm specified in norm_type. See normalize() for more details. The default value is True. @@ -360,19 +388,27 @@ def load( DecodingError, if audio file is not supported """ - try: + if BACK_END == 'ffmpeg': + y, r = _ffmpeg_load(file, offset=offset, duration=duration) + elif BACK_END == 'soundfile': y, r = _sound_file_load(file, offset=offset, dtype=dtype, duration=duration) - except FileNotFoundError: - raise FileNotFoundError( - f'Trying to load a file that doesnot exist {file}') - except: + else: try: - y, r = _ffmpeg_load(file, offset=offset, duration=duration) - except DecodingError: - raise DecodingError(f'Failed to load and decode file {file}') + y, r = _sound_file_load(file, + offset=offset, + dtype=dtype, + duration=duration) + except FileNotFoundError: + raise FileNotFoundError( + f'Trying to load a file that doesnot exist {file}') + except: + try: + y, r = _ffmpeg_load(file, offset=offset, duration=duration) + except DecodingError: + raise DecodingError(f'Failed to load and decode file {file}') if not ((y.ndim == 1 and len(y) > 0) or (y.ndim == 2 and len(y[0]) > 0)): return np.array([], dtype=dtype) # return empty audio diff --git a/PaddleAudio/paddleaudio/core/windowing.py b/PaddleAudio/paddleaudio/core/windowing.py index 0fc2cc24..62ffdc32 100644 --- a/PaddleAudio/paddleaudio/core/windowing.py +++ b/PaddleAudio/paddleaudio/core/windowing.py @@ -30,7 +30,7 @@ __all__ = [ 'tukey', 'taylor', ] -_PI = 3.141592653589793 +math.pi = 3.141592653589793 def _cat(a: List[Tensor], data_type: str) -> Tensor: @@ -93,11 +93,17 @@ def general_hamming(M: int, alpha: float, sym: bool = True) -> Tensor: def taylor(M: int, nbar=4, sll=30, norm=True, sym: bool = True) -> Tensor: """Compute a Taylor window. - The Taylor window taper function approximates the Dolph-Chebyshev window's constant sidelobe level for a parameterized number of near-in sidelobes. - - This function is consistent with scipy.signal.windows.taylor(). + Parameters: + M(int): window size + nbar, sil, norm: the window-specific parameter. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.taylor(). """ if _len_guards(M): return paddle.ones((M, ), dtype='float32') @@ -106,7 +112,7 @@ def taylor(M: int, nbar=4, sll=30, norm=True, sym: bool = True) -> Tensor: # it in the calculation of B. To keep consistent with other methods we # assume the sidelobe level parameter to be positive. B = 10**(sll / 20) - A = _acosh(B) / _PI + A = _acosh(B) / math.pi s2 = nbar**2 / (A**2 + (nbar - 0.5)**2) ma = paddle.arange(1, nbar, dtype='float32') @@ -131,7 +137,7 @@ def taylor(M: int, nbar=4, sll=30, norm=True, sym: bool = True) -> Tensor: def W(n): return 1 + 2 * paddle.matmul( Fm.unsqueeze(0), - paddle.cos(2 * _PI * ma.unsqueeze(1) * (n - M / 2. + 0.5) / M)) + paddle.cos(2 * math.pi * ma.unsqueeze(1) * (n - M / 2. + 0.5) / M)) w = W(paddle.arange(0, M, dtype='float32')) @@ -151,7 +157,7 @@ def general_cosine(M: int, a: float, sym: bool = True) -> Tensor: if _len_guards(M): return paddle.ones((M, ), dtype='float32') M, needs_trunc = _extend(M, sym) - fac = paddle.linspace(-_PI, _PI, M) + fac = paddle.linspace(-math.pi, math.pi, M) w = paddle.zeros((M, ), dtype='float32') for k in range(len(a)): w += a[k] * paddle.cos(k * fac) @@ -162,8 +168,14 @@ def hamming(M: int, sym: bool = True) -> Tensor: """Compute a Hamming window. The Hamming window is a taper formed by using a raised cosine with non-zero endpoints, optimized to minimize the nearest side lobe. - - This function is consistent with scipy.signal.windows.hamming(). + Parameters: + M(int): window size + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.hamming(). """ return general_hamming(M, 0.54, sym) @@ -172,8 +184,14 @@ def hann(M: int, sym: bool = True) -> Tensor: """Compute a Hann window. The Hann window is a taper formed by using a raised cosine or sine-squared with ends that touch zero. - - This function is consistent with scipy.signal.windows.hann(). + Parameters: + M(int): window size + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.hann(). """ return general_hamming(M, 0.5, sym) @@ -181,8 +199,14 @@ def hann(M: int, sym: bool = True) -> Tensor: def tukey(M: int, alpha=0.5, sym: bool = True) -> Tensor: """Compute a Tukey window. The Tukey window is also known as a tapered cosine window. - - This function is consistent with scipy.signal.windows.tukey(). + Parameters: + M(int): window size + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.tukey(). """ if _len_guards(M): return paddle.ones((M, ), dtype='float32') @@ -200,10 +224,10 @@ def tukey(M: int, alpha=0.5, sym: bool = True) -> Tensor: n2 = n[width + 1:M - width - 1] n3 = n[M - width - 1:] - w1 = 0.5 * (1 + paddle.cos(_PI * (-1 + 2.0 * n1 / alpha / (M - 1)))) + w1 = 0.5 * (1 + paddle.cos(math.pi * (-1 + 2.0 * n1 / alpha / (M - 1)))) w2 = paddle.ones(n2.shape, dtype='float32') - w3 = 0.5 * (1 + paddle.cos(_PI * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / - (M - 1)))) + w3 = 0.5 * (1 + paddle.cos(math.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha / + (M - 1)))) w = paddle.concat([w1, w2, w3]) return _truncate(w, needs_trunc) @@ -212,6 +236,15 @@ def tukey(M: int, alpha=0.5, sym: bool = True) -> Tensor: def kaiser(M: int, beta: float, sym: bool = True) -> Tensor: """Compute a Kaiser window. The Kaiser window is a taper formed by using a Bessel function. + Parameters: + M(int): window size. + beta(float): the window-specific parameter. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.kaiser(). """ @@ -222,7 +255,15 @@ def gaussian(M: int, std: float, sym: bool = True) -> Tensor: """Compute a Gaussian window. The Gaussian widows has a Gaussian shape defined by the standard deviation(std). - This function is consistent with scipy.signal.windows.gaussian(). + Parameters: + M(int): window size. + std(float): the window-specific parameter. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.gaussian(). """ if _len_guards(M): return paddle.ones((M, ), dtype='float32') @@ -237,8 +278,15 @@ def gaussian(M: int, std: float, sym: bool = True) -> Tensor: def exponential(M: int, center=None, tau=1., sym: bool = True) -> Tensor: """Compute an exponential (or Poisson) window. - - This function is consistent with scipy.signal.windows.exponential(). + Parameters: + M(int): window size. + tau(float): the window-specific parameter. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.exponential(). """ if sym and center is not None: raise ValueError("If sym==True, center must be None.") @@ -257,8 +305,14 @@ def exponential(M: int, center=None, tau=1., sym: bool = True) -> Tensor: def triang(M: int, sym: bool = True) -> Tensor: """Compute a triangular window. - - This function is consistent with scipy.signal.windows.triang(). + Parameters: + M(int): window size. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.triang(). """ if _len_guards(M): return paddle.ones((M, ), dtype='float32') @@ -278,15 +332,22 @@ def triang(M: int, sym: bool = True) -> Tensor: def bohman(M: int, sym: bool = True) -> Tensor: """Compute a Bohman window. The Bohman window is the autocorrelation of a cosine window. - - This function is consistent with scipy.signal.windows.bohman(). + Parameters: + M(int): window size. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.bohman(). """ if _len_guards(M): return paddle.ones((M, ), dtype='float32') M, needs_trunc = _extend(M, sym) fac = paddle.abs(paddle.linspace(-1, 1, M)[1:-1]) - w = (1 - fac) * paddle.cos(_PI * fac) + 1.0 / _PI * paddle.sin(_PI * fac) + w = (1 - fac) * paddle.cos(math.pi * fac) + 1.0 / math.pi * paddle.sin( + math.pi * fac) w = _cat([0, w, 0], 'float32') return _truncate(w, needs_trunc) @@ -299,19 +360,32 @@ def blackman(M: int, sym: bool = True) -> Tensor: leakage possible. It is close to optimal, only slightly worse than a Kaiser window. - This function is consistent with scipy.signal.windows.blackman(). + Parameters: + M(int): window size. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.blackman(). """ return general_cosine(M, [0.42, 0.50, 0.08], sym) def cosine(M: int, sym: bool = True) -> Tensor: """Compute a window with a simple cosine shape. - - This function is consistent with scipy.signal.windows.cosine(). + Parameters: + M(int): window size. + sym(bool):whether to return symmetric window. + The default value is True + Returns: + Tensor: the window tensor + Notes: + This function is consistent with scipy.signal.windows.cosine(). """ if _len_guards(M): return paddle.ones((M, ), dtype='float32') M, needs_trunc = _extend(M, sym) - w = paddle.sin(_PI / M * (paddle.arange(0, M) + .5)) + w = paddle.sin(math.pi / M * (paddle.arange(0, M) + .5)) return _truncate(w, needs_trunc) diff --git a/PaddleAudio/paddleaudio/functional.py b/PaddleAudio/paddleaudio/functional.py index 4b6dff04..a9c95e91 100644 --- a/PaddleAudio/paddleaudio/functional.py +++ b/PaddleAudio/paddleaudio/functional.py @@ -59,14 +59,24 @@ def complex_norm(x: Tensor) -> Tensor: """Compute compext norm of a given tensor. Typically, the input tensor is the result of a complex Fourier transform. Parameters: - x(Tensor): The input tensor of shape [..., 2] + x(Tensor): The input tensor of shape (..., 2) Returns: The element-wised l2-norm of input complex tensor. + Examples: + + .. code-block:: python + + x = paddle.rand((32, 16000)) + y = F.stft(x, n_fft=512) + z = F.complex_norm(y) + print(z.shape) + >> [32, 257, 126] + """ if x.shape[-1] != 2: raise ParameterError( - f'complex tensor must be of shape [..., 2], but received {x.shape} instead' + f'complex tensor must be of shape (..., 2), but received {x.shape} instead' ) return paddle.sqrt(paddle.square(x).sum(axis=-1)) @@ -75,13 +85,27 @@ def magphase(x: Tensor) -> Tuple[Tensor, Tensor]: """Compute compext norm of a given tensor. Typically,the input tensor is the result of a complex Fourier transform. Parameters: - x(Tensor): The input tensor of shape [..., 2]. + x(Tensor): The input tensor of shape (..., 2). Returns: The tuple of magnitude and phase. + + Shape: + x: the shape of x is arbitrary, with the shape of last axis being 2 + outputs: the shapes of magnitude and phase are both input.shape[:-1] + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = paddle.randn((10, 10, 2)) + angle, phase = F.magphase(x) + """ if x.shape[-1] != 2: raise ParameterError( - f'complex tensor must be of shape [..., 2], but received {x.shape} instead' + f'complex tensor must be of shape (..., 2), but received {x.shape} instead' ) mag = paddle.sqrt(paddle.square(x).sum(axis=-1)) x0 = x.reshape((-1, 2)) @@ -91,7 +115,8 @@ def magphase(x: Tensor) -> Tuple[Tensor, Tensor]: return mag, phase -def hz_to_mel(freq: Union[Tensor, float], htk: bool = False) -> float: +def hz_to_mel(freq: Union[Tensor, float], + htk: bool = False) -> Union[Tensor, float]: """Convert Hz to Mels. Parameters: @@ -102,6 +127,19 @@ def hz_to_mel(freq: Union[Tensor, float], htk: bool = False) -> float: The frequencies represented in Mel-scale. Notes: This function is consistent with librosa.hz_to_mel(). + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + print(F.hz_to_mel(10)) + >> 10 + print(F.hz_to_mel(paddle.to_tensor([0, 100, 1600]))) + >> Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [0., 1.50000000, 21.83624077]) + """ if htk: @@ -135,7 +173,8 @@ def hz_to_mel(freq: Union[Tensor, float], htk: bool = False) -> float: return mels -def mel_to_hz(mel: Union[float, Tensor], htk: bool = False) -> Tensor: +def mel_to_hz(mel: Union[float, Tensor], + htk: bool = False) -> Union[float, Tensor]: """Convert mel bin numbers to frequencies. Parameters: @@ -145,6 +184,19 @@ def mel_to_hz(mel: Union[float, Tensor], htk: bool = False) -> Tensor: The frequencies represented in hz. Notes: This function is consistent with librosa.mel_to_hz(). + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + print(F.mel_to_hz(10)) + >> 666.6666666666667 + print(F.mel_to_hz(paddle.to_tensor([0, 1.0, 10.0]))) + >> Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [0., 66.66666412, 666.66662598]) + """ if htk: return 700.0 * (10.0**(mel / 2595.0) - 1.0) @@ -171,7 +223,7 @@ def mel_to_hz(mel: Union[float, Tensor], htk: bool = False) -> Tensor: def mel_frequencies(n_mels: int = 128, f_min: float = 0.0, f_max: float = 11025.0, - htk: bool = False): + htk: bool = False) -> Tensor: """Compute mel frequencies. Parameters: @@ -181,8 +233,21 @@ def mel_frequencies(n_mels: int = 128, htk: whether to use htk formula. Returns: The frequencies represented in Mel-scale + Notes: This function is consistent with librosa.mel_frequencies(). + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + print(F.mel_frequencies(8)) + >> Tensor(shape=[8], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [0., 475.33898926, 950.67797852, 1551.68481445, 2533.36230469, + 4136.09960938, 6752.81396484, 11024.99902344]) + """ # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = hz_to_mel(f_min, htk=htk) @@ -200,8 +265,18 @@ def fft_frequencies(sr: int, n_fft: int) -> Tensor: n_fft(float): he number of fft bins. Returns: The frequencies represented in hz. + Notes: + This function is consistent with librosa.fft_frequencies(). + + Examples: + .. code-block:: python + + import paddle + import paddleaudio.functional as F + print(F.fft_frequencies(16000, 512)) + >> Tensor(shape=[257], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [0., 31.25000000, 62.50000000, ...] - This function is consistent with librosa.fft_frequencies(). """ return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2)) @@ -211,7 +286,7 @@ def compute_fbank_matrix(sr: int, n_mels: int = 128, f_min: float = 0.0, f_max: Optional[float] = None, - htk: bool = False): + htk: bool = False) -> Tensor: """Compute fbank matrix. Parameters: @@ -225,9 +300,22 @@ def compute_fbank_matrix(sr: int, be complex type. Otherwise, the real and image part will be stored in the last axis of returned tensor. Returns: - The fbank matrix of shape [n_mels, int(1+n_fft//2)]. + The fbank matrix of shape (n_mels, int(1+n_fft//2)). + Shape: + output: (n_mels, int(1+n_fft//2)) Notes: This function is consistent with librosa.filters.mel(). + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + m = F.compute_fbank_matrix(16000, 512) + print(m.shape) + >>[128, 257] + """ if f_max is None: @@ -270,15 +358,34 @@ def dft_matrix(n: int, return_complex: bool = False) -> Tensor: return_complex(bool): whether to return complex matrix. If True, the matrix will be complex type. Otherwise, the real and image part will be stored in the last axis of returned tensor. + Shape: + output: [n, n] or [n,n,2] + Returns: - Complex tensor of shape [n,n] if return_complex=True, and of shape [n,n,2] otherwise. + Complex tensor of shape (n,n) if return_complex=True, and of shape (n,n,2) otherwise. + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + m = F.dft_matrix(512) + print(m.shape) + >> [512, 512, 2] + m = F.dft_matrix(512, return_complex=True) + print(m.shape) + >> [512, 512] + """ x, y = paddle.meshgrid(paddle.arange(0, n), paddle.arange(0, n)) z = x * y * (-2 * math.pi / n) - cos = paddle.cos(z).unsqueeze(-1) - sin = paddle.sin(z).unsqueeze(-1) + cos = paddle.cos(z) + sin = paddle.sin(z) if return_complex: return cos + paddle.to_tensor([1j]) * sin + cos = cos.unsqueeze(-1) + sin = sin.unsqueeze(-1) return paddle.concat([cos, sin], -1) @@ -291,15 +398,30 @@ def idft_matrix(n: int, return_complex: bool = False) -> Tensor: be complex type. Otherwise, the real and image part will be stored in the last axis of returned tensor. Returns: - Complex tensor of shape [n,n] if return_complex=True, and of shape [n,n,2] otherwise. + Complex tensor of shape (n,n) if return_complex=True, and of shape (n,n,2) otherwise. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + m = F.dft_matrix(512) + print(m.shape) + >> [512, 512, 2] + m = F.dft_matrix(512, return_complex=True) + print(m.shape) + >> [512, 512] + """ x, y = paddle.meshgrid(paddle.arange(0, n), paddle.arange(0, n)) z = x * y * (2 * math.pi / n) - cos = paddle.cos(z).unsqueeze(-1) - sin = paddle.sin(z).unsqueeze(-1) + cos = paddle.cos(z) + sin = paddle.sin(z) if return_complex: return cos + paddle.to_tensor([1j]) * sin + cos = cos.unsqueeze(-1) + sin = sin.unsqueeze(-1) return paddle.concat([cos, sin], -1) @@ -316,6 +438,16 @@ def get_window(window: Union[str, Tuple[str, float]], The window represented as a tensor. Notes: This functional is consistent with scipy.signal.get_window() + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + w = F.get_window('hann', win_length=128) + print(w.shape) + >> [128] + """ sym = not fftbins @@ -363,8 +495,22 @@ def power_to_db(magnitude: Tensor, spectrum is clipped(to to_db). Returns: The spectrogram in log-scale. + shape: + input: any shape + output: same as input Notes: This function is consistent with librosa.power_to_db(). + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + F.power_to_db(paddle.rand((10, 10))) + >> Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [[-6.22858429, -3.51512218], + [-0.38168561, -1.44466150]]) + """ if amin <= 0: raise ParameterError("amin must be strictly positive") @@ -397,9 +543,18 @@ def mu_law_encode(x: Tensor, mu: int = 256, quantized: bool = True) -> Tensor: clip to be in range [0,mu-1]. quantized(bool): indicate whether the signal will quantized to integers. + Examples: + .. code-block:: python + + import paddle + import paddleaudio.functional as F + F.mu_law_encode(paddle.randn((2, 8))) + >> Tensor(shape=[2, 8], dtype=int32, place=CUDAPlace(0), stop_gradient=True, + [[0, 5, 30, 255, 255, 255, 12, 13], + [0, 241, 8, 243, 7, 35, 84, 228]]) + Reference: https://en.wikipedia.org/wiki/%CE%9C-law_algorithm - """ mu = mu - 1 y = paddle.sign(x) * paddle.log1p(mu * paddle.abs(x)) / math.log1p(mu) @@ -420,14 +575,29 @@ def mu_law_decode(x: Tensor, mu: int = 256, quantized: bool = True) -> Tensor: quantized(bool): whether the signal has been quantized to integers. The value should be the same as that used in mu_law_encode() + shape: + input: any shape + output: same as input Notes: This function assumes that the input x is in the - range [0,mu-1] when quantize is True and [-1,1] otherwise. + range [0,mu-1] when quantize is True and [-1,1] otherwise. + + + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + F.mu_law_decode(paddle.randint(0, 255, shape=(2, 8))) + >> Tensor(shape=[2, 8], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [[0.00796641, -0.28048742, -0.13789690, 0.67482352, -0.05550348, -0.00377374, 0.64593655, 0.03134083], + [0.45497340, -0.29312974, 0.29312995, -0.70499402, 0.51892924, -0.15078513, 0.07322186, 0.70499456]]) Reference: https://en.wikipedia.org/wiki/%CE%9C-law_algorithm - """ if mu < 1: raise ParameterError('mu is typically set as 2**k-1, k=1, 2, 3,...') @@ -452,7 +622,7 @@ def deframe(frames: Tensor, The frames are typically the output of inverse STFT that needs to be converted back to audio signals. Parameters: - frames(Tensor): the input audio frames of shape [N,n_fft,frame_number] or [n_fft,frame_number] + frames(Tensor): the input audio frames of shape (N,n_fft,frame_number) or (n_fft,frame_number) The frames are typically obtained from the output of inverse STFT. n_fft(int): the number of fft bins, see paddleaudio.functional.stft() hop_length(int): the hop length, see paddleaudio.functional.stft() @@ -466,7 +636,20 @@ def deframe(frames: Tensor, This function is implemented by transposing and reshaping. Shape: - - frames: 2-D tensor of shape [batch, signal_length]. + - input: (N,n_fft,frame_number] or (n_fft,frame_number) + - output: ( N, signal_length) + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = paddle.rand((128, 200)) + x = F.deframe(x, n_fft=128, hop_length=64, win_length=200) + print(x.shape) + >> [128, 200] + """ assert frames.ndim == 2 or frames.ndim == 3, ( f'The input frame must be a 2-d or 3-d tensor, ' + @@ -514,6 +697,22 @@ def random_masking(x: Tensor, The default value is -1. Returns: Tensor: the tensor after masking. + Examples: + + .. code-block:: python + + x = paddle.rand((64, 100)) + x = F.random_masking(x, max_mask_count=10, max_mask_width=2, axis=0) + print((x[:, 0] == 0).astype('int32').sum()) + >> Tensor(shape=[1], dtype=int32, place=CUDAPlace(0), stop_gradient=True, + [5]) + + x = paddle.rand((64, 100)) + x = F.random_masking(x, max_mask_count=10, max_mask_width=2, axis=1) + print((x[0, :] == 0).astype('int32').sum()) + >> Tensor(shape=[1], dtype=int32, place=CUDAPlace(0), stop_gradient=True, + [8]) + """ assert x.ndim == 2 or x.ndim, (f'only supports 2d or 3d tensor, ' + @@ -564,7 +763,22 @@ def random_cropping(x: Tensor, target_size: int, axis=-1) -> Tensor: Returns: Tensor: the cropped tensor. If target_size >= x.shape[axis], the original input tensor is returned without cropping. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = paddle.randn((2, 8)) + y = F.random_cropping(x, target_size=6) + print(y.shape) + >> [2, 6] + y = F.random_cropping(x, target_size=10) + print(y.shape) + >> [2, 8] # same as x + """ + assert axis < x.ndim, ('axis must be smaller than x.ndim, ' + f'but received aixs={axis},x.ndim={x.ndim}') @@ -592,8 +806,8 @@ def center_padding(x: Tensor, The function pads input x with pad_value to target_size along axis, such that output.shape[axis] == target_size Parameters: - x(Tensor): the input tesnor to apply padding in a central way. - target_size(int): the target lenght after padding. + x(Tensor): the input tensor to apply padding in a central way. + target_size(int): the target length after padding. axis(int):the axis along which to apply padding. The default value is -1. pad_value(int):the padding value. @@ -601,6 +815,16 @@ def center_padding(x: Tensor, Returns: Tensor: the padded tensor. If target_size <= x.shape[axis], the original input tensor is returned without padding. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = F.center_padding(paddle.randn(([8, 10])), target_size=12, axis=1) + print(x.shape) + >> [8, 12] + """ assert axis < x.ndim, ('axis must be smaller than x.ndim, ' + f'but received aixs={axis},x.ndim={x.ndim}') @@ -659,18 +883,28 @@ def stft(x: Tensor, Otherwise, it will return the full spectrum that have n_fft+1 frequency values. The default value is True. Shape: - - x: 1-D tensor with shape: (signal_length,) or 2-D tensor with shape (batch, signal_length). - - output: 2-D tensor with shape [batch_size, freq_dim, frame_number,2], + - x: 1-D tensor with shape: (signal_length,) or 2-D tensor with shape (N, signal_length). + - output: 2-D tensor with shape (N, freq_dim, frame_number,2), where freq_dim = n_fft+1 if one_sided is False and n_fft//2+1 if True. - The batch_size is set to 1 if input singal x is 1D tensor. + The batch size N is set to 1 if input singal x is 1D tensor. Notes: This result of stft function is consistent with librosa.stft() for the default value setting. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = F.istft(paddle.randn(([8, 1025, 32, 2])), signal_length=16000) + print(x.shape) + >> [8, 16000] + """ assert x.ndim in [ 1, 2 ], (f'The input signal x must be a 1-d tensor for ' + 'non-batched signal or 2-d tensor for batched signal, ' + - f'but received ndim={input.ndim} instead') + f'but received ndim={x.ndim} instead') if x.ndim == 1: x = x.unsqueeze((0, 1)) @@ -715,7 +949,7 @@ def istft(x: Tensor, window: str = 'hann', center: bool = True, pad_mode: str = 'reflect', - signal_length: Optional[int] = None): + signal_length: Optional[int] = None) -> Tensor: """Compute inverse short-time Fourier transform(ISTFT) of a given spectrum signal x. To accurately recover the input signal, the exact value of parameters should match those used in stft. @@ -727,9 +961,26 @@ def istft(x: Tensor, and win_length. The default value is None. Shape: - - x: 1-D tensor with shape: (signal_length,) or 2-D tensor with shape (batch, signal_length). - - output: the signal represented as a 2-D tensor with shape [batch_size, single_length] - The batch_size is set to 1 if input singal x is 1D tensor. + - x: 1-D tensor with shape: (signal_length,) or 2-D tensor with shape (N, signal_length). + - output: the signal represented as a 2-D tensor with shape (N, single_length) + The batch size N is set to 1 if input singal x is 1D tensor. + + Examples: + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = paddle.rand((32, 16000)) + y = F.stft(x, n_fft=512) + print(x.shape) + >> [32, 16000] + z = F.istft(y, n_fft=512, signal_length=16000) + print(z.shape) + >> [32, 16000] + print((z-x).abs().mean()) + >> Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [0.00000707]) + """ assert pad_mode in [ 'constant', 'reflect' @@ -795,7 +1046,7 @@ def spectrogram(x, window: str = 'hann', center: bool = True, pad_mode: str = 'reflect', - power: float = 2.0): + power: float = 2.0) -> Tensor: """Compute spectrogram of a given signal, typically an audio waveform. The spectorgram is defined as the complex norm of the short-time Fourier transformation. @@ -819,6 +1070,20 @@ def spectrogram(x, The default value is 'reflect'. power(float): The power of the complex norm. The default value is 2.0 + Shape: + - x: 1-D tensor with shape: (signal_length,) or 2-D tensor with shape (N, signal_length). + - output: 2-D tensor with shape (N, n_fft//2+1, frame_number), + The batch size N is set to 1 if input singal x is 1D tensor. + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = F.spectrogram(paddle.randn((8, 16000,))) + print(x.shape) + >> [8, 1025, 32] """ fft_signal = stft(x, @@ -850,7 +1115,7 @@ def melspectrogram(x: Tensor, f_min: float = 0.0, f_max: Optional[float] = None, to_db: bool = False, - **kwargs): + **kwargs) -> Tensor: """Compute the melspectrogram of a given signal, typically an audio waveform. The melspectrogram is also known as filterbank or fbank feature in audio community. It is computed by multiplying spectrogram with Mel filter bank matrix. @@ -886,14 +1151,30 @@ def melspectrogram(x: Tensor, smaller than half of sample rate. The default value is None. - to_db(bool): whether to convert the manitude to db scale. + to_db(bool): whether to convert the magnitude to db scale. The default value is False. kwargs: the key-word arguments that are passed to F.power_to_db if to_db is True + Shape: + - x: 1-D tensor with shape: (signal_length,) or 2-D tensor with shape (N, signal_length). + - output: 2-D tensor with shape (N, n_mels, frame_number), + The batch size N is set to 1 if input singal x is 1D tensor. + Notes: 1. The melspectrogram function relies on F.spectrogram and F.compute_fbank_matrix. - 2. The melspectrogram function r does not convert magnitude to db by default. - """ + 2. The melspectrogram function does not convert magnitude to db by default. + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.functional as F + x = F.melspectrogram(paddle.randn((8, 16000,))) + print(x.shape) + >> [8, 128, 32] + + """ x = spectrogram(x, n_fft, hop_length, win_length, window, center, pad_mode, power) diff --git a/PaddleAudio/paddleaudio/transforms.py b/PaddleAudio/paddleaudio/transforms.py index d4b45798..7cb8edb1 100644 --- a/PaddleAudio/paddleaudio/transforms.py +++ b/PaddleAudio/paddleaudio/transforms.py @@ -30,8 +30,8 @@ __all__ = [ 'CenterPadding', 'RandomCropping', 'RandomMuLawCodec', - 'MuLawDecoding', 'MuLawEncoding', + 'MuLawDecoding', ] @@ -69,6 +69,19 @@ class STFT(nn.Layer): The batch_size is set to 1 if input singal x is 1D tensor. Notes: This result of stft transform is consistent with librosa.stft() for the default value setting. + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + stft = T.STFT(n_fft=512) + x = paddle.randn((8, 16000,)) + y = stft(x) + print(y.shape) + >> [8, 257, 126, 2] + """ def __init__(self, n_fft: int = 2048, @@ -186,8 +199,19 @@ class Spectrogram(nn.Layer): The Spectrogram transform relies on STFT transform to compute the spectrogram. By default, the weights are not learnable. To fine-tune the Fourier coefficients, set stop_gradient=False before training. + For more information, see STFT(). + + Examples: + + .. code-block:: python - For more information, see STFT(). + import paddle + import paddleaudio.transforms as T + spectrogram = T.Spectrogram(n_fft=512) + x = paddle.randn((8, 16000)) + y = spectrogram(x) + print(y.shape) + >> [8, 257, 126] """ super(Spectrogram, self).__init__() @@ -259,6 +283,17 @@ class MelSpectrogram(nn.Layer): By default, the Fourier coefficients are not learnable. To fine-tune the Fourier coefficients, set stop_gradient=False before training. The fbank matrix is handcrafted and not learnable regardless of the setting of stop_gradient. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + melspectrogram = T.MelSpectrogram(n_fft=512, n_mels=64) + x = paddle.randn((8, 16000,)) + y = melspectrogram(x) + print(y.shape) + >> [8, 64, 126] """ super(MelSpectrogram, self).__init__() @@ -339,6 +374,18 @@ class LogMelSpectrogram(nn.Layer): By default, the weights are not learnable. To fine-tune the Fourier coefficients, set stop_gradient=False before training. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + melspectrogram = T.LogMelSpectrogram(n_fft=512, n_mels=64) + x = paddle.randn((8, 16000,)) + y = melspectrogram(x) + print(y.shape) + >> [8, 64, 126] + """ super(LogMelSpectrogram, self).__init__() self._melspectrogram = MelSpectrogram(sr, n_fft, hop_length, win_length, @@ -388,6 +435,17 @@ class ISTFT(nn.Layer): - output: the signal represented as a 2-D tensor with shape [batch_size, single_length] The batch_size is set to 1 if input singal x is 1D tensor. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + melspectrogram = T.LogMelSpectrogram(n_fft=512, n_mels=64) + x = paddle.randn((8, 16000,)) + y = melspectrogram(x) + print(y.shape) + >> [8, 64, 126] """ def __init__(self, n_fft: int = 2048, @@ -479,6 +537,18 @@ class RandomMasking(nn.Layer): Notes: Please refer to paddleaudio.functional.random_masking() for more details. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + transform = T.RandomMasking(max_mask_count=10, max_mask_width=2, axis=1) + x = paddle.rand((64, 100)) + x = transform(x) + print((x[0, :] == 0).astype('int32').sum()) + >> Tensor(shape=[1], dtype=int32, place=CUDAPlace(0), stop_gradient=True, + [8]) """ def __init__(self, max_mask_count: int = 3, @@ -508,6 +578,21 @@ class Compose(): Parameters: transforms: a list of transforms. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + x = paddle.randn((2, 18000)) + transform = T.Compose([ + T.RandomCropping(target_size=16000), + T.MelSpectrogram(sr=16000, n_fft=256, n_mels=64), + T.RandomMasking() + ]) + y = transform(x) + print(y.shape) + >> [2, 64, 251] """ def __init__(self, transforms: List[Any]): @@ -537,6 +622,21 @@ class RandomCropping(nn.Layer): Notes: Please refer to paddleaudio.functional.RandomCropping() for more details. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + transform = T.RandomCropping(target_size=8, axis=1) + y = transform(x) + print(y.shape) + >> [64, 8] + transform = T.RandomCropping(target_size=100, axis=1) + y = transform(x) + print(y.shape) + >> [64, 100] + """ def __init__(self, target_size: int, axis: int = -1): super(RandomCropping, self).__init__() @@ -562,6 +662,18 @@ class CenterPadding(nn.Layer): Notes: Please refer to paddleaudio.functional.center_padding() for more details. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + x = paddle.rand((8, 10)) + transform = T.CenterPadding(target_size=12, axis=1) + y = transform(x) + print(y.shape) + >> [8, 12] + """ def __init__(self, target_size: int, axis: int = -1): super(CenterPadding, self).__init__() @@ -588,7 +700,21 @@ class MuLawEncoding(nn.Layer): the result will be converted to integer in range [0,mu-1]. Otherwise, the resulting signal is in range [-1,1] Notes: - Please refer to paddleaudio.functional.mu_encode() for more details. + Please refer to paddleaudio.functional.mu_law_encode() for more details. + + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + x = paddle.randn((2,8)) + transform = T.MuLawEncoding() + y = transform(x) + print(y) + >> Tensor(shape=[2, 8], dtype=int32, place=CUDAPlace(0), stop_gradient=True, + [[0 , 252, 77 , 250, 221, 34 , 51 , 0 ], + [227, 33 , 0 , 255, 11 , 213, 255, 10 ]]) """ def __init__(self, mu: int = 256): @@ -597,7 +723,7 @@ class MuLawEncoding(nn.Layer): self.mu = mu def forward(self, x: Tensor) -> Tensor: - return F.mu_encode(x, mu=self.mu) + return F.mu_law_encode(x, mu=self.mu) def __repr__(self, ): return self.__class__.__name__ + f'(mu={self.mu})' @@ -613,7 +739,20 @@ class MuLawDecoding(nn.Layer): quantized(bool): indicate whether the signal has been quantized. The value of quantized parameter should be consistent with that used in MuLawEncoding. Notes: - Please refer to paddleaudio.functional.mu_decode() for more details. + Please refer to paddleaudio.functional.mu_law_decode() for more details. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + x = paddle.randint(0, 255, shape=(2, 8)) + transform = T.MuLawDecoding() + y = transform(x) + print(y) + >> Tensor(shape=[2, 8], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [[-0.01151094, -0.02702747, 0.00796641, -0.91636580, 0.45497340, 0.49667698, 0.01151095, -0.24569811], + [0.21516445, -0.30633399, 0.01291343, -0.01991909, -0.00904676, 0.00105976, 0.03990653, -0.20584014]]) """ def __init__(self, mu: int = 256): @@ -622,7 +761,7 @@ class MuLawDecoding(nn.Layer): self.mu = mu def forward(self, x: Tensor) -> Tensor: - return F.mu_decode(x, mu=self.mu) + return F.mu_law_decode(x, mu=self.mu) def __repr__(self, ): return self.__class__.__name__ + f'(mu={self.mu})' @@ -640,6 +779,20 @@ class RandomMuLawCodec(nn.Layer): Notes: Please refer to MuLawDecoding() and MuLawEncoding() for more details. + Examples: + + .. code-block:: python + + import paddle + import paddleaudio.transforms as T + x = paddle.randn((2, 8)) + transform = T.RandomMuLawCodec() + y = transform(x) + print(y) + >> Tensor(shape=[2, 8], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + [[0.61542195, -0.35218054, 0.30605811, -0.12115669, -0.75794631, 0.03876950, -0.23082513, -0.49945647], + [-0.35218054, -0.87066686, -0.53548712, 1., -1., 0.49945661, 1., -0.93311179]]) + """ def __init__(self, min_mu: int = 63, max_mu: int = 255): super(RandomMuLawCodec, self).__init__() diff --git a/PaddleAudio/paddleaudio/utils/utils.py b/PaddleAudio/paddleaudio/utils/utils.py index 0b833295..56c74520 100644 --- a/PaddleAudio/paddleaudio/utils/utils.py +++ b/PaddleAudio/paddleaudio/utils/utils.py @@ -46,6 +46,9 @@ def get_logger(name: Optional[str] = None, if name is None: name = __file__ + def list_handlers(logger): + return {str(h) for h in logger.handlers} + logger = logging.getLogger(name) logging_level = getattr(logging, 'INFO') logger.setLevel(logging_level) @@ -55,8 +58,8 @@ def get_logger(name: Optional[str] = None, stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setLevel(logging_level) stdout_handler.setFormatter(formatter) - logger.addHandler(stdout_handler) - + if str(stdout_handler) not in list_handlers(logger): + logger.addHandler(stdout_handler) if log_dir: #logging to file if log_file_name is None: log_file_name = 'log' @@ -64,15 +67,17 @@ def get_logger(name: Optional[str] = None, fh = logging.FileHandler(log_file) fh.setLevel(logging_level) fh.setFormatter(formatter) - logger.addHandler(fh) + if str(fh) not in list_handlers(logger): + logger.addHandler(fh) if use_error_log: stderr_handler = logging.StreamHandler(sys.stderr) stderr_handler.setLevel(logging.WARNING) stderr_handler.setFormatter(formatter) - logger.addHandler(stderr_handler) + if str(stderr_handler) not in list_handlers(logger): + logger.addHandler(stderr_handler) - logger.propagate = False + logger.propagate = 0 return logger -- GitLab