提交 1d8cc4a5 编写于 作者: X Xinghai Sun

Add multi-threading support for DS2 data generator.

上级 a5dcd23b
...@@ -44,6 +44,8 @@ class DataGenerator(object): ...@@ -44,6 +44,8 @@ class DataGenerator(object):
:types max_freq: None|float :types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'. :param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str :type specgram_type: str
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:param random_seed: Random seed. :param random_seed: Random seed.
:type random_seed: int :type random_seed: int
""" """
...@@ -58,6 +60,7 @@ class DataGenerator(object): ...@@ -58,6 +60,7 @@ class DataGenerator(object):
window_ms=20.0, window_ms=20.0,
max_freq=None, max_freq=None,
specgram_type='linear', specgram_type='linear',
num_threads=12,
random_seed=0): random_seed=0):
self._max_duration = max_duration self._max_duration = max_duration
self._min_duration = min_duration self._min_duration = min_duration
...@@ -70,6 +73,7 @@ class DataGenerator(object): ...@@ -70,6 +73,7 @@ class DataGenerator(object):
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
max_freq=max_freq) max_freq=max_freq)
self._num_threads = num_threads
self._rng = random.Random(random_seed) self._rng = random.Random(random_seed)
self._epoch = 0 self._epoch = 0
...@@ -207,10 +211,14 @@ class DataGenerator(object): ...@@ -207,10 +211,14 @@ class DataGenerator(object):
def reader(): def reader():
for instance in manifest: for instance in manifest:
yield self._process_utterance(instance["audio_filepath"], yield instance
instance["text"])
return reader def mapper(instance):
return self._process_utterance(instance["audio_filepath"],
instance["text"])
return paddle.reader.xmap_readers(
mapper, reader, self._num_threads, 1024, order=True)
def _padding_batch(self, batch, padding_to=-1, flatten=False): def _padding_batch(self, batch, padding_to=-1, flatten=False):
""" """
......
...@@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment): ...@@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment):
return cls(samples, sample_rate, transcripts) return cls(samples, sample_rate, transcripts)
@classmethod @classmethod
def slice_from_file(cls, filepath, start=None, end=None, transcript): def slice_from_file(cls, filepath, transcript, start=None, end=None):
"""Loads a small section of an speech without having to load """Loads a small section of an speech without having to load
the entire file into the memory which can be incredibly wasteful. the entire file into the memory which can be incredibly wasteful.
......
...@@ -38,6 +38,11 @@ parser.add_argument( ...@@ -38,6 +38,11 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)") help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=12,
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
...@@ -67,7 +72,8 @@ def infer(): ...@@ -67,7 +72,8 @@ def infer():
data_generator = DataGenerator( data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath, mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}') augmentation_config='{}',
num_threads=args.num_threads_data)
# create network config # create network config
# paddle.data_type.dense_array is used for variable batch input. # paddle.data_type.dense_array is used for variable batch input.
......
...@@ -52,6 +52,18 @@ parser.add_argument( ...@@ -52,6 +52,18 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)") help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--max_duration",
default=100.0,
type=float,
help="Audios with duration larger than this will be discarded. "
"(default: %(default)s)")
parser.add_argument(
"--min_duration",
default=0.0,
type=float,
help="Audios with duration smaller than this will be discarded. "
"(default: %(default)s)")
parser.add_argument( parser.add_argument(
"--shuffle_method", "--shuffle_method",
default='instance_shuffle', default='instance_shuffle',
...@@ -63,6 +75,11 @@ parser.add_argument( ...@@ -63,6 +75,11 @@ parser.add_argument(
default=4, default=4,
type=int, type=int,
help="Trainer number. (default: %(default)s)") help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=12,
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--mean_std_filepath", "--mean_std_filepath",
default='mean_std.npz', default='mean_std.npz',
...@@ -107,7 +124,10 @@ def train(): ...@@ -107,7 +124,10 @@ def train():
return DataGenerator( return DataGenerator(
vocab_filepath=args.vocab_filepath, vocab_filepath=args.vocab_filepath,
mean_std_filepath=args.mean_std_filepath, mean_std_filepath=args.mean_std_filepath,
augmentation_config=args.augmentation_config) augmentation_config=args.augmentation_config,
max_duration=args.max_duration,
min_duration=args.min_duration,
num_threads=args.num_threads_data)
train_generator = data_generator() train_generator = data_generator()
test_generator = data_generator() test_generator = data_generator()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册