Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
e6a34999
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
接近 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e6a34999
编写于
5月 30, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor data utils into a class and add feature normalization.
上级
9c3cd3c7
变更
2
展开全部
隐藏空白更改
内联
并排
Showing
2 changed file
with
389 addition
and
208 deletion
+389
-208
audio_data_utils.py
audio_data_utils.py
+340
-172
train.py
train.py
+49
-36
未找到文件。
audio_data_utils.py
浏览文件 @
e6a34999
此差异已折叠。
点击以展开。
train.py
浏览文件 @
e6a34999
...
...
@@ -5,16 +5,18 @@
import
paddle.v2
as
paddle
import
argparse
import
gzip
import
time
import
sys
from
model
import
deep_speech2
import
audio_data_utils
from
audio_data_utils
import
DataGenerator
import
numpy
as
np
#TODO: add WER metric
parser
=
argparse
.
ArgumentParser
(
description
=
'Simplified version of DeepSpeech2 trainer.'
)
parser
.
add_argument
(
"--batch_size"
,
default
=
51
2
,
type
=
int
,
help
=
"Minibatch size."
)
"--batch_size"
,
default
=
3
2
,
type
=
int
,
help
=
"Minibatch size."
)
parser
.
add_argument
(
"--trainer"
,
default
=
1
,
type
=
int
,
help
=
"Trainer number."
)
parser
.
add_argument
(
"--num_passes"
,
default
=
20
,
type
=
int
,
help
=
"Training pass number."
)
...
...
@@ -23,7 +25,7 @@ parser.add_argument(
parser
.
add_argument
(
"--num_rnn_layers"
,
default
=
5
,
type
=
int
,
help
=
"RNN layer number."
)
parser
.
add_argument
(
"--rnn_layer_size"
,
default
=
256
,
type
=
int
,
help
=
"RNN layer cell number."
)
"--rnn_layer_size"
,
default
=
512
,
type
=
int
,
help
=
"RNN layer cell number."
)
parser
.
add_argument
(
"--use_gpu"
,
default
=
True
,
type
=
bool
,
help
=
"Use gpu or not."
)
parser
.
add_argument
(
...
...
@@ -37,13 +39,45 @@ def train():
"""
DeepSpeech2 training.
"""
# create data readers
data_generator
=
DataGenerator
(
vocab_filepath
=
'eng_vocab.txt'
,
normalizer_manifest_path
=
'./libri.manifest.train'
,
normalizer_num_samples
=
200
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10
,
window_ms
=
20
)
train_batch_reader_sortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
True
,
shuffle
=
False
)
train_batch_reader_nosortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
True
)
test_batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.test'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
False
)
feeding
=
data_generator
.
data_name_feeding
()
# create network config
dict_size
=
audio_data_utils
.
get_
vocabulary_size
()
dict_size
=
data_generator
.
vocabulary_size
()
audio_data
=
paddle
.
layer
.
data
(
name
=
"audio_spectrogram"
,
height
=
161
,
width
=
1
000
,
type
=
paddle
.
data_type
.
dense_vector
(
161
000
))
width
=
2
000
,
type
=
paddle
.
data_type
.
dense_vector
(
322
000
))
text_data
=
paddle
.
layer
.
data
(
name
=
"transcript_text"
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
...
...
@@ -58,47 +92,26 @@ def train():
# create parameters and optimizer
parameters
=
paddle
.
parameters
.
create
(
cost
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
5e-
4
,
gradient_clipping_threshold
=
400
)
learning_rate
=
5e-
5
,
gradient_clipping_threshold
=
400
)
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
# create data readers
feeding
=
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
,
}
train_batch_reader_with_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
train_batch_reader_without_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
False
,
shuffle
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
test_batch_reader
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.dev"
,
sort_by_duration
=
False
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
# create event handler
def
event_handler
(
event
):
global
start_time
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
10
==
0
:
print
"
/
nPass: %d, Batch: %d, TrainCost: %f"
%
(
print
"
\
n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
if
isinstance
(
event
,
paddle
.
event
.
BeginPass
):
start_time
=
time
.
time
()
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
test_batch_reader
,
feeding
=
feeding
)
print
"Pass: %d, TestCost: %s"
%
(
event
.
pass_id
,
result
.
cost
)
print
"
\n
------- Time: %d, Pass: %d, TestCost: %s"
%
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
)
with
gzip
.
open
(
"params.tar.gz"
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
...
...
@@ -106,14 +119,14 @@ def train():
# first pass with sortagrad
if
args
.
use_sortagrad
:
trainer
.
train
(
reader
=
train_batch_reader_
with_
sortagrad
,
reader
=
train_batch_reader_sortagrad
,
event_handler
=
event_handler
,
num_passes
=
1
,
feeding
=
feeding
)
args
.
num_passes
-=
1
# other passes without sortagrad
trainer
.
train
(
reader
=
train_batch_reader_
without_
sortagrad
,
reader
=
train_batch_reader_
no
sortagrad
,
event_handler
=
event_handler
,
num_passes
=
args
.
num_passes
,
feeding
=
feeding
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录