data.py 14.6 KB
Newer Older
1 2
"""Contains data generator for orgnaizing various audio data preprocessing
pipeline and offering data reader interface of PaddlePaddle requirements.
3 4 5 6 7 8
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
9
import tarfile
10
import multiprocessing
11
import numpy as np
12
import paddle.v2 as paddle
13
from threading import local
14
from data_utils.utility import read_manifest
15
from data_utils.utility import xmap_readers_mp
16 17
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
18
from data_utils.speech import SpeechSegment
19 20 21 22 23 24
from data_utils.normalizer import FeatureNormalizer


class DataGenerator(object):
    """
    DataGenerator provides basic audio data preprocessing pipeline, and offers
25
    data reader interfaces of PaddlePaddle requirements.
26

27 28
    :param vocab_filepath: Vocabulary filepath for indexing tokenized
                           transcripts.
29
    :type vocab_filepath: basestring
30 31 32 33 34 35 36
    :param mean_std_filepath: File containing the pre-computed mean and stddev.
    :type mean_std_filepath: None|basestring
    :param augmentation_config: Augmentation configuration in json string.
                                Details see AugmentationPipeline.__doc__.
    :type augmentation_config: str
    :param max_duration: Audio with duration (in seconds) greater than
                         this will be discarded.
37
    :type max_duration: float
38 39
    :param min_duration: Audio with duration (in seconds) smaller than
                         this will be discarded.
40 41 42
    :type min_duration: float
    :param stride_ms: Striding size (in milliseconds) for generating frames.
    :type stride_ms: float
43
    :param window_ms: Window size (in milliseconds) for generating frames.
44
    :type window_ms: float
45 46 47 48 49 50
    :param max_freq: Used when specgram_type is 'linear', only FFT bins
                     corresponding to frequencies between [0, max_freq] are
                     returned.
    :types max_freq: None|float
    :param specgram_type: Specgram feature type. Options: 'linear'.
    :type specgram_type: str
51 52 53
    :param use_dB_normalization: Whether to normalize the audio to -20 dB
                                before extracting the features.
    :type use_dB_normalization: bool
54 55
    :param num_threads: Number of CPU threads for processing data.
    :type num_threads: int
56 57
    :param random_seed: Random seed.
    :type random_seed: int
58 59 60 61
    :param keep_transcription_text: If set to True, transcription text will
                                    be passed forward directly without
                                    converting to index sequence.
    :type keep_transcription_text: bool
62 63 64 65 66 67 68 69 70 71 72
    """

    def __init__(self,
                 vocab_filepath,
                 mean_std_filepath,
                 augmentation_config='{}',
                 max_duration=float('inf'),
                 min_duration=0.0,
                 stride_ms=10.0,
                 window_ms=20.0,
                 max_freq=None,
73
                 specgram_type='linear',
74
                 use_dB_normalization=True,
75
                 num_threads=multiprocessing.cpu_count() // 2,
76
                 random_seed=0,
77
                 keep_transcription_text=False):
78 79 80 81 82 83 84
        self._max_duration = max_duration
        self._min_duration = min_duration
        self._normalizer = FeatureNormalizer(mean_std_filepath)
        self._augmentation_pipeline = AugmentationPipeline(
            augmentation_config=augmentation_config, random_seed=random_seed)
        self._speech_featurizer = SpeechFeaturizer(
            vocab_filepath=vocab_filepath,
85
            specgram_type=specgram_type,
86 87
            stride_ms=stride_ms,
            window_ms=window_ms,
88 89
            max_freq=max_freq,
            use_dB_normalization=use_dB_normalization)
90
        self._num_threads = num_threads
91
        self._rng = random.Random(random_seed)
92
        self._keep_transcription_text = keep_transcription_text
93
        self._epoch = 0
94
        # for caching tar files info
95 96 97
        self._local_data = local()
        self._local_data.tar2info = {}
        self._local_data.tar2object = {}
W
wanghaoshuang 已提交
98

99
    def process_utterance(self, audio_file, transcript):
100 101
        """Load, augment, featurize and normalize for speech data.

102 103
        :param audio_file: Filepath or file object of audio file.
        :type audio_file: basestring | file
104 105
        :param transcript: Transcription text.
        :type transcript: basestring
Y
yangyaming 已提交
106 107
        :return: Tuple of audio feature tensor and data of transcription part,
                 where transcription part could be token ids or text.
