Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8ff6221d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8ff6221d
编写于
8月 29, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enable finite-state transducer in beam search decoding
上级
b5602054
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
189 addition
and
22 deletion
+189
-22
deploy.py
deploy.py
+4
-4
deploy/ctc_decoders.cpp
deploy/ctc_decoders.cpp
+13
-2
deploy/decoder_utils.cpp
deploy/decoder_utils.cpp
+27
-3
deploy/decoder_utils.h
deploy/decoder_utils.h
+3
-1
deploy/scorer.cpp
deploy/scorer.cpp
+132
-11
deploy/scorer.h
deploy/scorer.h
+10
-1
未找到文件。
deploy.py
浏览文件 @
8ff6221d
...
@@ -18,7 +18,7 @@ import time
...
@@ -18,7 +18,7 @@ import time
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_samples"
,
"--num_samples"
,
default
=
5
,
default
=
4
,
type
=
int
,
type
=
int
,
help
=
"Number of samples for inference. (default: %(default)s)"
)
help
=
"Number of samples for inference. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -89,7 +89,8 @@ parser.add_argument(
...
@@ -89,7 +89,8 @@ parser.add_argument(
help
=
"Number of output per sample in beam search. (default: %(default)d)"
)
help
=
"Number of output per sample in beam search. (default: %(default)d)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--language_model_path"
,
"--language_model_path"
,
default
=
"lm/data/common_crawl_00.prune01111.trie.klm"
,
default
=
"/home/work/liuyibing/lm_bak/common_crawl_00.prune01111.trie.klm"
,
#default="ptb_all.arpa",
type
=
str
,
type
=
str
,
help
=
"Path for language model. (default: %(default)s)"
)
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -183,8 +184,7 @@ def infer():
...
@@ -183,8 +184,7 @@ def infer():
vocabulary
=
data_generator
.
vocab_list
,
vocabulary
=
data_generator
.
vocab_list
,
blank_id
=
len
(
data_generator
.
vocab_list
),
blank_id
=
len
(
data_generator
.
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_prob
=
args
.
cutoff_prob
,
# ext_scoring_func=ext_scorer,
ext_scoring_func
=
ext_scorer
,
)
)
batch_beam_results
+=
[
beam_result
]
batch_beam_results
+=
[
beam_result
]
else
:
else
:
batch_beam_results
=
ctc_beam_search_decoder_batch
(
batch_beam_results
=
ctc_beam_search_decoder_batch
(
...
...
deploy/ctc_decoders.cpp
浏览文件 @
8ff6221d
...
@@ -103,10 +103,13 @@ std::vector<std::pair<double, std::string> >
...
@@ -103,10 +103,13 @@ std::vector<std::pair<double, std::string> >
prefixes
.
push_back
(
&
root
);
prefixes
.
push_back
(
&
root
);
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
if
(
ext_scorer
->
dictionary
==
nullptr
)
{
if
(
ext_scorer
->
_
dictionary
==
nullptr
)
{
// TODO: init dictionary
// TODO: init dictionary
ext_scorer
->
set_char_map
(
vocabulary
);
// add_space should be true?
ext_scorer
->
fill_dictionary
(
true
);
}
}
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
ext_scorer
->
dictionary
);
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
ext_scorer
->
_
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
.
set_dictionary
(
dict_ptr
);
root
.
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
...
@@ -288,6 +291,14 @@ std::vector<std::vector<std::pair<double, std::string>>>
...
@@ -288,6 +291,14 @@ std::vector<std::vector<std::pair<double, std::string>>>
ThreadPool
pool
(
num_processes
);
ThreadPool
pool
(
num_processes
);
// number of samples
// number of samples
int
batch_size
=
probs_split
.
size
();
int
batch_size
=
probs_split
.
size
();
// dictionary init
if
(
ext_scorer
!=
nullptr
)
{
if
(
ext_scorer
->
_dictionary
==
nullptr
)
{
// TODO: init dictionary
ext_scorer
->
set_char_map
(
vocabulary
);
ext_scorer
->
fill_dictionary
(
true
);
}
}
// enqueue the tasks of decoding
// enqueue the tasks of decoding
std
::
vector
<
std
::
future
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>>
res
;
std
::
vector
<
std
::
future
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>>
res
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
...
...
deploy/decoder_utils.cpp
浏览文件 @
8ff6221d
...
@@ -11,6 +11,32 @@ size_t get_utf8_str_len(const std::string& str) {
...
@@ -11,6 +11,32 @@ size_t get_utf8_str_len(const std::string& str) {
return
str_len
;
return
str_len
;
}
}
//------------------------------------------------------
//Splits string into vector of strings representing
//UTF-8 characters (not same as chars)
//------------------------------------------------------
std
::
vector
<
std
::
string
>
UTF8_split
(
const
std
::
string
&
str
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
string
out_str
;
for
(
char
c
:
str
)
{
if
((
c
&
0xc0
)
!=
0x80
)
//new UTF-8 character
{
if
(
!
out_str
.
empty
())
{
result
.
push_back
(
out_str
);
out_str
.
clear
();
}
}
out_str
.
append
(
1
,
c
);
}
result
.
push_back
(
out_str
);
return
result
;
}
//-------------------------------------------------------
//-------------------------------------------------------
// Overriding less than operator for sorting
// Overriding less than operator for sorting
//-------------------------------------------------------
//-------------------------------------------------------
...
@@ -49,12 +75,11 @@ void add_word_to_fst(const std::vector<int>& word,
...
@@ -49,12 +75,11 @@ void add_word_to_fst(const std::vector<int>& word,
// ---------------------------------------------------------
// ---------------------------------------------------------
// Adds a word to the dictionary FST based on char_map
// Adds a word to the dictionary FST based on char_map
// ---------------------------------------------------------
// ---------------------------------------------------------
bool
add
WordToD
ictionary
(
const
std
::
string
&
word
,
bool
add
_word_to_d
ictionary
(
const
std
::
string
&
word
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
bool
add_space
,
bool
add_space
,
int
SPACE
,
int
SPACE
,
fst
::
StdVectorFst
*
dictionary
)
{
fst
::
StdVectorFst
*
dictionary
)
{
/*
auto
characters
=
UTF8_split
(
word
);
auto
characters
=
UTF8_split
(
word
);
std
::
vector
<
int
>
int_word
;
std
::
vector
<
int
>
int_word
;
...
@@ -77,6 +102,5 @@ bool addWordToDictionary(const std::string& word,
...
@@ -77,6 +102,5 @@ bool addWordToDictionary(const std::string& word,
}
}
add_word_to_fst
(
int_word
,
dictionary
);
add_word_to_fst
(
int_word
,
dictionary
);
*/
return
true
;
return
true
;
}
// -------------- End of addWordToDictionary ------------
}
// -------------- End of addWordToDictionary ------------
deploy/decoder_utils.h
浏览文件 @
8ff6221d
...
@@ -35,10 +35,12 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y);
...
@@ -35,10 +35,12 @@ bool prefix_compare(const PathTrie* x, const PathTrie* y);
// See: http://stackoverflow.com/a/4063229
// See: http://stackoverflow.com/a/4063229
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
std
::
vector
<
std
::
string
>
UTF8_split
(
const
std
::
string
&
str
);
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
fst
::
StdVectorFst
*
dictionary
);
fst
::
StdVectorFst
*
dictionary
);
bool
add
WordToD
ictionary
(
const
std
::
string
&
word
,
bool
add
_word_to_d
ictionary
(
const
std
::
string
&
word
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
bool
add_space
,
bool
add_space
,
int
SPACE
,
int
SPACE
,
...
...
deploy/scorer.cpp
浏览文件 @
8ff6221d
...
@@ -15,7 +15,9 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
...
@@ -15,7 +15,9 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
this
->
beta
=
beta
;
this
->
beta
=
beta
;
_is_character_based
=
true
;
_is_character_based
=
true
;
_language_model
=
nullptr
;
_language_model
=
nullptr
;
_dictionary
=
nullptr
;
_max_order
=
0
;
_max_order
=
0
;
_SPACE
=
-
1
;
// load language model
// load language model
load_LM
(
lm_path
.
c_str
());
load_LM
(
lm_path
.
c_str
());
}
}
...
@@ -23,6 +25,8 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
...
@@ -23,6 +25,8 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path) {
Scorer
::~
Scorer
()
{
Scorer
::~
Scorer
()
{
if
(
_language_model
!=
nullptr
)
if
(
_language_model
!=
nullptr
)
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
if
(
_dictionary
!=
nullptr
)
delete
static_cast
<
fst
::
StdVectorFst
*>
(
_dictionary
);
}
}
void
Scorer
::
load_LM
(
const
char
*
filename
)
{
void
Scorer
::
load_LM
(
const
char
*
filename
)
{
...
@@ -176,11 +180,83 @@ double Scorer::get_score(std::string sentence, bool log) {
...
@@ -176,11 +180,83 @@ double Scorer::get_score(std::string sentence, bool log) {
return
final_score
;
return
final_score
;
}
}
//--------------------------------------------------
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
// Turn indices back into strings of chars
std
::
string
word
;
//--------------------------------------------------
for
(
auto
ind
:
input
)
{
word
+=
_char_list
[
ind
];
}
return
word
;
}
std
::
vector
<
std
::
string
>
Scorer
::
split_labels
(
const
std
::
vector
<
int
>
&
labels
)
{
if
(
labels
.
empty
())
return
{};
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
if
(
_is_character_based
)
{
words
=
UTF8_split
(
s
);
}
else
{
words
=
split_str
(
s
,
" "
);
}
return
words
;
}
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
std
::
vector
<
std
::
string
>
Scorer
::
split_str
(
const
std
::
string
&
s
,
const
std
::
string
&
delim
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
while
(
true
)
{
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
if
(
end
==
std
::
string
::
npos
)
{
if
(
start
<
s
.
size
())
{
result
.
push_back
(
s
.
substr
(
start
));
}
break
;
}
if
(
end
>
start
)
{
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
}
start
=
end
+
delim_len
;
}
return
result
;
}
//---------------------------------------------------
// Add index to char list for searching language model
//---------------------------------------------------
void
Scorer
::
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
)
{
_char_list
=
char_list
;
std
::
string
_SPACE_STR
=
" "
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
// if (_char_list[i] == _BLANK_STR) {
// _BLANK = i;
// } else
if
(
_char_list
[
i
]
==
_SPACE_STR
)
{
_SPACE
=
i
;
}
}
_char_map
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
if
(
i
==
(
unsigned
int
)
_SPACE
){
_char_map
[
' '
]
=
i
;
}
else
if
(
_char_list
[
i
].
size
()
==
1
){
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
}
}
}
//------------- End of set_char_map ----------------
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
/*
std
::
vector
<
std
::
string
>
ngram
;
std
::
vector
<
std
::
string
>
ngram
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
PathTrie
*
new_node
=
nullptr
;
...
@@ -189,10 +265,10 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
...
@@ -189,10 +265,10 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std
::
vector
<
int
>
prefix_vec
;
std
::
vector
<
int
>
prefix_vec
;
if
(
_is_character_based
)
{
if
(
_is_character_based
)
{
new_node = current_node->get_path_vec(prefix_vec,
' '
, 1);
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE
,
1
);
current_node
=
new_node
;
current_node
=
new_node
;
}
else
{
}
else
{
new_node = current_node->get
PathVec(prefix_vec, ' '
);
new_node
=
current_node
->
get
_path_vec
(
prefix_vec
,
_SPACE
);
current_node
=
new_node
->
_parent
;
// Skipping spaces
current_node
=
new_node
->
_parent
;
// Skipping spaces
}
}
...
@@ -202,15 +278,60 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
...
@@ -202,15 +278,60 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if
(
new_node
->
_character
==
-
1
)
{
if
(
new_node
->
_character
==
-
1
)
{
// No more spaces, but still need order
// No more spaces, but still need order
for (int i = 0; i < max_order - order - 1; i++) {
for
(
int
i
=
0
;
i
<
_
max_order
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
"<s>"
);
ngram
.
push_back
(
"<s>"
);
}
}
break
;
break
;
}
}
}
}
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
*/
std
::
vector
<
std
::
string
>
ngram
;
ngram
.
push_back
(
"this"
);
return
ngram
;
return
ngram
;
}
//---------------- End makeNgrams ------------------
}
//---------------------------------------------------------
// Helper function to populate Trie with a vocab using the
// char_list for maping from string to int
//---------------------------------------------------------
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
fst
::
StdVectorFst
dictionary
;
// First reverse char_list so ints can be accessed by chars
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
int
vocab_size
=
0
;
for
(
const
auto
&
word
:
_vocabulary
)
{
bool
added
=
add_word_to_dictionary
(
word
,
char_map
,
add_space
,
_SPACE
,
&
dictionary
);
vocab_size
+=
added
?
1
:
0
;
}
std
::
cerr
<<
"Vocab Size "
<<
vocab_size
<<
std
::
endl
;
// Simplify FST
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
// This makes the FST deterministic, meaning for any string input there's
// only one possible state the FST could be in. It is assumed our
// dictionary is deterministic when using it.
// (lest we'd have to check for multiple transitions at each state)
fst
::
Determinize
(
dictionary
,
new_dict
);
// Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary
fst
::
Minimize
(
new_dict
);
_dictionary
=
new_dict
;
}
deploy/scorer.h
浏览文件 @
8ff6221d
...
@@ -53,15 +53,23 @@ public:
...
@@ -53,15 +53,23 @@ public:
double
get_score
(
std
::
string
,
bool
log
=
false
);
double
get_score
(
std
::
string
,
bool
log
=
false
);
// make ngram
// make ngram
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
// fill dictionary for fst
void
fill_dictionary
(
bool
add_space
);
// set char map
void
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
);
// expose to decoder
// expose to decoder
double
alpha
;
double
alpha
;
double
beta
;
double
beta
;
// fst dictionary
// fst dictionary
void
*
dictionary
;
void
*
_
dictionary
;
protected:
protected:
void
load_LM
(
const
char
*
filename
);
void
load_LM
(
const
char
*
filename
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
const
std
::
string
&
delim
);
private:
private:
void
_init_char_list
();
void
_init_char_list
();
...
@@ -71,6 +79,7 @@ private:
...
@@ -71,6 +79,7 @@ private:
bool
_is_character_based
;
bool
_is_character_based
;
size_t
_max_order
;
size_t
_max_order
;
unsigned
int
_SPACE
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录