未验证 提交 90b456d5 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #71 from pkuyym/fix-65

Decouple data provider and model configuration.
......@@ -60,9 +60,6 @@ class DataGenerator(object):
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
:param num_conv_layers: The number of convolution layer, used to compute
the sequence length.
:type num_conv_layers: int
"""
def __init__(self,
......@@ -78,8 +75,7 @@ class DataGenerator(object):
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2,
random_seed=0,
keep_transcription_text=False,
num_conv_layers=2):
keep_transcription_text=False):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
......@@ -100,7 +96,6 @@ class DataGenerator(object):
self._local_data = local()
self._local_data.tar2info = {}
self._local_data.tar2object = {}
self._num_conv_layers = num_conv_layers
def process_utterance(self, filename, transcript):
"""Load, augment, featurize and normalize for speech data.
......@@ -219,14 +214,7 @@ class DataGenerator(object):
:return: Data feeding dict.
:rtype: dict
"""
feeding_dict = {
"audio_spectrogram": 0,
"transcript_text": 1,
"sequence_offset": 2,
"sequence_length": 3
}
for i in xrange(self._num_conv_layers):
feeding_dict["conv%d_index_range" % i] = len(feeding_dict)
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
return feeding_dict
@property
......@@ -322,29 +310,7 @@ class DataGenerator(object):
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
# Stride size for conv0 is (3, 2)
# Stride size for conv1 to convN is (1, 2)
# Same as the network, hard-coded here
padded_instance = [padded_audio, text]
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
valid_w = (audio.shape[1] - 1) // 3 + 1
padded_instance += [
[0], # sequence offset, always 0
[valid_w], # valid sequence length
# Index ranges for channel, height and width
# Please refer scale_sub_region layer to see details
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
]
pre_padded_h = padded_conv0_h
for i in xrange(self._num_conv_layers - 1):
padded_h = (pre_padded_h - 1) // 2 + 1
pre_padded_h = padded_h
padded_instance += [
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
]
padded_instance = [padded_audio, text, audio.shape[1]]
new_batch.append(padded_instance)
return new_batch
......
......@@ -147,8 +147,7 @@ def start_server():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1,
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
keep_transcription_text=True)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
......@@ -164,20 +163,9 @@ def start_server():
# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
ins = []
conv0_h = (feature[0].shape[0] - 1) // 2 + 1
conv0_w = (feature[0].shape[1] - 1) // 3 + 1
ins += [feature[0], feature[1],
[0], [conv0_w],
[1, 32, 1, conv0_h, conv0_w + 1, conv0_w]]
pre_h = conv0_h
for i in xrange(args.num_conv_layers - 1):
h = (pre_h - 1) // 2 + 1
pre_h = h
ins += [[1, 32, 1, h, conv0_w + 1, conv0_w]]
result_transcript = ds2_model.infer_batch(
infer_data=[ins],
infer_data=[feature],
decoding_method=args.decoding_method,
beam_alpha=args.alpha,
beam_beta=args.beta,
......
......@@ -69,8 +69,7 @@ def infer():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1,
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest,
batch_size=args.num_samples,
......
......@@ -8,6 +8,8 @@ import os
import time
import logging
import gzip
import copy
import inspect
from distutils.dir_util import mkpath
import paddle.v2 as paddle
from decoders.swig_wrapper import Scorer
......@@ -48,6 +50,7 @@ class DeepSpeech2Model(object):
self._inferer = None
self._loss_inferer = None
self._ext_scorer = None
self._num_conv_layers = num_conv_layers
self.logger = logging.getLogger("")
self.logger.setLevel(level=logging.INFO)
......@@ -91,6 +94,11 @@ class DeepSpeech2Model(object):
if not os.path.exists(output_model_dir):
mkpath(output_model_dir)
# adapt the feeding dict and reader according to the network
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
adapted_train_batch_reader = self._adapt_data(train_batch_reader)
adapted_dev_batch_reader = self._adapt_data(dev_batch_reader)
# prepare optimizer and trainer
optimizer = paddle.optimizer.Adam(
learning_rate=learning_rate,
......@@ -128,7 +136,8 @@ class DeepSpeech2Model(object):
(time.time() - start_time, event.pass_id))
else:
result = trainer.test(
reader=dev_batch_reader, feeding=feeding_dict)
reader=adapted_dev_batch_reader,
feeding=adapted_feeding_dict)
print(
"\n------- Time: %d sec, Pass: %d, "
"ValidationCost: %s" %
......@@ -140,11 +149,12 @@ class DeepSpeech2Model(object):
# run train
trainer.train(
reader=train_batch_reader,
reader=adapted_train_batch_reader,
event_handler=event_handler,
num_passes=num_passes,
feeding=feeding_dict)
feeding=adapted_feeding_dict)
# TODO(@pkuyym) merge this function into infer_batch
def infer_loss_batch(self, infer_data):
"""Model inference. Infer the ctc loss for a batch of speech
utterances.
......@@ -205,15 +215,17 @@ class DeepSpeech2Model(object):
if self._inferer == None:
self._inferer = paddle.inference.Inference(
output_layer=self._log_probs, parameters=self._parameters)
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
adapted_infer_data = self._adapt_data(infer_data)
# run inference
infer_results = self._inferer.infer(
input=infer_data, feeding=feeding_dict)
start_pos = [0] * (len(infer_data) + 1)
for i in xrange(len(infer_data)):
start_pos[i + 1] = start_pos[i] + infer_data[i][3][0]
input=adapted_infer_data, feeding=adapted_feeding_dict)
start_pos = [0] * (len(adapted_infer_data) + 1)
for i in xrange(len(adapted_infer_data)):
start_pos[i + 1] = start_pos[i] + adapted_infer_data[i][3][0]
probs_split = [
infer_results[start_pos[i]:start_pos[i + 1]]
for i in xrange(0, len(infer_data))
for i in xrange(0, len(adapted_infer_data))
]
# run decoder
results = []
......@@ -260,6 +272,100 @@ class DeepSpeech2Model(object):
decoding_method)
return results
def _adapt_feeding_dict(self, feeding_dict):
"""Adapt feeding dict according to network struct.
To remove impacts from padding part, we add scale_sub_region layer and
sub_seq layer. For sub_seq layer, 'sequence_offset' and
'sequence_length' fields are appended. For each scale_sub_region layer
'convN_index_range' field is appended.
:param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns.
:type feeding_dict: dict|list
:return: Adapted feeding dict.
:rtype: dict|list
"""
adapted_feeding_dict = copy.deepcopy(feeding_dict)
if isinstance(feeding_dict, dict):
adapted_feeding_dict["sequence_offset"] = len(adapted_feeding_dict)
adapted_feeding_dict["sequence_length"] = len(adapted_feeding_dict)
for i in xrange(self._num_conv_layers):
adapted_feeding_dict["conv%d_index_range" %i] = \
len(adapted_feeding_dict)
elif isinstance(feeding_dict, list):
adapted_feeding_dict.append("sequence_offset")
adapted_feeding_dict.append("sequence_length")
for i in xrange(self._num_conv_layers):
adapted_feeding_dict.append("conv%d_index_range" % i)
else:
raise ValueError("Type of feeding_dict is %s, not supported." %
type(feeding_dict))
return adapted_feeding_dict
def _adapt_data(self, data):
"""Adapt data according to network struct.
For each convolution layer in the conv_group, to remove impacts from
padding data, we can multiply zero to the padding part of the outputs
of each batch normalization layer. We add a scale_sub_region layer after
each batch normalization layer to reset the padding data.
For rnn layers, to remove impacts from padding data, we can truncate the
padding part before output data feeded into the first rnn layer. We use
sub_seq layer to achieve this.
:param data: Data from data_provider.
:type data: list|function
:return: Adapted data.
:rtype: list|function
"""
def adapt_instance(instance):
if len(instance) < 2 or len(instance) > 3:
raise ValueError("Size of instance should be 2 or 3.")
padded_audio = instance[0]
text = instance[1]
# no padding part
if len(instance) == 2:
audio_len = padded_audio.shape[1]
else:
audio_len = instance[2]
adapted_instance = [padded_audio, text]
# Stride size for conv0 is (3, 2)
# Stride size for conv1 to convN is (1, 2)
# Same as the network, hard-coded here
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
valid_w = (audio_len - 1) // 3 + 1
adapted_instance += [
[0], # sequence offset, always 0
[valid_w], # valid sequence length
# Index ranges for channel, height and width
# Please refer scale_sub_region layer to see details
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
]
pre_padded_h = padded_conv0_h
for i in xrange(self._num_conv_layers - 1):
padded_h = (pre_padded_h - 1) // 2 + 1
pre_padded_h = padded_h
adapted_instance += [
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
]
return adapted_instance
if isinstance(data, list):
return map(adapt_instance, data)
elif inspect.isgeneratorfunction(data):
def adapted_reader():
for instance in data():
yield map(adapt_instance, instance)
return adapted_reader
else:
raise ValueError("Type of data is %s, not supported." % type(data))
def _create_parameters(self, model_path=None):
"""Load or create model parameters."""
if model_path is None:
......
......@@ -70,8 +70,7 @@ def evaluate():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest,
batch_size=args.batch_size,
......
......@@ -75,15 +75,13 @@ def train():
max_duration=args.max_duration,
min_duration=args.min_duration,
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
num_conv_layers=args.num_conv_layers)
num_threads=args.num_proc_data)
dev_generator = DataGenerator(
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config="{}",
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
num_conv_layers=args.num_conv_layers)
num_threads=args.num_proc_data)
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest,
batch_size=args.batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册