Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
1d8cc4a5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1d8cc4a5
编写于
6月 20, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add multi-threading support for DS2 data generator.
上级
a5dcd23b
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
40 addition
and
6 deletion
+40
-6
data_utils/data.py
data_utils/data.py
+11
-3
data_utils/speech.py
data_utils/speech.py
+1
-1
infer.py
infer.py
+7
-1
train.py
train.py
+21
-1
未找到文件。
data_utils/data.py
浏览文件 @
1d8cc4a5
...
...
@@ -44,6 +44,8 @@ class DataGenerator(object):
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:param random_seed: Random seed.
:type random_seed: int
"""
...
...
@@ -58,6 +60,7 @@ class DataGenerator(object):
window_ms
=
20.0
,
max_freq
=
None
,
specgram_type
=
'linear'
,
num_threads
=
12
,
random_seed
=
0
):
self
.
_max_duration
=
max_duration
self
.
_min_duration
=
min_duration
...
...
@@ -70,6 +73,7 @@ class DataGenerator(object):
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
max_freq
=
max_freq
)
self
.
_num_threads
=
num_threads
self
.
_rng
=
random
.
Random
(
random_seed
)
self
.
_epoch
=
0
...
...
@@ -207,10 +211,14 @@ class DataGenerator(object):
def
reader
():
for
instance
in
manifest
:
yield
self
.
_process_utterance
(
instance
[
"audio_filepath"
],
yield
instance
def
mapper
(
instance
):
return
self
.
_process_utterance
(
instance
[
"audio_filepath"
],
instance
[
"text"
])
return
reader
return
paddle
.
reader
.
xmap_readers
(
mapper
,
reader
,
self
.
_num_threads
,
1024
,
order
=
True
)
def
_padding_batch
(
self
,
batch
,
padding_to
=-
1
,
flatten
=
False
):
"""
...
...
data_utils/speech.py
浏览文件 @
1d8cc4a5
...
...
@@ -94,7 +94,7 @@ class SpeechSegment(AudioSegment):
return
cls
(
samples
,
sample_rate
,
transcripts
)
@
classmethod
def
slice_from_file
(
cls
,
filepath
,
start
=
None
,
end
=
None
,
transcript
):
def
slice_from_file
(
cls
,
filepath
,
transcript
,
start
=
None
,
end
=
None
):
"""Loads a small section of an speech without having to load
the entire file into the memory which can be incredibly wasteful.
...
...
infer.py
浏览文件 @
1d8cc4a5
...
...
@@ -38,6 +38,11 @@ parser.add_argument(
default
=
True
,
type
=
distutils
.
util
.
strtobool
,
help
=
"Use gpu or not. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_threads_data"
,
default
=
12
,
type
=
int
,
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
parser
.
add_argument
(
"--mean_std_filepath"
,
default
=
'mean_std.npz'
,
...
...
@@ -67,7 +72,8 @@ def infer():
data_generator
=
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
mean_std_filepath
=
args
.
mean_std_filepath
,
augmentation_config
=
'{}'
)
augmentation_config
=
'{}'
,
num_threads
=
args
.
num_threads_data
)
# create network config
# paddle.data_type.dense_array is used for variable batch input.
...
...
train.py
浏览文件 @
1d8cc4a5
...
...
@@ -52,6 +52,18 @@ parser.add_argument(
default
=
True
,
type
=
distutils
.
util
.
strtobool
,
help
=
"Use sortagrad or not. (default: %(default)s)"
)
parser
.
add_argument
(
"--max_duration"
,
default
=
100.0
,
type
=
float
,
help
=
"Audios with duration larger than this will be discarded. "
"(default: %(default)s)"
)
parser
.
add_argument
(
"--min_duration"
,
default
=
0.0
,
type
=
float
,
help
=
"Audios with duration smaller than this will be discarded. "
"(default: %(default)s)"
)
parser
.
add_argument
(
"--shuffle_method"
,
default
=
'instance_shuffle'
,
...
...
@@ -63,6 +75,11 @@ parser.add_argument(
default
=
4
,
type
=
int
,
help
=
"Trainer number. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_threads_data"
,
default
=
12
,
type
=
int
,
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
parser
.
add_argument
(
"--mean_std_filepath"
,
default
=
'mean_std.npz'
,
...
...
@@ -107,7 +124,10 @@ def train():
return
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
mean_std_filepath
=
args
.
mean_std_filepath
,
augmentation_config
=
args
.
augmentation_config
)
augmentation_config
=
args
.
augmentation_config
,
max_duration
=
args
.
max_duration
,
min_duration
=
args
.
min_duration
,
num_threads
=
args
.
num_threads_data
)
train_generator
=
data_generator
()
test_generator
=
data_generator
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录