Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
3018dcb4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
接近 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3018dcb4
编写于
9月 17, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format varabiables' name & add more comments
上级
a24d0138
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
129 addition
and
126 deletion
+129
-126
decoders/swig/ctc_beam_search_decoder.cpp
decoders/swig/ctc_beam_search_decoder.cpp
+7
-8
decoders/swig/ctc_beam_search_decoder.h
decoders/swig/ctc_beam_search_decoder.h
+4
-5
decoders/swig/path_trie.cpp
decoders/swig/path_trie.cpp
+38
-38
decoders/swig/path_trie.h
decoders/swig/path_trie.h
+8
-8
decoders/swig/scorer.cpp
decoders/swig/scorer.cpp
+41
-41
decoders/swig/scorer.h
decoders/swig/scorer.h
+22
-17
decoders/swig_wrapper.py
decoders/swig_wrapper.py
+9
-9
未找到文件。
decoders/swig/ctc_beam_search_decoder.cpp
浏览文件 @
3018dcb4
...
...
@@ -18,8 +18,8 @@ using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
...
...
@@ -36,8 +36,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
size_t
blank_id
=
vocabulary
.
size
();
// assign space id
std
::
vector
<
std
::
string
>::
iterator
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
auto
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
// if no space in vocabulary
if
((
size_t
)
space_id
>=
vocabulary
.
size
())
{
...
...
@@ -173,11 +172,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
const
size_t
beam_size
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
const
size_t
num_processes
,
const
double
cutoff_prob
,
const
size_t
cutoff_top_n
,
size_t
beam_size
,
size_t
num_processes
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
VALID_CHECK_GT
(
num_processes
,
0
,
"num_processes must be nonnegative!"
);
// thread pool
...
...
@@ -190,8 +189,8 @@ ctc_beam_search_decoder_batch(
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
beam_size
,
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scorer
));
...
...
decoders/swig/ctc_beam_search_decoder.h
浏览文件 @
3018dcb4
...
...
@@ -12,8 +12,8 @@
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
...
...
@@ -25,8 +25,8 @@
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
nullptr
);
...
...
@@ -36,9 +36,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* .
* beam_size: The width of beam search.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
...
...
@@ -52,8 +51,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
size_t
beam_size
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
size_t
num_processes
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
...
...
decoders/swig/path_trie.cpp
浏览文件 @
3018dcb4
...
...
@@ -15,32 +15,32 @@ PathTrie::PathTrie() {
log_prob_nb_cur
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
_ROOT
=
-
1
;
character
=
_ROOT
;
_exists
=
true
;
ROOT_
=
-
1
;
character
=
ROOT_
;
exists_
=
true
;
parent
=
nullptr
;
_dictionary
=
nullptr
;
_dictionary_state
=
0
;
_has_dictionary
=
false
;
_matcher
=
nullptr
;
dictionary_
=
nullptr
;
dictionary_state_
=
0
;
has_dictionary_
=
false
;
matcher_
=
nullptr
;
}
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
children_
)
{
delete
child
.
second
;
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
_children
.
begin
();
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
auto
child
=
children_
.
begin
();
for
(
child
=
children_
.
begin
();
child
!=
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
break
;
}
}
if
(
child
!=
_children
.
end
())
{
if
(
!
child
->
second
->
_exists
)
{
child
->
second
->
_exists
=
true
;
if
(
child
!=
children_
.
end
())
{
if
(
!
child
->
second
->
exists_
)
{
child
->
second
->
exists_
=
true
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
...
...
@@ -48,47 +48,47 @@ PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
}
return
(
child
->
second
);
}
else
{
if
(
_has_dictionary
)
{
_matcher
->
SetState
(
_dictionary_state
);
bool
found
=
_matcher
->
Find
(
new_char
);
if
(
has_dictionary_
)
{
matcher_
->
SetState
(
dictionary_state_
);
bool
found
=
matcher_
->
Find
(
new_char
);
if
(
!
found
)
{
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
auto
final_weight
=
dictionary_
->
Final
(
dictionary_state_
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
dictionary_state_
=
dictionary_
->
Start
();
}
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
_matcher
=
_matcher
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
new_path
->
dictionary_
=
dictionary_
;
new_path
->
dictionary_state_
=
matcher_
->
Value
().
nextstate
;
new_path
->
has_dictionary_
=
true
;
new_path
->
matcher_
=
matcher_
;
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
_ROOT
);
return
get_path_vec
(
output
,
ROOT_
);
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
)
{
if
(
character
==
stop
||
character
==
_ROOT
||
output
.
size
()
==
max_steps
)
{
if
(
character
==
stop
||
character
==
ROOT_
||
output
.
size
()
==
max_steps
)
{
std
::
reverse
(
output
.
begin
(),
output
.
end
());
return
this
;
}
else
{
...
...
@@ -98,7 +98,7 @@ PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
}
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
_exists
)
{
if
(
exists_
)
{
log_prob_b_prev
=
log_prob_b_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
...
...
@@ -108,25 +108,25 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
output
.
push_back
(
this
);
}
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
children_
)
{
child
.
second
->
iterate_to_vec
(
output
);
}
}
void
PathTrie
::
remove
()
{
_exists
=
false
;
exists_
=
false
;
if
(
_children
.
size
()
==
0
)
{
auto
child
=
parent
->
_children
.
begin
();
for
(
child
=
parent
->
_children
.
begin
();
child
!=
parent
->
_children
.
end
();
if
(
children_
.
size
()
==
0
)
{
auto
child
=
parent
->
children_
.
begin
();
for
(
child
=
parent
->
children_
.
begin
();
child
!=
parent
->
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
character
)
{
parent
->
_children
.
erase
(
child
);
parent
->
children_
.
erase
(
child
);
break
;
}
}
if
(
parent
->
_children
.
size
()
==
0
&&
!
parent
->
_exists
)
{
if
(
parent
->
children_
.
size
()
==
0
&&
!
parent
->
exists_
)
{
parent
->
remove
();
}
...
...
@@ -135,12 +135,12 @@ void PathTrie::remove() {
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
_dictionary
=
dictionary
;
_dictionary_state
=
dictionary
->
Start
();
_has_dictionary
=
true
;
dictionary_
=
dictionary
;
dictionary_state_
=
dictionary
->
Start
();
has_dictionary_
=
true
;
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
_matcher
=
matcher
;
matcher_
=
matcher
;
}
decoders/swig/path_trie.h
浏览文件 @
3018dcb4
...
...
@@ -36,7 +36,7 @@ public:
void
set_matcher
(
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
);
bool
is_empty
()
{
return
_ROOT
==
character
;
}
bool
is_empty
()
{
return
ROOT_
==
character
;
}
// remove current path from root
void
remove
();
...
...
@@ -51,17 +51,17 @@ public:
PathTrie
*
parent
;
private:
int
_ROOT
;
bool
_exists
;
bool
_has_dictionary
;
int
ROOT_
;
bool
exists_
;
bool
has_dictionary_
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
_children
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
children_
;
// pointer to dictionary of FST
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
fst
::
StdVectorFst
*
dictionary_
;
fst
::
StdVectorFst
::
StateId
dictionary_state_
;
// true if finding ars in FST
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
_matcher
;
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
matcher_
;
};
#endif // PATH_TRIE_H
decoders/swig/scorer.cpp
浏览文件 @
3018dcb4
...
...
@@ -19,19 +19,19 @@ Scorer::Scorer(double alpha,
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
_is_character_based
=
true
;
_language_model
=
nullptr
;
is_character_based_
=
true
;
language_model_
=
nullptr
;
dictionary
=
nullptr
;
_max_order
=
0
;
_dict_size
=
0
;
_SPACE_ID
=
-
1
;
max_order_
=
0
;
dict_size_
=
0
;
SPACE_ID_
=
-
1
;
setup
(
lm_path
,
vocab_list
);
}
Scorer
::~
Scorer
()
{
if
(
_language_model
!=
nullptr
)
{
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
if
(
language_model_
!=
nullptr
)
{
delete
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
}
if
(
dictionary
!=
nullptr
)
{
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
...
...
@@ -57,20 +57,20 @@ void Scorer::load_lm(const std::string& lm_path) {
RetriveStrEnumerateVocab
enumerate
;
lm
::
ngram
::
Config
config
;
config
.
enumerate_vocab
=
&
enumerate
;
_language_model
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
_max_order
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
)
->
Order
();
_vocabulary
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
_vocabulary
.
size
();
++
i
)
{
if
(
_is_character_based
&&
_vocabulary
[
i
]
!=
UNK_TOKEN
&&
_vocabulary
[
i
]
!=
START_TOKEN
&&
_vocabulary
[
i
]
!=
END_TOKEN
&&
language_model_
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
max_order_
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
)
->
Order
();
vocabulary_
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
vocabulary_
.
size
();
++
i
)
{
if
(
is_character_based_
&&
vocabulary_
[
i
]
!=
UNK_TOKEN
&&
vocabulary_
[
i
]
!=
START_TOKEN
&&
vocabulary_
[
i
]
!=
END_TOKEN
&&
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
_is_character_based
=
false
;
is_character_based_
=
false
;
}
}
}
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
double
cond_prob
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
// avoid to inserting <s> in begin
...
...
@@ -93,11 +93,11 @@ double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
std
::
vector
<
std
::
string
>
sentence
;
if
(
words
.
size
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
_max_order
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
_max_order
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
...
...
@@ -107,11 +107,11 @@ double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
}
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
assert
(
words
.
size
()
>
_max_order
);
assert
(
words
.
size
()
>
max_order_
);
double
score
=
0.0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
_max_order
+
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
max_order_
+
1
;
++
i
)
{
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
words
.
begin
()
+
i
+
_max_order
);
words
.
begin
()
+
i
+
max_order_
);
score
+=
get_log_cond_prob
(
ngram
);
}
return
score
;
...
...
@@ -125,7 +125,7 @@ void Scorer::reset_params(float alpha, float beta) {
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
word
;
for
(
auto
ind
:
input
)
{
word
+=
_char_list
[
ind
];
word
+=
char_list_
[
ind
];
}
return
word
;
}
...
...
@@ -135,7 +135,7 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
if
(
_is_character_based
)
{
if
(
is_character_based_
)
{
words
=
split_utf8_str
(
s
);
}
else
{
words
=
split_str
(
s
,
" "
);
...
...
@@ -144,15 +144,15 @@ std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
}
void
Scorer
::
set_char_map
(
const
std
::
vector
<
std
::
string
>&
char_list
)
{
_char_list
=
char_list
;
_char_map
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
if
(
_char_list
[
i
]
==
" "
)
{
_SPACE_ID
=
i
;
_char_map
[
' '
]
=
i
;
}
else
if
(
_char_list
[
i
].
size
()
==
1
)
{
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
char_list_
=
char_list
;
char_map_
.
clear
();
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
if
(
char_list_
[
i
]
==
" "
)
{
SPACE_ID_
=
i
;
char_map_
[
' '
]
=
i
;
}
else
if
(
char_list_
[
i
].
size
()
==
1
)
{
char_map_
[
char_list_
[
i
][
0
]]
=
i
;
}
}
}
...
...
@@ -162,14 +162,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
for
(
int
order
=
0
;
order
<
_max_order
;
order
++
)
{
for
(
int
order
=
0
;
order
<
max_order_
;
order
++
)
{
std
::
vector
<
int
>
prefix_vec
;
if
(
_is_character_based
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
,
1
);
if
(
is_character_based_
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
,
1
);
current_node
=
new_node
;
}
else
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
);
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
);
current_node
=
new_node
->
parent
;
// Skipping spaces
}
...
...
@@ -179,7 +179,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
if
(
new_node
->
character
==
-
1
)
{
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
_max_order
-
order
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
max_order_
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
}
break
;
...
...
@@ -193,19 +193,19 @@ void Scorer::fill_dictionary(bool add_space) {
fst
::
StdVectorFst
dictionary
;
// First reverse char_list so ints can be accessed by chars
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
char_map
[
char_list_
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
int
dict_size
=
0
;
for
(
const
auto
&
word
:
_vocabulary
)
{
for
(
const
auto
&
word
:
vocabulary_
)
{
bool
added
=
add_word_to_dictionary
(
word
,
char_map
,
add_space
,
_SPACE_ID
,
&
dictionary
);
word
,
char_map
,
add_space
,
SPACE_ID_
,
&
dictionary
);
dict_size
+=
added
?
1
:
0
;
}
_dict_size
=
dict_size
;
dict_size_
=
dict_size
;
/* Simplify FST
...
...
decoders/swig/scorer.h
浏览文件 @
3018dcb4
...
...
@@ -18,7 +18,7 @@ const std::string START_TOKEN = "<s>";
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
// Implement a callback to retrive
string vocabulary
.
// Implement a callback to retrive
the dictionary of language model
.
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
public:
RetriveStrEnumerateVocab
()
{}
...
...
@@ -50,13 +50,14 @@ public:
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
size_t
get_max_order
()
const
{
return
_max_order
;
}
// return the max order
size_t
get_max_order
()
const
{
return
max_order_
;
}
size_t
get_dict_size
()
const
{
return
_dict_size
;
}
// return the dictionary size of language model
size_t
get_dict_size
()
const
{
return
dict_size_
;
}
bool
is_char_map_empty
()
const
{
return
_char_map
.
size
()
==
0
;
}
bool
is_character_based
()
const
{
return
_is_character_based
;
}
// retrun true if the language model is character based
bool
is_character_based
()
const
{
return
is_character_based_
;
}
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
...
...
@@ -68,20 +69,23 @@ public:
// the vector of characters (character based lm)
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
//
expose to decoder
//
language model weight
double
alpha
;
// word insertion weight
double
beta
;
//
fst dictionary
//
pointer to the dictionary of FST
void
*
dictionary
;
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void
setup
(
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>
&
vocab_list
);
// load language model from given path
void
load_lm
(
const
std
::
string
&
lm_path
);
// fill dictionary for
fst
// fill dictionary for
FST
void
fill_dictionary
(
bool
add_space
);
// set char map
...
...
@@ -89,19 +93,20 @@ protected:
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
// translate the vector in index to string
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
private:
void
*
_language_model
;
bool
_is_character_based
;
size_t
_max_order
;
size_t
_dict_size
;
void
*
language_model_
;
bool
is_character_based_
;
size_t
max_order_
;
size_t
dict_size_
;
int
_SPACE_ID
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
int
SPACE_ID_
;
std
::
vector
<
std
::
string
>
char_list_
;
std
::
unordered_map
<
char
,
int
>
char_map_
;
std
::
vector
<
std
::
string
>
_vocabulary
;
std
::
vector
<
std
::
string
>
vocabulary_
;
};
#endif // SCORER_H_
decoders/swig_wrapper.py
浏览文件 @
3018dcb4
...
...
@@ -39,8 +39,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def
ctc_beam_search_decoder
(
probs_seq
,
beam_size
,
vocabulary
,
beam_size
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
...
...
@@ -50,10 +50,10 @@ def ctc_beam_search_decoder(probs_seq,
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
...
...
@@ -69,14 +69,14 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
:rtype: list
"""
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
beam_size
,
vocabulary
,
cutoff_prob
,
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
def
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
...
...
@@ -86,10 +86,10 @@ def ctc_beam_search_decoder_batch(probs_split,
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
...
...
@@ -112,5 +112,5 @@ def ctc_beam_search_decoder_batch(probs_split,
probs_split
=
[
probs_seq
.
tolist
()
for
probs_seq
in
probs_split
]
return
swig_decoders
.
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
num_processes
,
cutoff_prob
,
probs_split
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录