Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
a0c89ae7
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0c89ae7
编写于
8月 30, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add min cutoff & top n cutoff
上级
a661941a
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
75 addition
and
36 deletion
+75
-36
deploy.py
deploy.py
+11
-3
deploy/ctc_decoders.cpp
deploy/ctc_decoders.cpp
+45
-26
deploy/ctc_decoders.h
deploy/ctc_decoders.h
+2
-0
deploy/scorer.h
deploy/scorer.h
+1
-1
deploy/swig_decoders_wrapper.py
deploy/swig_decoders_wrapper.py
+16
-6
未找到文件。
deploy.py
浏览文件 @
a0c89ae7
...
@@ -18,7 +18,7 @@ import time
...
@@ -18,7 +18,7 @@ import time
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_samples"
,
"--num_samples"
,
default
=
4
,
default
=
10
,
type
=
int
,
type
=
int
,
help
=
"Number of samples for inference. (default: %(default)s)"
)
help
=
"Number of samples for inference. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -95,12 +95,12 @@ parser.add_argument(
...
@@ -95,12 +95,12 @@ parser.add_argument(
help
=
"Path for language model. (default: %(default)s)"
)
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--alpha"
,
"--alpha"
,
default
=
0.26
,
default
=
1.5
,
type
=
float
,
type
=
float
,
help
=
"Parameter associated with language model. (default: %(default)f)"
)
help
=
"Parameter associated with language model. (default: %(default)f)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--beta"
,
"--beta"
,
default
=
0.
1
,
default
=
0.
3
,
type
=
float
,
type
=
float
,
help
=
"Parameter associated with word count. (default: %(default)f)"
)
help
=
"Parameter associated with word count. (default: %(default)f)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -109,6 +109,12 @@ parser.add_argument(
...
@@ -109,6 +109,12 @@ parser.add_argument(
type
=
float
,
type
=
float
,
help
=
"The cutoff probability of pruning"
help
=
"The cutoff probability of pruning"
"in beam search. (default: %(default)f)"
)
"in beam search. (default: %(default)f)"
)
parser
.
add_argument
(
"--cutoff_top_n"
,
default
=
40
,
type
=
int
,
help
=
"The cutoff number of pruning"
"in beam search. (default: %(default)f)"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -184,6 +190,7 @@ def infer():
...
@@ -184,6 +190,7 @@ def infer():
vocabulary
=
data_generator
.
vocab_list
,
vocabulary
=
data_generator
.
vocab_list
,
blank_id
=
len
(
data_generator
.
vocab_list
),
blank_id
=
len
(
data_generator
.
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_top_n
=
args
.
cutoff_top_n
,
ext_scoring_func
=
ext_scorer
,
)
ext_scoring_func
=
ext_scorer
,
)
batch_beam_results
+=
[
beam_result
]
batch_beam_results
+=
[
beam_result
]
else
:
else
:
...
@@ -194,6 +201,7 @@ def infer():
...
@@ -194,6 +201,7 @@ def infer():
blank_id
=
len
(
data_generator
.
vocab_list
),
blank_id
=
len
(
data_generator
.
vocab_list
),
num_processes
=
args
.
num_processes_beam_search
,
num_processes
=
args
.
num_processes_beam_search
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_top_n
=
args
.
cutoff_top_n
,
ext_scoring_func
=
ext_scorer
,
)
ext_scoring_func
=
ext_scorer
,
)
for
i
,
beam_result
in
enumerate
(
batch_beam_results
):
for
i
,
beam_result
in
enumerate
(
batch_beam_results
):
...
...
deploy/ctc_decoders.cpp
浏览文件 @
a0c89ae7
...
@@ -62,6 +62,7 @@ std::vector<std::pair<double, std::string> >
...
@@ -62,6 +62,7 @@ std::vector<std::pair<double, std::string> >
std
::
vector
<
std
::
string
>
vocabulary
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
blank_id
,
double
cutoff_prob
,
double
cutoff_prob
,
int
cutoff_top_n
,
Scorer
*
ext_scorer
)
Scorer
*
ext_scorer
)
{
{
// dimension check
// dimension check
...
@@ -116,12 +117,24 @@ std::vector<std::pair<double, std::string> >
...
@@ -116,12 +117,24 @@ std::vector<std::pair<double, std::string> >
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob
[
i
]));
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob
[
i
]));
}
}
float
min_cutoff
=
-
NUM_FLT_INF
;
bool
full_beam
=
false
;
if
(
ext_scorer
!=
nullptr
)
{
int
num_prefixes
=
std
::
min
((
int
)
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
min_cutoff
=
prefixes
[
num_prefixes
-
1
]
->
_score
+
log
(
prob
[
blank_id
])
-
std
::
max
(
0.0
,
ext_scorer
->
beta
);
full_beam
=
(
num_prefixes
==
beam_size
);
}
// pruning of vacobulary
// pruning of vacobulary
int
cutoff_len
=
prob
.
size
();
int
cutoff_len
=
prob
.
size
();
if
(
cutoff_prob
<
1.0
)
{
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
prob
.
size
()
)
{
std
::
sort
(
prob_idx
.
begin
(),
std
::
sort
(
prob_idx
.
begin
(),
prob_idx
.
end
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
pair_comp_second_rev
<
int
,
double
>
);
if
(
cutoff_prob
<
1.0
)
{
double
cum_prob
=
0.0
;
double
cum_prob
=
0.0
;
cutoff_len
=
0
;
cutoff_len
=
0
;
for
(
int
i
=
0
;
i
<
prob_idx
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
prob_idx
.
size
();
i
++
)
{
...
@@ -129,6 +142,8 @@ std::vector<std::pair<double, std::string> >
...
@@ -129,6 +142,8 @@ std::vector<std::pair<double, std::string> >
cutoff_len
+=
1
;
cutoff_len
+=
1
;
if
(
cum_prob
>=
cutoff_prob
)
break
;
if
(
cum_prob
>=
cutoff_prob
)
break
;
}
}
}
cutoff_len
=
std
::
min
(
cutoff_len
,
cutoff_top_n
);
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>
>
(
prob_idx
.
begin
(),
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>
>
(
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
prob_idx
.
begin
()
+
cutoff_len
);
}
}
...
@@ -138,15 +153,17 @@ std::vector<std::pair<double, std::string> >
...
@@ -138,15 +153,17 @@ std::vector<std::pair<double, std::string> >
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
float
>
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
float
>
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_FLT_MIN
)));
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_FLT_MIN
)));
}
}
// loop over chars
// loop over chars
for
(
int
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
for
(
int
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
auto
c
=
log_prob_idx
[
index
].
first
;
auto
c
=
log_prob_idx
[
index
].
first
;
float
log_prob_c
=
log_prob_idx
[
index
].
second
;
float
log_prob_c
=
log_prob_idx
[
index
].
second
;
//float log_probs_prev;
for
(
int
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
i
++
)
{
auto
prefix
=
prefixes
[
i
];
auto
prefix
=
prefixes
[
i
];
if
(
full_beam
&&
log_prob_c
+
prefix
->
_score
<
min_cutoff
)
{
break
;
}
// blank
// blank
if
(
c
==
blank_id
)
{
if
(
c
==
blank_id
)
{
prefix
->
_log_prob_b_cur
=
log_sum_exp
(
prefix
->
_log_prob_b_cur
=
log_sum_exp
(
...
@@ -178,7 +195,7 @@ std::vector<std::pair<double, std::string> >
...
@@ -178,7 +195,7 @@ std::vector<std::pair<double, std::string> >
(
c
==
space_id
||
ext_scorer
->
is_character_based
())
)
{
(
c
==
space_id
||
ext_scorer
->
is_character_based
())
)
{
PathTrie
*
prefix_to_score
=
nullptr
;
PathTrie
*
prefix_to_score
=
nullptr
;
//
don't score
the space
//
skip scoring
the space
if
(
ext_scorer
->
is_character_based
())
{
if
(
ext_scorer
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
prefix_to_score
=
prefix_new
;
}
else
{
}
else
{
...
@@ -202,10 +219,10 @@ std::vector<std::pair<double, std::string> >
...
@@ -202,10 +219,10 @@ std::vector<std::pair<double, std::string> >
}
// end of loop over chars
}
// end of loop over chars
prefixes
.
clear
();
prefixes
.
clear
();
// update log prob
abilitie
s
// update log probs
root
.
iterate_to_vec
(
prefixes
);
root
.
iterate_to_vec
(
prefixes
);
//
sort prefixes by score
//
preserve top beam_size prefixes
if
(
prefixes
.
size
()
>=
beam_size
)
{
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
begin
()
+
beam_size
,
...
@@ -218,18 +235,20 @@ std::vector<std::pair<double, std::string> >
...
@@ -218,18 +235,20 @@ std::vector<std::pair<double, std::string> >
}
}
}
}
// compute aproximate ctc score as the return score
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
double
approx_ctc
=
prefixes
[
i
]
->
_score
;
double
approx_ctc
=
prefixes
[
i
]
->
_score
;
// remove word insert:
if
(
ext_scorer
!=
nullptr
)
{
std
::
vector
<
int
>
output
;
std
::
vector
<
int
>
output
;
prefixes
[
i
]
->
get_path_vec
(
output
);
prefixes
[
i
]
->
get_path_vec
(
output
);
size_t
prefix_length
=
output
.
size
();
size_t
prefix_length
=
output
.
size
();
auto
words
=
ext_scorer
->
split_labels
(
output
);
// remove word insert
approx_ctc
=
approx_ctc
-
prefix_length
*
ext_scorer
->
beta
;
// remove language model weight:
// remove language model weight:
if
(
ext_scorer
!=
nullptr
)
{
approx_ctc
-=
(
ext_scorer
->
get_sent_log_prob
(
words
))
// auto words = split_labels(output);
*
ext_scorer
->
alpha
;
// approx_ctc = approx_ctc - path_length * ext_scorer->beta;
// approx_ctc -= (_lm->get_sent_log_prob(words)) * ext_scorer->alpha;
}
}
prefixes
[
i
]
->
_approx_ctc
=
approx_ctc
;
prefixes
[
i
]
->
_approx_ctc
=
approx_ctc
;
...
@@ -253,11 +272,9 @@ std::vector<std::pair<double, std::string> >
...
@@ -253,11 +272,9 @@ std::vector<std::pair<double, std::string> >
for
(
int
j
=
0
;
j
<
output
.
size
();
j
++
)
{
for
(
int
j
=
0
;
j
<
output
.
size
();
j
++
)
{
output_str
+=
vocabulary
[
output
[
j
]];
output_str
+=
vocabulary
[
output
[
j
]];
}
}
std
::
pair
<
double
,
std
::
string
>
output_pair
(
space_prefixes
[
i
]
->
_score
,
std
::
pair
<
double
,
std
::
string
>
output_str
);
output_pair
(
-
space_prefixes
[
i
]
->
_approx_ctc
,
output_str
);
output_vecs
.
emplace_back
(
output_vecs
.
emplace_back
(
output_pair
);
output_pair
);
}
}
return
output_vecs
;
return
output_vecs
;
...
@@ -272,6 +289,7 @@ std::vector<std::vector<std::pair<double, std::string> > >
...
@@ -272,6 +289,7 @@ std::vector<std::vector<std::pair<double, std::string> > >
int
blank_id
,
int
blank_id
,
int
num_processes
,
int
num_processes
,
double
cutoff_prob
,
double
cutoff_prob
,
int
cutoff_top_n
,
Scorer
*
ext_scorer
Scorer
*
ext_scorer
)
{
)
{
if
(
num_processes
<=
0
)
{
if
(
num_processes
<=
0
)
{
...
@@ -295,7 +313,8 @@ std::vector<std::vector<std::pair<double, std::string> > >
...
@@ -295,7 +313,8 @@ std::vector<std::vector<std::pair<double, std::string> > >
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
res
.
emplace_back
(
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
ext_scorer
)
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
cutoff_top_n
,
ext_scorer
)
);
);
}
}
// get decoding results
// get decoding results
...
...
deploy/ctc_decoders.h
浏览文件 @
a0c89ae7
...
@@ -39,6 +39,7 @@ std::vector<std::pair<double, std::string> >
...
@@ -39,6 +39,7 @@ std::vector<std::pair<double, std::string> >
std
::
vector
<
std
::
string
>
vocabulary
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
blank_id
,
double
cutoff_prob
=
1
.
0
,
double
cutoff_prob
=
1
.
0
,
int
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
NULL
Scorer
*
ext_scorer
=
NULL
);
);
...
@@ -66,6 +67,7 @@ std::vector<std::vector<std::pair<double, std::string>>>
...
@@ -66,6 +67,7 @@ std::vector<std::vector<std::pair<double, std::string>>>
int
blank_id
,
int
blank_id
,
int
num_processes
,
int
num_processes
,
double
cutoff_prob
=
1
.
0
,
double
cutoff_prob
=
1
.
0
,
int
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
NULL
Scorer
*
ext_scorer
=
NULL
);
);
...
...
deploy/scorer.h
浏览文件 @
a0c89ae7
...
@@ -50,6 +50,7 @@ public:
...
@@ -50,6 +50,7 @@ public:
void
fill_dictionary
(
bool
add_space
);
void
fill_dictionary
(
bool
add_space
);
// set char map
// set char map
void
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
);
void
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
// expose to decoder
// expose to decoder
double
alpha
;
double
alpha
;
double
beta
;
double
beta
;
...
@@ -60,7 +61,6 @@ protected:
...
@@ -60,7 +61,6 @@ protected:
void
load_LM
(
const
char
*
filename
);
void
load_LM
(
const
char
*
filename
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
private:
private:
void
*
_language_model
;
void
*
_language_model
;
...
...
deploy/swig_decoders_wrapper.py
浏览文件 @
a0c89ae7
...
@@ -43,6 +43,7 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -43,6 +43,7 @@ def ctc_beam_search_decoder(probs_seq,
vocabulary
,
vocabulary
,
blank_id
,
blank_id
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
ext_scoring_func
=
None
):
"""Wrapper for the CTC Beam Search Decoder.
"""Wrapper for the CTC Beam Search Decoder.
...
@@ -59,6 +60,10 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -59,6 +60,10 @@ def ctc_beam_search_decoder(probs_seq,
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
default 1.0, no pruning.
:type cutoff_prob: float
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param ext_scoring_func: External scoring function for
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
partially decoded sentence, e.g. word count
or language model.
or language model.
...
@@ -67,9 +72,9 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -67,9 +72,9 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
results, in descending order of the probability.
:rtype: list
:rtype: list
"""
"""
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
beam_size
,
return
swig_decoders
.
ctc_beam_search_decoder
(
vocabulary
,
blank_id
,
probs_seq
.
tolist
(),
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
cutoff_prob
,
ext_scoring_func
)
cutoff_top_n
,
ext_scoring_func
)
def
ctc_beam_search_decoder_batch
(
probs_split
,
def
ctc_beam_search_decoder_batch
(
probs_split
,
...
@@ -78,6 +83,7 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -78,6 +83,7 @@ def ctc_beam_search_decoder_batch(probs_split,
blank_id
,
blank_id
,
num_processes
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
ext_scoring_func
=
None
):
"""Wrapper for the batched CTC beam search decoder.
"""Wrapper for the batched CTC beam search decoder.
...
@@ -92,11 +98,15 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -92,11 +98,15 @@ def ctc_beam_search_decoder_batch(probs_split,
:type blank_id: int
:type blank_id: int
:param num_processes: Number of parallel processes.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in
vocabulary
pruning,
default 1.0, no pruning.
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param num_processes: Number of parallel processes.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type num_processes: int
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
partially decoded sentence, e.g. word count
or language model.
or language model.
...
@@ -109,4 +119,4 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -109,4 +119,4 @@ def ctc_beam_search_decoder_batch(probs_split,
return
swig_decoders
.
ctc_beam_search_decoder_batch
(
return
swig_decoders
.
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
blank_id
,
num_processes
,
probs_split
,
beam_size
,
vocabulary
,
blank_id
,
num_processes
,
cutoff_prob
,
ext_scoring_func
)
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录