Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7c85e0fd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7c85e0fd
编写于
6月 07, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support variable input batch and sortagrad.
上级
730d5c4d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
62 addition
and
55 deletion
+62
-55
audio_data_utils.py
audio_data_utils.py
+40
-16
train.py
train.py
+22
-39
未找到文件。
audio_data_utils.py
浏览文件 @
7c85e0fd
...
...
@@ -8,6 +8,7 @@ import json
import
random
import
soundfile
import
numpy
as
np
import
itertools
import
os
RANDOM_SEED
=
0
...
...
@@ -62,6 +63,7 @@ class DataGenerator(object):
self
.
__stride_ms__
=
stride_ms
self
.
__window_ms__
=
window_ms
self
.
__max_frequency__
=
max_frequency
self
.
__epoc__
=
0
self
.
__random__
=
random
.
Random
(
RANDOM_SEED
)
# load vocabulary (dictionary)
self
.
__vocab_dict__
,
self
.
__vocab_list__
=
\
...
...
@@ -245,9 +247,33 @@ class DataGenerator(object):
new_batch
.
append
((
padded_audio
,
text
))
return
new_batch
def
__batch_shuffle__
(
self
,
manifest
,
batch_size
):
"""
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_size: batch size.
:type batch_size: int
"""
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
shift_len
=
self
.
__random__
.
randint
(
0
,
batch_size
-
1
)
batch_manifest
=
zip
(
*
[
iter
(
manifest
[
shift_len
:])]
*
batch_size
)
self
.
__random__
.
shuffle
(
batch_manifest
)
batch_manifest
=
list
(
sum
(
batch_manifest
,
()))
res_len
=
len
(
manifest
)
-
shift_len
-
len
(
batch_manifest
)
batch_manifest
.
extend
(
manifest
[
-
res_len
:])
batch_manifest
.
extend
(
manifest
[
0
:
shift_len
])
return
batch_manifest
def
instance_reader_creator
(
self
,
manifest_path
,
sort_by_duration
=
True
,
batch_size
,
sortagrad
=
True
,
shuffle
=
False
):
"""
Instance reader creator for audio data. Creat a callable function to
...
...
@@ -258,18 +284,14 @@ class DataGenerator(object):
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param sort
_by_duration: Sort the audio clips by duration if set True
(for SortaGrad)
.
:type sort
_by_duration
: bool
:param sort
agrad: Sort the audio clips by duration in the first epoc
if set True
.
:type sort
agrad
: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Data reader function.
:rtype: callable
"""
if
sort_by_duration
and
shuffle
:
sort_by_duration
=
False
logger
.
warn
(
"When shuffle set to true, "
"sort_by_duration is forced to set False."
)
def
reader
():
# read manifest
...
...
@@ -278,16 +300,17 @@ class DataGenerator(object):
max_duration
=
self
.
__max_duration__
,
min_duration
=
self
.
__min_duration__
)
# sort (by duration) or shuffle manifest
if
s
ort_by_duration
:
if
s
elf
.
__epoc__
==
0
and
sortagrad
:
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
if
shuffle
:
self
.
__random__
.
shuffle
(
manifest
)
el
if
shuffle
:
manifest
=
self
.
__batch_shuffle__
(
manifest
,
batch_size
)
# extract spectrogram feature
for
instance
in
manifest
:
spectrogram
=
self
.
__audio_featurize__
(
instance
[
"audio_filepath"
])
transcript
=
self
.
__text_featurize__
(
instance
[
"text"
])
yield
(
spectrogram
,
transcript
)
self
.
__epoc__
+=
1
return
reader
...
...
@@ -296,7 +319,7 @@ class DataGenerator(object):
batch_size
,
padding_to
=-
1
,
flatten
=
False
,
sort
_by_duration
=
Tru
e
,
sort
agrad
=
Fals
e
,
shuffle
=
False
):
"""
Batch data reader creator for audio data. Creat a callable function to
...
...
@@ -317,9 +340,9 @@ class DataGenerator(object):
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sort
_by_duration: Sort the audio clips by duration if set True
(for SortaGrad)
.
:type sort
_by_duration
: bool
:param sort
agrad: Sort the audio clips by duration in the first epoc
if set True
.
:type sort
agrad
: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Batch reader function, producing batches of data when called.
...
...
@@ -329,7 +352,8 @@ class DataGenerator(object):
def
batch_reader
():
instance_reader
=
self
.
instance_reader_creator
(
manifest_path
=
manifest_path
,
sort_by_duration
=
sort_by_duration
,
batch_size
=
batch_size
,
sortagrad
=
sortagrad
,
shuffle
=
shuffle
)
batch
=
[]
for
instance
in
instance_reader
():
...
...
train.py
浏览文件 @
7c85e0fd
...
...
@@ -85,23 +85,27 @@ def train():
"""
DeepSpeech2 training.
"""
# initialize data generator
data_generator
=
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
normalizer_manifest_path
=
args
.
normalizer_manifest_path
,
normalizer_num_samples
=
200
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10
,
window_ms
=
20
)
def
data_generator
():
return
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
normalizer_manifest_path
=
args
.
normalizer_manifest_path
,
normalizer_num_samples
=
200
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10
,
window_ms
=
20
)
train_generator
=
data_generator
()
test_generator
=
data_generator
()
# create network config
dict_size
=
data_generator
.
vocabulary_size
()
dict_size
=
train_generator
.
vocabulary_size
()
# paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch.
audio_data
=
paddle
.
layer
.
data
(
name
=
"audio_spectrogram"
,
height
=
161
,
width
=
2000
,
type
=
paddle
.
data_type
.
dense_vector
(
322000
))
name
=
"audio_spectrogram"
,
type
=
paddle
.
data_type
.
dense_array
(
161
*
161
))
text_data
=
paddle
.
layer
.
data
(
name
=
"transcript_text"
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
...
...
@@ -122,28 +126,16 @@ def train():
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
# prepare data reader
train_batch_reader_sortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
args
.
train_manifest_path
,
batch_size
=
args
.
batch_size
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
True
,
shuffle
=
False
)
train_batch_reader_nosortagrad
=
data_generator
.
batch_reader_creator
(
train_batch_reader
=
train_generator
.
batch_reader_creator
(
manifest_path
=
args
.
train_manifest_path
,
batch_size
=
args
.
batch_size
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
sortagrad
=
True
,
shuffle
=
True
)
test_batch_reader
=
data
_generator
.
batch_reader_creator
(
test_batch_reader
=
test
_generator
.
batch_reader_creator
(
manifest_path
=
args
.
dev_manifest_path
,
batch_size
=
args
.
batch_size
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
False
)
feeding
=
data
_generator
.
data_name_feeding
()
feeding
=
train
_generator
.
data_name_feeding
()
# create event handler
def
event_handler
(
event
):
...
...
@@ -169,17 +161,8 @@ def train():
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
)
# run train
# first pass with sortagrad
if
args
.
use_sortagrad
:
trainer
.
train
(
reader
=
train_batch_reader_sortagrad
,
event_handler
=
event_handler
,
num_passes
=
1
,
feeding
=
feeding
)
args
.
num_passes
-=
1
# other passes without sortagrad
trainer
.
train
(
reader
=
train_batch_reader
_nosortagrad
,
reader
=
train_batch_reader
,
event_handler
=
event_handler
,
num_passes
=
args
.
num_passes
,
feeding
=
feeding
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录