108 109
        :rtype: tuple of (2darray, list)
        """
110
        if isinstance(audio_file, basestring) and audio_file.startswith('tar:'):
111
            speech_segment = SpeechSegment.from_file(
112
                self._subfile_from_tar(audio_file), transcript)
113
        else:
114
            speech_segment = SpeechSegment.from_file(audio_file, transcript)
115
        self._augmentation_pipeline.transform_audio(speech_segment)
116 117
        specgram, transcript_part = self._speech_featurizer.featurize(
            speech_segment, self._keep_transcription_text)
118
        specgram = self._normalizer.apply(specgram)
119
        return specgram, transcript_part
120

121 122 123
    def batch_reader_creator(self,
                             manifest_path,
                             batch_size,
124
                             min_batch_size=1,
125 126 127
                             padding_to=-1,
                             flatten=False,
                             sortagrad=False,
128
                             shuffle_method="batch_shuffle"):
129
        """
130 131
        Batch data reader creator for audio data. Return a callable generator
        function to produce batches of data.
W
wanghaoshuang 已提交
132

133 134
        Audio features within one batch will be padded with zeros to have the
        same shape, or a user-defined shape.
135

136
        :param manifest_path: Filepath of manifest for audio files.
137
        :type manifest_path: basestring
138
        :param batch_size: Number of instances in a batch.
139
        :type batch_size: int
140 141 142 143 144 145
        :param min_batch_size: Any batch with batch size smaller than this will
                               be discarded. (To be deprecated in the future.)
        :type min_batch_size: int
        :param padding_to:  If set -1, the maximun shape in the batch
                            will be used as the target shape for padding.
                            Otherwise, `padding_to` will be the target shape.
146
        :type padding_to: int
147
        :param flatten: If set True, audio features will be flatten to 1darray.
148
        :type flatten: bool
149 150
        :param sortagrad: If set True, sort the instances by audio duration
                          in the first epoch for speed up training.
151
        :type sortagrad: bool
152 153 154 155 156 157 158 159 160 161 162 163 164 165
        :param shuffle_method: Shuffle method. Options:
                                '' or None: no shuffle.
                                'instance_shuffle': instance-wise shuffle.
                                'batch_shuffle': similarly-sized instances are
                                                 put into batches, and then
                                                 batch-wise shuffle the batches.
                                                 For more details, please see
                                                 ``_batch_shuffle.__doc__``.
                                'batch_shuffle_clipped': 'batch_shuffle' with
                                                         head shift and tail
                                                         clipping. For more
                                                         details, please see
                                                         ``_batch_shuffle``.
                              If sortagrad is True, shuffle is disabled
166
                              for the first epoch.
167
        :type shuffle_method: None|str
168 169 170 171 172 173
        :return: Batch reader function, producing batches of data when called.
        :rtype: callable
        """

        def batch_reader():
            # read manifest
174
            manifest = read_manifest(
175 176 177 178 179 180
                manifest_path=manifest_path,
                max_duration=self._max_duration,
                min_duration=self._min_duration)
            # sort (by duration) or batch-wise shuffle the manifest
            if self._epoch == 0 and sortagrad:
                manifest.sort(key=lambda x: x["duration"])
181 182 183 184 185 186 187 188 189
            else:
                if shuffle_method == "batch_shuffle":
                    manifest = self._batch_shuffle(
                        manifest, batch_size, clipped=False)
                elif shuffle_method == "batch_shuffle_clipped":
                    manifest = self._batch_shuffle(
                        manifest, batch_size, clipped=True)
                elif shuffle_method == "instance_shuffle":
                    self._rng.shuffle(manifest)
190
                elif shuffle_method == None:
191 192 193 194
                    pass
                else:
                    raise ValueError("Unknown shuffle method %s." %
                                     shuffle_method)
195
            # prepare batches
196
            instance_reader, cleanup = self._instance_reader_creator(manifest)
197
            batch = []
198 199 200 201 202 203 204
            try:
                for instance in instance_reader():
                    batch.append(instance)
                    if len(batch) == batch_size:
                        yield self._padding_batch(batch, padding_to, flatten)
                        batch = []
                if len(batch) >= min_batch_size:
205
                    yield self._padding_batch(batch, padding_to, flatten)
206 207
            finally:
                cleanup()
208 209 210 211 212 213
            self._epoch += 1

        return batch_reader

    @property
    def feeding(self):
214
        """Returns data reader's feeding dict.
W
wanghaoshuang 已提交
215

216
        :return: Data feeding dict.
W
wanghaoshuang 已提交
217
        :rtype: dict
218
        """
219
        feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
Y
yangyaming 已提交
220
        return feeding_dict
221 222 223

    @property
    def vocab_size(self):
224 225 226 227 228
        """Return the vocabulary size.

        :return: Vocabulary size.
        :rtype: int
        """
229 230 231 232
        return self._speech_featurizer.vocab_size

    @property
    def vocab_list(self):
233 234 235 236 237
        """Return the vocabulary in list.

        :return: Vocabulary in list.
        :rtype: list
        """
238 239
        return self._speech_featurizer.vocab_list

W
wanghaoshuang 已提交
240
    def _parse_tar(self, file):
241 242
        """Parse a tar file to get a tarfile object
        and a map containing tarinfoes
