Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
ed5f04af
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed5f04af
编写于
6月 15, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add shuffle type of instance_shuffle and batch_shuffle_clipped.
上级
04a225ae
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
82 addition
and
29 deletion
+82
-29
data_utils/data.py
data_utils/data.py
+38
-12
datasets/librispeech/librispeech.py
datasets/librispeech/librispeech.py
+1
-2
decoder.py
decoder.py
+2
-4
infer.py
infer.py
+5
-6
train.py
train.py
+11
-5
utils.py
utils.py
+25
-0
未找到文件。
data_utils/data.py
浏览文件 @
ed5f04af
...
@@ -80,7 +80,7 @@ class DataGenerator(object):
...
@@ -80,7 +80,7 @@ class DataGenerator(object):
padding_to
=-
1
,
padding_to
=-
1
,
flatten
=
False
,
flatten
=
False
,
sortagrad
=
False
,
sortagrad
=
False
,
batch_shuffle
=
False
):
shuffle_method
=
"batch_shuffle"
):
"""
"""
Batch data reader creator for audio data. Return a callable generator
Batch data reader creator for audio data. Return a callable generator
function to produce batches of data.
function to produce batches of data.
...
@@ -104,12 +104,22 @@ class DataGenerator(object):
...
@@ -104,12 +104,22 @@ class DataGenerator(object):
:param sortagrad: If set True, sort the instances by audio duration
:param sortagrad: If set True, sort the instances by audio duration
in the first epoch for speed up training.
in the first epoch for speed up training.
:type sortagrad: bool
:type sortagrad: bool
:param batch_shuffle: If set True, instances are batch-wise shuffled.
:param shuffle_method: Shuffle method. Options:
'' or None: no shuffle.
'instance_shuffle': instance-wise shuffle.
'batch_shuffle': similarly-sized instances are
put into batches, and then
batch-wise shuffle the batches.
For more details, please see
For more details, please see
``_batch_shuffle.__doc__``.
``_batch_shuffle.__doc__``.
If sortagrad is True, batch_shuffle is disabled
'batch_shuffle_clipped': 'batch_shuffle' with
head shift and tail
clipping. For more
details, please see
``_batch_shuffle``.
If sortagrad is True, shuffle is disabled
for the first epoch.
for the first epoch.
:type
batch_shuffle: bool
:type
shuffle_method: None|str
:return: Batch reader function, producing batches of data when called.
:return: Batch reader function, producing batches of data when called.
:rtype: callable
:rtype: callable
"""
"""
...
@@ -123,8 +133,20 @@ class DataGenerator(object):
...
@@ -123,8 +133,20 @@ class DataGenerator(object):
# sort (by duration) or batch-wise shuffle the manifest
# sort (by duration) or batch-wise shuffle the manifest
if
self
.
_epoch
==
0
and
sortagrad
:
if
self
.
_epoch
==
0
and
sortagrad
:
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
elif
batch_shuffle
:
else
:
manifest
=
self
.
_batch_shuffle
(
manifest
,
batch_size
)
if
shuffle_method
==
"batch_shuffle"
:
manifest
=
self
.
_batch_shuffle
(
manifest
,
batch_size
,
clipped
=
False
)
elif
shuffle_method
==
"batch_shuffle_clipped"
:
manifest
=
self
.
_batch_shuffle
(
manifest
,
batch_size
,
clipped
=
True
)
elif
shuffle_method
==
"instance_shuffle"
:
self
.
_rng
.
shuffle
(
manifest
)
elif
not
shuffle_method
:
pass
else
:
raise
ValueError
(
"Unknown shuffle method %s."
%
shuffle_method
)
# prepare batches
# prepare batches
instance_reader
=
self
.
_instance_reader_creator
(
manifest
)
instance_reader
=
self
.
_instance_reader_creator
(
manifest
)
batch
=
[]
batch
=
[]
...
@@ -218,7 +240,7 @@ class DataGenerator(object):
...
@@ -218,7 +240,7 @@ class DataGenerator(object):
new_batch
.
append
((
padded_audio
,
text
))
new_batch
.
append
((
padded_audio
,
text
))
return
new_batch
return
new_batch
def
_batch_shuffle
(
self
,
manifest
,
batch_size
):
def
_batch_shuffle
(
self
,
manifest
,
batch_size
,
clipped
=
False
):
"""Put similarly-sized instances into minibatches for better efficiency
"""Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle.
and make a batch-wise shuffle.
...
@@ -233,6 +255,9 @@ class DataGenerator(object):
...
@@ -233,6 +255,9 @@ class DataGenerator(object):
:param batch_size: Batch size. This size is also used for generate
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
a random number for batch shuffle.
:type batch_size: int
:type batch_size: int
:param clipped: Whether to clip the heading (small shift) and trailing
(incomplete batch) instances.
:type clipped: bool
:return: Batch shuffled mainifest.
:return: Batch shuffled mainifest.
:rtype: list
:rtype: list
"""
"""
...
@@ -241,6 +266,7 @@ class DataGenerator(object):
...
@@ -241,6 +266,7 @@ class DataGenerator(object):
batch_manifest
=
zip
(
*
[
iter
(
manifest
[
shift_len
:])]
*
batch_size
)
batch_manifest
=
zip
(
*
[
iter
(
manifest
[
shift_len
:])]
*
batch_size
)
self
.
_rng
.
shuffle
(
batch_manifest
)
self
.
_rng
.
shuffle
(
batch_manifest
)
batch_manifest
=
list
(
sum
(
batch_manifest
,
()))
batch_manifest
=
list
(
sum
(
batch_manifest
,
()))
if
not
clipped
:
res_len
=
len
(
manifest
)
-
shift_len
-
len
(
batch_manifest
)
res_len
=
len
(
manifest
)
-
shift_len
-
len
(
batch_manifest
)
batch_manifest
.
extend
(
manifest
[
-
res_len
:])
batch_manifest
.
extend
(
manifest
[
-
res_len
:])
batch_manifest
.
extend
(
manifest
[
0
:
shift_len
])
batch_manifest
.
extend
(
manifest
[
0
:
shift_len
])
...
...
datasets/librispeech/librispeech.py
浏览文件 @
ed5f04af
...
@@ -37,8 +37,7 @@ MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
...
@@ -37,8 +37,7 @@ MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
MD5_TRAIN_CLEAN_360
=
"c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_CLEAN_360
=
"c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_OTHER_500
=
"d1a0fd59409feb2c614ce4d30c387708"
MD5_TRAIN_OTHER_500
=
"d1a0fd59409feb2c614ce4d30c387708"
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
description
=
'Downloads and prepare LibriSpeech dataset.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--target_dir"
,
"--target_dir"
,
default
=
DATA_HOME
+
"/Libri"
,
default
=
DATA_HOME
+
"/Libri"
,
...
...
decoder.py
浏览文件 @
ed5f04af
...
@@ -8,8 +8,7 @@ from itertools import groupby
...
@@ -8,8 +8,7 @@ from itertools import groupby
def
ctc_best_path_decode
(
probs_seq
,
vocabulary
):
def
ctc_best_path_decode
(
probs_seq
,
vocabulary
):
"""
"""Best path decoding, also called argmax decoding or greedy decoding.
Best path decoding, also called argmax decoding or greedy decoding.
Path consisting of the most probable tokens are further post-processed to
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
remove consecutive repetitions and all blanks.
...
@@ -38,8 +37,7 @@ def ctc_best_path_decode(probs_seq, vocabulary):
...
@@ -38,8 +37,7 @@ def ctc_best_path_decode(probs_seq, vocabulary):
def
ctc_decode
(
probs_seq
,
vocabulary
,
method
):
def
ctc_decode
(
probs_seq
,
vocabulary
,
method
):
"""
"""CTC-like sequence decoding from a sequence of likelihood probablilites.
CTC-like sequence decoding from a sequence of likelihood probablilites.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
character. Each element is a list of float probabilities
...
...
infer.py
浏览文件 @
ed5f04af
...
@@ -10,9 +10,9 @@ import paddle.v2 as paddle
...
@@ -10,9 +10,9 @@ import paddle.v2 as paddle
from
data_utils.data
import
DataGenerator
from
data_utils.data
import
DataGenerator
from
model
import
deep_speech2
from
model
import
deep_speech2
from
decoder
import
ctc_decode
from
decoder
import
ctc_decode
import
utils
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
description
=
'Simplified version of DeepSpeech2 inference.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_samples"
,
"--num_samples"
,
default
=
10
,
default
=
10
,
...
@@ -62,9 +62,7 @@ args = parser.parse_args()
...
@@ -62,9 +62,7 @@ args = parser.parse_args()
def
infer
():
def
infer
():
"""
"""Max-ctc-decoding for DeepSpeech2."""
Max-ctc-decoding for DeepSpeech2.
"""
# initialize data generator
# initialize data generator
data_generator
=
DataGenerator
(
data_generator
=
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
vocab_filepath
=
args
.
vocab_filepath
,
...
@@ -98,7 +96,7 @@ def infer():
...
@@ -98,7 +96,7 @@ def infer():
manifest_path
=
args
.
decode_manifest_path
,
manifest_path
=
args
.
decode_manifest_path
,
batch_size
=
args
.
num_samples
,
batch_size
=
args
.
num_samples
,
sortagrad
=
False
,
sortagrad
=
False
,
batch_shuffle
=
Fals
e
)
shuffle_method
=
Non
e
)
infer_data
=
batch_reader
().
next
()
infer_data
=
batch_reader
().
next
()
# run inference
# run inference
...
@@ -123,6 +121,7 @@ def infer():
...
@@ -123,6 +121,7 @@ def infer():
def
main
():
def
main
():
utils
.
print_arguments
(
args
)
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
1
)
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
1
)
infer
()
infer
()
...
...
train.py
浏览文件 @
ed5f04af
...
@@ -12,6 +12,7 @@ import distutils.util
...
@@ -12,6 +12,7 @@ import distutils.util
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
from
model
import
deep_speech2
from
model
import
deep_speech2
from
data_utils.data
import
DataGenerator
from
data_utils.data
import
DataGenerator
import
utils
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -51,6 +52,12 @@ parser.add_argument(
...
@@ -51,6 +52,12 @@ parser.add_argument(
default
=
True
,
default
=
True
,
type
=
distutils
.
util
.
strtobool
,
type
=
distutils
.
util
.
strtobool
,
help
=
"Use sortagrad or not. (default: %(default)s)"
)
help
=
"Use sortagrad or not. (default: %(default)s)"
)
parser
.
add_argument
(
"--shuffle_method"
,
default
=
'instance_shuffle'
,
type
=
str
,
help
=
"Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--trainer_count"
,
"--trainer_count"
,
default
=
4
,
default
=
4
,
...
@@ -93,9 +100,7 @@ args = parser.parse_args()
...
@@ -93,9 +100,7 @@ args = parser.parse_args()
def
train
():
def
train
():
"""
"""DeepSpeech2 training."""
DeepSpeech2 training.
"""
# initialize data generator
# initialize data generator
def
data_generator
():
def
data_generator
():
...
@@ -145,13 +150,13 @@ def train():
...
@@ -145,13 +150,13 @@ def train():
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
min_batch_size
=
args
.
trainer_count
,
min_batch_size
=
args
.
trainer_count
,
sortagrad
=
args
.
use_sortagrad
if
args
.
init_model_path
is
None
else
False
,
sortagrad
=
args
.
use_sortagrad
if
args
.
init_model_path
is
None
else
False
,
batch_shuffle
=
True
)
shuffle_method
=
args
.
shuffle_method
)
test_batch_reader
=
test_generator
.
batch_reader_creator
(
test_batch_reader
=
test_generator
.
batch_reader_creator
(
manifest_path
=
args
.
dev_manifest_path
,
manifest_path
=
args
.
dev_manifest_path
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
min_batch_size
=
1
,
# must be 1, but will have errors.
min_batch_size
=
1
,
# must be 1, but will have errors.
sortagrad
=
False
,
sortagrad
=
False
,
batch_shuffle
=
Fals
e
)
shuffle_method
=
Non
e
)
# create event handler
# create event handler
def
event_handler
(
event
):
def
event_handler
(
event
):
...
@@ -186,6 +191,7 @@ def train():
...
@@ -186,6 +191,7 @@ def train():
def
main
():
def
main
():
utils
.
print_arguments
(
args
)
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
args
.
trainer_count
)
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
args
.
trainer_count
)
train
()
train
()
...
...
utils.py
0 → 100644
浏览文件 @
ed5f04af
"""Contains common utility functions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
def
print_arguments
(
args
):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print
(
"----- Configuration Arguments -----"
)
for
arg
,
value
in
vars
(
args
).
iteritems
():
print
(
"%s: %s"
%
(
arg
,
value
))
print
(
"------------------------------------"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录