Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
8c5576d9
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8c5576d9
编写于
9月 17, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format varabiables' name & add more comments
上级
e6740af4
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
129 addition
and
126 deletion
+129
-126
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
+7
-8
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
+4
-5
deep_speech_2/decoders/swig/path_trie.cpp
deep_speech_2/decoders/swig/path_trie.cpp
+38
-38
deep_speech_2/decoders/swig/path_trie.h
deep_speech_2/decoders/swig/path_trie.h
+8
-8
deep_speech_2/decoders/swig/scorer.cpp
deep_speech_2/decoders/swig/scorer.cpp
+41
-41
deep_speech_2/decoders/swig/scorer.h
deep_speech_2/decoders/swig/scorer.h
+22
-17
deep_speech_2/decoders/swig_wrapper.py
deep_speech_2/decoders/swig_wrapper.py
+9
-9
未找到文件。
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
浏览文件 @
8c5576d9
...
...
@@ -18,8 +18,8 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
...
...
@@ -36,8 +36,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t
blank_id
=
vocabulary
.
size
();
// assign space id
std
::
vector
<
std
::
string
>::
iterator
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
auto
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
// if no space in vocabulary
if
((
size_t
)
space_id
>=
vocabulary
.
size
())
{
...
...
@@ -173,11 +172,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
const
size_t
beam_size
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
const
size_t
num_processes
,
const
double
cutoff_prob
,
const
size_t
cutoff_top_n
,
size_t
beam_size
,
size_t
num_processes
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
VALID_CHECK_GT
(
num_processes
,
0
,
"num_processes must be nonnegative!"
);
// thread pool
...
...
@@ -190,8 +189,8 @@ ctc_beam_search_decoder_batch(
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
beam_size
,
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scorer
));
...
...
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
浏览文件 @
8c5576d9
...
...
@@ -12,8 +12,8 @@
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
...
...
@@ -25,8 +25,8 @@
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
nullptr
);
...
...
@@ -36,9 +36,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* .
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
...
...
@@ -52,8 +51,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
size_t
beam_size
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
size_t
num_processes
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
...
...
deep_speech_2/decoders/swig/path_trie.cpp
浏览文件 @
8c5576d9
...
...
@@ -15,32 +15,32 @@ PathTrie::PathTrie() {
log_prob_nb_cur
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
_ROOT
=
-
1
;
character
=
_ROOT
;
_exists
=
true
;
ROOT_
=
-
1
;
character
=
ROOT_
;
exists_
=
true
;
parent
=
nullptr
;
_dictionary
=
nullptr
;
_dictionary_state
=
0
;
_has_dictionary
=
false
;
_matcher
=
nullptr
;
dictionary_
=
nullptr
;
dictionary_state_
=
0
;
has_dictionary_
=
false
;
matcher_
=
nullptr
;
}
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
children_
)
{
delete
child
.
second
;
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
_children
.
begin
();
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
auto
child
=
children_
.
begin
();
for
(
child
=
children_
.
begin
();
child
!=
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
break
;
}
}
if
(
child
!=
_children
.
end
())
{
if
(
!
child
->
second
->
_exists
)
{
child
->
second
->
_exists
=
true
;
if
(
child
!=
children_
.
end
())
{
if
(
!
child
->
second
->
exists_
)
{
child
->
second
->
exists_
=
true
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
...
...
@@ -48,47 +48,47 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
}
return
(
child
->
second
);
}
else
{
if
(
_has_dictionary
)
{
_matcher
->
SetState
(
_dictionary_state
);
bool
found
=
_matcher
->
Find
(
new_char
);
if
(
has_dictionary_
)
{
matcher_
->
SetState
(
dictionary_state_
);
bool
found
=
matcher_
->
Find
(
new_char
);
if
(
!
found
)
{
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
auto
final_weight
=
dictionary_
->
Final
(
dictionary_state_
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
dictionary_state_
=
dictionary_
->
Start
();
}
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
_matcher
=
_matcher
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
new_path
->
dictionary_
=
dictionary_
;
new_path
->
dictionary_state_
=
matcher_
->
Value
().
nextstate
;
new_path
->
has_dictionary_
=
true
;
new_path
->
matcher_
=
matcher_
;
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
_ROOT
);
return
get_path_vec
(
output
,
ROOT_
);
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
)
{
if
(
character
==
stop
||
character
==
_ROOT
||
output
.
size
()
==
max_steps
)
{
if
(
character
==
stop
||
character
==
ROOT_
||
output
.
size
()
==
max_steps
)
{
std
::
reverse
(
output
.
begin
(),
output
.
end
());
return
this
;
}
else
{
...
...
@@ -98,7 +98,7 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
}
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
_exists
)
{
if
(
exists_
)
{
log_prob_b_prev
=
log_prob_b_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
...
...
@@ -108,25 +108,25 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
output
.
push_back
(
this
);
}
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
children_
)
{
child
.
second
->
iterate_to_vec
(
output
);
}
}
void
PathTrie
::
remove
()
{
_exists
=
false
;
exists_
=
false
;
if
(
_children
.
size
()
==
0
)
{
auto
child
=
parent
->
_children
.
begin
();
for
(
child
=
parent
->
_children
.
begin
();
child
!=
parent
->
_children
.
end
();
if
(
children_
.
size
()
==
0
)
{
auto
child
=
parent
->
children_
.
begin
();
for
(
child
=
parent
->
children_
.
begin
();
child
!=
parent
->
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
character
)
{
parent
->
_children
.
erase
(
child
);
parent
->
children_
.
erase
(
child
);
break
;
}
}
if
(
parent
->
_children
.
size
()
==
0
&&
!
parent
->
_exists
)
{
if
(
parent
->
children_
.
size
()
==
0
&&
!
parent
->
exists_
)
{
parent
->
remove
();
}
...
...
@@ -135,12 +135,12 @@ void PathTrie::remove() {
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
_dictionary
=
dictionary
;
_dictionary_state
=
dictionary
->
Start
();
_has_dictionary
=
true
;
dictionary_
=
dictionary
;
dictionary_state_
=
dictionary
->
Start
();
has_dictionary_
=
true
;
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
_matcher
=
matcher
;
matcher_
=
matcher
;
}
deep_speech_2/decoders/swig/path_trie.h
浏览文件 @
8c5576d9
...
...
@@ -36,7 +36,7 @@ public:
void
set_matcher
(
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
);
bool
is_empty
()
{
return
_ROOT
==
character
;
}
bool
is_empty
()
{
return
ROOT_
==
character
;
}
// remove current path from root
void
remove
();
...
...
@@ -51,17 +51,17 @@ public:
PathTrie
*
parent
;
private:
int
_ROOT
;
bool
_exists
;
bool
_has_dictionary
;
int
ROOT_
;
bool
exists_
;
bool
has_dictionary_
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
_children
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
children_
;
// pointer to dictionary of FST
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
fst
::
StdVectorFst
*
dictionary_
;
fst
::
StdVectorFst
::
StateId
dictionary_state_
;
// true if finding ars in FST
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
_matcher
;
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
matcher_
;
};
#endif // PATH_TRIE_H
deep_speech_2/decoders/swig/scorer.cpp
浏览文件 @
8c5576d9
...
...
@@ -19,19 +19,19 @@ Scorer::Scorer(double alpha,
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
_is_character_based
=
true
;
_language_model
=
nullptr
;
is_character_based_
=
true
;
language_model_
=
nullptr
;
dictionary
=
nullptr
;
_max_order
=
0
;
_dict_size
=
0
;
_SPACE_ID
=
-
1
;
max_order_
=
0
;
dict_size_
=
0
;
SPACE_ID_
=
-
1
;
setup
(
lm_path
,
vocab_list
);
}
Scorer
::~
Scorer
()
{
if
(
_language_model
!=
nullptr
)
{
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
if
(
language_model_
!=
nullptr
)
{
delete
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
}
if
(
dictionary
!=
nullptr
)
{
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
...
...
@@ -57,20 +57,20 @@ void Scorer::load_lm(const std::string& lm_path) {
RetriveStrEnumerateVocab
enumerate
;
lm
::
ngram
::
Config
config
;
config
.
enumerate_vocab
=
&
enumerate
;
_language_model
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
_max_order
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
)
->
Order
();
_vocabulary
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
_vocabulary
.
size
();
++
i
)
{
if
(
_is_character_based
&&
_vocabulary
[
i
]
!=
UNK_TOKEN
&&
_vocabulary
[
i
]
!=
START_TOKEN
&&
_vocabulary
[
i
]
!=
END_TOKEN
&&
language_model_
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
max_order_
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
)
->
Order
();
vocabulary_
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
vocabulary_
.
size
();
++
i
)
{
if
(
is_character_based_
&&
vocabulary_
[
i
]
!=
UNK_TOKEN
&&
vocabulary_
[
i
]
!=
START_TOKEN
&&
vocabulary_
[
i
]
!=
END_TOKEN
&&
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
_is_character_based
=
false
;
is_character_based_
=
false
;
}
}
}
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
double
cond_prob
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
// avoid to inserting <s> in begin
...
...
@@ -93,11 +93,11 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
std
::
vector
<
std
::
string
>
sentence
;
if
(
words
.
size
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
_max_order
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
_max_order
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
...
...
@@ -107,11 +107,11 @@ double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
}
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
assert
(
words
.
size
()
>
_max_order
);
assert
(
words
.
size
()
>
max_order_
);
double
score
=
0.0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
_max_order
+
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
max_order_
+
1
;
++
i
)
{
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
words
.
begin
()
+
i
+
_max_order
);
words
.
begin
()
+
i
+
max_order_
);
score
+=
get_log_cond_prob
(
ngram
);
}
return
score
;
...
...
@@ -125,7 +125,7 @@ void Scorer::reset_params(float alpha, float beta) {
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
word
;
for
(
auto
ind
:
input
)
{
word
+=
_char_list
[
ind
];
word
+=
char_list_
[
ind
];
}
return
word
;
}
...
...
@@ -135,7 +135,7 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
if
(
_is_character_based
)
{
if
(
is_character_based_
)
{
words
=
split_utf8_str
(
s
);
}
else
{
words
=
split_str
(
s
,
" "
);
...
...
@@ -144,15 +144,15 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
}
void
Scorer
::
set_char_map
(
const
std
::
vector
<
std
::
string
>&
char_list
)
{
_char_list
=
char_list
;
_char_map
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
if
(
_char_list
[
i
]
==
" "
)
{
_SPACE_ID
=
i
;
_char_map
[
' '
]
=
i
;
}
else
if
(
_char_list
[
i
].
size
()
==
1
)
{
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
char_list_
=
char_list
;
char_map_
.
clear
();
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
if
(
char_list_
[
i
]
==
" "
)
{
SPACE_ID_
=
i
;
char_map_
[
' '
]
=
i
;
}
else
if
(
char_list_
[
i
].
size
()
==
1
)
{
char_map_
[
char_list_
[
i
][
0
]]
=
i
;
}
}
}
...
...
@@ -162,14 +162,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
for
(
int
order
=
0
;
order
<
_max_order
;
order
++
)
{
for
(
int
order
=
0
;
order
<
max_order_
;
order
++
)
{
std
::
vector
<
int
>
prefix_vec
;
if
(
_is_character_based
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
,
1
);
if
(
is_character_based_
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
,
1
);
current_node
=
new_node
;
}
else
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
);
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
);
current_node
=
new_node
->
parent
;
// Skipping spaces
}
...
...
@@ -179,7 +179,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if
(
new_node
->
character
==
-
1
)
{
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
_max_order
-
order
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
max_order_
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
}
break
;
...
...
@@ -193,19 +193,19 @@ void Scorer::fill_dictionary(bool add_space) {
fst
::
StdVectorFst
dictionary
;
// First reverse char_list so ints can be accessed by chars
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
char_map
[
char_list_
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
int
dict_size
=
0
;
for
(
const
auto
&
word
:
_vocabulary
)
{
for
(
const
auto
&
word
:
vocabulary_
)
{
bool
added
=
add_word_to_dictionary
(
word
,
char_map
,
add_space
,
_SPACE_ID
,
&
dictionary
);
word
,
char_map
,
add_space
,
SPACE_ID_
,
&
dictionary
);
dict_size
+=
added
?
1
:
0
;
}
_dict_size
=
dict_size
;
dict_size_
=
dict_size
;
/* Simplify FST
...
...
deep_speech_2/decoders/swig/scorer.h
浏览文件 @
8c5576d9
...
...
@@ -18,7 +18,7 @@ const std::string START_TOKEN = "<s>";
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
// Implement a callback to retrive
string vocabulary
.
// Implement a callback to retrive
the dictionary of language model
.
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
public:
RetriveStrEnumerateVocab
()
{}
...
...
@@ -50,13 +50,14 @@ public:
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
size_t
get_max_order
()
const
{
return
_max_order
;
}
// return the max order
size_t
get_max_order
()
const
{
return
max_order_
;
}
size_t
get_dict_size
()
const
{
return
_dict_size
;
}
// return the dictionary size of language model
size_t
get_dict_size
()
const
{
return
dict_size_
;
}
bool
is_char_map_empty
()
const
{
return
_char_map
.
size
()
==
0
;
}
bool
is_character_based
()
const
{
return
_is_character_based
;
}
// retrun true if the language model is character based
bool
is_character_based
()
const
{
return
is_character_based_
;
}
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
...
...
@@ -68,20 +69,23 @@ public:
// the vector of characters (character based lm)
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
//
expose to decoder
//
language model weight
double
alpha
;
// word insertion weight
double
beta
;
//
fst dictionary
//
pointer to the dictionary of FST
void
*
dictionary
;
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void
setup
(
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>
&
vocab_list
);
// load language model from given path
void
load_lm
(
const
std
::
string
&
lm_path
);
// fill dictionary for
fst
// fill dictionary for
FST
void
fill_dictionary
(
bool
add_space
);
// set char map
...
...
@@ -89,19 +93,20 @@ protected:
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
// translate the vector in index to string
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
private:
void
*
_language_model
;
bool
_is_character_based
;
size_t
_max_order
;
size_t
_dict_size
;
void
*
language_model_
;
bool
is_character_based_
;
size_t
max_order_
;
size_t
dict_size_
;
int
_SPACE_ID
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
int
SPACE_ID_
;
std
::
vector
<
std
::
string
>
char_list_
;
std
::
unordered_map
<
char
,
int
>
char_map_
;
std
::
vector
<
std
::
string
>
_vocabulary
;
std
::
vector
<
std
::
string
>
vocabulary_
;
};
#endif // SCORER_H_
deep_speech_2/decoders/swig_wrapper.py
浏览文件 @
8c5576d9
...
...
@@ -39,8 +39,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def
ctc_beam_search_decoder
(
probs_seq
,
beam_size
,
vocabulary
,
beam_size
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
...
...
@@ -50,10 +50,10 @@ def ctc_beam_search_decoder(probs_seq,
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
...
...
@@ -69,14 +69,14 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
:rtype: list
"""
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
beam_size
,
vocabulary
,
cutoff_prob
,
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
def
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
...
...
@@ -86,10 +86,10 @@ def ctc_beam_search_decoder_batch(probs_split,
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
...
...
@@ -112,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split,
probs_split
=
[
probs_seq
.
tolist
()
for
probs_seq
in
probs_split
]
return
swig_decoders
.
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
num_processes
,
cutoff_prob
,
probs_split
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录