From 85094bce453729e5241eb7badecd63672600ce87 Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 26 Oct 2022 18:12:55 +0800 Subject: [PATCH] [Cherry-pick][audio] fix tess split fold (#47350) * fix tess split fold * format --- python/paddle/audio/datasets/esc50.py | 42 +++++++++++-------- python/paddle/audio/datasets/tess.py | 58 +++++++++++++++------------ 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/python/paddle/audio/datasets/esc50.py b/python/paddle/audio/datasets/esc50.py index f702fe518fa..1d3e0e07baf 100644 --- a/python/paddle/audio/datasets/esc50.py +++ b/python/paddle/audio/datasets/esc50.py @@ -133,22 +133,27 @@ class ESC50(AudioClassificationDataset): meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv') meta_info = collections.namedtuple( 'META_INFO', - ('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take')) + ('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'), + ) audio_path = os.path.join('ESC-50-master', 'audio') - def __init__(self, - mode: str = 'train', - split: int = 1, - feat_type: str = 'raw', - archive=None, - **kwargs): + def __init__( + self, + mode: str = 'train', + split: int = 1, + feat_type: str = 'raw', + archive=None, + **kwargs, + ): + assert split in range( + 1, 6 + ), f'The selected split should be integer, and 1 <= split <= 5, but got {split}' if archive is not None: self.archive = archive files, labels = self._get_data(mode, split) - super(ESC50, self).__init__(files=files, - labels=labels, - feat_type=feat_type, - **kwargs) + super(ESC50, self).__init__( + files=files, labels=labels, feat_type=feat_type, **kwargs + ) def _get_meta_info(self) -> List[collections.namedtuple]: ret = [] @@ -158,12 +163,15 @@ class ESC50(AudioClassificationDataset): return ret def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]: - if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \ - not os.path.isfile(os.path.join(DATA_HOME, self.meta)): - download.get_path_from_url(self.archive['url'], - DATA_HOME, - self.archive['md5'], - decompress=True) + if not os.path.isdir( + os.path.join(DATA_HOME, self.audio_path) + ) or not os.path.isfile(os.path.join(DATA_HOME, self.meta)): + download.get_path_from_url( + self.archive['url'], + DATA_HOME, + self.archive['md5'], + decompress=True, + ) meta_info = self._get_meta_info() diff --git a/python/paddle/audio/datasets/tess.py b/python/paddle/audio/datasets/tess.py index 0f375aa2b01..b180d704e68 100644 --- a/python/paddle/audio/datasets/tess.py +++ b/python/paddle/audio/datasets/tess.py @@ -71,8 +71,7 @@ class TESS(AudioClassificationDataset): """ archive = { - 'url': - 'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip', + 'url': 'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip', 'md5': '1465311b24d1de704c4c63e4ccc470c7', } @@ -85,28 +84,32 @@ class TESS(AudioClassificationDataset): 'ps', # pleasant surprise 'sad', ] - meta_info = collections.namedtuple('META_INFO', - ('speaker', 'word', 'emotion')) + meta_info = collections.namedtuple( + 'META_INFO', ('speaker', 'word', 'emotion') + ) audio_path = 'TESS_Toronto_emotional_speech_set' - def __init__(self, - mode='train', - n_folds=5, - split=1, - feat_type='raw', - archive=None, - **kwargs): - """ - - """ - assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}' + def __init__( + self, + mode: str = 'train', + n_folds: int = 5, + split: int = 1, + feat_type: str = 'raw', + archive=None, + **kwargs, + ): + assert isinstance(n_folds, int) and ( + n_folds >= 1 + ), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}' + assert split in range( + 1, n_folds + 1 + ), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}' if archive is not None: self.archive = archive files, labels = self._get_data(mode, n_folds, split) - super(TESS, self).__init__(files=files, - labels=labels, - feat_type=feat_type, - **kwargs) + super(TESS, self).__init__( + files=files, labels=labels, feat_type=feat_type, **kwargs + ) def _get_meta_info(self, files) -> List[collections.namedtuple]: ret = [] @@ -115,12 +118,16 @@ class TESS(AudioClassificationDataset): ret.append(self.meta_info(*basename_without_extend.split('_'))) return ret - def _get_data(self, mode, n_folds, split) -> Tuple[List[str], List[int]]: + def _get_data( + self, mode: str, n_folds: int, split: int + ) -> Tuple[List[str], List[int]]: if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)): - download.get_path_from_url(self.archive['url'], - DATA_HOME, - self.archive['md5'], - decompress=True) + download.get_path_from_url( + self.archive['url'], + DATA_HOME, + self.archive['md5'], + decompress=True, + ) wav_files = [] for root, _, files in os.walk(os.path.join(DATA_HOME, self.audio_path)): @@ -132,11 +139,10 @@ class TESS(AudioClassificationDataset): files = [] labels = [] - n_samples_per_fold = len(meta_info) // n_folds for idx, sample in enumerate(meta_info): _, _, emotion = sample target = self.label_list.index(emotion) - fold = idx // n_samples_per_fold + 1 + fold = idx % n_folds + 1 if mode == 'train' and int(fold) != split: files.append(wav_files[idx]) -- GitLab