data.py 13.5 KB
Newer Older
1 2
"""Contains data generator for orgnaizing various audio data preprocessing
pipeline and offering data reader interface of PaddlePaddle requirements.
3 4 5 6 7 8
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
9
import tarfile
10
import multiprocessing
11
import numpy as np
12
import paddle.v2 as paddle
13
from threading import local
14 15 16
from data_utils import utils
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
17
from data_utils.speech import SpeechSegment
18 19
from data_utils.normalizer import FeatureNormalizer

W
wanghaoshuang 已提交
20 21 22 23 24
# for caching tar files info
local_data = local()
local_data.tar2info = {}
local_data.tar2object = {}

25 26 27 28

class DataGenerator(object):
    """
    DataGenerator provides basic audio data preprocessing pipeline, and offers
29
    data reader interfaces of PaddlePaddle requirements.
30

31 32
    :param vocab_filepath: Vocabulary filepath for indexing tokenized
                           transcripts.
33
    :type vocab_filepath: basestring
34 35 36 37 38 39 40
    :param mean_std_filepath: File containing the pre-computed mean and stddev.
    :type mean_std_filepath: None|basestring
    :param augmentation_config: Augmentation configuration in json string.
                                Details see AugmentationPipeline.__doc__.
    :type augmentation_config: str
    :param max_duration: Audio with duration (in seconds) greater than
                         this will be discarded.
41
    :type max_duration: float
42 43
    :param min_duration: Audio with duration (in seconds) smaller than
                         this will be discarded.
44 45 46
    :type min_duration: float
    :param stride_ms: Striding size (in milliseconds) for generating frames.
    :type stride_ms: float
47
    :param window_ms: Window size (in milliseconds) for generating frames.
48
    :type window_ms: float
49 50 51 52 53 54
    :param max_freq: Used when specgram_type is 'linear', only FFT bins
                     corresponding to frequencies between [0, max_freq] are
                     returned.
    :types max_freq: None|float
    :param specgram_type: Specgram feature type. Options: 'linear'.
    :type specgram_type: str
55 56 57
    :param use_dB_normalization: Whether to normalize the audio to -20 dB
                                before extracting the features.
    :type use_dB_normalization: bool
58 59
    :param num_threads: Number of CPU threads for processing data.
    :type num_threads: int
60 61
    :param random_seed: Random seed.
    :type random_seed: int
62 63 64 65 66 67 68 69 70 71 72
    """

    def __init__(self,
                 vocab_filepath,
                 mean_std_filepath,
                 augmentation_config='{}',
                 max_duration=float('inf'),
                 min_duration=0.0,
                 stride_ms=10.0,
                 window_ms=20.0,
                 max_freq=None,
73
                 specgram_type='linear',
74
                 use_dB_normalization=True,
75
                 num_threads=multiprocessing.cpu_count(),
76 77 78 79 80 81 82 83
                 random_seed=0):
        self._max_duration = max_duration
        self._min_duration = min_duration
        self._normalizer = FeatureNormalizer(mean_std_filepath)
        self._augmentation_pipeline = AugmentationPipeline(
            augmentation_config=augmentation_config, random_seed=random_seed)
        self._speech_featurizer = SpeechFeaturizer(
            vocab_filepath=vocab_filepath,
84
            specgram_type=specgram_type,
85 86
            stride_ms=stride_ms,
            window_ms=window_ms,
87 88
            max_freq=max_freq,
            use_dB_normalization=use_dB_normalization)
89
        self._num_threads = num_threads
90 91 92
        self._rng = random.Random(random_seed)
        self._epoch = 0

W
wanghaoshuang 已提交
93 94 95 96
        # for caching tar files info
        self.tar2info = {}
        self.tar2object = {}

97 98 99
    def batch_reader_creator(self,
                             manifest_path,
                             batch_size,
100
                             min_batch_size=1,
101 102 103
                             padding_to=-1,
                             flatten=False,
                             sortagrad=False,
104
                             shuffle_method="batch_shuffle"):
105
        """
106 107
        Batch data reader creator for audio data. Return a callable generator
        function to produce batches of data.
W
wanghaoshuang 已提交
108

109 110
        Audio features within one batch will be padded with zeros to have the
        same shape, or a user-defined shape.
111

