Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
526e18b1
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
526e18b1
编写于
8月 02, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add function docs for layer.py and model.py and update other details.
上级
8122dd9c
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
136 addition
and
41 deletion
+136
-41
decoder.py
decoder.py
+1
-1
infer.py
infer.py
+1
-1
layer.py
layer.py
+53
-31
model.py
model.py
+72
-2
setup.sh
setup.sh
+0
-3
train.py
train.py
+7
-1
tune.py
tune.py
+2
-2
未找到文件。
decoder.py
浏览文件 @
526e18b1
...
@@ -205,9 +205,9 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -205,9 +205,9 @@ def ctc_beam_search_decoder_batch(probs_split,
:type num_processes: int
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
default 1.0, no pruning.
:type cutoff_prob: float
:param num_processes: Number of parallel processes.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
partially decoded sentence, e.g. word count
or language model.
or language model.
...
...
infer.py
浏览文件 @
526e18b1
...
@@ -40,7 +40,7 @@ parser.add_argument(
...
@@ -40,7 +40,7 @@ parser.add_argument(
help
=
"Use gpu or not. (default: %(default)s)"
)
help
=
"Use gpu or not. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_threads_data"
,
"--num_threads_data"
,
default
=
multiprocessing
.
cpu_count
()
,
default
=
1
,
type
=
int
,
type
=
int
,
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
layer.py
浏览文件 @
526e18b1
...
@@ -5,13 +5,27 @@ from __future__ import print_function
...
@@ -5,13 +5,27 @@ from __future__ import print_function
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
DISABLE_CUDNN_BATCH_NORM
=
True
def
conv_bn_layer
(
input
,
filter_size
,
num_channels_in
,
num_channels_out
,
stride
,
def
conv_bn_layer
(
input
,
filter_size
,
num_channels_in
,
num_channels_out
,
stride
,
padding
,
act
):
padding
,
act
):
"""
"""Convolution layer with batch normalization.
Convolution layer with batch normalization.
:param input: Input layer.
:type input: LayerOutput
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type filter_size: int|tuple|list
:param num_channels_in: Number of input channels.
:type num_channels_in: int
:type num_channels_out: Number of output channels.
:type num_channels_in: out
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
:type padding: int|tuple|list
:param act: Activation type.
:type act: BaseActivation
:return: Batch norm layer after convolution layer.
:rtype: LayerOutput
"""
"""
conv_layer
=
paddle
.
layer
.
img_conv
(
conv_layer
=
paddle
.
layer
.
img_conv
(
input
=
input
,
input
=
input
,
...
@@ -22,30 +36,28 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
...
@@ -22,30 +36,28 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
padding
=
padding
,
padding
=
padding
,
act
=
paddle
.
activation
.
Linear
(),
act
=
paddle
.
activation
.
Linear
(),
bias_attr
=
False
)
bias_attr
=
False
)
if
DISABLE_CUDNN_BATCH_NORM
:
# temopary patch, need to be removed.
return
paddle
.
layer
.
batch_norm
(
input
=
conv_layer
,
act
=
act
,
batch_norm_type
=
"batch_norm"
)
else
:
return
paddle
.
layer
.
batch_norm
(
input
=
conv_layer
,
act
=
act
)
return
paddle
.
layer
.
batch_norm
(
input
=
conv_layer
,
act
=
act
)
def
bidirectional_simple_rnn_bn_layer
(
name
,
input
,
size
,
act
):
def
bidirectional_simple_rnn_bn_layer
(
name
,
input
,
size
,
act
):
"""
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: LayerOutput
:param size: Number of RNN cells.
:type size: int
:param act: Activation type.
:type act: BaseActivation
:return: Bidirectional simple rnn layer.
:rtype: LayerOutput
"""
"""
# input-hidden weights shared across bi-direcitonal rnn.
# input-hidden weights shared across bi-direcitonal rnn.
input_proj
=
paddle
.
layer
.
fc
(
input_proj
=
paddle
.
layer
.
fc
(
input
=
input
,
size
=
size
,
act
=
paddle
.
activation
.
Linear
(),
bias_attr
=
False
)
input
=
input
,
size
=
size
,
act
=
paddle
.
activation
.
Linear
(),
bias_attr
=
False
)
# batch norm is only performed on input-state projection
# batch norm is only performed on input-state projection
if
DISABLE_CUDNN_BATCH_NORM
:
# temopary patch, need to be removed.
input_proj_bn
=
paddle
.
layer
.
batch_norm
(
input
=
input_proj
,
act
=
paddle
.
activation
.
Linear
(),
batch_norm_type
=
"batch_norm"
)
else
:
input_proj_bn
=
paddle
.
layer
.
batch_norm
(
input_proj_bn
=
paddle
.
layer
.
batch_norm
(
input
=
input_proj
,
act
=
paddle
.
activation
.
Linear
())
input
=
input_proj
,
act
=
paddle
.
activation
.
Linear
())
# forward and backward in time
# forward and backward in time
...
@@ -57,8 +69,14 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, act):
...
@@ -57,8 +69,14 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, act):
def
conv_group
(
input
,
num_stacks
):
def
conv_group
(
input
,
num_stacks
):
"""
"""Convolution group with stacked convolution layers.
Convolution group with several stacking convolution layers.
:param input: Input layer.
:type input: LayerOutput
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
:return: Output layer of the convolution group.
:rtype: LayerOutput
"""
"""
conv
=
conv_bn_layer
(
conv
=
conv_bn_layer
(
input
=
input
,
input
=
input
,
...
@@ -83,8 +101,16 @@ def conv_group(input, num_stacks):
...
@@ -83,8 +101,16 @@ def conv_group(input, num_stacks):
def
rnn_group
(
input
,
size
,
num_stacks
):
def
rnn_group
(
input
,
size
,
num_stacks
):
"""
"""RNN group with stacked bidirectional simple RNN layers.
RNN group with several stacking RNN layers.
:param input: Input layer.
:type input: LayerOutput
:param size: Number of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:return: Output layer of the RNN group.
:rtype: LayerOutput
"""
"""
output
=
input
output
=
input
for
i
in
xrange
(
num_stacks
):
for
i
in
xrange
(
num_stacks
):
...
@@ -114,12 +140,8 @@ def deep_speech2(audio_data,
...
@@ -114,12 +140,8 @@ def deep_speech2(audio_data,
:type num_rnn_layers: int
:type num_rnn_layers: int
:param rnn_size: RNN layer size (number of RNN cells).
:param rnn_size: RNN layer size (number of RNN cells).
:type rnn_size: int
:type rnn_size: int
:param is_inference: False in the training mode, and True in the
:return: A tuple of an output unnormalized log probability layer (
inferene mode.
before softmax) and a ctc cost layer.
:type is_inference: bool
:return: If is_inference set False, return a ctc cost layer;
if is_inference set True, return a sequence layer of output
probability distribution.
:rtype: tuple of LayerOutput
:rtype: tuple of LayerOutput
"""
"""
# convolution group
# convolution group
...
...
model.py
浏览文件 @
526e18b1
...
@@ -14,6 +14,21 @@ from layer import *
...
@@ -14,6 +14,21 @@ from layer import *
class
DeepSpeech2Model
(
object
):
class
DeepSpeech2Model
(
object
):
"""DeepSpeech2Model class.
:param vocab_size: Decoding vocabulary size.
:type vocab_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_layer_size: RNN layer size (number of RNN cells).
:type rnn_layer_size: int
:param pretrained_model_path: Pretrained model path. If None, will train
from stratch.
:type pretrained_model_path: basestring|None
"""
def
__init__
(
self
,
vocab_size
,
num_conv_layers
,
num_rnn_layers
,
def
__init__
(
self
,
vocab_size
,
num_conv_layers
,
num_rnn_layers
,
rnn_layer_size
,
pretrained_model_path
):
rnn_layer_size
,
pretrained_model_path
):
self
.
_create_network
(
vocab_size
,
num_conv_layers
,
num_rnn_layers
,
self
.
_create_network
(
vocab_size
,
num_conv_layers
,
num_rnn_layers
,
...
@@ -29,8 +44,33 @@ class DeepSpeech2Model(object):
...
@@ -29,8 +44,33 @@ class DeepSpeech2Model(object):
learning_rate
,
learning_rate
,
gradient_clipping
,
gradient_clipping
,
num_passes
,
num_passes
,
num_iterations_print
=
100
,
output_model_dir
,
output_model_dir
=
'checkpoints'
):
num_iterations_print
=
100
):
"""Train the model.
:param train_batch_reader: Train data reader.
:type train_batch_reader: callable
:param dev_batch_reader: Validation data reader.
:type dev_batch_reader: callable
:param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns.
:type feeding_dict: dict|list
:param learning_rate: Learning rate for ADAM optimizer.
:type learning_rate: float
:param gradient_clipping: Gradient clipping threshold.
:type gradient_clipping: float
:param num_passes: Number of training epochs.
:type num_passes: int
:param num_iterations_print: Number of training iterations for printing
a training loss.
:type rnn_iteratons_print: int
:param output_model_dir: Directory for saving the model (every pass).
:type output_model_dir: basestring
"""
# prepare model output directory
if
not
os
.
path
.
exists
(
output_model_dir
):
os
.
mkdir
(
output_model_dir
)
# prepare optimizer and trainer
# prepare optimizer and trainer
optimizer
=
paddle
.
optimizer
.
Adam
(
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
...
@@ -81,6 +121,34 @@ class DeepSpeech2Model(object):
...
@@ -81,6 +121,34 @@ class DeepSpeech2Model(object):
def
infer_batch
(
self
,
infer_data
,
decode_method
,
beam_alpha
,
beam_beta
,
def
infer_batch
(
self
,
infer_data
,
decode_method
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
vocab_list
,
language_model_path
,
beam_size
,
cutoff_prob
,
vocab_list
,
language_model_path
,
num_processes
):
num_processes
):
"""Model inference. Infer the transcription for a batch of speech
utterances.
:param infer_data: List of utterances to infer, with each utterance a
tuple of audio features and transcription text (empty
string).
:type infer_data: list
:param decode_method: Decoding method name, 'best_path' or
'beam search'.
:param decode_method: string
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param language_model_path: Filepath for language model.
:type language_model_path: basestring|None
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:return: List of transcription texts.
:rtype: List of basestring
"""
# define inferer
# define inferer
if
self
.
_inferer
==
None
:
if
self
.
_inferer
==
None
:
self
.
_inferer
=
paddle
.
inference
.
Inference
(
self
.
_inferer
=
paddle
.
inference
.
Inference
(
...
@@ -126,6 +194,7 @@ class DeepSpeech2Model(object):
...
@@ -126,6 +194,7 @@ class DeepSpeech2Model(object):
return
results
return
results
def
_create_parameters
(
self
,
model_path
=
None
):
def
_create_parameters
(
self
,
model_path
=
None
):
"""Load or create model parameters."""
if
model_path
is
None
:
if
model_path
is
None
:
self
.
_parameters
=
paddle
.
parameters
.
create
(
self
.
_loss
)
self
.
_parameters
=
paddle
.
parameters
.
create
(
self
.
_loss
)
else
:
else
:
...
@@ -134,6 +203,7 @@ class DeepSpeech2Model(object):
...
@@ -134,6 +203,7 @@ class DeepSpeech2Model(object):
def
_create_network
(
self
,
vocab_size
,
num_conv_layers
,
num_rnn_layers
,
def
_create_network
(
self
,
vocab_size
,
num_conv_layers
,
num_rnn_layers
,
rnn_layer_size
):
rnn_layer_size
):
"""Create data layers and model network."""
# paddle.data_type.dense_array is used for variable batch input.
# paddle.data_type.dense_array is used for variable batch input.
# The size 161 * 161 is only an placeholder value and the real shape
# The size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be induced during training.
# of input batch data will be induced during training.
...
...
setup.sh
浏览文件 @
526e18b1
...
@@ -26,7 +26,4 @@ if [ $? != 0 ]; then
...
@@ -26,7 +26,4 @@ if [ $? != 0 ]; then
rm
libsndfile-1.0.28.tar.gz
rm
libsndfile-1.0.28.tar.gz
fi
fi
# prepare ./checkpoints
mkdir
checkpoints
echo
"Install all dependencies successfully."
echo
"Install all dependencies successfully."
train.py
浏览文件 @
526e18b1
...
@@ -116,6 +116,11 @@ parser.add_argument(
...
@@ -116,6 +116,11 @@ parser.add_argument(
help
=
"If set None, the training will start from scratch. "
help
=
"If set None, the training will start from scratch. "
"Otherwise, the training will resume from "
"Otherwise, the training will resume from "
"the existing model of this path. (default: %(default)s)"
)
"the existing model of this path. (default: %(default)s)"
)
parser
.
add_argument
(
"--output_model_dir"
,
default
=
"./checkpoints"
,
type
=
str
,
help
=
"Directory for saving models. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--augmentation_config"
,
"--augmentation_config"
,
default
=
'[{"type": "shift", '
default
=
'[{"type": "shift", '
...
@@ -169,7 +174,8 @@ def train():
...
@@ -169,7 +174,8 @@ def train():
learning_rate
=
args
.
adam_learning_rate
,
learning_rate
=
args
.
adam_learning_rate
,
gradient_clipping
=
400
,
gradient_clipping
=
400
,
num_passes
=
args
.
num_passes
,
num_passes
=
args
.
num_passes
,
num_iterations_print
=
args
.
num_iterations_print
)
num_iterations_print
=
args
.
num_iterations_print
,
output_model_dir
=
args
.
output_model_dir
)
def
main
():
def
main
():
...
...
tune.py
浏览文件 @
526e18b1
...
@@ -46,7 +46,7 @@ parser.add_argument(
...
@@ -46,7 +46,7 @@ parser.add_argument(
help
=
"Trainer number. (default: %(default)s)"
)
help
=
"Trainer number. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_threads_data"
,
"--num_threads_data"
,
default
=
multiprocessing
.
cpu_count
()
,
default
=
1
,
type
=
int
,
type
=
int
,
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -67,7 +67,7 @@ parser.add_argument(
...
@@ -67,7 +67,7 @@ parser.add_argument(
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tune_manifest_path"
,
"--tune_manifest_path"
,
default
=
'datasets/manifest.
test
'
,
default
=
'datasets/manifest.
dev
'
,
type
=
str
,
type
=
str
,
help
=
"Manifest path for tuning. (default: %(default)s)"
)
help
=
"Manifest path for tuning. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录