Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
94a68116
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
94a68116
编写于
7月 06, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
code cleanup for the deployment decoder
上级
2e76f82c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
105 addition
and
48 deletion
+105
-48
deploy/ctc_beam_search_decoder.cpp
deploy/ctc_beam_search_decoder.cpp
+45
-27
deploy/ctc_beam_search_decoder.h
deploy/ctc_beam_search_decoder.h
+25
-9
deploy/decoder_setup.py
deploy/decoder_setup.py
+3
-4
deploy/scorer.cpp
deploy/scorer.cpp
+12
-2
deploy/scorer.h
deploy/scorer.h
+17
-3
deploy/scorer_setup.py
deploy/scorer_setup.py
+3
-3
未找到文件。
deploy/ctc_beam_search_decoder.cpp
浏览文件 @
94a68116
...
@@ -6,35 +6,47 @@
...
@@ -6,35 +6,47 @@
#include "ctc_beam_search_decoder.h"
#include "ctc_beam_search_decoder.h"
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
a
,
const
std
::
pair
<
T1
,
T2
>
b
)
{
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
a
,
const
std
::
pair
<
T1
,
T2
>
b
)
{
return
a
.
first
>
b
.
first
;
return
a
.
first
>
b
.
first
;
}
}
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
a
,
const
std
::
pair
<
T1
,
T2
>
b
)
{
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
a
,
const
std
::
pair
<
T1
,
T2
>
b
)
{
return
a
.
second
>
b
.
second
;
return
a
.
second
>
b
.
second
;
}
}
/* CTC beam search decoder in C++, the interface is consistent with the original
decoder in Python version.
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
ctc_beam_search_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
ctc_beam_search_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
int
beam_size
,
int
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
blank_id
,
double
cutoff_prob
,
double
cutoff_prob
,
Scorer
*
ext_scorer
,
Scorer
*
ext_scorer
,
bool
nproc
bool
nproc
)
{
)
// dimension check
{
int
num_time_steps
=
probs_seq
.
size
();
int
num_time_steps
=
probs_seq
.
size
();
for
(
int
i
=
0
;
i
<
num_time_steps
;
i
++
)
{
if
(
probs_seq
[
i
].
size
()
!=
vocabulary
.
size
()
+
1
)
{
std
::
cout
<<
"The shape of probs_seq does not match"
<<
" with the shape of the vocabulary!"
<<
std
::
endl
;
exit
(
1
);
}
}
// blank_id check
if
(
blank_id
>
vocabulary
.
size
())
{
std
::
cout
<<
"Invalid blank_id!"
<<
std
::
endl
;
exit
(
1
);
}
// assign space ID
// assign space ID
std
::
vector
<
std
::
string
>::
iterator
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
std
::
vector
<
std
::
string
>::
iterator
it
=
std
::
find
(
vocabulary
.
begin
(),
int
space_id
=
it
-
vocabulary
.
begin
();
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
if
(
space_id
>=
vocabulary
.
size
())
{
if
(
space_id
>=
vocabulary
.
size
())
{
std
::
cout
<<
"The character space is not in the vocabulary!"
;
std
::
cout
<<
"The character space is not in the vocabulary!"
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
}
}
...
@@ -60,7 +72,8 @@ std::vector<std::pair<double, std::string> >
...
@@ -60,7 +72,8 @@ std::vector<std::pair<double, std::string> >
}
}
// pruning of vacobulary
// pruning of vacobulary
if
(
cutoff_prob
<
1.0
)
{
if
(
cutoff_prob
<
1.0
)
{
std
::
sort
(
prob_idx
.
begin
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
std
::
sort
(
prob_idx
.
begin
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
float
cum_prob
=
0.0
;
float
cum_prob
=
0.0
;
int
cutoff_len
=
0
;
int
cutoff_len
=
0
;
for
(
int
i
=
0
;
i
<
prob_idx
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
prob_idx
.
size
();
i
++
)
{
...
@@ -68,7 +81,8 @@ std::vector<std::pair<double, std::string> >
...
@@ -68,7 +81,8 @@ std::vector<std::pair<double, std::string> >
cutoff_len
+=
1
;
cutoff_len
+=
1
;
if
(
cum_prob
>=
cutoff_prob
)
break
;
if
(
cum_prob
>=
cutoff_prob
)
break
;
}
}
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>
>
(
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>
>
(
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
}
}
// extend prefix
// extend prefix
for
(
std
::
map
<
std
::
string
,
double
>::
iterator
it
=
prefix_set_prev
.
begin
();
for
(
std
::
map
<
std
::
string
,
double
>::
iterator
it
=
prefix_set_prev
.
begin
();
...
@@ -82,11 +96,11 @@ std::vector<std::pair<double, std::string> >
...
@@ -82,11 +96,11 @@ std::vector<std::pair<double, std::string> >
int
c
=
prob_idx
[
index
].
first
;
int
c
=
prob_idx
[
index
].
first
;
double
prob_c
=
prob_idx
[
index
].
second
;
double
prob_c
=
prob_idx
[
index
].
second
;
if
(
c
==
blank_id
)
{
if
(
c
==
blank_id
)
{
probs_b_cur
[
l
]
+=
prob_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
]);
probs_b_cur
[
l
]
+=
prob_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
]);
}
else
{
}
else
{
std
::
string
last_char
=
l
.
substr
(
l
.
size
()
-
1
,
1
);
std
::
string
last_char
=
l
.
substr
(
l
.
size
()
-
1
,
1
);
std
::
string
new_char
=
vocabulary
[
c
];
std
::
string
new_char
=
vocabulary
[
c
];
std
::
string
l_plus
=
l
+
new_char
;
std
::
string
l_plus
=
l
+
new_char
;
if
(
prefix_set_next
.
find
(
l_plus
)
==
prefix_set_next
.
end
())
{
if
(
prefix_set_next
.
find
(
l_plus
)
==
prefix_set_next
.
end
())
{
probs_b_cur
[
l_plus
]
=
probs_nb_cur
[
l_plus
]
=
0.0
;
probs_b_cur
[
l_plus
]
=
probs_nb_cur
[
l_plus
]
=
0.0
;
...
@@ -105,19 +119,22 @@ std::vector<std::pair<double, std::string> >
...
@@ -105,19 +119,22 @@ std::vector<std::pair<double, std::string> >
probs_nb_cur
[
l_plus
]
+=
prob_c
*
(
probs_nb_cur
[
l_plus
]
+=
prob_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
]);
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
]);
}
}
prefix_set_next
[
l_plus
]
=
probs_nb_cur
[
l_plus
]
+
probs_b_cur
[
l_plus
];
prefix_set_next
[
l_plus
]
=
probs_nb_cur
[
l_plus
]
+
probs_b_cur
[
l_plus
];
}
}
}
}
prefix_set_next
[
l
]
=
probs_b_cur
[
l
]
+
probs_nb_cur
[
l
];
prefix_set_next
[
l
]
=
probs_b_cur
[
l
]
+
probs_nb_cur
[
l
];
}
}
probs_b_prev
=
probs_b_cur
;
probs_b_prev
=
probs_b_cur
;
probs_nb_prev
=
probs_nb_cur
;
probs_nb_prev
=
probs_nb_cur
;
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>
>
std
::
vector
<
std
::
pair
<
std
::
string
,
double
>
>
prefix_vec_next
(
prefix_set_next
.
begin
(),
prefix_set_next
.
end
());
prefix_vec_next
(
prefix_set_next
.
begin
(),
std
::
sort
(
prefix_vec_next
.
begin
(),
prefix_vec_next
.
end
(),
pair_comp_second_rev
<
std
::
string
,
double
>
);
prefix_set_next
.
end
());
int
k
=
beam_size
<
prefix_vec_next
.
size
()
?
beam_size
:
prefix_vec_next
.
size
();
std
::
sort
(
prefix_vec_next
.
begin
(),
prefix_vec_next
.
end
(),
pair_comp_second_rev
<
std
::
string
,
double
>
);
int
k
=
beam_size
<
prefix_vec_next
.
size
()
?
beam_size
:
prefix_vec_next
.
size
();
prefix_set_prev
=
std
::
map
<
std
::
string
,
double
>
prefix_set_prev
=
std
::
map
<
std
::
string
,
double
>
(
prefix_vec_next
.
begin
(),
prefix_vec_next
.
begin
()
+
k
);
(
prefix_vec_next
.
begin
(),
prefix_vec_next
.
begin
()
+
k
);
}
}
...
@@ -138,6 +155,7 @@ std::vector<std::pair<double, std::string> >
...
@@ -138,6 +155,7 @@ std::vector<std::pair<double, std::string> >
}
}
}
}
// sort the result and return
// sort the result and return
std
::
sort
(
beam_result
.
begin
(),
beam_result
.
end
(),
pair_comp_first_rev
<
double
,
std
::
string
>
);
std
::
sort
(
beam_result
.
begin
(),
beam_result
.
end
(),
pair_comp_first_rev
<
double
,
std
::
string
>
);
return
beam_result
;
return
beam_result
;
}
}
deploy/ctc_beam_search_decoder.h
浏览文件 @
94a68116
...
@@ -6,14 +6,30 @@
...
@@ -6,14 +6,30 @@
#include <utility>
#include <utility>
#include "scorer.h"
#include "scorer.h"
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
/* CTC Beam Search Decoder, the interface is consistent with the
ctc_beam_search_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
* original decoder in Python version.
int
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
* Parameters:
int
blank_id
=
0
,
* probs_seq: 2-D vector that each element is a vector of probabilities
double
cutoff_prob
=
1
.
0
,
* over vocabulary of one time step.
Scorer
*
ext_scorer
=
NULL
,
* beam_size: The width of beam search.
bool
nproc
=
false
* vocabulary: A vector of vocabulary.
);
* blank_id: ID of blank.
* cutoff_prob: Cutoff probability of pruning
* ext_scorer: External scorer to evaluate a prefix.
* nproc: Whether this function used in multiprocessing.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
ctc_beam_search_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
int
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
double
cutoff_prob
=
1
.
0
,
Scorer
*
ext_scorer
=
NULL
,
bool
nproc
=
false
);
#endif // CTC_BEAM_SEARCH_DECODER_H_
#endif // CTC_BEAM_SEARCH_DECODER_H_
deploy/decoder_setup.py
浏览文件 @
94a68116
...
@@ -10,8 +10,8 @@ def compile_test(header, library):
...
@@ -10,8 +10,8 @@ def compile_test(header, library):
return
os
.
system
(
command
)
==
0
return
os
.
system
(
command
)
==
0
FILES
=
glob
.
glob
(
'
util/*.cc'
)
+
glob
.
glob
(
'
lm/*.cc'
)
+
glob
.
glob
(
FILES
=
glob
.
glob
(
'
kenlm/util/*.cc'
)
+
glob
.
glob
(
'kenlm/
lm/*.cc'
)
+
glob
.
glob
(
'util/double-conversion/*.cc'
)
'
kenlm/
util/double-conversion/*.cc'
)
FILES
=
[
FILES
=
[
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
))
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
))
]
]
...
@@ -44,7 +44,7 @@ ctc_beam_search_decoder_module = [
...
@@ -44,7 +44,7 @@ ctc_beam_search_decoder_module = [
'ctc_beam_search_decoder.cpp'
'ctc_beam_search_decoder.cpp'
],
],
language
=
'C++'
,
language
=
'C++'
,
include_dirs
=
[
'.'
],
include_dirs
=
[
'.'
,
'./kenlm'
],
libraries
=
LIBS
,
libraries
=
LIBS
,
extra_compile_args
=
ARGS
)
extra_compile_args
=
ARGS
)
]
]
...
@@ -52,7 +52,6 @@ ctc_beam_search_decoder_module = [
...
@@ -52,7 +52,6 @@ ctc_beam_search_decoder_module = [
setup
(
setup
(
name
=
'swig_ctc_beam_search_decoder'
,
name
=
'swig_ctc_beam_search_decoder'
,
version
=
'0.1'
,
version
=
'0.1'
,
author
=
'Yibing Liu'
,
description
=
"""CTC beam search decoder"""
,
description
=
"""CTC beam search decoder"""
,
ext_modules
=
ctc_beam_search_decoder_module
,
ext_modules
=
ctc_beam_search_decoder_module
,
py_modules
=
[
'swig_ctc_beam_search_decoder'
],
)
py_modules
=
[
'swig_ctc_beam_search_decoder'
],
)
deploy/scorer.cpp
浏览文件 @
94a68116
#include <iostream>
#include <iostream>
#include "scorer.h"
#include "scorer.h"
#include "lm/model.hh"
#include "lm/model.hh"
#include "util/tokenize_piece.hh"
#include "util/tokenize_piece.hh"
...
@@ -17,6 +16,13 @@ Scorer::~Scorer(){
...
@@ -17,6 +16,13 @@ Scorer::~Scorer(){
delete
(
Model
*
)
this
->
_language_model
;
delete
(
Model
*
)
this
->
_language_model
;
}
}
/* Strip a input sentence
* Parameters:
* str: A reference to the objective string
* ch: The character to prune
* Return:
* void
*/
inline
void
strip
(
std
::
string
&
str
,
char
ch
=
' '
)
{
inline
void
strip
(
std
::
string
&
str
,
char
ch
=
' '
)
{
if
(
str
.
size
()
==
0
)
return
;
if
(
str
.
size
()
==
0
)
return
;
int
start
=
0
;
int
start
=
0
;
...
@@ -69,10 +75,14 @@ double Scorer::language_model_score(std::string sentence) {
...
@@ -69,10 +75,14 @@ double Scorer::language_model_score(std::string sentence) {
}
}
//log10 prob
//log10 prob
double
log_prob
=
ret
.
prob
;
double
log_prob
=
ret
.
prob
;
return
log_prob
;
return
log_prob
;
}
}
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
this
->
_alpha
=
alpha
;
this
->
_beta
=
beta
;
}
double
Scorer
::
get_score
(
std
::
string
sentence
)
{
double
Scorer
::
get_score
(
std
::
string
sentence
)
{
double
lm_score
=
language_model_score
(
sentence
);
double
lm_score
=
language_model_score
(
sentence
);
int
word_cnt
=
word_count
(
sentence
);
int
word_cnt
=
word_count
(
sentence
);
...
...
deploy/scorer.h
浏览文件 @
94a68116
...
@@ -3,20 +3,34 @@
...
@@ -3,20 +3,34 @@
#include <string>
#include <string>
/* External scorer to evaluate a prefix or a complete sentence
* when a new word appended during decoding, consisting of word
* count and language model scoring.
* Example:
* Scorer ext_scorer(alpha, beta, "path_to_language_model.klm");
* double score = ext_scorer.get_score("sentence_to_score");
*/
class
Scorer
{
class
Scorer
{
private:
private:
float
_alpha
;
float
_alpha
;
float
_beta
;
float
_beta
;
void
*
_language_model
;
void
*
_language_model
;
// word insertion term
int
word_count
(
std
::
string
);
// n-gram language model scoring
double
language_model_score
(
std
::
string
);
public:
public:
Scorer
(){}
Scorer
(){}
Scorer
(
float
alpha
,
float
beta
,
std
::
string
lm_model_path
);
Scorer
(
float
alpha
,
float
beta
,
std
::
string
lm_model_path
);
~
Scorer
();
~
Scorer
();
int
word_count
(
std
::
string
);
double
language_model_score
(
std
::
string
);
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
// get the final score
double
get_score
(
std
::
string
);
double
get_score
(
std
::
string
);
};
};
#endif
#endif
//SCORER_H_
deploy/scorer_setup.py
浏览文件 @
94a68116
...
@@ -10,8 +10,8 @@ def compile_test(header, library):
...
@@ -10,8 +10,8 @@ def compile_test(header, library):
return
os
.
system
(
command
)
==
0
return
os
.
system
(
command
)
==
0
FILES
=
glob
.
glob
(
'
util/*.cc'
)
+
glob
.
glob
(
'
lm/*.cc'
)
+
glob
.
glob
(
FILES
=
glob
.
glob
(
'
kenlm/util/*.cc'
)
+
glob
.
glob
(
'kenlm/
lm/*.cc'
)
+
glob
.
glob
(
'util/double-conversion/*.cc'
)
'
kenlm/
util/double-conversion/*.cc'
)
FILES
=
[
FILES
=
[
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
))
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
))
]
]
...
@@ -41,7 +41,7 @@ ext_modules = [
...
@@ -41,7 +41,7 @@ ext_modules = [
name
=
'_swig_scorer'
,
name
=
'_swig_scorer'
,
sources
=
FILES
+
[
'scorer_wrap.cxx'
,
'scorer.cpp'
],
sources
=
FILES
+
[
'scorer_wrap.cxx'
,
'scorer.cpp'
],
language
=
'C++'
,
language
=
'C++'
,
include_dirs
=
[
'.'
],
include_dirs
=
[
'.'
,
'./kenlm'
],
libraries
=
LIBS
,
libraries
=
LIBS
,
extra_compile_args
=
ARGS
)
extra_compile_args
=
ARGS
)
]
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录