112
        :param manifest_path: Filepath of manifest for audio files.
113
        :type manifest_path: basestring
114
        :param batch_size: Number of instances in a batch.
115
        :type batch_size: int
116 117 118 119 120 121
        :param min_batch_size: Any batch with batch size smaller than this will
                               be discarded. (To be deprecated in the future.)
        :type min_batch_size: int
        :param padding_to:  If set -1, the maximun shape in the batch
                            will be used as the target shape for padding.
                            Otherwise, `padding_to` will be the target shape.
122
        :type padding_to: int
123
        :param flatten: If set True, audio features will be flatten to 1darray.
124
        :type flatten: bool
125 126
        :param sortagrad: If set True, sort the instances by audio duration
                          in the first epoch for speed up training.
127
        :type sortagrad: bool
128 129 130 131 132 133 134 135 136 137 138 139 140 141
        :param shuffle_method: Shuffle method. Options:
                                '' or None: no shuffle.
                                'instance_shuffle': instance-wise shuffle.
                                'batch_shuffle': similarly-sized instances are
                                                 put into batches, and then
                                                 batch-wise shuffle the batches.
                                                 For more details, please see
                                                 ``_batch_shuffle.__doc__``.
                                'batch_shuffle_clipped': 'batch_shuffle' with
                                                         head shift and tail
                                                         clipping. For more
                                                         details, please see
                                                         ``_batch_shuffle``.
                              If sortagrad is True, shuffle is disabled
142
                              for the first epoch.
143
        :type shuffle_method: None|str
144 145 146 147 148 149 150 151 152 153 154 155 156
        :return: Batch reader function, producing batches of data when called.
        :rtype: callable
        """

        def batch_reader():
            # read manifest
            manifest = utils.read_manifest(
                manifest_path=manifest_path,
                max_duration=self._max_duration,
                min_duration=self._min_duration)
            # sort (by duration) or batch-wise shuffle the manifest
            if self._epoch == 0 and sortagrad:
                manifest.sort(key=lambda x: x["duration"])
157 158 159 160 161 162 163 164 165 166 167 168 169 170
            else:
                if shuffle_method == "batch_shuffle":
                    manifest = self._batch_shuffle(
                        manifest, batch_size, clipped=False)
                elif shuffle_method == "batch_shuffle_clipped":
                    manifest = self._batch_shuffle(
                        manifest, batch_size, clipped=True)
                elif shuffle_method == "instance_shuffle":
                    self._rng.shuffle(manifest)
                elif not shuffle_method:
                    pass
                else:
                    raise ValueError("Unknown shuffle method %s." %
                                     shuffle_method)
171 172 173 174 175 176 177 178
            # prepare batches
            instance_reader = self._instance_reader_creator(manifest)
            batch = []
            for instance in instance_reader():
                batch.append(instance)
                if len(batch) == batch_size:
                    yield self._padding_batch(batch, padding_to, flatten)
                    batch = []
179
            if len(batch) >= min_batch_size:
180 181 182 183 184 185 186
                yield self._padding_batch(batch, padding_to, flatten)
            self._epoch += 1

        return batch_reader

    @property
    def feeding(self):
187
        """Returns data reader's feeding dict.
W
wanghaoshuang 已提交
188

189
        :return: Data feeding dict.
W
wanghaoshuang 已提交
190
        :rtype: dict
