未验证 提交 c98af923 编写于 作者: Y YangZhou 提交者: GitHub

[audio]fix split fold in tess dataset (#47328)

* fix split fold in tess dataset

* check split and n_folds

* fix tess assert log

* fix format

* format
上级 c334405f
...@@ -143,8 +143,11 @@ class ESC50(AudioClassificationDataset): ...@@ -143,8 +143,11 @@ class ESC50(AudioClassificationDataset):
split: int = 1, split: int = 1,
feat_type: str = 'raw', feat_type: str = 'raw',
archive=None, archive=None,
**kwargs **kwargs,
): ):
assert split in range(
1, 6
), f'The selected split should be integer, and 1 <= split <= 5, but got {split}'
if archive is not None: if archive is not None:
self.archive = archive self.archive = archive
files, labels = self._get_data(mode, split) files, labels = self._get_data(mode, split)
......
...@@ -91,17 +91,19 @@ class TESS(AudioClassificationDataset): ...@@ -91,17 +91,19 @@ class TESS(AudioClassificationDataset):
def __init__( def __init__(
self, self,
mode='train', mode: str = 'train',
n_folds=5, n_folds: int = 5,
split=1, split: int = 1,
feat_type='raw', feat_type: str = 'raw',
archive=None, archive=None,
**kwargs, **kwargs,
): ):
""" """ assert isinstance(n_folds, int) and (
assert ( n_folds >= 1
split <= n_folds ), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}'
), f'The selected split should not be larger than n_fold, but got {split} > {n_folds}' assert split in range(
1, n_folds + 1
), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}'
if archive is not None: if archive is not None:
self.archive = archive self.archive = archive
files, labels = self._get_data(mode, n_folds, split) files, labels = self._get_data(mode, n_folds, split)
...@@ -116,7 +118,9 @@ class TESS(AudioClassificationDataset): ...@@ -116,7 +118,9 @@ class TESS(AudioClassificationDataset):
ret.append(self.meta_info(*basename_without_extend.split('_'))) ret.append(self.meta_info(*basename_without_extend.split('_')))
return ret return ret
def _get_data(self, mode, n_folds, split) -> Tuple[List[str], List[int]]: def _get_data(
self, mode: str, n_folds: int, split: int
) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)): if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)):
download.get_path_from_url( download.get_path_from_url(
self.archive['url'], self.archive['url'],
...@@ -135,11 +139,10 @@ class TESS(AudioClassificationDataset): ...@@ -135,11 +139,10 @@ class TESS(AudioClassificationDataset):
files = [] files = []
labels = [] labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info): for idx, sample in enumerate(meta_info):
_, _, emotion = sample _, _, emotion = sample
target = self.label_list.index(emotion) target = self.label_list.index(emotion)
fold = idx // n_samples_per_fold + 1 fold = idx % n_folds + 1
if mode == 'train' and int(fold) != split: if mode == 'train' and int(fold) != split:
files.append(wav_files[idx]) files.append(wav_files[idx])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册