提交 a9c817cd 编写于 作者: X Xinghai Sun 提交者: GitHub

Merge pull request #111 from xinghai-sun/ds2_threads

Add multi-threading support for DS2 data generator.
......@@ -7,6 +7,7 @@ from __future__ import print_function
import random
import numpy as np
import multiprocessing
import paddle.v2 as paddle
from data_utils import utils
from data_utils.augmentor.augmentation import AugmentationPipeline
......@@ -44,6 +45,8 @@ class DataGenerator(object):
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:param random_seed: Random seed.
:type random_seed: int
"""
......@@ -58,6 +61,7 @@ class DataGenerator(object):
window_ms=20.0,
max_freq=None,
specgram_type='linear',
num_threads=multiprocessing.cpu_count(),
random_seed=0):
self._max_duration = max_duration
self._min_duration = min_duration
......@@ -70,6 +74,7 @@ class DataGenerator(object):
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq)
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._epoch = 0
......@@ -207,10 +212,14 @@ class DataGenerator(object):
def reader():
for instance in manifest:
yield self._process_utterance(instance["audio_filepath"],
yield instance
def mapper(instance):
return self._process_utterance(instance["audio_filepath"],
instance["text"])
return reader
return paddle.reader.xmap_readers(
mapper, reader, self._num_threads, 1024, order=True)
def _padding_batch(self, batch, padding_to=-1, flatten=False):
"""
......
......@@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment):
return cls(samples, sample_rate, transcripts)
@classmethod
def slice_from_file(cls, filepath, start=None, end=None, transcript):
def slice_from_file(cls, filepath, transcript, start=None, end=None):
"""Loads a small section of an speech without having to load
the entire file into the memory which can be incredibly wasteful.
......
......@@ -6,6 +6,7 @@ from __future__ import print_function
import argparse
import gzip
import distutils.util
import multiprocessing
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
......@@ -38,6 +39,11 @@ parser.add_argument(
default=True,
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
......@@ -67,7 +73,8 @@ def infer():
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}')
augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config
# paddle.data_type.dense_array is used for variable batch input.
......
......@@ -9,6 +9,7 @@ import argparse
import gzip
import time
import distutils.util
import multiprocessing
import paddle.v2 as paddle
from model import deep_speech2
from data_utils.data import DataGenerator
......@@ -52,6 +53,18 @@ parser.add_argument(
default=True,
type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--max_duration",
default=100.0,
type=float,
help="Audios with duration larger than this will be discarded. "
"(default: %(default)s)")
parser.add_argument(
"--min_duration",
default=0.0,
type=float,
help="Audios with duration smaller than this will be discarded. "
"(default: %(default)s)")
parser.add_argument(
"--shuffle_method",
default='instance_shuffle',
......@@ -63,6 +76,11 @@ parser.add_argument(
default=4,
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
"--mean_std_filepath",
default='mean_std.npz',
......@@ -107,7 +125,10 @@ def train():
return DataGenerator(
vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath,
augmentation_config=args.augmentation_config)
augmentation_config=args.augmentation_config,
max_duration=args.max_duration,
min_duration=args.min_duration,
num_threads=args.num_threads_data)
train_generator = data_generator()
test_generator = data_generator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册