191
        """
192 193 194 195
        return {"audio_spectrogram": 0, "transcript_text": 1}

    @property
    def vocab_size(self):
196 197 198 199 200
        """Return the vocabulary size.

        :return: Vocabulary size.
        :rtype: int
        """
201 202 203 204
        return self._speech_featurizer.vocab_size

    @property
    def vocab_list(self):
205 206 207 208 209
        """Return the vocabulary in list.

        :return: Vocabulary in list.
        :rtype: list
        """
210 211
        return self._speech_featurizer.vocab_list

W
wanghaoshuang 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    def _parse_tar(self, file):
        """
        Parse a tar file to get a tarfile object and a map containing tarinfoes
        """
        result = {}
        f = tarfile.open(file)
        for tarinfo in f.getmembers():
            result[tarinfo.name] = tarinfo
        return f, result

    def _read_soundbytes(self, filepath):
        """
        Read bytes from file.
        If filepath startwith tar, we will read bytes from tar file
        and cached tar file info for next reading request.
        """
        if filepath.startswith('tar:'):
            tarpath, filename = filepath.split(':', 1)[1].split('#', 1)
            if 'tar2info' not in local_data.__dict__:
                local_data.tar2info = {}
            if 'tar2object' not in local_data.__dict__:
                local_data.tar2object = {}
            if tarpath not in local_data.tar2info:
                object, infoes = self._parse_tar(tarpath)
                local_data.tar2info[tarpath] = infoes
                local_data.tar2object[tarpath] = object
            return local_data.tar2object[tarpath].extractfile(
                local_data.tar2info[tarpath][filename]).read()
        else:
            return open(filepath).read()

243
    def _process_utterance(self, filename, transcript):
244
        """Load, augment, featurize and normalize for speech data."""
W
wanghaoshuang 已提交
245 246
        speech_segment = SpeechSegment.from_bytes(
            self._read_soundbytes(filename), transcript)
247 248 249 250 251 252 253
        self._augmentation_pipeline.transform_audio(speech_segment)
        specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
        specgram = self._normalizer.apply(specgram)
        return specgram, text_ids

    def _instance_reader_creator(self, manifest):
        """
254 255
        Instance reader creator. Create a callable function to produce
        instances of data.
256

257 258
        Instance: a tuple of ndarray of audio spectrogram and a list of
        token indices for transcript.
259 260 261 262
        """

        def reader():
            for instance in manifest:
263
                yield instance
264

265 266 267 268 269 270
        def mapper(instance):
            return self._process_utterance(instance["audio_filepath"],
                                           instance["text"])

        return paddle.reader.xmap_readers(
            mapper, reader, self._num_threads, 1024, order=True)
271 272 273

    def _padding_batch(self, batch, padding_to=-1, flatten=False):
        """
274 275
        Padding audio features with zeros to make them have the same shape (or
        a user-defined shape) within one bach.
276

277 278 279
        If ``padding_to`` is -1, the maximun shape in the batch will be used
        as the target shape for padding. Otherwise, `padding_to` will be the
        target shape (only refers to the second axis).
280

281
        If `flatten` is True, features will be flatten to 1darray.
282 283 284 285 286 287
        """
        new_batch = []
        # get target shape
        max_length = max([audio.shape[1] for audio, text in batch])
        if padding_to != -1:
            if padding_to < max_length:
288 289
                raise ValueError("If padding_to is not -1, it should be larger "
                                 "than any instance's shape in the batch")
290 291 292 293 294 295 296 297 298 299
            max_length = padding_to
        # padding
        for audio, text in batch:
            padded_audio = np.zeros([audio.shape[0], max_length])
            padded_audio[:, :audio.shape[1]] = audio
            if flatten:
                padded_audio = padded_audio.flatten()
            new_batch.append((padded_audio, text))
        return new_batch

300
    def _batch_shuffle(self, manifest, batch_size, clipped=False):
301 302
        """Put similarly-sized instances into minibatches for better efficiency
        and make a batch-wise shuffle.
303 304 305

        1. Sort the audio clips by duration.
        2. Generate a random number `k`, k in [0, batch_size).
306 307
        3. Randomly shift `k` instances in order to create different batches
           for different epochs. Create minibatches.
308 309
        4. Shuffle the minibatches.

310
        :param manifest: Manifest contents. List of dict.
311 312 313 314
        :type manifest: list
        :param batch_size: Batch size. This size is also used for generate
                           a random number for batch shuffle.
        :type batch_size: int
315 316 317
        :param clipped: Whether to clip the heading (small shift) and trailing
                        (incomplete batch) instances.
        :type clipped: bool
318
        :return: Batch shuffled mainifest.
319 320 321 322 323 324 325
        :rtype: list
        """
        manifest.sort(key=lambda x: x["duration"])
        shift_len = self._rng.randint(0, batch_size - 1)
        batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
        self._rng.shuffle(batch_manifest)
        batch_manifest = list(sum(batch_manifest, ()))
326 327 328 329
        if not clipped:
            res_len = len(manifest) - shift_len - len(batch_manifest)
            batch_manifest.extend(manifest[-res_len:])
            batch_manifest.extend(manifest[0:shift_len])
330
        return batch_manifest