From c98af923a3b451b425b0f179425c40ca81b1c9ae Mon Sep 17 00:00:00 2001 From: YangZhou <56786796+SmileGoat@users.noreply.github.com> Date: Wed, 26 Oct 2022 12:37:46 +0800 Subject: [PATCH] [audio]fix split fold in tess dataset (#47328) * fix split fold in tess dataset * check split and n_folds * fix tess assert log * fix format * format --- python/paddle/audio/datasets/esc50.py | 5 ++++- python/paddle/audio/datasets/tess.py | 25 ++++++++++++++----------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/paddle/audio/datasets/esc50.py b/python/paddle/audio/datasets/esc50.py index 4d106c27129..1d3e0e07baf 100644 --- a/python/paddle/audio/datasets/esc50.py +++ b/python/paddle/audio/datasets/esc50.py @@ -143,8 +143,11 @@ class ESC50(AudioClassificationDataset): split: int = 1, feat_type: str = 'raw', archive=None, - **kwargs + **kwargs, ): + assert split in range( + 1, 6 + ), f'The selected split should be integer, and 1 <= split <= 5, but got {split}' if archive is not None: self.archive = archive files, labels = self._get_data(mode, split) diff --git a/python/paddle/audio/datasets/tess.py b/python/paddle/audio/datasets/tess.py index 5cdf8cc65f3..b180d704e68 100644 --- a/python/paddle/audio/datasets/tess.py +++ b/python/paddle/audio/datasets/tess.py @@ -91,17 +91,19 @@ class TESS(AudioClassificationDataset): def __init__( self, - mode='train', - n_folds=5, - split=1, - feat_type='raw', + mode: str = 'train', + n_folds: int = 5, + split: int = 1, + feat_type: str = 'raw', archive=None, **kwargs, ): - """ """ - assert ( - split <= n_folds - ), f'The selected split should not be larger than n_fold, but got {split} > {n_folds}' + assert isinstance(n_folds, int) and ( + n_folds >= 1 + ), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}' + assert split in range( + 1, n_folds + 1 + ), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}' if archive is not None: self.archive = archive files, labels = self._get_data(mode, n_folds, split) @@ -116,7 +118,9 @@ class TESS(AudioClassificationDataset): ret.append(self.meta_info(*basename_without_extend.split('_'))) return ret - def _get_data(self, mode, n_folds, split) -> Tuple[List[str], List[int]]: + def _get_data( + self, mode: str, n_folds: int, split: int + ) -> Tuple[List[str], List[int]]: if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)): download.get_path_from_url( self.archive['url'], @@ -135,11 +139,10 @@ class TESS(AudioClassificationDataset): files = [] labels = [] - n_samples_per_fold = len(meta_info) // n_folds for idx, sample in enumerate(meta_info): _, _, emotion = sample target = self.label_list.index(emotion) - fold = idx // n_samples_per_fold + 1 + fold = idx % n_folds + 1 if mode == 'train' and int(fold) != split: files.append(wav_files[idx]) -- GitLab