Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
dedbfb26
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dedbfb26
编写于
6月 05, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable ctc beam search decoder
上级
cfe9d228
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
65 addition
and
14 deletion
+65
-14
ctc_beam_search_decoder.py
ctc_beam_search_decoder.py
+21
-9
decoder.py
decoder.py
+14
-2
infer.py
infer.py
+30
-3
未找到文件。
ctc_beam_search_decoder.py
浏览文件 @
dedbfb26
...
...
@@ -23,10 +23,26 @@ def ids_id2token(ids_list):
return
ids_str
def
language_model
(
ids_list
,
vocabulary
):
# lookup ptb vocabulary
ptb_vocab_path
=
"./data/ptb_vocab.txt"
sentence
=
''
.
join
([
vocabulary
[
ids
]
for
ids
in
ids_list
])
words
=
sentence
.
split
(
' '
)
last_word
=
words
[
-
1
]
with
open
(
ptb_vocab_path
,
'r'
)
as
ptb_vocab
:
f
=
ptb_vocab
.
readline
()
while
f
:
if
f
==
last_word
:
return
1.0
f
=
ptb_vocab
.
readline
()
return
0.0
def
ctc_beam_search_decoder
(
input_probs_matrix
,
beam_size
,
vocabulary
,
max_time_steps
=
None
,
lang_model
=
None
,
lang_model
=
language_model
,
alpha
=
1.0
,
beta
=
1.0
,
blank_id
=
0
,
...
...
@@ -120,7 +136,7 @@ def ctc_beam_search_decoder(input_probs_matrix,
probs_nb_cur
[
l
]
+=
prob
[
c
]
*
probs_nb
[
l
]
elif
c
==
space_id
:
lm
=
1.0
if
lang_model
is
None
\
else
np
.
power
(
lang_model
(
ids_list
),
alpha
)
else
np
.
power
(
lang_model
(
ids_list
,
vocabulary
),
alpha
)
probs_nb_cur
[
l_plus
]
+=
lm
*
prob
[
c
]
*
(
probs_b
[
l
]
+
probs_nb
[
l
])
else
:
...
...
@@ -145,9 +161,10 @@ def ctc_beam_search_decoder(input_probs_matrix,
beam_result
=
[]
for
(
seq
,
prob
)
in
prefix_set_prev
.
items
():
if
prob
>
0.0
:
ids_list
=
ids_str2list
(
seq
)
ids_list
=
ids_str2list
(
seq
)[
1
:]
result
=
''
.
join
([
vocabulary
[
ids
]
for
ids
in
ids_list
])
log_prob
=
np
.
log
(
prob
)
beam_result
.
append
([
log_prob
,
ids_list
[
1
:]
])
beam_result
.
append
([
log_prob
,
result
])
## output top beam_size decoding results
beam_result
=
sorted
(
beam_result
,
key
=
lambda
asd
:
asd
[
0
],
reverse
=
True
)
...
...
@@ -156,11 +173,6 @@ def ctc_beam_search_decoder(input_probs_matrix,
return
beam_result
def
language_model
(
input
):
# TODO
return
random
.
uniform
(
0
,
1
)
def
simple_test
():
input_probs_matrix
=
[[
0.1
,
0.3
,
0.6
],
[
0.2
,
0.1
,
0.7
],
[
0.5
,
0.2
,
0.3
]]
...
...
decoder.py
浏览文件 @
dedbfb26
...
...
@@ -4,6 +4,7 @@
from
itertools
import
groupby
import
numpy
as
np
from
ctc_beam_search_decoder
import
*
def
ctc_best_path_decode
(
probs_seq
,
vocabulary
):
...
...
@@ -36,7 +37,11 @@ def ctc_best_path_decode(probs_seq, vocabulary):
return
''
.
join
([
vocabulary
[
index
]
for
index
in
index_list
])
def
ctc_decode
(
probs_seq
,
vocabulary
,
method
):
def
ctc_decode
(
probs_seq
,
vocabulary
,
method
,
beam_size
=
None
,
num_results_per_sample
=
None
):
"""
CTC-like sequence decoding from a sequence of likelihood probablilites.
...
...
@@ -56,5 +61,12 @@ def ctc_decode(probs_seq, vocabulary, method):
raise
ValueError
(
"probs dimension mismatchedd with vocabulary"
)
if
method
==
"best_path"
:
return
ctc_best_path_decode
(
probs_seq
,
vocabulary
)
elif
method
==
"beam_search"
:
return
ctc_beam_search_decoder
(
input_probs_matrix
=
probs_seq
,
vocabulary
=
vocabulary
,
beam_size
=
beam_size
,
blank_id
=
len
(
vocabulary
),
num_results_per_sample
=
num_results_per_sample
)
else
:
raise
ValueError
(
"Decoding method [%s] is not supported."
)
raise
ValueError
(
"Decoding method [%s] is not supported."
%
method
)
infer.py
浏览文件 @
dedbfb26
...
...
@@ -57,6 +57,23 @@ parser.add_argument(
default
=
'data/eng_vocab.txt'
,
type
=
str
,
help
=
"Vocabulary filepath. (default: %(default)s)"
)
parser
.
add_argument
(
"--decode_method"
,
default
=
'best_path'
,
type
=
str
,
help
=
"Method for ctc decoding, best_path or beam_search. (default: %(default)s)"
)
parser
.
add_argument
(
"--beam_size"
,
default
=
50
,
type
=
int
,
help
=
"Width for beam search decoding. (default: %(default)d)"
)
parser
.
add_argument
(
"--num_result_per_sample"
,
default
=
2
,
type
=
int
,
help
=
"Number of results per given sample in beam search. (default: %(default)d)"
)
args
=
parser
.
parse_args
()
...
...
@@ -120,12 +137,22 @@ def infer():
# decode and print
for
i
,
probs
in
enumerate
(
probs_split
):
output
_transcription
=
ctc_decode
(
best_path
_transcription
=
ctc_decode
(
probs_seq
=
probs
,
vocabulary
=
vocab_list
,
method
=
"best_path"
)
target_transcription
=
''
.
join
(
[
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]])
print
(
"Target Transcription: %s
\n
Output Transcription: %s
\n
"
%
(
target_transcription
,
output_transcription
))
print
(
"
\n
Target Transcription: %s
\n
Bst_path Transcription: %s"
%
(
target_transcription
,
best_path_transcription
))
beam_search_transcription
=
ctc_decode
(
probs_seq
=
probs
,
vocabulary
=
vocab_list
,
method
=
"beam_search"
,
beam_size
=
args
.
beam_size
,
num_results_per_sample
=
args
.
num_result_per_sample
)
for
index
in
range
(
len
(
beam_search_transcription
)):
print
(
"LM No, %d - %4f: %s "
%
(
index
,
beam_search_transcription
[
index
][
0
],
beam_search_transcription
[
index
][
1
]))
def
main
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录