未验证 提交 c98af923 编写于 作者: Y YangZhou 提交者: GitHub

[audio]fix split fold in tess dataset (#47328)

* fix split fold in tess dataset

* check split and n_folds

* fix tess assert log

* fix format

* format
上级 c334405f
......@@ -143,8 +143,11 @@ class ESC50(AudioClassificationDataset):
split: int = 1,
feat_type: str = 'raw',
archive=None,
**kwargs
**kwargs,
):
assert split in range(
1, 6
), f'The selected split should be integer, and 1 <= split <= 5, but got {split}'
if archive is not None:
self.archive = archive
files, labels = self._get_data(mode, split)
......
......@@ -91,17 +91,19 @@ class TESS(AudioClassificationDataset):
def __init__(
self,
mode='train',
n_folds=5,
split=1,
feat_type='raw',
mode: str = 'train',
n_folds: int = 5,
split: int = 1,
feat_type: str = 'raw',
archive=None,
**kwargs,
):
""" """
assert (
split <= n_folds
), f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
assert isinstance(n_folds, int) and (
n_folds >= 1
), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}'
assert split in range(
1, n_folds + 1
), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}'
if archive is not None:
self.archive = archive
files, labels = self._get_data(mode, n_folds, split)
......@@ -116,7 +118,9 @@ class TESS(AudioClassificationDataset):
ret.append(self.meta_info(*basename_without_extend.split('_')))
return ret
def _get_data(self, mode, n_folds, split) -> Tuple[List[str], List[int]]:
def _get_data(
self, mode: str, n_folds: int, split: int
) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)):
download.get_path_from_url(
self.archive['url'],
......@@ -135,11 +139,10 @@ class TESS(AudioClassificationDataset):
files = []
labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info):
_, _, emotion = sample
target = self.label_list.index(emotion)
fold = idx // n_samples_per_fold + 1
fold = idx % n_folds + 1
if mode == 'train' and int(fold) != split:
files.append(wav_files[idx])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册