未验证 提交 85094bce 编写于 作者: Y YangZhou 提交者: GitHub

[Cherry-pick][audio] fix tess split fold (#47350)

* fix tess split fold

* format
上级 12e6dfcf
......@@ -133,22 +133,27 @@ class ESC50(AudioClassificationDataset):
meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
meta_info = collections.namedtuple(
'META_INFO',
('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'))
('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'),
)
audio_path = os.path.join('ESC-50-master', 'audio')
def __init__(self,
mode: str = 'train',
split: int = 1,
feat_type: str = 'raw',
archive=None,
**kwargs):
def __init__(
self,
mode: str = 'train',
split: int = 1,
feat_type: str = 'raw',
archive=None,
**kwargs,
):
assert split in range(
1, 6
), f'The selected split should be integer, and 1 <= split <= 5, but got {split}'
if archive is not None:
self.archive = archive
files, labels = self._get_data(mode, split)
super(ESC50, self).__init__(files=files,
labels=labels,
feat_type=feat_type,
**kwargs)
super(ESC50, self).__init__(
files=files, labels=labels, feat_type=feat_type, **kwargs
)
def _get_meta_info(self) -> List[collections.namedtuple]:
ret = []
......@@ -158,12 +163,15 @@ class ESC50(AudioClassificationDataset):
return ret
def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
download.get_path_from_url(self.archive['url'],
DATA_HOME,
self.archive['md5'],
decompress=True)
if not os.path.isdir(
os.path.join(DATA_HOME, self.audio_path)
) or not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
download.get_path_from_url(
self.archive['url'],
DATA_HOME,
self.archive['md5'],
decompress=True,
)
meta_info = self._get_meta_info()
......
......@@ -71,8 +71,7 @@ class TESS(AudioClassificationDataset):
"""
archive = {
'url':
'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
'url': 'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
'md5': '1465311b24d1de704c4c63e4ccc470c7',
}
......@@ -85,28 +84,32 @@ class TESS(AudioClassificationDataset):
'ps', # pleasant surprise
'sad',
]
meta_info = collections.namedtuple('META_INFO',
('speaker', 'word', 'emotion'))
meta_info = collections.namedtuple(
'META_INFO', ('speaker', 'word', 'emotion')
)
audio_path = 'TESS_Toronto_emotional_speech_set'
def __init__(self,
mode='train',
n_folds=5,
split=1,
feat_type='raw',
archive=None,
**kwargs):
"""
"""
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}'
def __init__(
self,
mode: str = 'train',
n_folds: int = 5,
split: int = 1,
feat_type: str = 'raw',
archive=None,
**kwargs,
):
assert isinstance(n_folds, int) and (
n_folds >= 1
), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}'
assert split in range(
1, n_folds + 1
), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}'
if archive is not None:
self.archive = archive
files, labels = self._get_data(mode, n_folds, split)
super(TESS, self).__init__(files=files,
labels=labels,
feat_type=feat_type,
**kwargs)
super(TESS, self).__init__(
files=files, labels=labels, feat_type=feat_type, **kwargs
)
def _get_meta_info(self, files) -> List[collections.namedtuple]:
ret = []
......@@ -115,12 +118,16 @@ class TESS(AudioClassificationDataset):
ret.append(self.meta_info(*basename_without_extend.split('_')))
return ret
def _get_data(self, mode, n_folds, split) -> Tuple[List[str], List[int]]:
def _get_data(
self, mode: str, n_folds: int, split: int
) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)):
download.get_path_from_url(self.archive['url'],
DATA_HOME,
self.archive['md5'],
decompress=True)
download.get_path_from_url(
self.archive['url'],
DATA_HOME,
self.archive['md5'],
decompress=True,
)
wav_files = []
for root, _, files in os.walk(os.path.join(DATA_HOME, self.audio_path)):
......@@ -132,11 +139,10 @@ class TESS(AudioClassificationDataset):
files = []
labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info):
_, _, emotion = sample
target = self.label_list.index(emotion)
fold = idx // n_samples_per_fold + 1
fold = idx % n_folds + 1
if mode == 'train' and int(fold) != split:
files.append(wav_files[idx])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册