Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
92eacf54
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92eacf54
编写于
7月 31, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update default config params and result display for evaluator.py and infer.py for DS2.
上级
de212572
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
25 addition
and
10 deletion
+25
-10
evaluate.py
evaluate.py
+18
-8
infer.py
infer.py
+7
-2
未找到文件。
evaluate.py
浏览文件 @
92eacf54
...
@@ -4,6 +4,7 @@ from __future__ import division
...
@@ -4,6 +4,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
distutils.util
import
distutils.util
import
sys
import
argparse
import
argparse
import
gzip
import
gzip
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
...
@@ -12,13 +13,19 @@ from model import deep_speech2
...
@@ -12,13 +13,19 @@ from model import deep_speech2
from
decoder
import
*
from
decoder
import
*
from
lm.lm_scorer
import
LmScorer
from
lm.lm_scorer
import
LmScorer
from
error_rate
import
wer
from
error_rate
import
wer
import
utils
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch_size"
,
"--batch_size"
,
default
=
1
00
,
default
=
1
28
,
type
=
int
,
type
=
int
,
help
=
"Minibatch size for evaluation. (default: %(default)s)"
)
help
=
"Minibatch size for evaluation. (default: %(default)s)"
)
parser
.
add_argument
(
"--trainer_count"
,
default
=
8
,
type
=
int
,
help
=
"Trainer number. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_conv_layers"
,
"--num_conv_layers"
,
default
=
2
,
default
=
2
,
...
@@ -58,8 +65,8 @@ parser.add_argument(
...
@@ -58,8 +65,8 @@ parser.add_argument(
"--decode_method"
,
"--decode_method"
,
default
=
'beam_search'
,
default
=
'beam_search'
,
type
=
str
,
type
=
str
,
help
=
"Method for ctc decoding, best_path or beam_search.
(default: %(default)s)
"
help
=
"Method for ctc decoding, best_path or beam_search. "
)
"(default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--language_model_path"
,
"--language_model_path"
,
default
=
"lm/data/common_crawl_00.prune01111.trie.klm"
,
default
=
"lm/data/common_crawl_00.prune01111.trie.klm"
,
...
@@ -67,12 +74,12 @@ parser.add_argument(
...
@@ -67,12 +74,12 @@ parser.add_argument(
help
=
"Path for language model. (default: %(default)s)"
)
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--alpha"
,
"--alpha"
,
default
=
0.
2
6
,
default
=
0.
3
6
,
type
=
float
,
type
=
float
,
help
=
"Parameter associated with language model. (default: %(default)f)"
)
help
=
"Parameter associated with language model. (default: %(default)f)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--beta"
,
"--beta"
,
default
=
0.
1
,
default
=
0.
25
,
type
=
float
,
type
=
float
,
help
=
"Parameter associated with word count. (default: %(default)f)"
)
help
=
"Parameter associated with word count. (default: %(default)f)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -191,7 +198,7 @@ def evaluate():
...
@@ -191,7 +198,7 @@ def evaluate():
blank_id
=
len
(
data_generator
.
vocab_list
),
blank_id
=
len
(
data_generator
.
vocab_list
),
num_processes
=
args
.
num_processes_beam_search
,
num_processes
=
args
.
num_processes_beam_search
,
ext_scoring_func
=
ext_scorer
,
ext_scoring_func
=
ext_scorer
,
cutoff_prob
=
args
.
cutoff_prob
,
)
cutoff_prob
=
args
.
cutoff_prob
)
for
i
,
beam_search_result
in
enumerate
(
beam_search_results
):
for
i
,
beam_search_result
in
enumerate
(
beam_search_results
):
wer_sum
+=
wer
(
target_transcription
[
i
],
wer_sum
+=
wer
(
target_transcription
[
i
],
beam_search_result
[
0
][
1
])
beam_search_result
[
0
][
1
])
...
@@ -199,12 +206,15 @@ def evaluate():
...
@@ -199,12 +206,15 @@ def evaluate():
else
:
else
:
raise
ValueError
(
"Decoding method [%s] is not supported."
%
raise
ValueError
(
"Decoding method [%s] is not supported."
%
decode_method
)
decode_method
)
print
(
"WER (%d/?) = %f"
%
(
wer_counter
,
wer_sum
/
wer_counter
))
print
(
"Final WER = %f"
%
(
wer_sum
/
wer_counter
))
print
(
"Final WER (%d/%d) = %f"
%
(
wer_counter
,
wer_counter
,
wer_sum
/
wer_counter
))
def
main
():
def
main
():
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
1
)
utils
.
print_arguments
(
args
)
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
args
.
trainer_count
)
evaluate
()
evaluate
()
...
...
infer.py
浏览文件 @
92eacf54
...
@@ -57,6 +57,11 @@ parser.add_argument(
...
@@ -57,6 +57,11 @@ parser.add_argument(
type
=
str
,
type
=
str
,
help
=
"Feature type of audio data: 'linear' (power spectrum)"
help
=
"Feature type of audio data: 'linear' (power spectrum)"
" or 'mfcc'. (default: %(default)s)"
)
" or 'mfcc'. (default: %(default)s)"
)
parser
.
add_argument
(
"--trainer_count"
,
default
=
8
,
type
=
int
,
help
=
"Trainer number. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mean_std_filepath"
,
"--mean_std_filepath"
,
default
=
'mean_std.npz'
,
default
=
'mean_std.npz'
,
...
@@ -208,7 +213,7 @@ def infer():
...
@@ -208,7 +213,7 @@ def infer():
wer_cur
=
wer
(
target_transcription
[
i
],
beam_search_result
[
0
][
1
])
wer_cur
=
wer
(
target_transcription
[
i
],
beam_search_result
[
0
][
1
])
wer_sum
+=
wer_cur
wer_sum
+=
wer_cur
wer_counter
+=
1
wer_counter
+=
1
print
(
"
cur wer = %f , average wer
= %f"
%
print
(
"
Current WER = %f , Average WER
= %f"
%
(
wer_cur
,
wer_sum
/
wer_counter
))
(
wer_cur
,
wer_sum
/
wer_counter
))
else
:
else
:
raise
ValueError
(
"Decoding method [%s] is not supported."
%
raise
ValueError
(
"Decoding method [%s] is not supported."
%
...
@@ -217,7 +222,7 @@ def infer():
...
@@ -217,7 +222,7 @@ def infer():
def
main
():
def
main
():
utils
.
print_arguments
(
args
)
utils
.
print_arguments
(
args
)
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
1
)
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
args
.
trainer_count
)
infer
()
infer
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录