diff --git a/deep_speech_2/data_utils/data.py b/deep_speech_2/data_utils/data.py index 424343a48ffa579a8ab465794987f957de36abdb..44af7ffaa999c618a7dcd4884f528ef60e59eefe 100644 --- a/deep_speech_2/data_utils/data.py +++ b/deep_speech_2/data_utils/data.py @@ -7,6 +7,7 @@ from __future__ import print_function import random import numpy as np +import multiprocessing import paddle.v2 as paddle from data_utils import utils from data_utils.augmentor.augmentation import AugmentationPipeline @@ -44,6 +45,8 @@ class DataGenerator(object): :types max_freq: None|float :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str + :param num_threads: Number of CPU threads for processing data. + :type num_threads: int :param random_seed: Random seed. :type random_seed: int """ @@ -58,6 +61,7 @@ class DataGenerator(object): window_ms=20.0, max_freq=None, specgram_type='linear', + num_threads=multiprocessing.cpu_count(), random_seed=0): self._max_duration = max_duration self._min_duration = min_duration @@ -70,6 +74,7 @@ class DataGenerator(object): stride_ms=stride_ms, window_ms=window_ms, max_freq=max_freq) + self._num_threads = num_threads self._rng = random.Random(random_seed) self._epoch = 0 @@ -207,10 +212,14 @@ class DataGenerator(object): def reader(): for instance in manifest: - yield self._process_utterance(instance["audio_filepath"], - instance["text"]) + yield instance - return reader + def mapper(instance): + return self._process_utterance(instance["audio_filepath"], + instance["text"]) + + return paddle.reader.xmap_readers( + mapper, reader, self._num_threads, 1024, order=True) def _padding_batch(self, batch, padding_to=-1, flatten=False): """ diff --git a/deep_speech_2/data_utils/speech.py b/deep_speech_2/data_utils/speech.py index fc031ff46f4a3820e3b13f7804c91b33948712d1..568e4443ba557149505dfb4de6f230b4962e332a 100644 --- a/deep_speech_2/data_utils/speech.py +++ b/deep_speech_2/data_utils/speech.py @@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment): return cls(samples, sample_rate, transcripts) @classmethod - def slice_from_file(cls, filepath, start=None, end=None, transcript): + def slice_from_file(cls, filepath, transcript, start=None, end=None): """Loads a small section of an speech without having to load the entire file into the memory which can be incredibly wasteful. diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index 06449ab05c7960ec78acc9ce5bb664cf1058a845..71518133a347c459bbcf2670fa5d1dc226a619c8 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -6,6 +6,7 @@ from __future__ import print_function import argparse import gzip import distutils.util +import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 @@ -38,6 +39,11 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -67,7 +73,8 @@ def infer(): data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config='{}') + augmentation_config='{}', + num_threads=args.num_threads_data) # create network config # paddle.data_type.dense_array is used for variable batch input. diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index c60a039b69d91a89eb20e83ec1e090c8600d47a3..fc23ec72692f319b556a75004a7508990df5357e 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -9,6 +9,7 @@ import argparse import gzip import time import distutils.util +import multiprocessing import paddle.v2 as paddle from model import deep_speech2 from data_utils.data import DataGenerator @@ -52,6 +53,18 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--max_duration", + default=100.0, + type=float, + help="Audios with duration larger than this will be discarded. " + "(default: %(default)s)") +parser.add_argument( + "--min_duration", + default=0.0, + type=float, + help="Audios with duration smaller than this will be discarded. " + "(default: %(default)s)") parser.add_argument( "--shuffle_method", default='instance_shuffle', @@ -63,6 +76,11 @@ parser.add_argument( default=4, type=int, help="Trainer number. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -107,7 +125,10 @@ def train(): return DataGenerator( vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, - augmentation_config=args.augmentation_config) + augmentation_config=args.augmentation_config, + max_duration=args.max_duration, + min_duration=args.min_duration, + num_threads=args.num_threads_data) train_generator = data_generator() test_generator = data_generator()