tess.py 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
from typing import List
from typing import Tuple

from paddle.utils import download
from paddle.dataset.common import DATA_HOME
from .dataset import AudioClassificationDataset

__all__ = ['TESS']


class TESS(AudioClassificationDataset):
    """
    TESS is a set of 200 target words were spoken in the carrier phrase
    "Say the word _____' by two actresses (aged 26 and 64 years) and
    recordings were made of the set portraying each of seven emotions(anger,
    disgust, fear, happiness, pleasant surprise, sadness, and neutral).
    There are 2800 stimuli in total.

    Reference:
        Toronto emotional speech set (TESS) https://tspace.library.utoronto.ca/handle/1807/24487
        https://doi.org/10.5683/SP2/E8H2MF

    Args:
       mode (str, optional): It identifies the dataset mode (train or dev). Defaults to train.
       n_folds (int, optional): Split the dataset into n folds. 1 fold for dev dataset and n-1 for train dataset. Defaults to 5.
       split (int, optional): It specify the fold of dev dataset. Defaults to 1.
       feat_type (str, optional): It identifies the feature type that user wants to extrace of an audio file. Defaults to raw.
       archive(dict): it tells where to download the audio archive. Defaults to None.

    Returns:
        :ref:`api_paddle_io_Dataset`. An instance of TESS dataset.

    Examples:

        .. code-block:: python

            import paddle

            mode = 'dev'
            tess_dataset = paddle.audio.datasets.TESS(mode=mode,
                                                    feat_type='raw')
            for idx in range(5):
                audio, label = tess_dataset[idx]
                # do something with audio, label
                print(audio.shape, label)
                # [audio_data_length] , label_id

            tess_dataset = paddle.audio.datasets.TESS(mode=mode,
                                                    feat_type='mfcc',
                                                    n_mfcc=40)
            for idx in range(5):
                audio, label = tess_dataset[idx]
                # do something with mfcc feature, label
                print(audio.shape, label)
                # [feature_dim, num_frames] , label_id
    """

    archive = {
74
        'url': 'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
75 76 77 78 79 80 81 82 83 84 85 86
        'md5': '1465311b24d1de704c4c63e4ccc470c7',
    }

    label_list = [
        'angry',
        'disgust',
        'fear',
        'happy',
        'neutral',
        'ps',  # pleasant surprise
        'sad',
    ]
87 88 89
    meta_info = collections.namedtuple(
        'META_INFO', ('speaker', 'word', 'emotion')
    )
90 91
    audio_path = 'TESS_Toronto_emotional_speech_set'

92 93
    def __init__(
        self,
94 95 96 97
        mode: str = 'train',
        n_folds: int = 5,
        split: int = 1,
        feat_type: str = 'raw',
98 99 100
        archive=None,
        **kwargs,
    ):
101 102 103 104 105 106
        assert isinstance(n_folds, int) and (
            n_folds >= 1
        ), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}'
        assert split in range(
            1, n_folds + 1
        ), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}'
107 108 109
        if archive is not None:
            self.archive = archive
        files, labels = self._get_data(mode, n_folds, split)
110 111 112
        super(TESS, self).__init__(
            files=files, labels=labels, feat_type=feat_type, **kwargs
        )
113 114 115 116 117 118 119 120

    def _get_meta_info(self, files) -> List[collections.namedtuple]:
        ret = []
        for file in files:
            basename_without_extend = os.path.basename(file)[:-4]
            ret.append(self.meta_info(*basename_without_extend.split('_')))
        return ret

121 122 123
    def _get_data(
        self, mode: str, n_folds: int, split: int
    ) -> Tuple[List[str], List[int]]:
124
        if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)):
125 126 127 128 129 130
            download.get_path_from_url(
                self.archive['url'],
                DATA_HOME,
                self.archive['md5'],
                decompress=True,
            )
131 132 133 134 135 136 137 138 139 140 141 142 143 144

        wav_files = []
        for root, _, files in os.walk(os.path.join(DATA_HOME, self.audio_path)):
            for file in files:
                if file.endswith('.wav'):
                    wav_files.append(os.path.join(root, file))

        meta_info = self._get_meta_info(wav_files)

        files = []
        labels = []
        for idx, sample in enumerate(meta_info):
            _, _, emotion = sample
            target = self.label_list.index(emotion)
145
            fold = idx % n_folds + 1
146 147 148 149 150 151 152 153 154 155

            if mode == 'train' and int(fold) != split:
                files.append(wav_files[idx])
                labels.append(target)

            if mode != 'train' and int(fold) == split:
                files.append(wav_files[idx])
                labels.append(target)

        return files, labels