提交 589aaae8 编写于 作者: X Xinghai Sun

Add function docs for layer.py and model.py and update other details.

上级 8e44743e
......@@ -205,9 +205,9 @@ def ctc_beam_search_decoder_batch(probs_split,
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
......
......@@ -40,7 +40,7 @@ parser.add_argument(
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
default=1,
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
......
......@@ -5,13 +5,27 @@ from __future__ import print_function
import paddle.v2 as paddle
DISABLE_CUDNN_BATCH_NORM = True
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
padding, act):
"""
Convolution layer with batch normalization.
"""Convolution layer with batch normalization.
:param input: Input layer.
:type input: LayerOutput
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type filter_size: int|tuple|list
:param num_channels_in: Number of input channels.
:type num_channels_in: int
:type num_channels_out: Number of output channels.
:type num_channels_in: out
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
:type padding: int|tuple|list
:param act: Activation type.
:type act: BaseActivation
:return: Batch norm layer after convolution layer.
:rtype: LayerOutput
"""
conv_layer = paddle.layer.img_conv(
input=input,
......@@ -22,30 +36,28 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
padding=padding,
act=paddle.activation.Linear(),
bias_attr=False)
if DISABLE_CUDNN_BATCH_NORM:
# temopary patch, need to be removed.
return paddle.layer.batch_norm(
input=conv_layer, act=act, batch_norm_type="batch_norm")
else:
return paddle.layer.batch_norm(input=conv_layer, act=act)
def bidirectional_simple_rnn_bn_layer(name, input, size, act):
"""
Bidirectonal simple rnn layer with sequence-wise batch normalization.
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: LayerOutput
:param size: Number of RNN cells.
:type size: int
:param act: Activation type.
:type act: BaseActivation
:return: Bidirectional simple rnn layer.
:rtype: LayerOutput
"""
# input-hidden weights shared across bi-direcitonal rnn.
input_proj = paddle.layer.fc(
input=input, size=size, act=paddle.activation.Linear(), bias_attr=False)
# batch norm is only performed on input-state projection
if DISABLE_CUDNN_BATCH_NORM:
# temopary patch, need to be removed.
input_proj_bn = paddle.layer.batch_norm(
input=input_proj,
act=paddle.activation.Linear(),
batch_norm_type="batch_norm")
else:
input_proj_bn = paddle.layer.batch_norm(
input=input_proj, act=paddle.activation.Linear())
# forward and backward in time
......@@ -57,8 +69,14 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, act):
def conv_group(input, num_stacks):
"""
Convolution group with several stacking convolution layers.
"""Convolution group with stacked convolution layers.
:param input: Input layer.
:type input: LayerOutput
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
:return: Output layer of the convolution group.
:rtype: LayerOutput
"""
conv = conv_bn_layer(
input=input,
......@@ -83,8 +101,16 @@ def conv_group(input, num_stacks):
def rnn_group(input, size, num_stacks):
"""
RNN group with several stacking RNN layers.
"""RNN group with stacked bidirectional simple RNN layers.
:param input: Input layer.
:type input: LayerOutput
:param size: Number of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:return: Output layer of the RNN group.
:rtype: LayerOutput
"""
output = input
for i in xrange(num_stacks):
......@@ -114,12 +140,8 @@ def deep_speech2(audio_data,
:type num_rnn_layers: int
:param rnn_size: RNN layer size (number of RNN cells).
:type rnn_size: int
:param is_inference: False in the training mode, and True in the
inferene mode.
:type is_inference: bool
:return: If is_inference set False, return a ctc cost layer;
if is_inference set True, return a sequence layer of output
probability distribution.
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
# convolution group
......
......@@ -14,6 +14,21 @@ from layer import *
class DeepSpeech2Model(object):
"""DeepSpeech2Model class.
:param vocab_size: Decoding vocabulary size.
:type vocab_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_layer_size: RNN layer size (number of RNN cells).
:type rnn_layer_size: int
:param pretrained_model_path: Pretrained model path. If None, will train
from stratch.
:type pretrained_model_path: basestring|None
"""
def __init__(self, vocab_size, num_conv_layers, num_rnn_layers,
rnn_layer_size, pretrained_model_path):
self._create_network(vocab_size, num_conv_layers, num_rnn_layers,
......@@ -29,8 +44,33 @@ class DeepSpeech2Model(object):
learning_rate,
gradient_clipping,
num_passes,
num_iterations_print=100,
output_model_dir='checkpoints'):
output_model_dir,
num_iterations_print=100):
"""Train the model.
:param train_batch_reader: Train data reader.
:type train_batch_reader: callable
:param dev_batch_reader: Validation data reader.
:type dev_batch_reader: callable
:param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns.
:type feeding_dict: dict|list
:param learning_rate: Learning rate for ADAM optimizer.
:type learning_rate: float
:param gradient_clipping: Gradient clipping threshold.
:type gradient_clipping: float
:param num_passes: Number of training epochs.
:type num_passes: int
:param num_iterations_print: Number of training iterations for printing
a training loss.
:type rnn_iteratons_print: int
:param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring
"""
# prepare model output directory
if not os.path.exists(output_model_dir):
os.mkdir(output_model_dir)
# prepare optimizer and trainer
optimizer = paddle.optimizer.Adam(
learning_rate=learning_rate,
......@@ -81,6 +121,34 @@ class DeepSpeech2Model(object):
def infer_batch(self, infer_data, decode_method, beam_alpha, beam_beta,
beam_size, cutoff_prob, vocab_list, language_model_path,
num_processes):
"""Model inference. Infer the transcription for a batch of speech
utterances.
:param infer_data: List of utterances to infer, with each utterance a
tuple of audio features and transcription text (empty
string).
:type infer_data: list
:param decode_method: Decoding method name, 'best_path' or
'beam search'.
:param decode_method: string
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param language_model_path: Filepath for language model.
:type language_model_path: basestring|None
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:return: List of transcription texts.
:rtype: List of basestring
"""
# define inferer
if self._inferer == None:
self._inferer = paddle.inference.Inference(
......@@ -126,6 +194,7 @@ class DeepSpeech2Model(object):
return results
def _create_parameters(self, model_path=None):
"""Load or create model parameters."""
if model_path is None:
self._parameters = paddle.parameters.create(self._loss)
else:
......@@ -134,6 +203,7 @@ class DeepSpeech2Model(object):
def _create_network(self, vocab_size, num_conv_layers, num_rnn_layers,
rnn_layer_size):
"""Create data layers and model network."""
# paddle.data_type.dense_array is used for variable batch input.
# The size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be induced during training.
......
......@@ -26,7 +26,4 @@ if [ $? != 0 ]; then
rm libsndfile-1.0.28.tar.gz
fi
# prepare ./checkpoints
mkdir checkpoints
echo "Install all dependencies successfully."
......@@ -116,6 +116,11 @@ parser.add_argument(
help="If set None, the training will start from scratch. "
"Otherwise, the training will resume from "
"the existing model of this path. (default: %(default)s)")
parser.add_argument(
"--output_model_dir",
default="./checkpoints",
type=str,
help="Directory for saving models. (default: %(default)s)")
parser.add_argument(
"--augmentation_config",
default='[{"type": "shift", '
......@@ -169,7 +174,8 @@ def train():
learning_rate=args.adam_learning_rate,
gradient_clipping=400,
num_passes=args.num_passes,
num_iterations_print=args.num_iterations_print)
num_iterations_print=args.num_iterations_print,
output_model_dir=args.output_model_dir)
def main():
......
......@@ -46,7 +46,7 @@ parser.add_argument(
help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--num_threads_data",
default=multiprocessing.cpu_count(),
default=1,
type=int,
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
parser.add_argument(
......@@ -67,7 +67,7 @@ parser.add_argument(
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--tune_manifest_path",
default='datasets/manifest.test',
default='datasets/manifest.dev',
type=str,
help="Manifest path for tuning. (default: %(default)s)")
parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册