Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
19eb6343
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
19eb6343
编写于
9月 27, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of upstream into develop
上级
50cf9c5f
70e43c18
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
69 addition
and
24 deletion
+69
-24
tools/tune.py
tools/tune.py
+69
-24
未找到文件。
tools/tune.py
浏览文件 @
19eb6343
...
...
@@ -4,13 +4,18 @@ from __future__ import division
from
__future__
import
print_function
import
sys
import
os
import
numpy
as
np
import
argparse
import
functools
import
gzip
import
logging
import
paddle.v2
as
paddle
import
_init_paths
from
data_utils.data
import
DataGenerator
from
model_utils.model
import
DeepSpeech2Model
from
decoders.swig_wrapper
import
Scorer
from
decoders.swig_wrapper
import
ctc_beam_search_decoder_batch
from
model_utils.model
import
deep_speech_v2_network
from
utils.error_rate
import
wer
,
cer
from
utils.utility
import
add_arguments
,
print_arguments
...
...
@@ -66,6 +71,9 @@ add_arg('specgram_type', str,
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
format
=
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
def
tune
():
"""Tune parameters alpha and beta incrementally."""
if
not
args
.
num_alphas
>=
0
:
...
...
@@ -79,29 +87,59 @@ def tune():
augmentation_config
=
'{}'
,
specgram_type
=
args
.
specgram_type
,
num_threads
=
1
)
audio_data
=
paddle
.
layer
.
data
(
name
=
"audio_spectrogram"
,
type
=
paddle
.
data_type
.
dense_array
(
161
*
161
))
text_data
=
paddle
.
layer
.
data
(
name
=
"transcript_text"
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
data_generator
.
vocab_size
))
output_probs
,
_
=
deep_speech_v2_network
(
audio_data
=
audio_data
,
text_data
=
text_data
,
dict_size
=
data_generator
.
vocab_size
,
num_conv_layers
=
args
.
num_conv_layers
,
num_rnn_layers
=
args
.
num_rnn_layers
,
rnn_size
=
args
.
rnn_layer_size
,
use_gru
=
args
.
use_gru
,
share_rnn_weights
=
args
.
share_rnn_weights
)
batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
args
.
tune_manifest
,
batch_size
=
args
.
batch_size
,
sortagrad
=
False
,
shuffle_method
=
None
)
tune_data
=
batch_reader
().
next
()
target_transcripts
=
[
''
.
join
([
data_generator
.
vocab_list
[
token
]
for
token
in
transcript
])
for
_
,
transcript
in
tune_data
]
ds2_model
=
DeepSpeech2Model
(
vocab_size
=
data_generator
.
vocab_size
,
num_conv_layers
=
args
.
num_conv_layers
,
num_rnn_layers
=
args
.
num_rnn_layers
,
rnn_layer_size
=
args
.
rnn_layer_size
,
use_gru
=
args
.
use_gru
,
pretrained_model_path
=
args
.
model_path
,
share_rnn_weights
=
args
.
share_rnn_weights
)
# load parameters
if
not
os
.
path
.
isfile
(
args
.
model_path
):
raise
IOError
(
"Invaid model path: %s"
%
args
.
model_path
)
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
gzip
.
open
(
args
.
model_path
))
inferer
=
paddle
.
inference
.
Inference
(
output_layer
=
output_probs
,
parameters
=
parameters
)
# decoders only accept string encoded in utf-8
vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
data_generator
.
vocab_list
]
# init logger
logger
=
logging
.
getLogger
(
""
)
logger
.
setLevel
(
level
=
logging
.
INFO
)
# init external scorer
logger
.
info
(
"begin to initialize the external scorer for tuning"
)
if
not
os
.
path
.
isfile
(
args
.
lang_model_path
):
raise
IOError
(
"Invaid language model path: %s"
%
args
.
lang_model_path
)
ext_scorer
=
Scorer
(
alpha
=
args
.
alpha_from
,
beta
=
args
.
beta_from
,
model_path
=
args
.
lang_model_path
,
vocabulary
=
vocab_list
)
logger
.
info
(
"language model: "
"is_character_based = %d,"
%
ext_scorer
.
is_character_based
()
+
" max_order = %d,"
%
ext_scorer
.
get_max_order
()
+
" dict_size = %d"
%
ext_scorer
.
get_dict_size
())
logger
.
info
(
"end initializing scorer. Start tuning ..."
)
error_rate_func
=
cer
if
args
.
error_rate_type
==
'cer'
else
wer
# create grid for search
cand_alphas
=
np
.
linspace
(
args
.
alpha_from
,
args
.
alpha_to
,
args
.
num_alphas
)
...
...
@@ -116,6 +154,13 @@ def tune():
for
infer_data
in
batch_reader
():
if
(
args
.
num_batches
>=
0
)
and
(
cur_batch
>=
args
.
num_batches
):
break
infer_results
=
inferer
.
infer
(
input
=
infer_data
)
num_steps
=
len
(
infer_results
)
//
len
(
infer_data
)
probs_split
=
[
infer_results
[
i
*
num_steps
:(
i
+
1
)
*
num_steps
]
for
i
in
xrange
(
len
(
infer_data
))
]
target_transcripts
=
[
''
.
join
([
data_generator
.
vocab_list
[
token
]
for
token
in
transcript
])
...
...
@@ -125,18 +170,18 @@ def tune():
num_ins
+=
len
(
target_transcripts
)
# grid search
for
index
,
(
alpha
,
beta
)
in
enumerate
(
params_grid
):
result_transcripts
=
ds2_model
.
infer_batch
(
infer_data
=
infer_data
,
decoding_method
=
'ctc_beam_search'
,
beam_alpha
=
alpha
,
beam_beta
=
beta
,
# reset alpha & beta
ext_scorer
.
reset_params
(
alpha
,
beta
)
beam_search_results
=
ctc_beam_search_decoder_batch
(
probs_split
=
probs_split
,
vocabulary
=
vocab_list
,
beam_size
=
args
.
beam_size
,
num_processes
=
args
.
num_proc_bsearch
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_top_n
=
args
.
cutoff_top_n
,
vocab_list
=
vocab_list
,
language_model_path
=
args
.
lang_model_path
,
num_processes
=
args
.
num_proc_bsearch
)
ext_scoring_func
=
ext_scorer
,
)
result_transcripts
=
[
res
[
0
][
1
]
for
res
in
beam_search_results
]
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
err_sum
[
index
]
+=
error_rate_func
(
target
,
result
)
err_ave
[
index
]
=
err_sum
[
index
]
/
num_ins
...
...
@@ -167,7 +212,7 @@ def tune():
%
(
args
.
num_batches
,
"%.3f"
%
params_grid
[
min_index
][
0
],
"%.3f"
%
params_grid
[
min_index
][
1
]))
ds2_model
.
logger
.
info
(
"finish inference
"
)
logger
.
info
(
"finish tuning
"
)
def
main
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录