W
wanghaoshuang 已提交
243 244 245 246 247 248 249
        """
        result = {}
        f = tarfile.open(file)
        for tarinfo in f.getmembers():
            result[tarinfo.name] = tarinfo
        return f, result

250 251
    def _subfile_from_tar(self, file):
        """Get subfile object from tar.
W
wanghaoshuang 已提交
252

253
        It will return a subfile object from tar file
W
wanghaoshuang 已提交
254 255
        and cached tar file info for next reading request.
        """
256 257 258 259 260 261 262 263 264 265 266
        tarpath, filename = file.split(':', 1)[1].split('#', 1)
        if 'tar2info' not in self._local_data.__dict__:
            self._local_data.tar2info = {}
        if 'tar2object' not in self._local_data.__dict__:
            self._local_data.tar2object = {}
        if tarpath not in self._local_data.tar2info:
            object, infoes = self._parse_tar(tarpath)
            self._local_data.tar2info[tarpath] = infoes
            self._local_data.tar2object[tarpath] = object
        return self._local_data.tar2object[tarpath].extractfile(
            self._local_data.tar2info[tarpath][filename])
267 268 269

    def _instance_reader_creator(self, manifest):
        """
270 271
        Instance reader creator. Create a callable function to produce
        instances of data.
272

273 274
        Instance: a tuple of ndarray of audio spectrogram and a list of
        token indices for transcript.
275 276 277 278
        """

        def reader():
            for instance in manifest:
279
                yield instance
280

Y
yangyaming 已提交
281
        reader, cleanup_callback = xmap_readers_mp(
282
            lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
283
            reader, self._num_threads, 4096)
284

285
        return reader, cleanup_callback
Y
yangyaming 已提交
286

287 288
    def _padding_batch(self, batch, padding_to=-1, flatten=False):
        """
289 290
        Padding audio features with zeros to make them have the same shape (or
        a user-defined shape) within one bach.
291

292 293 294
        If ``padding_to`` is -1, the maximun shape in the batch will be used
        as the target shape for padding. Otherwise, `padding_to` will be the
        target shape (only refers to the second axis).
295

296
        If `flatten` is True, features will be flatten to 1darray.
297 298 299 300 301 302
        """
        new_batch = []
        # get target shape
        max_length = max([audio.shape[1] for audio, text in batch])
        if padding_to != -1:
            if padding_to < max_length:
303 304
                raise ValueError("If padding_to is not -1, it should be larger "
                                 "than any instance's shape in the batch")
305 306 307 308 309 310 311
            max_length = padding_to
        # padding
        for audio, text in batch:
            padded_audio = np.zeros([audio.shape[0], max_length])
            padded_audio[:, :audio.shape[1]] = audio
            if flatten:
                padded_audio = padded_audio.flatten()
312
            padded_instance = [padded_audio, text, audio.shape[1]]
Y
yangyaming 已提交
313
            new_batch.append(padded_instance)
314 315
        return new_batch

316
    def _batch_shuffle(self, manifest, batch_size, clipped=False):
317 318
        """Put similarly-sized instances into minibatches for better efficiency
        and make a batch-wise shuffle.
319 320 321

        1. Sort the audio clips by duration.
        2. Generate a random number `k`, k in [0, batch_size).
322 323
        3. Randomly shift `k` instances in order to create different batches
           for different epochs. Create minibatches.
324 325
        4. Shuffle the minibatches.

326
        :param manifest: Manifest contents. List of dict.
327 328 329 330
        :type manifest: list
        :param batch_size: Batch size. This size is also used for generate
                           a random number for batch shuffle.
        :type batch_size: int
331 332 333
        :param clipped: Whether to clip the heading (small shift) and trailing
                        (incomplete batch) instances.
        :type clipped: bool
334
        :return: Batch shuffled mainifest.
335 336 337 338 339 340
        :rtype: list
        """
        manifest.sort(key=lambda x: x["duration"])
        shift_len = self._rng.randint(0, batch_size - 1)
        batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
        self._rng.shuffle(batch_manifest)
L
lispczz 已提交
341
        batch_manifest = [item for batch in batch_manifest for item in batch]
342 343 344 345
        if not clipped:
            res_len = len(manifest) - shift_len - len(batch_manifest)
            batch_manifest.extend(manifest[-res_len:])
            batch_manifest.extend(manifest[0:shift_len])
346
        return batch_manifest