Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
8c5576d9
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c5576d9
编写于
9月 17, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format varabiables' name & add more comments
上级
e6740af4
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
129 addition
and
126 deletion
+129
-126
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
+7
-8
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
+4
-5
deep_speech_2/decoders/swig/path_trie.cpp
deep_speech_2/decoders/swig/path_trie.cpp
+38
-38
deep_speech_2/decoders/swig/path_trie.h
deep_speech_2/decoders/swig/path_trie.h
+8
-8
deep_speech_2/decoders/swig/scorer.cpp
deep_speech_2/decoders/swig/scorer.cpp
+41
-41
deep_speech_2/decoders/swig/scorer.h
deep_speech_2/decoders/swig/scorer.h
+22
-17
deep_speech_2/decoders/swig_wrapper.py
deep_speech_2/decoders/swig_wrapper.py
+9
-9
未找到文件。
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
浏览文件 @
8c5576d9
...
@@ -18,8 +18,8 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
...
@@ -18,8 +18,8 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
size_t
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
double
cutoff_prob
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
Scorer
*
ext_scorer
)
{
...
@@ -36,8 +36,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
...
@@ -36,8 +36,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t
blank_id
=
vocabulary
.
size
();
size_t
blank_id
=
vocabulary
.
size
();
// assign space id
// assign space id
std
::
vector
<
std
::
string
>::
iterator
it
=
auto
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
int
space_id
=
it
-
vocabulary
.
begin
();
// if no space in vocabulary
// if no space in vocabulary
if
((
size_t
)
space_id
>=
vocabulary
.
size
())
{
if
((
size_t
)
space_id
>=
vocabulary
.
size
())
{
...
@@ -173,11 +172,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
...
@@ -173,11 +172,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
const
size_t
beam_size
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
const
size_t
num_processes
,
size_t
beam_size
,
const
double
cutoff_prob
,
size_t
num_processes
,
const
size_t
cutoff_top_n
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
Scorer
*
ext_scorer
)
{
VALID_CHECK_GT
(
num_processes
,
0
,
"num_processes must be nonnegative!"
);
VALID_CHECK_GT
(
num_processes
,
0
,
"num_processes must be nonnegative!"
);
// thread pool
// thread pool
...
@@ -190,8 +189,8 @@ ctc_beam_search_decoder_batch(
...
@@ -190,8 +189,8 @@ ctc_beam_search_decoder_batch(
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
probs_split
[
i
],
beam_size
,
vocabulary
,
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_prob
,
cutoff_top_n
,
cutoff_top_n
,
ext_scorer
));
ext_scorer
));
...
...
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
浏览文件 @
8c5576d9
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
* Parameters:
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* over vocabulary of one time step.
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* ext_scorer: External scorer to evaluate a prefix, which consists of
...
@@ -25,8 +25,8 @@
...
@@ -25,8 +25,8 @@
*/
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
size_t
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
double
cutoff_prob
=
1
.
0
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
size_t
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
nullptr
);
Scorer
*
ext_scorer
=
nullptr
);
...
@@ -36,9 +36,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
...
@@ -36,9 +36,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* Parameters:
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* by ctc_beam_search_decoder().
* .
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* cutoff_top_n: Cutoff number for pruning.
...
@@ -52,8 +51,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
...
@@ -52,8 +51,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
size_t
beam_size
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
size_t
num_processes
,
size_t
num_processes
,
double
cutoff_prob
=
1
.
0
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
size_t
cutoff_top_n
=
40
,
...
...
deep_speech_2/decoders/swig/path_trie.cpp
浏览文件 @
8c5576d9
...
@@ -15,32 +15,32 @@ PathTrie::PathTrie() {
...
@@ -15,32 +15,32 @@ PathTrie::PathTrie() {
log_prob_nb_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
_ROOT
=
-
1
;
ROOT_
=
-
1
;
character
=
_ROOT
;
character
=
ROOT_
;
_exists
=
true
;
exists_
=
true
;
parent
=
nullptr
;
parent
=
nullptr
;
_dictionary
=
nullptr
;
dictionary_
=
nullptr
;
_dictionary_state
=
0
;
dictionary_state_
=
0
;
_has_dictionary
=
false
;
has_dictionary_
=
false
;
_matcher
=
nullptr
;
matcher_
=
nullptr
;
}
}
PathTrie
::~
PathTrie
()
{
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
children_
)
{
delete
child
.
second
;
delete
child
.
second
;
}
}
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
_children
.
begin
();
auto
child
=
children_
.
begin
();
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
for
(
child
=
children_
.
begin
();
child
!=
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
if
(
child
->
first
==
new_char
)
{
break
;
break
;
}
}
}
}
if
(
child
!=
_children
.
end
())
{
if
(
child
!=
children_
.
end
())
{
if
(
!
child
->
second
->
_exists
)
{
if
(
!
child
->
second
->
exists_
)
{
child
->
second
->
_exists
=
true
;
child
->
second
->
exists_
=
true
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
...
@@ -48,47 +48,47 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
...
@@ -48,47 +48,47 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
}
}
return
(
child
->
second
);
return
(
child
->
second
);
}
else
{
}
else
{
if
(
_has_dictionary
)
{
if
(
has_dictionary_
)
{
_matcher
->
SetState
(
_dictionary_state
);
matcher_
->
SetState
(
dictionary_state_
);
bool
found
=
_matcher
->
Find
(
new_char
);
bool
found
=
matcher_
->
Find
(
new_char
);
if
(
!
found
)
{
if
(
!
found
)
{
// Adding this character causes word outside dictionary
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
auto
final_weight
=
dictionary_
->
Final
(
dictionary_state_
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
dictionary_state_
=
dictionary_
->
Start
();
}
}
return
nullptr
;
return
nullptr
;
}
else
{
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
dictionary_
=
dictionary_
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
dictionary_state_
=
matcher_
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
has_dictionary_
=
true
;
new_path
->
_matcher
=
_matcher
;
new_path
->
matcher_
=
matcher_
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
return
new_path
;
}
}
}
else
{
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
parent
=
this
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
return
new_path
;
}
}
}
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
_ROOT
);
return
get_path_vec
(
output
,
ROOT_
);
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
int
stop
,
size_t
max_steps
)
{
size_t
max_steps
)
{
if
(
character
==
stop
||
character
==
_ROOT
||
output
.
size
()
==
max_steps
)
{
if
(
character
==
stop
||
character
==
ROOT_
||
output
.
size
()
==
max_steps
)
{
std
::
reverse
(
output
.
begin
(),
output
.
end
());
std
::
reverse
(
output
.
begin
(),
output
.
end
());
return
this
;
return
this
;
}
else
{
}
else
{
...
@@ -98,7 +98,7 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
...
@@ -98,7 +98,7 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
}
}
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
_exists
)
{
if
(
exists_
)
{
log_prob_b_prev
=
log_prob_b_cur
;
log_prob_b_prev
=
log_prob_b_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
...
@@ -108,25 +108,25 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
...
@@ -108,25 +108,25 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
output
.
push_back
(
this
);
output
.
push_back
(
this
);
}
}
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
children_
)
{
child
.
second
->
iterate_to_vec
(
output
);
child
.
second
->
iterate_to_vec
(
output
);
}
}
}
}
void
PathTrie
::
remove
()
{
void
PathTrie
::
remove
()
{
_exists
=
false
;
exists_
=
false
;
if
(
_children
.
size
()
==
0
)
{
if
(
children_
.
size
()
==
0
)
{
auto
child
=
parent
->
_children
.
begin
();
auto
child
=
parent
->
children_
.
begin
();
for
(
child
=
parent
->
_children
.
begin
();
child
!=
parent
->
_children
.
end
();
for
(
child
=
parent
->
children_
.
begin
();
child
!=
parent
->
children_
.
end
();
++
child
)
{
++
child
)
{
if
(
child
->
first
==
character
)
{
if
(
child
->
first
==
character
)
{
parent
->
_children
.
erase
(
child
);
parent
->
children_
.
erase
(
child
);
break
;
break
;
}
}
}
}
if
(
parent
->
_children
.
size
()
==
0
&&
!
parent
->
_exists
)
{
if
(
parent
->
children_
.
size
()
==
0
&&
!
parent
->
exists_
)
{
parent
->
remove
();
parent
->
remove
();
}
}
...
@@ -135,12 +135,12 @@ void PathTrie::remove() {
...
@@ -135,12 +135,12 @@ void PathTrie::remove() {
}
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
_dictionary
=
dictionary
;
dictionary_
=
dictionary
;
_dictionary_state
=
dictionary
->
Start
();
dictionary_state_
=
dictionary
->
Start
();
_has_dictionary
=
true
;
has_dictionary_
=
true
;
}
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
_matcher
=
matcher
;
matcher_
=
matcher
;
}
}
deep_speech_2/decoders/swig/path_trie.h
浏览文件 @
8c5576d9
...
@@ -36,7 +36,7 @@ public:
...
@@ -36,7 +36,7 @@ public:
void
set_matcher
(
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
);
void
set_matcher
(
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
);
bool
is_empty
()
{
return
_ROOT
==
character
;
}
bool
is_empty
()
{
return
ROOT_
==
character
;
}
// remove current path from root
// remove current path from root
void
remove
();
void
remove
();
...
@@ -51,17 +51,17 @@ public:
...
@@ -51,17 +51,17 @@ public:
PathTrie
*
parent
;
PathTrie
*
parent
;
private:
private:
int
_ROOT
;
int
ROOT_
;
bool
_exists
;
bool
exists_
;
bool
_has_dictionary
;
bool
has_dictionary_
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
_children
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
children_
;
// pointer to dictionary of FST
// pointer to dictionary of FST
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
*
dictionary_
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
fst
::
StdVectorFst
::
StateId
dictionary_state_
;
// true if finding ars in FST
// true if finding ars in FST
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
_matcher
;
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
matcher_
;
};
};
#endif // PATH_TRIE_H
#endif // PATH_TRIE_H
deep_speech_2/decoders/swig/scorer.cpp
浏览文件 @
8c5576d9
...
@@ -19,19 +19,19 @@ Scorer::Scorer(double alpha,
...
@@ -19,19 +19,19 @@ Scorer::Scorer(double alpha,
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
this
->
alpha
=
alpha
;
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
this
->
beta
=
beta
;
_is_character_based
=
true
;
is_character_based_
=
true
;
_language_model
=
nullptr
;
language_model_
=
nullptr
;
dictionary
=
nullptr
;
dictionary
=
nullptr
;
_max_order
=
0
;
max_order_
=
0
;
_dict_size
=
0
;
dict_size_
=
0
;
_SPACE_ID
=
-
1
;
SPACE_ID_
=
-
1
;
setup
(
lm_path
,
vocab_list
);
setup
(
lm_path
,
vocab_list
);
}
}
Scorer
::~
Scorer
()
{
Scorer
::~
Scorer
()
{
if
(
_language_model
!=
nullptr
)
{
if
(
language_model_
!=
nullptr
)
{
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
delete
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
}
}
if
(
dictionary
!=
nullptr
)
{
if
(
dictionary
!=
nullptr
)
{
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
...
@@ -57,20 +57,20 @@ void Scorer::load_lm(const std::string& lm_path) {
...
@@ -57,20 +57,20 @@ void Scorer::load_lm(const std::string& lm_path) {
RetriveStrEnumerateVocab
enumerate
;
RetriveStrEnumerateVocab
enumerate
;
lm
::
ngram
::
Config
config
;
lm
::
ngram
::
Config
config
;
config
.
enumerate_vocab
=
&
enumerate
;
config
.
enumerate_vocab
=
&
enumerate
;
_language_model
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
language_model_
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
_max_order
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
)
->
Order
();
max_order_
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
)
->
Order
();
_vocabulary
=
enumerate
.
vocabulary
;
vocabulary_
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
_vocabulary
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vocabulary_
.
size
();
++
i
)
{
if
(
_is_character_based
&&
_vocabulary
[
i
]
!=
UNK_TOKEN
&&
if
(
is_character_based_
&&
vocabulary_
[
i
]
!=
UNK_TOKEN
&&
_vocabulary
[
i
]
!=
START_TOKEN
&&
_vocabulary
[
i
]
!=
END_TOKEN
&&
vocabulary_
[
i
]
!=
START_TOKEN
&&
vocabulary_
[
i
]
!=
END_TOKEN
&&
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
_is_character_based
=
false
;
is_character_based_
=
false
;
}
}
}
}
}
}
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
double
cond_prob
;
double
cond_prob
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
// avoid to inserting <s> in begin
// avoid to inserting <s> in begin
...
@@ -93,11 +93,11 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
...
@@ -93,11 +93,11 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
std
::
vector
<
std
::
string
>
sentence
;
std
::
vector
<
std
::
string
>
sentence
;
if
(
words
.
size
()
==
0
)
{
if
(
words
.
size
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
_max_order
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
sentence
.
push_back
(
START_TOKEN
);
}
}
}
else
{
}
else
{
for
(
size_t
i
=
0
;
i
<
_max_order
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
sentence
.
push_back
(
START_TOKEN
);
}
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
...
@@ -107,11 +107,11 @@ double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
...
@@ -107,11 +107,11 @@ double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
}
}
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
assert
(
words
.
size
()
>
_max_order
);
assert
(
words
.
size
()
>
max_order_
);
double
score
=
0.0
;
double
score
=
0.0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
_max_order
+
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
max_order_
+
1
;
++
i
)
{
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
words
.
begin
()
+
i
+
_max_order
);
words
.
begin
()
+
i
+
max_order_
);
score
+=
get_log_cond_prob
(
ngram
);
score
+=
get_log_cond_prob
(
ngram
);
}
}
return
score
;
return
score
;
...
@@ -125,7 +125,7 @@ void Scorer::reset_params(float alpha, float beta) {
...
@@ -125,7 +125,7 @@ void Scorer::reset_params(float alpha, float beta) {
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
word
;
std
::
string
word
;
for
(
auto
ind
:
input
)
{
for
(
auto
ind
:
input
)
{
word
+=
_char_list
[
ind
];
word
+=
char_list_
[
ind
];
}
}
return
word
;
return
word
;
}
}
...
@@ -135,7 +135,7 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
...
@@ -135,7 +135,7 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
std
::
string
s
=
vec2str
(
labels
);
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
std
::
vector
<
std
::
string
>
words
;
if
(
_is_character_based
)
{
if
(
is_character_based_
)
{
words
=
split_utf8_str
(
s
);
words
=
split_utf8_str
(
s
);
}
else
{
}
else
{
words
=
split_str
(
s
,
" "
);
words
=
split_str
(
s
,
" "
);
...
@@ -144,15 +144,15 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
...
@@ -144,15 +144,15 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
}
}
void
Scorer
::
set_char_map
(
const
std
::
vector
<
std
::
string
>&
char_list
)
{
void
Scorer
::
set_char_map
(
const
std
::
vector
<
std
::
string
>&
char_list
)
{
_char_list
=
char_list
;
char_list_
=
char_list
;
_char_map
.
clear
();
char_map_
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
if
(
_char_list
[
i
]
==
" "
)
{
if
(
char_list_
[
i
]
==
" "
)
{
_SPACE_ID
=
i
;
SPACE_ID_
=
i
;
_char_map
[
' '
]
=
i
;
char_map_
[
' '
]
=
i
;
}
else
if
(
_char_list
[
i
].
size
()
==
1
)
{
}
else
if
(
char_list_
[
i
].
size
()
==
1
)
{
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
char_map_
[
char_list_
[
i
][
0
]]
=
i
;
}
}
}
}
}
}
...
@@ -162,14 +162,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
...
@@ -162,14 +162,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
PathTrie
*
current_node
=
prefix
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
PathTrie
*
new_node
=
nullptr
;
for
(
int
order
=
0
;
order
<
_max_order
;
order
++
)
{
for
(
int
order
=
0
;
order
<
max_order_
;
order
++
)
{
std
::
vector
<
int
>
prefix_vec
;
std
::
vector
<
int
>
prefix_vec
;
if
(
_is_character_based
)
{
if
(
is_character_based_
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
,
1
);
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
,
1
);
current_node
=
new_node
;
current_node
=
new_node
;
}
else
{
}
else
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
);
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
);
current_node
=
new_node
->
parent
;
// Skipping spaces
current_node
=
new_node
->
parent
;
// Skipping spaces
}
}
...
@@ -179,7 +179,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
...
@@ -179,7 +179,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if
(
new_node
->
character
==
-
1
)
{
if
(
new_node
->
character
==
-
1
)
{
// No more spaces, but still need order
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
_max_order
-
order
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
max_order_
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
ngram
.
push_back
(
START_TOKEN
);
}
}
break
;
break
;
...
@@ -193,19 +193,19 @@ void Scorer::fill_dictionary(bool add_space) {
...
@@ -193,19 +193,19 @@ void Scorer::fill_dictionary(bool add_space) {
fst
::
StdVectorFst
dictionary
;
fst
::
StdVectorFst
dictionary
;
// First reverse char_list so ints can be accessed by chars
// First reverse char_list so ints can be accessed by chars
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
char_map
[
char_list_
[
i
]]
=
i
;
}
}
// For each unigram convert to ints and put in trie
// For each unigram convert to ints and put in trie
int
dict_size
=
0
;
int
dict_size
=
0
;
for
(
const
auto
&
word
:
_vocabulary
)
{
for
(
const
auto
&
word
:
vocabulary_
)
{
bool
added
=
add_word_to_dictionary
(
bool
added
=
add_word_to_dictionary
(
word
,
char_map
,
add_space
,
_SPACE_ID
,
&
dictionary
);
word
,
char_map
,
add_space
,
SPACE_ID_
,
&
dictionary
);
dict_size
+=
added
?
1
:
0
;
dict_size
+=
added
?
1
:
0
;
}
}
_dict_size
=
dict_size
;
dict_size_
=
dict_size
;
/* Simplify FST
/* Simplify FST
...
...
deep_speech_2/decoders/swig/scorer.h
浏览文件 @
8c5576d9
...
@@ -18,7 +18,7 @@ const std::string START_TOKEN = "<s>";
...
@@ -18,7 +18,7 @@ const std::string START_TOKEN = "<s>";
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
// Implement a callback to retrive
string vocabulary
.
// Implement a callback to retrive
the dictionary of language model
.
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
public:
public:
RetriveStrEnumerateVocab
()
{}
RetriveStrEnumerateVocab
()
{}
...
@@ -50,13 +50,14 @@ public:
...
@@ -50,13 +50,14 @@ public:
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
size_t
get_max_order
()
const
{
return
_max_order
;
}
// return the max order
size_t
get_max_order
()
const
{
return
max_order_
;
}
size_t
get_dict_size
()
const
{
return
_dict_size
;
}
// return the dictionary size of language model
size_t
get_dict_size
()
const
{
return
dict_size_
;
}
bool
is_char_map_empty
()
const
{
return
_char_map
.
size
()
==
0
;
}
// retrun true if the language model is character based
bool
is_character_based
()
const
{
return
is_character_based_
;
}
bool
is_character_based
()
const
{
return
_is_character_based
;
}
// reset params alpha & beta
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
void
reset_params
(
float
alpha
,
float
beta
);
...
@@ -68,20 +69,23 @@ public:
...
@@ -68,20 +69,23 @@ public:
// the vector of characters (character based lm)
// the vector of characters (character based lm)
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
//
expose to decoder
//
language model weight
double
alpha
;
double
alpha
;
// word insertion weight
double
beta
;
double
beta
;
//
fst dictionary
//
pointer to the dictionary of FST
void
*
dictionary
;
void
*
dictionary
;
protected:
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void
setup
(
const
std
::
string
&
lm_path
,
void
setup
(
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>
&
vocab_list
);
const
std
::
vector
<
std
::
string
>
&
vocab_list
);
// load language model from given path
void
load_lm
(
const
std
::
string
&
lm_path
);
void
load_lm
(
const
std
::
string
&
lm_path
);
// fill dictionary for
fst
// fill dictionary for
FST
void
fill_dictionary
(
bool
add_space
);
void
fill_dictionary
(
bool
add_space
);
// set char map
// set char map
...
@@ -89,19 +93,20 @@ protected:
...
@@ -89,19 +93,20 @@ protected:
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
// translate the vector in index to string
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
private:
private:
void
*
_language_model
;
void
*
language_model_
;
bool
_is_character_based
;
bool
is_character_based_
;
size_t
_max_order
;
size_t
max_order_
;
size_t
_dict_size
;
size_t
dict_size_
;
int
_SPACE_ID
;
int
SPACE_ID_
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
vector
<
std
::
string
>
char_list_
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
std
::
unordered_map
<
char
,
int
>
char_map_
;
std
::
vector
<
std
::
string
>
_vocabulary
;
std
::
vector
<
std
::
string
>
vocabulary_
;
};
};
#endif // SCORER_H_
#endif // SCORER_H_
deep_speech_2/decoders/swig_wrapper.py
浏览文件 @
8c5576d9
...
@@ -39,8 +39,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
...
@@ -39,8 +39,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def
ctc_beam_search_decoder
(
probs_seq
,
def
ctc_beam_search_decoder
(
probs_seq
,
beam_size
,
vocabulary
,
vocabulary
,
beam_size
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
ext_scoring_func
=
None
):
...
@@ -50,10 +50,10 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -50,10 +50,10 @@ def ctc_beam_search_decoder(probs_seq,
step, with each element being a list of normalized
step, with each element being a list of normalized
probabilities over vocabulary and blank.
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:param vocabulary: Vocabulary list.
:type vocabulary: list
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
default 1.0, no pruning.
:type cutoff_prob: float
:type cutoff_prob: float
...
@@ -69,14 +69,14 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -69,14 +69,14 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
results, in descending order of the probability.
:rtype: list
:rtype: list
"""
"""
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
beam_size
,
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
vocabulary
,
vocabulary
,
cutoff_prob
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
cutoff_top_n
,
ext_scoring_func
)
def
ctc_beam_search_decoder_batch
(
probs_split
,
def
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
vocabulary
,
beam_size
,
num_processes
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
cutoff_top_n
=
40
,
...
@@ -86,10 +86,10 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -86,10 +86,10 @@ def ctc_beam_search_decoder_batch(probs_split,
:param probs_seq: 3-D list with each element as an instance of 2-D list
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:param vocabulary: Vocabulary list.
:type vocabulary: list
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
:param cutoff_prob: Cutoff probability in vocabulary pruning,
...
@@ -112,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -112,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split,
probs_split
=
[
probs_seq
.
tolist
()
for
probs_seq
in
probs_split
]
probs_split
=
[
probs_seq
.
tolist
()
for
probs_seq
in
probs_split
]
return
swig_decoders
.
ctc_beam_search_decoder_batch
(
return
swig_decoders
.
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
num_processes
,
cutoff_prob
,
probs_split
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
cutoff_top_n
,
ext_scoring_func
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录