未验证 提交 85094bce 编写于 作者: Y YangZhou 提交者: GitHub

[Cherry-pick][audio] fix tess split fold (#47350)

* fix tess split fold

* format
上级 12e6dfcf
...@@ -133,22 +133,27 @@ class ESC50(AudioClassificationDataset): ...@@ -133,22 +133,27 @@ class ESC50(AudioClassificationDataset):
meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv') meta = os.path.join('ESC-50-master', 'meta', 'esc50.csv')
meta_info = collections.namedtuple( meta_info = collections.namedtuple(
'META_INFO', 'META_INFO',
('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take')) ('filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'),
)
audio_path = os.path.join('ESC-50-master', 'audio') audio_path = os.path.join('ESC-50-master', 'audio')
def __init__(self, def __init__(
self,
mode: str = 'train', mode: str = 'train',
split: int = 1, split: int = 1,
feat_type: str = 'raw', feat_type: str = 'raw',
archive=None, archive=None,
**kwargs): **kwargs,
):
assert split in range(
1, 6
), f'The selected split should be integer, and 1 <= split <= 5, but got {split}'
if archive is not None: if archive is not None:
self.archive = archive self.archive = archive
files, labels = self._get_data(mode, split) files, labels = self._get_data(mode, split)
super(ESC50, self).__init__(files=files, super(ESC50, self).__init__(
labels=labels, files=files, labels=labels, feat_type=feat_type, **kwargs
feat_type=feat_type, )
**kwargs)
def _get_meta_info(self) -> List[collections.namedtuple]: def _get_meta_info(self) -> List[collections.namedtuple]:
ret = [] ret = []
...@@ -158,12 +163,15 @@ class ESC50(AudioClassificationDataset): ...@@ -158,12 +163,15 @@ class ESC50(AudioClassificationDataset):
return ret return ret
def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]: def _get_data(self, mode: str, split: int) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \ if not os.path.isdir(
not os.path.isfile(os.path.join(DATA_HOME, self.meta)): os.path.join(DATA_HOME, self.audio_path)
download.get_path_from_url(self.archive['url'], ) or not os.path.isfile(os.path.join(DATA_HOME, self.meta)):
download.get_path_from_url(
self.archive['url'],
DATA_HOME, DATA_HOME,
self.archive['md5'], self.archive['md5'],
decompress=True) decompress=True,
)
meta_info = self._get_meta_info() meta_info = self._get_meta_info()
......
...@@ -71,8 +71,7 @@ class TESS(AudioClassificationDataset): ...@@ -71,8 +71,7 @@ class TESS(AudioClassificationDataset):
""" """
archive = { archive = {
'url': 'url': 'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
'md5': '1465311b24d1de704c4c63e4ccc470c7', 'md5': '1465311b24d1de704c4c63e4ccc470c7',
} }
...@@ -85,28 +84,32 @@ class TESS(AudioClassificationDataset): ...@@ -85,28 +84,32 @@ class TESS(AudioClassificationDataset):
'ps', # pleasant surprise 'ps', # pleasant surprise
'sad', 'sad',
] ]
meta_info = collections.namedtuple('META_INFO', meta_info = collections.namedtuple(
('speaker', 'word', 'emotion')) 'META_INFO', ('speaker', 'word', 'emotion')
)
audio_path = 'TESS_Toronto_emotional_speech_set' audio_path = 'TESS_Toronto_emotional_speech_set'
def __init__(self, def __init__(
mode='train', self,
n_folds=5, mode: str = 'train',
split=1, n_folds: int = 5,
feat_type='raw', split: int = 1,
feat_type: str = 'raw',
archive=None, archive=None,
**kwargs): **kwargs,
""" ):
assert isinstance(n_folds, int) and (
""" n_folds >= 1
assert split <= n_folds, f'The selected split should not be larger than n_fold, but got {split} > {n_folds}' ), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}'
assert split in range(
1, n_folds + 1
), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}'
if archive is not None: if archive is not None:
self.archive = archive self.archive = archive
files, labels = self._get_data(mode, n_folds, split) files, labels = self._get_data(mode, n_folds, split)
super(TESS, self).__init__(files=files, super(TESS, self).__init__(
labels=labels, files=files, labels=labels, feat_type=feat_type, **kwargs
feat_type=feat_type, )
**kwargs)
def _get_meta_info(self, files) -> List[collections.namedtuple]: def _get_meta_info(self, files) -> List[collections.namedtuple]:
ret = [] ret = []
...@@ -115,12 +118,16 @@ class TESS(AudioClassificationDataset): ...@@ -115,12 +118,16 @@ class TESS(AudioClassificationDataset):
ret.append(self.meta_info(*basename_without_extend.split('_'))) ret.append(self.meta_info(*basename_without_extend.split('_')))
return ret return ret
def _get_data(self, mode, n_folds, split) -> Tuple[List[str], List[int]]: def _get_data(
self, mode: str, n_folds: int, split: int
) -> Tuple[List[str], List[int]]:
if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)): if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)):
download.get_path_from_url(self.archive['url'], download.get_path_from_url(
self.archive['url'],
DATA_HOME, DATA_HOME,
self.archive['md5'], self.archive['md5'],
decompress=True) decompress=True,
)
wav_files = [] wav_files = []
for root, _, files in os.walk(os.path.join(DATA_HOME, self.audio_path)): for root, _, files in os.walk(os.path.join(DATA_HOME, self.audio_path)):
...@@ -132,11 +139,10 @@ class TESS(AudioClassificationDataset): ...@@ -132,11 +139,10 @@ class TESS(AudioClassificationDataset):
files = [] files = []
labels = [] labels = []
n_samples_per_fold = len(meta_info) // n_folds
for idx, sample in enumerate(meta_info): for idx, sample in enumerate(meta_info):
_, _, emotion = sample _, _, emotion = sample
target = self.label_list.index(emotion) target = self.label_list.index(emotion)
fold = idx // n_samples_per_fold + 1 fold = idx % n_folds + 1
if mode == 'train' and int(fold) != split: if mode == 'train' and int(fold) != split:
files.append(wav_files[idx]) files.append(wav_files[idx])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册