Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7e093ed1
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7e093ed1
编写于
9月 16, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
expose param cutoff_top_n
上级
1cf61a15
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
35 addition
and
23 deletion
+35
-23
data_utils/featurizer/text_featurizer.py
data_utils/featurizer/text_featurizer.py
+0
-2
decoders/decoder_deprecated.py
decoders/decoder_deprecated.py
+8
-12
decoders/lm_scorer_deprecated.py
decoders/lm_scorer_deprecated.py
+1
-1
decoders/swig/ctc_decoders.cpp
decoders/swig/ctc_decoders.cpp
+1
-1
examples/librispeech/run_infer.sh
examples/librispeech/run_infer.sh
+1
-0
examples/librispeech/run_infer_golden.sh
examples/librispeech/run_infer_golden.sh
+1
-0
examples/librispeech/run_test_golden.sh
examples/librispeech/run_test_golden.sh
+1
-0
infer.py
infer.py
+7
-2
model_utils/model.py
model_utils/model.py
+8
-3
test.py
test.py
+7
-2
未找到文件。
data_utils/featurizer/text_featurizer.py
浏览文件 @
7e093ed1
...
...
@@ -22,8 +22,6 @@ class TextFeaturizer(object):
def
__init__
(
self
,
vocab_filepath
):
self
.
_vocab_dict
,
self
.
_vocab_list
=
self
.
_load_vocabulary_from_file
(
vocab_filepath
)
# from unicode to string
self
.
_vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
self
.
_vocab_list
]
def
featurize
(
self
,
text
):
"""Convert text string to a list of token indices in char-level.Note
...
...
decoders/decoder_deprecated.py
浏览文件 @
7e093ed1
...
...
@@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def
ctc_beam_search_decoder
(
probs_seq
,
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
,
nproc
=
False
):
"""CTC Beam search decoder.
...
...
@@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
...
...
@@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq,
raise
ValueError
(
"The shape of prob_seq does not match with the "
"shape of the vocabulary."
)
# blank_id check
if
not
blank_id
<
len
(
probs_seq
[
0
]):
raise
ValueError
(
"blank_id shouldn't be greater than probs dimension"
)
# blank_id assign
blank_id
=
len
(
vocabulary
)
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
...
...
@@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq,
prob_idx
=
list
(
enumerate
(
probs_seq
[
time_step
]))
cutoff_len
=
len
(
prob_idx
)
#If pruning is enabled
if
cutoff_prob
<
1.0
:
if
cutoff_prob
<
1.0
or
cutoff_top_n
<
cutoff_len
:
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
cutoff_len
,
cum_prob
=
0
,
0.0
for
i
in
xrange
(
len
(
prob_idx
)):
...
...
@@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_len
+=
1
if
cum_prob
>=
cutoff_prob
:
break
cutoff_len
=
min
(
cutoff_top_n
,
cutoff_top_n
)
prob_idx
=
prob_idx
[
0
:
cutoff_len
]
for
l
in
prefix_set_prev
:
...
...
@@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq,
def
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
blank_id
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
"""CTC beam search decoder using multiple processes.
...
...
@@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
...
...
@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool
=
multiprocessing
.
Pool
(
processes
=
num_processes
)
results
=
[]
for
i
,
probs_list
in
enumerate
(
probs_split
):
args
=
(
probs_list
,
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
None
,
nproc
)
args
=
(
probs_list
,
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
cutoff_top_n
,
None
,
nproc
)
results
.
append
(
pool
.
apply_async
(
ctc_beam_search_decoder
,
args
))
pool
.
close
()
...
...
decoders/lm_scorer_deprecated.py
浏览文件 @
7e093ed1
...
...
@@ -8,7 +8,7 @@ import kenlm
import
numpy
as
np
class
Lm
Scorer
(
object
):
class
Scorer
(
object
):
"""External scorer to evaluate a prefix or whole sentence in
beam search decoding, including the score from n-gram language
model and word count.
...
...
decoders/swig/ctc_decoders.cpp
浏览文件 @
7e093ed1
...
...
@@ -128,7 +128,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// pruning of vacobulary
size_t
cutoff_len
=
prob
.
size
();
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
prob
.
size
()
)
{
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
cutoff_len
)
{
std
::
sort
(
prob_idx
.
begin
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
if
(
cutoff_prob
<
1.0
)
{
...
...
examples/librispeech/run_infer.sh
浏览文件 @
7e093ed1
...
...
@@ -24,6 +24,7 @@ python -u infer.py \
--alpha
=
2.15
\
--beta
=
0.35
\
--cutoff_prob
=
1.0
\
--cutoff_top_n
=
40
\
--use_gru
=
False
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
...
...
examples/librispeech/run_infer_golden.sh
浏览文件 @
7e093ed1
...
...
@@ -33,6 +33,7 @@ python -u infer.py \
--alpha
=
2.15
\
--beta
=
0.35
\
--cutoff_prob
=
1.0
\
--cutoff_top_n
=
40
\
--use_gru
=
False
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
...
...
examples/librispeech/run_test_golden.sh
浏览文件 @
7e093ed1
...
...
@@ -34,6 +34,7 @@ python -u test.py \
--alpha
=
2.15
\
--beta
=
0.35
\
--cutoff_prob
=
1.0
\
--cutoff_top_n
=
40
\
--use_gru
=
False
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
...
...
infer.py
浏览文件 @
7e093ed1
...
...
@@ -23,7 +23,8 @@ add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg
(
'rnn_layer_size'
,
int
,
2048
,
"# of recurrent cells per layer."
)
add_arg
(
'alpha'
,
float
,
2.15
,
"Coef of LM for beam search."
)
add_arg
(
'beta'
,
float
,
0.35
,
"Coef of WC for beam search."
)
add_arg
(
'cutoff_prob'
,
float
,
1.0
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_prob'
,
float
,
1.0
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_top_n'
,
int
,
40
,
"Cutoff number for pruning."
)
add_arg
(
'use_gru'
,
bool
,
False
,
"Use GRUs instead of simple RNNs."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Use GPU or not."
)
add_arg
(
'share_rnn_weights'
,
bool
,
True
,
"Share input-hidden weights across "
...
...
@@ -85,6 +86,9 @@ def infer():
pretrained_model_path
=
args
.
model_path
,
share_rnn_weights
=
args
.
share_rnn_weights
)
# decoders only accept string encoded in utf-8
vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
data_generator
.
vocab_list
]
result_transcripts
=
ds2_model
.
infer_batch
(
infer_data
=
infer_data
,
decoding_method
=
args
.
decoding_method
,
...
...
@@ -92,7 +96,8 @@ def infer():
beam_beta
=
args
.
beta
,
beam_size
=
args
.
beam_size
,
cutoff_prob
=
args
.
cutoff_prob
,
vocab_list
=
data_generator
.
vocab_list
,
cutoff_top_n
=
args
.
cutoff_top_n
,
vocab_list
=
vocab_list
,
language_model_path
=
args
.
lang_model_path
,
num_processes
=
args
.
num_proc_bsearch
)
...
...
model_utils/model.py
浏览文件 @
7e093ed1
...
...
@@ -148,8 +148,8 @@ class DeepSpeech2Model(object):
return
self
.
_loss_inferer
.
infer
(
input
=
infer_data
)
def
infer_batch
(
self
,
infer_data
,
decoding_method
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
vocab_list
,
language_model_path
,
num_processes
):
beam_size
,
cutoff_prob
,
cutoff_top_n
,
vocab_list
,
language_model_path
,
num_processes
):
"""Model inference. Infer the transcription for a batch of speech
utterances.
...
...
@@ -169,6 +169,10 @@ class DeepSpeech2Model(object):
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param language_model_path: Filepath for language model.
...
...
@@ -216,7 +220,8 @@ class DeepSpeech2Model(object):
beam_size
=
beam_size
,
num_processes
=
num_processes
,
ext_scoring_func
=
self
.
_ext_scorer
,
cutoff_prob
=
cutoff_prob
)
cutoff_prob
=
cutoff_prob
,
cutoff_top_n
=
cutoff_top_n
)
results
=
[
result
[
0
][
1
]
for
result
in
beam_search_results
]
else
:
...
...
test.py
浏览文件 @
7e093ed1
...
...
@@ -24,7 +24,8 @@ add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg
(
'rnn_layer_size'
,
int
,
2048
,
"# of recurrent cells per layer."
)
add_arg
(
'alpha'
,
float
,
2.15
,
"Coef of LM for beam search."
)
add_arg
(
'beta'
,
float
,
0.35
,
"Coef of WC for beam search."
)
add_arg
(
'cutoff_prob'
,
float
,
1.0
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_prob'
,
float
,
1.0
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_top_n'
,
int
,
40
,
"Cutoff number for pruning."
)
add_arg
(
'use_gru'
,
bool
,
False
,
"Use GRUs instead of simple RNNs."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Use GPU or not."
)
add_arg
(
'share_rnn_weights'
,
bool
,
True
,
"Share input-hidden weights across "
...
...
@@ -85,6 +86,9 @@ def evaluate():
pretrained_model_path
=
args
.
model_path
,
share_rnn_weights
=
args
.
share_rnn_weights
)
# decoders only accept string encoded in utf-8
vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
data_generator
.
vocab_list
]
error_rate_func
=
cer
if
args
.
error_rate_type
==
'cer'
else
wer
error_sum
,
num_ins
=
0.0
,
0
for
infer_data
in
batch_reader
():
...
...
@@ -95,7 +99,8 @@ def evaluate():
beam_beta
=
args
.
beta
,
beam_size
=
args
.
beam_size
,
cutoff_prob
=
args
.
cutoff_prob
,
vocab_list
=
data_generator
.
vocab_list
,
cutoff_top_n
=
args
.
cutoff_top_n
,
vocab_list
=
vocab_list
,
language_model_path
=
args
.
lang_model_path
,
num_processes
=
args
.
num_proc_bsearch
)
target_transcripts
=
[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录