未验证 提交 12a3d6b7 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #2388 from SmileGoat/split_paddle_audio

[draft]split paddle audio
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import compliance
from . import datasets
from . import features
from . import functional
from . import io
from . import metric
from . import sox_effects
from . import backends
import importlib.util
import warnings
from functools import wraps
from typing import Optional
#code is from https://github.com/pytorch/audio/blob/main/torchaudio/_internal/module_utils.py
def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
return all(importlib.util.find_spec(m) is not None for m in modules)
def requires_module(*modules: str):
"""Decorate function to give error message if invoked without required optional modules.
This decorator is to give better error message to users rather
than raising ``NameError: name 'module' is not defined`` at random places.
"""
missing = [m for m in modules if not is_module_available(m)]
if not missing:
# fall through. If all the modules are available, no need to decorate
def decorator(func):
return func
else:
req = f"module: {missing[0]}" if len(
missing) == 1 else f"modules: {missing}"
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires {req}")
return wrapped
return decorator
def deprecated(direction: str, version: Optional[str]=None):
"""Decorator to add deprecation message
Args:
direction (str): Migration steps to be given to users.
version (str or int): The version when the object will be removed
"""
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
message = (
f"{func.__module__}.{func.__name__} has been deprecated "
f'and will be removed from {"future" if version is None else version} release. '
f"{direction}")
warnings.warn(message, stacklevel=2)
return func(*args, **kwargs)
return wrapped
return decorator
def is_kaldi_available():
return is_module_available("paddleaudio._paddleaudio")
def requires_kaldi():
if is_kaldi_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires kaldi")
return wrapped
return decorator
def _check_soundfile_importable():
if not is_module_available("soundfile"):
return False
try:
import soundfile # noqa: F401
return True
except Exception:
warnings.warn(
"Failed to import soundfile. 'soundfile' backend is not available.")
return False
_is_soundfile_importable = _check_soundfile_importable()
def is_soundfile_available():
return _is_soundfile_importable
def requires_soundfile():
if is_soundfile_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires soundfile")
return wrapped
return decorator
def is_sox_available():
return is_module_available("paddleaudio._paddleaudio")
def requires_sox():
if is_sox_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(
f"{func.__module__}.{func.__name__} requires sox")
return wrapped
return decorator
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .soundfile_backend import depth_convert
from .soundfile_backend import soundfile_load
from .soundfile_backend import normalize
from .soundfile_backend import resample
from .soundfile_backend import soundfile_save
from .soundfile_backend import to_mono
from . import utils
from .utils import get_audio_backend
from .utils import list_audio_backends
from .utils import set_audio_backend
utils._init_audio_backend()
\ No newline at end of file
# Token form https://github.com/pytorch/audio/blob/main/torchaudio/backend/common.py with modification.
class AudioInfo:
"""return of info function.
This class is used by :ref:`"sox_io" backend<sox_io_backend>` and
:ref:`"soundfile" backend with the new interface<soundfile_backend>`.
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
:ivar str encoding: Audio encoding
The values encoding can take are one of the following:
* ``PCM_S``: Signed integer linear PCM
* ``PCM_U``: Unsigned integer linear PCM
* ``PCM_F``: Floating point linear PCM
* ``FLAC``: Flac, Free Lossless Audio Codec
* ``ULAW``: Mu-law
* ``ALAW``: A-law
* ``MP3`` : MP3, MPEG-1 Audio Layer III
* ``VORBIS``: OGG Vorbis
* ``AMR_WB``: Adaptive Multi-Rate
* ``AMR_NB``: Adaptive Multi-Rate Wideband
* ``OPUS``: Opus
* ``HTK``: Single channel 16-bit PCM
* ``UNKNOWN`` : None of above
"""
def __init__(
self,
sample_rate: int,
num_frames: int,
num_channels: int,
bits_per_sample: int,
encoding: str,
):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample
self.encoding = encoding
def __str__(self):
return (
f"AudioMetaData("
f"sample_rate={self.sample_rate}, "
f"num_frames={self.num_frames}, "
f"num_channels={self.num_channels}, "
f"bits_per_sample={self.bits_per_sample}, "
f"encoding={self.encoding}"
f")"
)
from pathlib import Path
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import Union
from paddle import Tensor
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/no_backend.py
def load(
filepath: Union[str, Path],
out: Optional[Tensor]=None,
normalization: Union[bool, float, Callable]=True,
channels_first: bool=True,
num_frames: int=0,
offset: int=0,
filetype: Optional[str]=None, ) -> Tuple[Tensor, int]:
raise RuntimeError("No audio I/O backend is available.")
def save(filepath: str,
src: Tensor,
sample_rate: int,
precision: int=16,
channels_first: bool=True) -> None:
raise RuntimeError("No audio I/O backend is available.")
def info(filepath: str) -> None:
raise RuntimeError("No audio I/O backend is available.")
此差异已折叠。
from pathlib import Path
from typing import Callable
from typing import Optional, Tuple, Union
import paddle
import paddleaudio
from paddle import Tensor
from .common import AudioInfo
import os
from paddleaudio._internal import module_utils as _mod_utils
#https://github.com/pytorch/audio/blob/main/torchaudio/backend/sox_io_backend.py
def _fail_info(filepath: str, format: Optional[str]) -> AudioInfo:
raise RuntimeError("Failed to fetch metadata from {}".format(filepath))
def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioInfo:
raise RuntimeError("Failed to fetch metadata from {}".format(fileobj))
# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[Tensor, int]:
raise RuntimeError("Failed to load audio from {}".format(filepath))
def _fail_load_fileobj(fileobj, *args, **kwargs):
raise RuntimeError(f"Failed to load audio from {fileobj}")
_fallback_info = _fail_info
_fallback_info_fileobj = _fail_info_fileobj
_fallback_load = _fail_load
_fallback_load_filebj = _fail_load_fileobj
@_mod_utils.requires_sox()
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int=-1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str]=None, ) -> Tuple[Tensor, int]:
if hasattr(filepath, "read"):
ret = paddleaudio._paddleaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath)
ret = paddleaudio._paddleaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
audio_tensor = paddle.to_tensor(ret[0])
return (audio_tensor, ret[1])
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_sox()
def save(filepath: str,
src: Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
src_arr = src.numpy()
if hasattr(filepath, "write"):
paddleaudio._paddleaudio.save_audio_fileobj(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
return
filepath = os.fspath(filepath)
paddleaudio._paddleaudio.sox_io_save_audio_file(
filepath, src_arr, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
@_mod_utils.requires_sox()
def info(filepath: str, format: Optional[str] = None,) -> AudioInfo:
if hasattr(filepath, "read"):
sinfo = paddleaudio._paddleaudio.get_info_fileobj(filepath, format)
if sinfo is not None:
return AudioInfo(*sinfo)
return _fallback_info_fileobj(filepath, format)
filepath = os.fspath(filepath)
sinfo = paddleaudio._paddleaudio.get_info_file(filepath, format)
if sinfo is not None:
return AudioInfo(*sinfo)
return _fallback_info(filepath, format)
"""Defines utilities for switching audio backends"""
#code is from: https://github.com/pytorch/audio/blob/main/torchaudio/backend/utils.py
import warnings
from typing import List
from typing import Optional
import paddleaudio
from paddleaudio._internal import module_utils as _mod_utils
from . import no_backend, soundfile_backend, sox_io_backend
__all__ = [
"list_audio_backends",
"get_audio_backend",
"set_audio_backend",
]
def list_audio_backends() -> List[str]:
"""List available backends
Returns:
List[str]: The list of available backends.
"""
backends = []
if _mod_utils.is_module_available("soundfile"):
backends.append("soundfile")
if _mod_utils.is_sox_available():
backends.append("sox_io")
return backends
def set_audio_backend(backend: Optional[str]):
"""Set the backend for I/O operation
Args:
backend (str or None): Name of the backend.
One of ``"sox_io"`` or ``"soundfile"`` based on availability
of the system. If ``None`` is provided the current backend is unassigned.
"""
if backend is not None and backend not in list_audio_backends():
raise RuntimeError(f'Backend "{backend}" is not one of '
f"available backends: {list_audio_backends()}.")
if backend is None:
module = no_backend
elif backend == "sox_io":
module = sox_io_backend
elif backend == "soundfile":
module = soundfile_backend
else:
raise NotImplementedError(f'Unexpected backend "{backend}"')
for func in ["save", "load", "info"]:
setattr(paddleaudio, func, getattr(module, func))
def _init_audio_backend():
backends = list_audio_backends()
if "soundfile" in backends:
set_audio_backend("soundfile")
elif "sox_io" in backends:
set_audio_backend("sox_io")
else:
warnings.warn("No audio backend is available.")
set_audio_backend(None)
def get_audio_backend() -> Optional[str]:
"""Get the name of the current backend
Returns:
Optional[str]: The name of the current backend or ``None`` if no backend is assigned.
"""
if paddleaudio.load == no_backend.load:
return None
if paddleaudio.load == sox_io_backend.load:
return "sox_io"
if paddleaudio.load == soundfile_backend.load:
return "soundfile"
raise ValueError("Unknown backend.")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import kaldi
from . import librosa
此差异已折叠。
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .esc50 import ESC50
from .gtzan import GTZAN
from .hey_snips import HeySnips
from .rirs_noises import OpenRIRNoise
from .tess import TESS
from .urban_sound import UrbanSound8K
from .voxceleb import VoxCeleb
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
import paddle
from ..backends.soundfile_backend import soundfile_load as load_audio
from ..compliance.kaldi import fbank as kaldi_fbank
from ..compliance.kaldi import mfcc as kaldi_mfcc
from ..compliance.librosa import melspectrogram
from ..compliance.librosa import mfcc
feat_funcs = {
'raw': None,
'melspectrogram': melspectrogram,
'mfcc': mfcc,
'kaldi_fbank': kaldi_fbank,
'kaldi_mfcc': kaldi_mfcc,
}
class AudioClassificationDataset(paddle.io.Dataset):
"""
Base class of audio classification dataset.
"""
def __init__(self,
files: List[str],
labels: List[int],
feat_type: str='raw',
sample_rate: int=None,
**kwargs):
"""
Ags:
files (:obj:`List[str]`): A list of absolute path of audio files.
labels (:obj:`List[int]`): Labels of audio files.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
super(AudioClassificationDataset, self).__init__()
if feat_type not in feat_funcs.keys():
raise RuntimeError(
f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
)
self.files = files
self.labels = labels
self.feat_type = feat_type
self.sample_rate = sample_rate
self.feat_config = kwargs # Pass keyword arguments to customize feature config
def _get_data(self, input_file: str):
raise NotImplementedError
def _convert_to_record(self, idx):
file, label = self.files[idx], self.labels[idx]
if self.sample_rate is None:
waveform, sample_rate = load_audio(file)
else:
waveform, sample_rate = load_audio(file, sr=self.sample_rate)
feat_func = feat_funcs[self.feat_type]
record = {}
if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']:
waveform = paddle.to_tensor(waveform).unsqueeze(0) # (C, T)
record['feat'] = feat_func(
waveform=waveform, sr=self.sample_rate, **self.feat_config)
else:
record['feat'] = feat_func(
waveform, sample_rate,
**self.feat_config) if feat_func else waveform
record['label'] = label
return record
def __getitem__(self, idx):
record = self._convert_to_record(idx)
if self.feat_type in ['kaldi_fbank', 'kaldi_mfcc']:
return self.keys[idx], record['feat'], record['label']
else:
return np.array(record['feat']).transpose(), np.array(
record['label'], dtype=np.int64)
def __len__(self):
return len(self.files)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
from typing import List
from typing import Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['ESC50']
class ESC50(AudioClassificationDataset):
"""
The ESC-50 dataset is a labeled collection of 2000 environmental audio recordings
suitable for benchmarking methods of environmental sound classification. The dataset
consists of 5-second-long recordings organized into 50 semantical classes (with
40 examples per class)
Reference:
ESC: Dataset for Environmental Sound Classification
http://dx.doi.org/10.1145/2733373.2806390
"""
archieves = [
{
'url':
'https://paddleaudio.bj.bcebos.com/datasets/ESC-50-master.zip',
'md5': '7771e4b9d86d0945acce719c7a59305a',
},
]
label_list = [
# Animals
'Dog',
'Rooster',
'Pig',
'Cow',
'Frog',
'Cat',
'Hen',
'Insects (flying)',
'Sheep',
'Crow',
# Natural soundscapes & water sounds
'Rain',
'Sea waves',
'Crackling fire',
'Crickets',
'Chirping birds',
'Water drops',
'Wind',
'Pouring water',
'Toilet flush',
'Thunderstorm',
# Human, non-speech sounds
'Crying baby',
'Sneezing',
'Clapping',
'Breathing',
'Coughing',
'Footsteps',
'Laughing',
'Brushing teeth',
'Snoring',
'Drinking, sipping',
# Interior/domestic sounds
'Door knock',
'Mouse click',
'Keyboard typing',
'Door, wood creaks',
'Can opening',
'Washing machine',
'Vacuum cleaner',
'Clock alarm',
'Clock tick',
'Glass breaking',
# Exterior/urban noises
'Helicopter',
'Chainsaw',
'Siren',
'Car horn',
'Engine',
'Train',
'Church bells',
'Airplane',
'Fireworks',
'Hand saw',
]
meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
meta_info = collections.namedtuple(
'META_INFO',
('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'))
audio_path = os.path.join('ESC-50-master', 'audio')
def __init__(self,
mode: str='train',
split: int=1,
feat_type: str='raw',
**kwargs):
"""
Ags:
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train or dev).
split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
files, labels = self._get_data(mode, split)
super(ESC50, self).__init__(
files=files, labels=labels, feat_type=feat_type, **kwargs)
def _get_meta_info(self) -> List[collections.namedtuple]:
ret = []
with open(os.path.join(DATA_HOME, self.meta), 'r') as rf:
for line in rf.readlines()[1:]:
ret.append(self.meta_info(*line.strip().split(',')))
return ret
def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
download_and_decompress(self.archieves, DATA_HOME)
meta_info = self._get_meta_info()
files = []
labels = []
for sample in meta_info:
filename, fold, target, _, _, _, _ = sample
if mode == 'train' and int(fold) != split:
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
labels.append(int(target))
if mode != 'train' and int(fold) == split:
files.append(os.path.join(DATA_HOME, self.audio_path, filename))
labels.append(int(target))
return files, labels
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import random
from typing import List
from typing import Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['GTZAN']
class GTZAN(AudioClassificationDataset):
"""
The GTZAN dataset consists of 1000 audio tracks each 30 seconds long. It contains 10 genres,
each represented by 100 tracks. The dataset is the most-used public dataset for evaluation
in machine listening research for music genre recognition (MGR).
Reference:
Musical genre classification of audio signals
https://ieeexplore.ieee.org/document/1021072/
"""
archieves = [
{
'url': 'http://opihi.cs.uvic.ca/sound/genres.tar.gz',
'md5': '5b3d6dddb579ab49814ab86dba69e7c7',
},
]
label_list = [
'blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal',
'pop', 'reggae', 'rock'
]
meta = os.path.join('genres', 'input.mf')
meta_info = collections.namedtuple('META_INFO', ('file_path', 'label'))
audio_path = 'genres'
def __init__(self,
mode='train',
seed=0,
n_folds=5,
split=1,
feat_type='raw',
**kwargs):
"""
Ags:
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train or dev).
seed (:obj:`int`, `optional`, defaults to 0):
Set the random seed to shuffle samples.
n_folds (:obj:`int`, `optional`, defaults to 5):
Split the dataset into n folds. 1 fold for dev dataset and n-1 for train dataset.
split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
files, labels = self._get_data(mode, seed, n_folds, split)
super(GTZAN, self).__init__(
files=files, labels=labels, feat_type=feat_type, **kwargs)
def _get_meta_info(self) -> List[collections.namedtuple]:
ret = []
with open(os.path.join(DATA_HOME, self.meta), 'r') as rf:
for line in rf.readlines():
ret.append(self.meta_info(*line.strip().split('\t')))
return ret
def _get_data(self, mode, seed, n_folds,
split) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
download_and_decompress(self.archieves, DATA_HOME)
meta_info = self._get_meta_info()
random.seed(seed) # shuffle samples to split data
random.shuffle(
meta_info
) # make sure using the same seed to create train and dev dataset
files = []
labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info):
file_path, label = sample
filename = os.path.basename(file_path)
target = self.label_list.index(label)
fold = idx // n_samples_per_fold + 1
if mode == 'train' and int(fold) != split:
files.append(
os.path.join(DATA_HOME, self.audio_path, label, filename))
labels.append(target)
if mode != 'train' and int(fold) == split:
files.append(
os.path.join(DATA_HOME, self.audio_path, label, filename))
labels.append(target)
return files, labels
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import os
from typing import List
from typing import Tuple
from .dataset import AudioClassificationDataset
__all__ = ['HeySnips']
class HeySnips(AudioClassificationDataset):
meta_info = collections.namedtuple('META_INFO',
('key', 'label', 'duration', 'wav'))
def __init__(self,
data_dir: os.PathLike,
mode: str='train',
feat_type: str='kaldi_fbank',
sample_rate: int=16000,
**kwargs):
self.data_dir = data_dir
files, labels = self._get_data(mode)
super(HeySnips, self).__init__(
files=files,
labels=labels,
feat_type=feat_type,
sample_rate=sample_rate,
**kwargs)
def _get_meta_info(self, mode) -> List[collections.namedtuple]:
ret = []
with open(os.path.join(self.data_dir, '{}.json'.format(mode)),
'r') as f:
data = json.load(f)
for item in data:
sample = collections.OrderedDict()
if item['duration'] > 0:
sample['key'] = item['id']
sample['label'] = 0 if item['is_hotword'] == 1 else -1
sample['duration'] = item['duration']
sample['wav'] = os.path.join(self.data_dir,
item['audio_file_path'])
ret.append(self.meta_info(*sample.values()))
return ret
def _get_data(self, mode: str) -> Tuple[List[str], List[int]]:
meta_info = self._get_meta_info(mode)
files = []
labels = []
self.keys = []
self.durations = []
for sample in meta_info:
key, target, duration, wav = sample
files.append(wav)
labels.append(int(target))
self.keys.append(key)
self.durations.append(float(duration))
return files, labels
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import os
import random
from typing import List
from paddle.io import Dataset
from tqdm import tqdm
from ..backends.soundfile_backend import soundfile_load as load_audio
from ..backends.soundfile_backend import soundfile_save as save_wav
from ..utils import DATA_HOME
from ..utils.download import download_and_decompress
from .dataset import feat_funcs
__all__ = ['OpenRIRNoise']
class OpenRIRNoise(Dataset):
archieves = [
{
'url': 'http://www.openslr.org/resources/28/rirs_noises.zip',
'md5': 'e6f48e257286e05de56413b4779d8ffb',
},
]
sample_rate = 16000
meta_info = collections.namedtuple('META_INFO', ('id', 'duration', 'wav'))
base_path = os.path.join(DATA_HOME, 'open_rir_noise')
wav_path = os.path.join(base_path, 'RIRS_NOISES')
csv_path = os.path.join(base_path, 'csv')
subsets = ['rir', 'noise']
def __init__(self,
subset: str='rir',
feat_type: str='raw',
target_dir=None,
random_chunk: bool=True,
chunk_duration: float=3.0,
seed: int=0,
**kwargs):
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
self.subset = subset
self.feat_type = feat_type
self.feat_config = kwargs
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
OpenRIRNoise.csv_path = os.path.join(
target_dir, "open_rir_noise",
"csv") if target_dir else self.csv_path
self._data = self._get_data()
super(OpenRIRNoise, self).__init__()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def _get_data(self):
# Download audio files.
print(f"rirs noises base path: {self.base_path}")
if not os.path.isdir(self.base_path):
download_and_decompress(
self.archieves, self.base_path, decompress=True)
else:
print(
f"{self.base_path} already exists, we will not download and decompress again"
)
# Data preparation.
print(f"prepare the csv to {self.csv_path}")
if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path)
self.prepare_data()
data = []
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav = line.strip().split(',')
data.append(self.meta_info(audio_id, float(duration), wav))
random.shuffle(data)
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
return record
@staticmethod
def _get_chunks(seg_dur, audio_id, audio_duration):
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def _get_audio_info(self, wav_file: str,
split_chunks: bool) -> List[List[str]]:
waveform, sr = load_audio(wav_file)
audio_id = wav_file.split("/open_rir_noise/")[-1].split(".")[0]
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks and audio_duration > self.chunk_duration: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
audio_duration)
for idx, chunk in enumerate(uniq_chunks_list):
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
new_wav_file = os.path.join(self.base_path,
audio_id + f'_chunk_{idx+1:02}.wav')
save_wav(waveform[start_sample:end_sample], sr, new_wav_file)
# id, duration, new_wav
ret.append([chunk, self.chunk_duration, new_wav_file])
else: # Keep whole audio.
ret.append([audio_id, audio_duration, wav_file])
return ret
def generate_csv(self,
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav"]
infos = list(
tqdm(
map(self._get_audio_info, wav_files, [split_chunks] * len(
wav_files)),
total=len(wav_files)))
csv_lines = []
for info in infos:
csv_lines.extend(info)
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(self):
rir_list = os.path.join(self.wav_path, "real_rirs_isotropic_noises",
"rir_list")
rir_files = []
with open(rir_list, 'r') as f:
for line in f.readlines():
rir_file = line.strip().split(' ')[-1]
rir_files.append(os.path.join(self.base_path, rir_file))
noise_list = os.path.join(self.wav_path, "pointsource_noises",
"noise_list")
noise_files = []
with open(noise_list, 'r') as f:
for line in f.readlines():
noise_file = line.strip().split(' ')[-1]
noise_files.append(os.path.join(self.base_path, noise_file))
self.generate_csv(rir_files, os.path.join(self.csv_path, 'rir.csv'))
self.generate_csv(noise_files, os.path.join(self.csv_path, 'noise.csv'))
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import random
from typing import List
from typing import Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['TESS']
class TESS(AudioClassificationDataset):
"""
TESS is a set of 200 target words were spoken in the carrier phrase
"Say the word _____' by two actresses (aged 26 and 64 years) and
recordings were made of the set portraying each of seven emotions(anger,
disgust, fear, happiness, pleasant surprise, sadness, and neutral).
There are 2800 stimuli in total.
Reference:
Toronto emotional speech set (TESS)
https://doi.org/10.5683/SP2/E8H2MF
"""
archieves = [
{
'url':
'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
'md5':
'1465311b24d1de704c4c63e4ccc470c7',
},
]
label_list = [
'angry',
'disgust',
'fear',
'happy',
'neutral',
'ps', # pleasant surprise
'sad',
]
meta_info = collections.namedtuple('META_INFO',
('speaker', 'word', 'emotion'))
audio_path = 'TESS_Toronto_emotional_speech_set'
def __init__(self,
mode='train',
seed=0,
n_folds=5,
split=1,
feat_type='raw',
**kwargs):
"""
Ags:
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train or dev).
seed (:obj:`int`, `optional`, defaults to 0):
Set the random seed to shuffle samples.
n_folds (:obj:`int`, `optional`, defaults to 5):
Split the dataset into n folds. 1 fold for dev dataset and n-1 for train dataset.
split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
files, labels = self._get_data(mode, seed, n_folds, split)
super(TESS, self).__init__(
files=files, labels=labels, feat_type=feat_type, **kwargs)
def _get_meta_info(self, files) -> List[collections.namedtuple]:
ret = []
for file in files:
basename_without_extend = os.path.basename(file)[:-4]
ret.append(self.meta_info(*basename_without_extend.split('_')))
return ret
def _get_data(self, mode, seed, n_folds,
split) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)):
download_and_decompress(self.archieves, DATA_HOME)
wav_files = []
for root, _, files in os.walk(os.path.join(DATA_HOME, self.audio_path)):
for file in files:
if file.endswith('.wav'):
wav_files.append(os.path.join(root, file))
random.seed(seed) # shuffle samples to split data
random.shuffle(
wav_files
) # make sure using the same seed to create train and dev dataset
meta_info = self._get_meta_info(wav_files)
files = []
labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info):
_, _, emotion = sample
target = self.label_list.index(emotion)
fold = idx // n_samples_per_fold + 1
if mode == 'train' and int(fold) != split:
files.append(wav_files[idx])
labels.append(target)
if mode != 'train' and int(fold) == split:
files.append(wav_files[idx])
labels.append(target)
return files, labels
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
from typing import List
from typing import Tuple
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from .dataset import AudioClassificationDataset
__all__ = ['UrbanSound8K']
class UrbanSound8K(AudioClassificationDataset):
"""
UrbanSound8K dataset contains 8732 labeled sound excerpts (<=4s) of urban
sounds from 10 classes: air_conditioner, car_horn, children_playing, dog_bark,
drilling, enginge_idling, gun_shot, jackhammer, siren, and street_music. The
classes are drawn from the urban sound taxonomy.
Reference:
A Dataset and Taxonomy for Urban Sound Research
https://dl.acm.org/doi/10.1145/2647868.2655045
"""
archieves = [
{
'url':
'https://zenodo.org/record/1203745/files/UrbanSound8K.tar.gz',
'md5': '9aa69802bbf37fb986f71ec1483a196e',
},
]
label_list = [
"air_conditioner", "car_horn", "children_playing", "dog_bark",
"drilling", "engine_idling", "gun_shot", "jackhammer", "siren",
"street_music"
]
meta = os.path.join('UrbanSound8K', 'metadata', 'UrbanSound8K.csv')
meta_info = collections.namedtuple(
'META_INFO', ('filename', 'fsid', 'start', 'end', 'salience', 'fold',
'class_id', 'label'))
audio_path = os.path.join('UrbanSound8K', 'audio')
def __init__(self,
mode: str='train',
split: int=1,
feat_type: str='raw',
**kwargs):
files, labels = self._get_data(mode, split)
super(UrbanSound8K, self).__init__(
files=files, labels=labels, feat_type=feat_type, **kwargs)
"""
Ags:
mode (:obj:`str`, `optional`, defaults to `train`):
It identifies the dataset mode (train or dev).
split (:obj:`int`, `optional`, defaults to 1):
It specify the fold of dev dataset.
feat_type (:obj:`str`, `optional`, defaults to `raw`):
It identifies the feature type that user wants to extrace of an audio file.
"""
def _get_meta_info(self):
ret = []
with open(os.path.join(DATA_HOME, self.meta), 'r') as rf:
for line in rf.readlines()[1:]:
ret.append(self.meta_info(*line.strip().split(',')))
return ret
def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
download_and_decompress(self.archieves, DATA_HOME)
meta_info = self._get_meta_info()
files = []
labels = []
for sample in meta_info:
filename, _, _, _, _, fold, target, _ = sample
if mode == 'train' and int(fold) != split:
files.append(
os.path.join(DATA_HOME, self.audio_path, f'fold{fold}',
filename))
labels.append(int(target))
if mode != 'train' and int(fold) == split:
files.append(
os.path.join(DATA_HOME, self.audio_path, f'fold{fold}',
filename))
labels.append(int(target))
return files, labels
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import glob
import os
import random
from multiprocessing import cpu_count
from typing import List
from paddle.io import Dataset
from pathos.multiprocessing import Pool
from tqdm import tqdm
from ..backends.soundfile_backend import soundfile_load as load_audio
from ..utils import DATA_HOME
from ..utils import decompress
from ..utils.download import download_and_decompress
from .dataset import feat_funcs
__all__ = ['VoxCeleb']
class VoxCeleb(Dataset):
source_url = 'https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/'
archieves_audio_dev = [
{
'url': source_url + 'vox1_dev_wav_partaa',
'md5': 'e395d020928bc15670b570a21695ed96',
},
{
'url': source_url + 'vox1_dev_wav_partab',
'md5': 'bbfaaccefab65d82b21903e81a8a8020',
},
{
'url': source_url + 'vox1_dev_wav_partac',
'md5': '017d579a2a96a077f40042ec33e51512',
},
{
'url': source_url + 'vox1_dev_wav_partad',
'md5': '7bb1e9f70fddc7a678fa998ea8b3ba19',
},
]
archieves_audio_test = [
{
'url': source_url + 'vox1_test_wav.zip',
'md5': '185fdc63c3c739954633d50379a3d102',
},
]
archieves_meta = [
{
'url':
'https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt',
'md5':
'b73110731c9223c1461fe49cb48dddfc',
},
]
num_speakers = 1211 # 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
sample_rate = 16000
meta_info = collections.namedtuple(
'META_INFO', ('id', 'duration', 'wav', 'start', 'stop', 'spk_id'))
base_path = os.path.join(DATA_HOME, 'vox1')
wav_path = os.path.join(base_path, 'wav')
meta_path = os.path.join(base_path, 'meta')
veri_test_file = os.path.join(meta_path, 'veri_test2.txt')
csv_path = os.path.join(base_path, 'csv')
subsets = ['train', 'dev', 'enroll', 'test']
def __init__(
self,
subset: str='train',
feat_type: str='raw',
random_chunk: bool=True,
chunk_duration: float=3.0, # seconds
split_ratio: float=0.9, # train split ratio
seed: int=0,
target_dir: str=None,
vox2_base_path=None,
**kwargs):
"""VoxCeleb data prepare and get the specific dataset audio info
Args:
subset (str, optional): dataset name, such as train, dev, enroll or test. Defaults to 'train'.
feat_type (str, optional): feat type, such raw, melspectrogram(fbank) or mfcc . Defaults to 'raw'.
random_chunk (bool, optional): random select a duration from audio. Defaults to True.
chunk_duration (float, optional): chunk duration if random_chunk flag is set. Defaults to 3.0.
target_dir (str, optional): data dir, audio info will be stored in this directory. Defaults to None.
vox2_base_path (_type_, optional): vox2 directory. vox2 data must be converted from m4a to wav. Defaults to None.
"""
assert subset in self.subsets, \
'Dataset subset must be one in {}, but got {}'.format(self.subsets, subset)
self.subset = subset
self.spk_id2label = {}
self.feat_type = feat_type
self.feat_config = kwargs
self.random_chunk = random_chunk
self.chunk_duration = chunk_duration
self.split_ratio = split_ratio
self.target_dir = target_dir if target_dir else VoxCeleb.base_path
self.vox2_base_path = vox2_base_path
# if we set the target dir, we will change the vox data info data from base path to target dir
VoxCeleb.csv_path = os.path.join(
target_dir, "voxceleb", 'csv') if target_dir else VoxCeleb.csv_path
VoxCeleb.meta_path = os.path.join(
target_dir, "voxceleb",
'meta') if target_dir else VoxCeleb.meta_path
VoxCeleb.veri_test_file = os.path.join(VoxCeleb.meta_path,
'veri_test2.txt')
# self._data = self._get_data()[:1000] # KP: Small dataset test.
self._data = self._get_data()
super(VoxCeleb, self).__init__()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def _get_data(self):
# Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status
print(f"wav base path: {self.wav_path}")
if not os.path.isdir(self.wav_path):
print("start to download the voxceleb1 dataset")
download_and_decompress( # multi-zip parts concatenate to vox1_dev_wav.zip
self.archieves_audio_dev,
self.base_path,
decompress=False)
download_and_decompress( # download the vox1_test_wav.zip and unzip
self.archieves_audio_test,
self.base_path,
decompress=True)
# Download all parts and concatenate the files into one zip file.
dev_zipfile = os.path.join(self.base_path, 'vox1_dev_wav.zip')
print(f'Concatenating all parts to: {dev_zipfile}')
os.system(
f'cat {os.path.join(self.base_path, "vox1_dev_wav_parta*")} > {dev_zipfile}'
)
# Extract all audio files of dev and test set.
decompress(dev_zipfile, self.base_path)
# Download meta files.
if not os.path.isdir(self.meta_path):
print("prepare the meta data")
download_and_decompress(
self.archieves_meta, self.meta_path, decompress=False)
# Data preparation.
if not os.path.isdir(self.csv_path):
os.makedirs(self.csv_path)
self.prepare_data()
data = []
print(
f"read the {self.subset} from {os.path.join(self.csv_path, f'{self.subset}.csv')}"
)
with open(os.path.join(self.csv_path, f'{self.subset}.csv'), 'r') as rf:
for line in rf.readlines()[1:]:
audio_id, duration, wav, start, stop, spk_id = line.strip(
).split(',')
data.append(
self.meta_info(audio_id,
float(duration), wav,
int(start), int(stop), spk_id))
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'r') as f:
for line in f.readlines():
spk_id, label = line.strip().split(' ')
self.spk_id2label[spk_id] = int(label)
return data
def _convert_to_record(self, idx: int):
sample = self._data[idx]
record = {}
# To show all fields in a namedtuple: `type(sample)._fields`
for field in type(sample)._fields:
record[field] = getattr(sample, field)
waveform, sr = load_audio(record['wav'])
# random select a chunk audio samples from the audio
if self.random_chunk:
num_wav_samples = waveform.shape[0]
num_chunk_samples = int(self.chunk_duration * sr)
start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
stop = start + num_chunk_samples
else:
start = record['start']
stop = record['stop']
waveform = waveform[start:stop]
assert self.feat_type in feat_funcs.keys(), \
f"Unknown feat_type: {self.feat_type}, it must be one in {list(feat_funcs.keys())}"
feat_func = feat_funcs[self.feat_type]
feat = feat_func(
waveform, sr=sr, **self.feat_config) if feat_func else waveform
record.update({'feat': feat})
if self.subset in ['train',
'dev']: # Labels are available in train and dev.
record.update({'label': self.spk_id2label[record['spk_id']]})
return record
@staticmethod
def _get_chunks(seg_dur, audio_id, audio_duration):
num_chunks = int(audio_duration / seg_dur) # all in milliseconds
chunk_lst = [
audio_id + "_" + str(i * seg_dur) + "_" + str(i * seg_dur + seg_dur)
for i in range(num_chunks)
]
return chunk_lst
def _get_audio_info(self, wav_file: str,
split_chunks: bool) -> List[List[str]]:
waveform, sr = load_audio(wav_file)
spk_id, sess_id, utt_id = wav_file.split("/")[-3:]
audio_id = '-'.join([spk_id, sess_id, utt_id.split(".")[0]])
audio_duration = waveform.shape[0] / sr
ret = []
if split_chunks: # Split into pieces of self.chunk_duration seconds.
uniq_chunks_list = self._get_chunks(self.chunk_duration, audio_id,
audio_duration)
for chunk in uniq_chunks_list:
s, e = chunk.split("_")[-2:] # Timestamps of start and end
start_sample = int(float(s) * sr)
end_sample = int(float(e) * sr)
# id, duration, wav, start, stop, spk_id
ret.append([
chunk, audio_duration, wav_file, start_sample, end_sample,
spk_id
])
else: # Keep whole audio.
ret.append([
audio_id, audio_duration, wav_file, 0, waveform.shape[0], spk_id
])
return ret
def generate_csv(self,
wav_files: List[str],
output_file: str,
split_chunks: bool=True):
print(f'Generating csv: {output_file}')
header = ["id", "duration", "wav", "start", "stop", "spk_id"]
# Note: this may occurs c++ execption, but the program will execute fine
# so we can ignore the execption
with Pool(cpu_count()) as p:
infos = list(
tqdm(
p.imap(lambda x: self._get_audio_info(x, split_chunks),
wav_files),
total=len(wav_files)))
csv_lines = []
for info in infos:
csv_lines.extend(info)
with open(output_file, mode="w") as csv_f:
csv_writer = csv.writer(
csv_f, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL)
csv_writer.writerow(header)
for line in csv_lines:
csv_writer.writerow(line)
def prepare_data(self):
# Audio of speakers in veri_test_file should not be included in training set.
print("start to prepare the data csv file")
enroll_files = set()
test_files = set()
# get the enroll and test audio file path
with open(self.veri_test_file, 'r') as f:
for line in f.readlines():
_, enrol_file, test_file = line.strip().split(' ')
enroll_files.add(os.path.join(self.wav_path, enrol_file))
test_files.add(os.path.join(self.wav_path, test_file))
enroll_files = sorted(enroll_files)
test_files = sorted(test_files)
# get the enroll and test speakers
test_spks = set()
for file in (enroll_files + test_files):
spk = file.split('/wav/')[1].split('/')[0]
test_spks.add(spk)
# get all the train and dev audios file path
audio_files = []
speakers = set()
print("Getting file list...")
for path in [self.wav_path, self.vox2_base_path]:
# if vox2 directory is not set and vox2 is not a directory
# we will not process this directory
if not path or not os.path.exists(path):
print(f"{path} is an invalid path, please check again, "
"and we will ignore the vox2 base path")
continue
for file in glob.glob(
os.path.join(path, "**", "*.wav"), recursive=True):
spk = file.split('/wav/')[1].split('/')[0]
if spk in test_spks:
continue
speakers.add(spk)
audio_files.append(file)
print(
f"start to generate the {os.path.join(self.meta_path, 'spk_id2label.txt')}"
)
# encode the train and dev speakers label to spk_id2label.txt
with open(os.path.join(self.meta_path, 'spk_id2label.txt'), 'w') as f:
for label, spk_id in enumerate(
sorted(speakers)): # 1211 vox1, 5994 vox2, 7205 vox1+2
f.write(f'{spk_id} {label}\n')
audio_files = sorted(audio_files)
random.shuffle(audio_files)
split_idx = int(self.split_ratio * len(audio_files))
# split_ratio to train
train_files, dev_files = audio_files[:split_idx], audio_files[
split_idx:]
self.generate_csv(train_files, os.path.join(self.csv_path, 'train.csv'))
self.generate_csv(dev_files, os.path.join(self.csv_path, 'dev.csv'))
self.generate_csv(
enroll_files,
os.path.join(self.csv_path, 'enroll.csv'),
split_chunks=False)
self.generate_csv(
test_files,
os.path.join(self.csv_path, 'test.csv'),
split_chunks=False)
def __getitem__(self, idx):
return self._convert_to_record(idx)
def __len__(self):
return len(self._data)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .layers import LogMelSpectrogram
from .layers import MelSpectrogram
from .layers import MFCC
from .layers import Spectrogram
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional
from typing import Union
import paddle
import paddle.nn as nn
from paddle import Tensor
from ..functional import compute_fbank_matrix
from ..functional import create_dct
from ..functional import power_to_db
from ..functional.window import get_window
__all__ = [
'Spectrogram',
'MelSpectrogram',
'LogMelSpectrogram',
'MFCC',
]
class Spectrogram(nn.Layer):
"""Compute spectrogram of given signals, typically audio waveforms.
The spectorgram is defined as the complex norm of the short-time Fourier transformation.
Args:
n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512.
hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None.
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
dtype (str, optional): Data type of input and window. Defaults to 'float32'.
"""
def __init__(self,
n_fft: int=512,
hop_length: Optional[int]=None,
win_length: Optional[int]=None,
window: str='hann',
power: float=2.0,
center: bool=True,
pad_mode: str='reflect',
dtype: str='float32') -> None:
super(Spectrogram, self).__init__()
assert power > 0, 'Power of spectrogram must be > 0.'
self.power = power
if win_length is None:
win_length = n_fft
self.fft_window = get_window(
window, win_length, fftbins=True, dtype=dtype)
self._stft = partial(
paddle.signal.stft,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=self.fft_window,
center=center,
pad_mode=pad_mode)
self.register_buffer('fft_window', self.fft_window)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Tensor of waveforms with shape `(N, T)`
Returns:
Tensor: Spectrograms with shape `(N, n_fft//2 + 1, num_frames)`.
"""
stft = self._stft(x)
spectrogram = paddle.pow(paddle.abs(stft), self.power)
return spectrogram
class MelSpectrogram(nn.Layer):
"""Compute the melspectrogram of given signals, typically audio waveforms. It is computed by multiplying spectrogram with Mel filter bank matrix.
Args:
sr (int, optional): Sample rate. Defaults to 22050.
n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512.
hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None.
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False.
norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'.
dtype (str, optional): Data type of input and window. Defaults to 'float32'.
"""
def __init__(self,
sr: int=22050,
n_fft: int=512,
hop_length: Optional[int]=None,
win_length: Optional[int]=None,
window: str='hann',
power: float=2.0,
center: bool=True,
pad_mode: str='reflect',
n_mels: int=64,
f_min: float=50.0,
f_max: Optional[float]=None,
htk: bool=False,
norm: Union[str, float]='slaney',
dtype: str='float32') -> None:
super(MelSpectrogram, self).__init__()
self._spectrogram = Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
power=power,
center=center,
pad_mode=pad_mode,
dtype=dtype)
self.n_mels = n_mels
self.f_min = f_min
self.f_max = f_max
self.htk = htk
self.norm = norm
if f_max is None:
f_max = sr // 2
self.fbank_matrix = compute_fbank_matrix(
sr=sr,
n_fft=n_fft,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
htk=htk,
norm=norm,
dtype=dtype) # float64 for better numerical results
self.register_buffer('fbank_matrix', self.fbank_matrix)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Tensor of waveforms with shape `(N, T)`
Returns:
Tensor: Mel spectrograms with shape `(N, n_mels, num_frames)`.
"""
spect_feature = self._spectrogram(x)
mel_feature = paddle.matmul(self.fbank_matrix, spect_feature)
return mel_feature
class LogMelSpectrogram(nn.Layer):
"""Compute log-mel-spectrogram feature of given signals, typically audio waveforms.
Args:
sr (int, optional): Sample rate. Defaults to 22050.
n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512.
hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None.
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False.
norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'.
ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
amin (float, optional): The minimum value of input magnitude. Defaults to 1e-10.
top_db (Optional[float], optional): The maximum db value of spectrogram. Defaults to None.
dtype (str, optional): Data type of input and window. Defaults to 'float32'.
"""
def __init__(self,
sr: int=22050,
n_fft: int=512,
hop_length: Optional[int]=None,
win_length: Optional[int]=None,
window: str='hann',
power: float=2.0,
center: bool=True,
pad_mode: str='reflect',
n_mels: int=64,
f_min: float=50.0,
f_max: Optional[float]=None,
htk: bool=False,
norm: Union[str, float]='slaney',
ref_value: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=None,
dtype: str='float32') -> None:
super(LogMelSpectrogram, self).__init__()
self._melspectrogram = MelSpectrogram(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
power=power,
center=center,
pad_mode=pad_mode,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
htk=htk,
norm=norm,
dtype=dtype)
self.ref_value = ref_value
self.amin = amin
self.top_db = top_db
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Tensor of waveforms with shape `(N, T)`
Returns:
Tensor: Log mel spectrograms with shape `(N, n_mels, num_frames)`.
"""
mel_feature = self._melspectrogram(x)
log_mel_feature = power_to_db(
mel_feature,
ref_value=self.ref_value,
amin=self.amin,
top_db=self.top_db)
return log_mel_feature
class MFCC(nn.Layer):
"""Compute mel frequency cepstral coefficients(MFCCs) feature of given waveforms.
Args:
sr (int, optional): Sample rate. Defaults to 22050.
n_mfcc (int, optional): [description]. Defaults to 40.
n_fft (int, optional): The number of frequency components of the discrete Fourier transform. Defaults to 512.
hop_length (Optional[int], optional): The hop length of the short time FFT. If `None`, it is set to `win_length//4`. Defaults to None.
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
htk (bool, optional): Use HTK formula in computing fbank matrix. Defaults to False.
norm (Union[str, float], optional): Type of normalization in computing fbank matrix. Slaney-style is used by default. You can specify norm=1.0/2.0 to use customized p-norm normalization. Defaults to 'slaney'.
ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
amin (float, optional): The minimum value of input magnitude. Defaults to 1e-10.
top_db (Optional[float], optional): The maximum db value of spectrogram. Defaults to None.
dtype (str, optional): Data type of input and window. Defaults to 'float32'.
"""
def __init__(self,
sr: int=22050,
n_mfcc: int=40,
n_fft: int=512,
hop_length: Optional[int]=None,
win_length: Optional[int]=None,
window: str='hann',
power: float=2.0,
center: bool=True,
pad_mode: str='reflect',
n_mels: int=64,
f_min: float=50.0,
f_max: Optional[float]=None,
htk: bool=False,
norm: Union[str, float]='slaney',
ref_value: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=None,
dtype: str=paddle.float32) -> None:
super(MFCC, self).__init__()
assert n_mfcc <= n_mels, 'n_mfcc cannot be larger than n_mels: %d vs %d' % (
n_mfcc, n_mels)
self._log_melspectrogram = LogMelSpectrogram(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
power=power,
center=center,
pad_mode=pad_mode,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
htk=htk,
norm=norm,
ref_value=ref_value,
amin=amin,
top_db=top_db,
dtype=dtype)
self.dct_matrix = create_dct(n_mfcc=n_mfcc, n_mels=n_mels, dtype=dtype)
self.register_buffer('dct_matrix', self.dct_matrix)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): Tensor of waveforms with shape `(N, T)`
Returns:
Tensor: Mel frequency cepstral coefficients with shape `(N, n_mfcc, num_frames)`.
"""
log_mel_feature = self._log_melspectrogram(x)
mfcc = paddle.matmul(
log_mel_feature.transpose((0, 2, 1)), self.dct_matrix).transpose(
(0, 2, 1)) # (B, n_mels, L)
return mfcc
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .functional import compute_fbank_matrix
from .functional import create_dct
from .functional import fft_frequencies
from .functional import hz_to_mel
from .functional import mel_frequencies
from .functional import mel_to_hz
from .functional import power_to_db
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from librosa(https://github.com/librosa/librosa)
import math
from typing import Optional
from typing import Union
import paddle
from paddle import Tensor
__all__ = [
'hz_to_mel',
'mel_to_hz',
'mel_frequencies',
'fft_frequencies',
'compute_fbank_matrix',
'power_to_db',
'create_dct',
]
def hz_to_mel(freq: Union[Tensor, float],
htk: bool=False) -> Union[Tensor, float]:
"""Convert Hz to Mels.
Args:
freq (Union[Tensor, float]): The input tensor with arbitrary shape.
htk (bool, optional): Use htk scaling. Defaults to False.
Returns:
Union[Tensor, float]: Frequency in mels.
"""
if htk:
if isinstance(freq, Tensor):
return 2595.0 * paddle.log10(1.0 + freq / 700.0)
else:
return 2595.0 * math.log10(1.0 + freq / 700.0)
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (freq - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
if isinstance(freq, Tensor):
target = min_log_mel + paddle.log(
freq / min_log_hz + 1e-10) / logstep # prevent nan with 1e-10
mask = (freq > min_log_hz).astype(freq.dtype)
mels = target * mask + mels * (
1 - mask) # will replace by masked_fill OP in future
else:
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz + 1e-10) / logstep
return mels
def mel_to_hz(mel: Union[float, Tensor],
htk: bool=False) -> Union[float, Tensor]:
"""Convert mel bin numbers to frequencies.
Args:
mel (Union[float, Tensor]): The mel frequency represented as a tensor with arbitrary shape.
htk (bool, optional): Use htk scaling. Defaults to False.
Returns:
Union[float, Tensor]: Frequencies in Hz.
"""
if htk:
return 700.0 * (10.0**(mel / 2595.0) - 1.0)
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mel
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
if isinstance(mel, Tensor):
target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
mask = (mel > min_log_mel).astype(mel.dtype)
freqs = target * mask + freqs * (
1 - mask) # will replace by masked_fill OP in future
else:
if mel >= min_log_mel:
freqs = min_log_hz * math.exp(logstep * (mel - min_log_mel))
return freqs
def mel_frequencies(n_mels: int=64,
f_min: float=0.0,
f_max: float=11025.0,
htk: bool=False,
dtype: str='float32') -> Tensor:
"""Compute mel frequencies.
Args:
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0.
fmax (float, optional): Maximum frequency in Hz. Defaults to 11025.0.
htk (bool, optional): Use htk scaling. Defaults to False.
dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'.
Returns:
Tensor: Tensor of n_mels frequencies in Hz with shape `(n_mels,)`.
"""
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = hz_to_mel(f_min, htk=htk)
max_mel = hz_to_mel(f_max, htk=htk)
mels = paddle.linspace(min_mel, max_mel, n_mels, dtype=dtype)
freqs = mel_to_hz(mels, htk=htk)
return freqs
def fft_frequencies(sr: int, n_fft: int, dtype: str='float32') -> Tensor:
"""Compute fourier frequencies.
Args:
sr (int): Sample rate.
n_fft (int): Number of fft bins.
dtype (str, optional): The data type of the return frequencies. Defaults to 'float32'.
Returns:
Tensor: FFT frequencies in Hz with shape `(n_fft//2 + 1,)`.
"""
return paddle.linspace(0, float(sr) / 2, int(1 + n_fft // 2), dtype=dtype)
def compute_fbank_matrix(sr: int,
n_fft: int,
n_mels: int=64,
f_min: float=0.0,
f_max: Optional[float]=None,
htk: bool=False,
norm: Union[str, float]='slaney',
dtype: str='float32') -> Tensor:
"""Compute fbank matrix.
Args:
sr (int): Sample rate.
n_fft (int): Number of fft bins.
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 0.0.
f_max (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
htk (bool, optional): Use htk scaling. Defaults to False.
norm (Union[str, float], optional): Type of normalization. Defaults to 'slaney'.
dtype (str, optional): The data type of the return matrix. Defaults to 'float32'.
Returns:
Tensor: Mel transform matrix with shape `(n_mels, n_fft//2 + 1)`.
"""
if f_max is None:
f_max = float(sr) / 2
# Initialize the weights
weights = paddle.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft, dtype=dtype)
# 'Center freqs' of mel bands - uniformly spaced between limits
mel_f = mel_frequencies(
n_mels + 2, f_min=f_min, f_max=f_max, htk=htk, dtype=dtype)
fdiff = mel_f[1:] - mel_f[:-1] #np.diff(mel_f)
ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
#ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = paddle.maximum(
paddle.zeros_like(lower), paddle.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
if norm == 'slaney':
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
weights *= enorm.unsqueeze(1)
elif isinstance(norm, int) or isinstance(norm, float):
weights = paddle.nn.functional.normalize(weights, p=norm, axis=-1)
return weights
def power_to_db(spect: Tensor,
ref_value: float=1.0,
amin: float=1e-10,
top_db: Optional[float]=None) -> Tensor:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way.
Args:
spect (Tensor): STFT power spectrogram.
ref_value (float, optional): The reference value. If smaller than 1.0, the db level of the signal will be pulled up accordingly. Otherwise, the db level is pushed down. Defaults to 1.0.
amin (float, optional): Minimum threshold. Defaults to 1e-10.
top_db (Optional[float], optional): Threshold the output at `top_db` below the peak. Defaults to None.
Returns:
Tensor: Power spectrogram in db scale.
"""
if amin <= 0:
raise Exception("amin must be strictly positive")
if ref_value <= 0:
raise Exception("ref_value must be strictly positive")
ones = paddle.ones_like(spect)
log_spec = 10.0 * paddle.log10(paddle.maximum(ones * amin, spect))
log_spec -= 10.0 * math.log10(max(ref_value, amin))
if top_db is not None:
if top_db < 0:
raise Exception("top_db must be non-negative")
log_spec = paddle.maximum(log_spec, ones * (log_spec.max() - top_db))
return log_spec
def create_dct(n_mfcc: int,
n_mels: int,
norm: Optional[str]='ortho',
dtype: str='float32') -> Tensor:
"""Create a discrete cosine transform(DCT) matrix.
Args:
n_mfcc (int): Number of mel frequency cepstral coefficients.
n_mels (int): Number of mel filterbanks.
norm (Optional[str], optional): Normalizaiton type. Defaults to 'ortho'.
dtype (str, optional): The data type of the return matrix. Defaults to 'float32'.
Returns:
Tensor: The DCT matrix with shape `(n_mels, n_mfcc)`.
"""
n = paddle.arange(n_mels, dtype=dtype)
k = paddle.arange(n_mfcc, dtype=dtype).unsqueeze(1)
dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) *
k) # size (n_mfcc, n_mels)
if norm is None:
dct *= 2.0
else:
assert norm == "ortho"
dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(n_mels))
return dct.T
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import math
from typing import List
from typing import Tuple
from typing import Union
import paddle
from paddle import Tensor
__all__ = [
'get_window',
]
def _cat(x: List[Tensor], data_type: str) -> Tensor:
l = [paddle.to_tensor(_, data_type) for _ in x]
return paddle.concat(l)
def _acosh(x: Union[Tensor, float]) -> Tensor:
if isinstance(x, float):
return math.log(x + math.sqrt(x**2 - 1))
return paddle.log(x + paddle.sqrt(paddle.square(x) - 1))
def _extend(M: int, sym: bool) -> bool:
"""Extend window by 1 sample if needed for DFT-even symmetry. """
if not sym:
return M + 1, True
else:
return M, False
def _len_guards(M: int) -> bool:
"""Handle small or incorrect window lengths. """
if int(M) != M or M < 0:
raise ValueError('Window length M must be a non-negative integer')
return M <= 1
def _truncate(w: Tensor, needed: bool) -> Tensor:
"""Truncate window by 1 sample if needed for DFT-even symmetry. """
if needed:
return w[:-1]
else:
return w
def _general_gaussian(M: int, p, sig, sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute a window with a generalized Gaussian shape.
This function is consistent with scipy.signal.windows.general_gaussian().
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
w = paddle.exp(-0.5 * paddle.abs(n / sig)**(2 * p))
return _truncate(w, needs_trunc)
def _general_cosine(M: int, a: float, sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute a generic weighted sum of cosine terms window.
This function is consistent with scipy.signal.windows.general_cosine().
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
fac = paddle.linspace(-math.pi, math.pi, M, dtype=dtype)
w = paddle.zeros((M, ), dtype=dtype)
for k in range(len(a)):
w += a[k] * paddle.cos(k * fac)
return _truncate(w, needs_trunc)
def _general_hamming(M: int, alpha: float, sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute a generalized Hamming window.
This function is consistent with scipy.signal.windows.general_hamming()
"""
return _general_cosine(M, [alpha, 1. - alpha], sym, dtype=dtype)
def _taylor(M: int,
nbar=4,
sll=30,
norm=True,
sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute a Taylor window.
The Taylor window taper function approximates the Dolph-Chebyshev window's
constant sidelobe level for a parameterized number of near-in sidelobes.
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
# Original text uses a negative sidelobe level parameter and then negates
# it in the calculation of B. To keep consistent with other methods we
# assume the sidelobe level parameter to be positive.
B = 10**(sll / 20)
A = _acosh(B) / math.pi
s2 = nbar**2 / (A**2 + (nbar - 0.5)**2)
ma = paddle.arange(1, nbar, dtype=dtype)
Fm = paddle.empty((nbar - 1, ), dtype=dtype)
signs = paddle.empty_like(ma)
signs[::2] = 1
signs[1::2] = -1
m2 = ma * ma
for mi in range(len(ma)):
numer = signs[mi] * paddle.prod(1 - m2[mi] / s2 / (A**2 + (ma - 0.5)**2
))
if mi == 0:
denom = 2 * paddle.prod(1 - m2[mi] / m2[mi + 1:])
elif mi == len(ma) - 1:
denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi])
else:
denom = 2 * paddle.prod(1 - m2[mi] / m2[:mi]) * paddle.prod(1 - m2[
mi] / m2[mi + 1:])
Fm[mi] = numer / denom
def W(n):
return 1 + 2 * paddle.matmul(
Fm.unsqueeze(0),
paddle.cos(2 * math.pi * ma.unsqueeze(1) * (n - M / 2. + 0.5) / M))
w = W(paddle.arange(0, M, dtype=dtype))
# normalize (Note that this is not described in the original text [1])
if norm:
scale = 1.0 / W((M - 1) / 2)
w *= scale
w = w.squeeze()
return _truncate(w, needs_trunc)
def _hamming(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a Hamming window.
The Hamming window is a taper formed by using a raised cosine with
non-zero endpoints, optimized to minimize the nearest side lobe.
"""
return _general_hamming(M, 0.54, sym, dtype=dtype)
def _hann(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a Hann window.
The Hann window is a taper formed by using a raised cosine or sine-squared
with ends that touch zero.
"""
return _general_hamming(M, 0.5, sym, dtype=dtype)
def _tukey(M: int, alpha=0.5, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a Tukey window.
The Tukey window is also known as a tapered cosine window.
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
if alpha <= 0:
return paddle.ones((M, ), dtype=dtype)
elif alpha >= 1.0:
return hann(M, sym=sym)
M, needs_trunc = _extend(M, sym)
n = paddle.arange(0, M, dtype=dtype)
width = int(alpha * (M - 1) / 2.0)
n1 = n[0:width + 1]
n2 = n[width + 1:M - width - 1]
n3 = n[M - width - 1:]
w1 = 0.5 * (1 + paddle.cos(math.pi * (-1 + 2.0 * n1 / alpha / (M - 1))))
w2 = paddle.ones(n2.shape, dtype=dtype)
w3 = 0.5 * (1 + paddle.cos(math.pi * (-2.0 / alpha + 1 + 2.0 * n3 / alpha /
(M - 1))))
w = paddle.concat([w1, w2, w3])
return _truncate(w, needs_trunc)
def _kaiser(M: int, beta: float, sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute a Kaiser window.
The Kaiser window is a taper formed by using a Bessel function.
"""
raise NotImplementedError()
def _gaussian(M: int, std: float, sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute a Gaussian window.
The Gaussian widows has a Gaussian shape defined by the standard deviation(std).
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
n = paddle.arange(0, M, dtype=dtype) - (M - 1.0) / 2.0
sig2 = 2 * std * std
w = paddle.exp(-n**2 / sig2)
return _truncate(w, needs_trunc)
def _exponential(M: int,
center=None,
tau=1.,
sym: bool=True,
dtype: str='float64') -> Tensor:
"""Compute an exponential (or Poisson) window. """
if sym and center is not None:
raise ValueError("If sym==True, center must be None.")
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
if center is None:
center = (M - 1) / 2
n = paddle.arange(0, M, dtype=dtype)
w = paddle.exp(-paddle.abs(n - center) / tau)
return _truncate(w, needs_trunc)
def _triang(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a triangular window.
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
n = paddle.arange(1, (M + 1) // 2 + 1, dtype=dtype)
if M % 2 == 0:
w = (2 * n - 1.0) / M
w = paddle.concat([w, w[::-1]])
else:
w = 2 * n / (M + 1.0)
w = paddle.concat([w, w[-2::-1]])
return _truncate(w, needs_trunc)
def _bohman(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a Bohman window.
The Bohman window is the autocorrelation of a cosine window.
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
fac = paddle.abs(paddle.linspace(-1, 1, M, dtype=dtype)[1:-1])
w = (1 - fac) * paddle.cos(math.pi * fac) + 1.0 / math.pi * paddle.sin(
math.pi * fac)
w = _cat([0, w, 0], dtype)
return _truncate(w, needs_trunc)
def _blackman(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a Blackman window.
The Blackman window is a taper formed by using the first three terms of
a summation of cosines. It was designed to have close to the minimal
leakage possible. It is close to optimal, only slightly worse than a
Kaiser window.
"""
return _general_cosine(M, [0.42, 0.50, 0.08], sym, dtype=dtype)
def _cosine(M: int, sym: bool=True, dtype: str='float64') -> Tensor:
"""Compute a window with a simple cosine shape.
"""
if _len_guards(M):
return paddle.ones((M, ), dtype=dtype)
M, needs_trunc = _extend(M, sym)
w = paddle.sin(math.pi / M * (paddle.arange(0, M, dtype=dtype) + .5))
return _truncate(w, needs_trunc)
def get_window(window: Union[str, Tuple[str, float]],
win_length: int,
fftbins: bool=True,
dtype: str='float64') -> Tensor:
"""Return a window of a given length and type.
Args:
window (Union[str, Tuple[str, float]]): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'.
win_length (int): Number of samples.
fftbins (bool, optional): If True, create a "periodic" window. Otherwise, create a "symmetric" window, for use in filter design. Defaults to True.
dtype (str, optional): The data type of the return window. Defaults to 'float64'.
Returns:
Tensor: The window represented as a tensor.
"""
sym = not fftbins
args = ()
if isinstance(window, tuple):
winstr = window[0]
if len(window) > 1:
args = window[1:]
elif isinstance(window, str):
if window in ['gaussian', 'exponential']:
raise ValueError("The '" + window + "' window needs one or "
"more parameters -- pass a tuple.")
else:
winstr = window
else:
raise ValueError("%s as window type is not supported." %
str(type(window)))
try:
winfunc = eval('_' + winstr)
except KeyError as e:
raise ValueError("Unknown window type.") from e
params = (win_length, ) + args
kwargs = {'sym': sym}
return winfunc(*params, dtype=dtype, **kwargs)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .eer import compute_eer
from .eer import compute_minDCF
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import numpy as np
import paddle
from sklearn.metrics import roc_curve
def compute_eer(labels: np.ndarray, scores: np.ndarray) -> List[float]:
"""Compute EER and return score threshold.
Args:
labels (np.ndarray): the trial label, shape: [N], one-dimention, N refer to the samples num
scores (np.ndarray): the trial scores, shape: [N], one-dimention, N refer to the samples num
Returns:
List[float]: eer and the specific threshold
"""
fpr, tpr, threshold = roc_curve(y_true=labels, y_score=scores)
fnr = 1 - tpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer, eer_threshold
def compute_minDCF(positive_scores,
negative_scores,
c_miss=1.0,
c_fa=1.0,
p_target=0.01):
"""
This is modified from SpeechBrain
https://github.com/speechbrain/speechbrain/blob/085be635c07f16d42cd1295045bc46c407f1e15b/speechbrain/utils/metric_stats.py#L509
Computes the minDCF metric normally used to evaluate speaker verification
systems. The min_DCF is the minimum of the following C_det function computed
within the defined threshold range:
C_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 -p_target)
where p_miss is the missing probability and p_fa is the probability of having
a false alarm.
Args:
positive_scores (Paddle.Tensor): The scores from entries of the same class.
negative_scores (Paddle.Tensor): The scores from entries of different classes.
c_miss (float, optional): Cost assigned to a missing error (default 1.0).
c_fa (float, optional): Cost assigned to a false alarm (default 1.0).
p_target (float, optional): Prior probability of having a target (default 0.01).
Returns:
List[float]: min dcf and the specific threshold
"""
# Computing candidate thresholds
if len(positive_scores.shape) > 1:
positive_scores = positive_scores.squeeze()
if len(negative_scores.shape) > 1:
negative_scores = negative_scores.squeeze()
thresholds = paddle.sort(paddle.concat([positive_scores, negative_scores]))
thresholds = paddle.unique(thresholds)
# Adding intermediate thresholds
interm_thresholds = (thresholds[0:-1] + thresholds[1:]) / 2
thresholds = paddle.sort(paddle.concat([thresholds, interm_thresholds]))
# Computing False Rejection Rate (miss detection)
positive_scores = paddle.concat(
len(thresholds) * [positive_scores.unsqueeze(0)])
pos_scores_threshold = positive_scores.transpose(perm=[1, 0]) <= thresholds
p_miss = (pos_scores_threshold.sum(0)
).astype("float32") / positive_scores.shape[1]
del positive_scores
del pos_scores_threshold
# Computing False Acceptance Rate (false alarm)
negative_scores = paddle.concat(
len(thresholds) * [negative_scores.unsqueeze(0)])
neg_scores_threshold = negative_scores.transpose(perm=[1, 0]) > thresholds
p_fa = (neg_scores_threshold.sum(0)
).astype("float32") / negative_scores.shape[1]
del negative_scores
del neg_scores_threshold
c_det = c_miss * p_miss * p_target + c_fa * p_fa * (1 - p_target)
c_min = paddle.min(c_det, axis=0)
min_index = paddle.argmin(c_det, axis=0)
return float(c_min), float(thresholds[min_index])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .download import decompress
from .download import download_and_decompress
from .download import load_state_dict_from_url
from .env import DATA_HOME
from .env import MODEL_HOME
from .env import PPAUDIO_HOME
from .env import USER_HOME
from .error import ParameterError
from .log import Logger
from .log import logger
from .time import seconds_to_hms
from .time import Timer
from .numeric import depth_convert
from .numeric import pcm16to32
此差异已折叠。
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['ParameterError']
class ParameterError(Exception):
"""Exception class for Parameter checking"""
pass
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册