Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
ae796a9d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae796a9d
编写于
12月 04, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Correct the error rate's computation for multiple sentences
上级
78968af6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
20 addition
and
12 deletion
+20
-12
test.py
test.py
+8
-6
tools/tune.py
tools/tune.py
+12
-6
未找到文件。
test.py
浏览文件 @
ae796a9d
...
@@ -8,7 +8,7 @@ import functools
...
@@ -8,7 +8,7 @@ import functools
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
from
data_utils.data
import
DataGenerator
from
data_utils.data
import
DataGenerator
from
model_utils.model
import
DeepSpeech2Model
from
model_utils.model
import
DeepSpeech2Model
from
utils.error_rate
import
wer
,
cer
from
utils.error_rate
import
char_errors
,
word_errors
from
utils.utility
import
add_arguments
,
print_arguments
from
utils.utility
import
add_arguments
,
print_arguments
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
@@ -91,8 +91,8 @@ def evaluate():
...
@@ -91,8 +91,8 @@ def evaluate():
# decoders only accept string encoded in utf-8
# decoders only accept string encoded in utf-8
vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
data_generator
.
vocab_list
]
vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
data_generator
.
vocab_list
]
error
_rate_func
=
cer
if
args
.
error_rate_type
==
'cer'
else
wer
error
s_func
=
char_errors
if
args
.
error_rate_type
==
'cer'
else
word_errors
error
_sum
,
num_ins
=
0.
0
,
0
error
s_sum
,
len_refs
,
num_ins
=
0.0
,
0
,
0
for
infer_data
in
batch_reader
():
for
infer_data
in
batch_reader
():
result_transcripts
=
ds2_model
.
infer_batch
(
result_transcripts
=
ds2_model
.
infer_batch
(
infer_data
=
infer_data
,
infer_data
=
infer_data
,
...
@@ -108,12 +108,14 @@ def evaluate():
...
@@ -108,12 +108,14 @@ def evaluate():
feeding_dict
=
data_generator
.
feeding
)
feeding_dict
=
data_generator
.
feeding
)
target_transcripts
=
[
data
[
1
]
for
data
in
infer_data
]
target_transcripts
=
[
data
[
1
]
for
data
in
infer_data
]
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
error_sum
+=
error_rate_func
(
target
,
result
)
errors
,
len_ref
=
errors_func
(
target
,
result
)
errors_sum
+=
errors
len_refs
+=
len_ref
num_ins
+=
1
num_ins
+=
1
print
(
"Error rate [%s] (%d/?) = %f"
%
print
(
"Error rate [%s] (%d/?) = %f"
%
(
args
.
error_rate_type
,
num_ins
,
error
_sum
/
num_in
s
))
(
args
.
error_rate_type
,
num_ins
,
error
s_sum
/
len_ref
s
))
print
(
"Final error rate [%s] (%d/%d) = %f"
%
print
(
"Final error rate [%s] (%d/%d) = %f"
%
(
args
.
error_rate_type
,
num_ins
,
num_ins
,
error
_sum
/
num_in
s
))
(
args
.
error_rate_type
,
num_ins
,
num_ins
,
error
s_sum
/
len_ref
s
))
ds2_model
.
logger
.
info
(
"finish evaluation"
)
ds2_model
.
logger
.
info
(
"finish evaluation"
)
...
...
tools/tune.py
浏览文件 @
ae796a9d
...
@@ -16,7 +16,7 @@ from data_utils.data import DataGenerator
...
@@ -16,7 +16,7 @@ from data_utils.data import DataGenerator
from
decoders.swig_wrapper
import
Scorer
from
decoders.swig_wrapper
import
Scorer
from
decoders.swig_wrapper
import
ctc_beam_search_decoder_batch
from
decoders.swig_wrapper
import
ctc_beam_search_decoder_batch
from
model_utils.model
import
deep_speech_v2_network
from
model_utils.model
import
deep_speech_v2_network
from
utils.error_rate
import
wer
,
cer
from
utils.error_rate
import
char_errors
,
word_errors
from
utils.utility
import
add_arguments
,
print_arguments
from
utils.utility
import
add_arguments
,
print_arguments
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
@@ -158,7 +158,7 @@ def tune():
...
@@ -158,7 +158,7 @@ def tune():
" dict_size = %d"
%
ext_scorer
.
get_dict_size
())
" dict_size = %d"
%
ext_scorer
.
get_dict_size
())
logger
.
info
(
"end initializing scorer. Start tuning ..."
)
logger
.
info
(
"end initializing scorer. Start tuning ..."
)
error
_rate_func
=
cer
if
args
.
error_rate_type
==
'cer'
else
wer
error
s_func
=
char_errors
if
args
.
error_rate_type
==
'cer'
else
word_errors
# create grid for search
# create grid for search
cand_alphas
=
np
.
linspace
(
args
.
alpha_from
,
args
.
alpha_to
,
args
.
num_alphas
)
cand_alphas
=
np
.
linspace
(
args
.
alpha_from
,
args
.
alpha_to
,
args
.
num_alphas
)
cand_betas
=
np
.
linspace
(
args
.
beta_from
,
args
.
beta_to
,
args
.
num_betas
)
cand_betas
=
np
.
linspace
(
args
.
beta_from
,
args
.
beta_to
,
args
.
num_betas
)
...
@@ -167,7 +167,7 @@ def tune():
...
@@ -167,7 +167,7 @@ def tune():
err_sum
=
[
0.0
for
i
in
xrange
(
len
(
params_grid
))]
err_sum
=
[
0.0
for
i
in
xrange
(
len
(
params_grid
))]
err_ave
=
[
0.0
for
i
in
xrange
(
len
(
params_grid
))]
err_ave
=
[
0.0
for
i
in
xrange
(
len
(
params_grid
))]
num_ins
,
cur_batch
=
0
,
0
num_ins
,
len_refs
,
cur_batch
=
0
,
0
,
0
## incremental tuning parameters over multiple batches
## incremental tuning parameters over multiple batches
for
infer_data
in
batch_reader
():
for
infer_data
in
batch_reader
():
if
(
args
.
num_batches
>=
0
)
and
(
cur_batch
>=
args
.
num_batches
):
if
(
args
.
num_batches
>=
0
)
and
(
cur_batch
>=
args
.
num_batches
):
...
@@ -200,8 +200,14 @@ def tune():
...
@@ -200,8 +200,14 @@ def tune():
result_transcripts
=
[
res
[
0
][
1
]
for
res
in
beam_search_results
]
result_transcripts
=
[
res
[
0
][
1
]
for
res
in
beam_search_results
]
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
for
target
,
result
in
zip
(
target_transcripts
,
result_transcripts
):
err_sum
[
index
]
+=
error_rate_func
(
target
,
result
)
errors
,
len_ref
=
errors_func
(
target
,
result
)
err_ave
[
index
]
=
err_sum
[
index
]
/
num_ins
err_sum
[
index
]
+=
errors
# accumulate the length of references of every batch
# in the first iteration
if
args
.
alpha_from
==
alpha
and
args
.
beta_from
==
beta
:
len_refs
+=
len_ref
err_ave
[
index
]
=
err_sum
[
index
]
/
len_refs
if
index
%
2
==
0
:
if
index
%
2
==
0
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
@@ -226,7 +232,7 @@ def tune():
...
@@ -226,7 +232,7 @@ def tune():
err_ave_min
=
min
(
err_ave
)
err_ave_min
=
min
(
err_ave
)
min_index
=
err_ave
.
index
(
err_ave_min
)
min_index
=
err_ave
.
index
(
err_ave_min
)
print
(
"
\n
Finish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
print
(
"
\n
Finish tuning on %d batches, final opt (alpha, beta) = (%s, %s)"
%
(
args
.
num_batches
,
"%.3f"
%
params_grid
[
min_index
][
0
],
%
(
cur_batch
,
"%.3f"
%
params_grid
[
min_index
][
0
],
"%.3f"
%
params_grid
[
min_index
][
1
]))
"%.3f"
%
params_grid
[
min_index
][
1
]))
logger
.
info
(
"finish tuning"
)
logger
.
info
(
"finish tuning"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录