Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
f842c79a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
1 年多 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f842c79a
编写于
5月 12, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format code
上级
e969a8ec
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
720 addition
and
708 deletion
+720
-708
deepspeech/decoders/swig/ctc_beam_search_decoder.cpp
deepspeech/decoders/swig/ctc_beam_search_decoder.cpp
+184
-176
deepspeech/decoders/swig/ctc_greedy_decoder.cpp
deepspeech/decoders/swig/ctc_greedy_decoder.cpp
+33
-33
deepspeech/decoders/swig/decoder_utils.cpp
deepspeech/decoders/swig/decoder_utils.cpp
+124
-122
deepspeech/decoders/swig/decoder_utils.h
deepspeech/decoders/swig/decoder_utils.h
+12
-12
deepspeech/decoders/swig/path_trie.cpp
deepspeech/decoders/swig/path_trie.cpp
+102
-101
deepspeech/decoders/swig/path_trie.h
deepspeech/decoders/swig/path_trie.h
+38
-37
deepspeech/decoders/swig/scorer.cpp
deepspeech/decoders/swig/scorer.cpp
+165
-165
deepspeech/decoders/swig/scorer.h
deepspeech/decoders/swig/scorer.h
+55
-55
third_party/pymmseg-cpp/bin/pymmseg
third_party/pymmseg-cpp/bin/pymmseg
+5
-5
third_party/python-pinyin/pinyin-data/CHANGELOG.md
third_party/python-pinyin/pinyin-data/CHANGELOG.md
+1
-1
third_party/python-pinyin/pinyin-data/README.md
third_party/python-pinyin/pinyin-data/README.md
+1
-1
未找到文件。
deepspeech/decoders/swig/ctc_beam_search_decoder.cpp
浏览文件 @
f842c79a
...
@@ -36,169 +36,177 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
...
@@ -36,169 +36,177 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
double
cutoff_prob
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
Scorer
*
ext_scorer
)
{
// dimension check
// dimension check
size_t
num_time_steps
=
probs_seq
.
size
();
size_t
num_time_steps
=
probs_seq
.
size
();
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
VALID_CHECK_EQ
(
probs_seq
[
i
].
size
(),
VALID_CHECK_EQ
(
probs_seq
[
i
].
size
(),
// vocabulary.size() + 1,
// vocabulary.size() + 1,
vocabulary
.
size
(),
vocabulary
.
size
(),
"The shape of probs_seq does not match with "
"The shape of probs_seq does not match with "
"the shape of the vocabulary"
);
"the shape of the vocabulary"
);
}
// assign blank id
//size_t blank_id = vocabulary.size();
size_t
blank_id
=
0
;
// assign space id
auto
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
// if no space in vocabulary
if
((
size_t
)
space_id
>=
vocabulary
.
size
())
{
space_id
=
-
2
;
}
// init prefixes' root
PathTrie
root
;
root
.
score
=
root
.
log_prob_b_prev
=
0.0
;
std
::
vector
<
PathTrie
*>
prefixes
;
prefixes
.
push_back
(
&
root
);
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
ext_scorer
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
.
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
root
.
set_matcher
(
matcher
);
}
// prefix search over time
for
(
size_t
time_step
=
0
;
time_step
<
num_time_steps
;
++
time_step
)
{
auto
&
prob
=
probs_seq
[
time_step
];
float
min_cutoff
=
-
NUM_FLT_INF
;
bool
full_beam
=
false
;
if
(
ext_scorer
!=
nullptr
)
{
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
min_cutoff
=
prefixes
[
num_prefixes
-
1
]
->
score
+
std
::
log
(
prob
[
blank_id
])
-
std
::
max
(
0.0
,
ext_scorer
->
beta
);
full_beam
=
(
num_prefixes
==
beam_size
);
}
}
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
log_prob_idx
=
// assign blank id
get_pruned_log_probs
(
prob
,
cutoff_prob
,
cutoff_top_n
);
// size_t blank_id = vocabulary.size();
// loop over chars
size_t
blank_id
=
0
;
for
(
size_t
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
auto
c
=
log_prob_idx
[
index
].
first
;
// assign space id
auto
log_prob_c
=
log_prob_idx
[
index
].
second
;
auto
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
for
(
size_t
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
++
i
)
{
// if no space in vocabulary
auto
prefix
=
prefixes
[
i
];
if
((
size_t
)
space_id
>=
vocabulary
.
size
())
{
if
(
full_beam
&&
log_prob_c
+
prefix
->
score
<
min_cutoff
)
{
space_id
=
-
2
;
break
;
}
}
// blank
// init prefixes' root
if
(
c
==
blank_id
)
{
PathTrie
root
;
prefix
->
log_prob_b_cur
=
root
.
score
=
root
.
log_prob_b_prev
=
0.0
;
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
std
::
vector
<
PathTrie
*>
prefixes
;
continue
;
prefixes
.
push_back
(
&
root
);
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
ext_scorer
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
.
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
root
.
set_matcher
(
matcher
);
}
// prefix search over time
for
(
size_t
time_step
=
0
;
time_step
<
num_time_steps
;
++
time_step
)
{
auto
&
prob
=
probs_seq
[
time_step
];
float
min_cutoff
=
-
NUM_FLT_INF
;
bool
full_beam
=
false
;
if
(
ext_scorer
!=
nullptr
)
{
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
min_cutoff
=
prefixes
[
num_prefixes
-
1
]
->
score
+
std
::
log
(
prob
[
blank_id
])
-
std
::
max
(
0.0
,
ext_scorer
->
beta
);
full_beam
=
(
num_prefixes
==
beam_size
);
}
}
// repeated character
if
(
c
==
prefix
->
character
)
{
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
log_prob_idx
=
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
get_pruned_log_probs
(
prob
,
cutoff_prob
,
cutoff_top_n
);
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
// loop over chars
for
(
size_t
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
auto
c
=
log_prob_idx
[
index
].
first
;
auto
log_prob_c
=
log_prob_idx
[
index
].
second
;
for
(
size_t
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
++
i
)
{
auto
prefix
=
prefixes
[
i
];
if
(
full_beam
&&
log_prob_c
+
prefix
->
score
<
min_cutoff
)
{
break
;
}
// blank
if
(
c
==
blank_id
)
{
prefix
->
log_prob_b_cur
=
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
continue
;
}
// repeated character
if
(
c
==
prefix
->
character
)
{
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
}
// get new prefix
auto
prefix_new
=
prefix
->
get_path_trie
(
c
);
if
(
prefix_new
!=
nullptr
)
{
float
log_p
=
-
NUM_FLT_INF
;
if
(
c
==
prefix
->
character
&&
prefix
->
log_prob_b_prev
>
-
NUM_FLT_INF
)
{
log_p
=
log_prob_c
+
prefix
->
log_prob_b_prev
;
}
else
if
(
c
!=
prefix
->
character
)
{
log_p
=
log_prob_c
+
prefix
->
score
;
}
// language model scoring
if
(
ext_scorer
!=
nullptr
&&
(
c
==
space_id
||
ext_scorer
->
is_character_based
()))
{
PathTrie
*
prefix_to_score
=
nullptr
;
// skip scoring the space
if
(
ext_scorer
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
}
else
{
prefix_to_score
=
prefix
;
}
float
score
=
0.0
;
std
::
vector
<
std
::
string
>
ngram
;
ngram
=
ext_scorer
->
make_ngram
(
prefix_to_score
);
score
=
ext_scorer
->
get_log_cond_prob
(
ngram
)
*
ext_scorer
->
alpha
;
log_p
+=
score
;
log_p
+=
ext_scorer
->
beta
;
}
prefix_new
->
log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
}
}
// end of loop over prefix
}
// end of loop over vocabulary
prefixes
.
clear
();
// update log probs
root
.
iterate_to_vec
(
prefixes
);
// only preserve top beam_size prefixes
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
end
(),
prefix_compare
);
for
(
size_t
i
=
beam_size
;
i
<
prefixes
.
size
();
++
i
)
{
prefixes
[
i
]
->
remove
();
}
}
}
// get new prefix
}
// end of loop over time
auto
prefix_new
=
prefix
->
get_path_trie
(
c
);
// score the last word of each prefix that doesn't end with space
if
(
prefix_new
!=
nullptr
)
{
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
float
log_p
=
-
NUM_FLT_INF
;
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
auto
prefix
=
prefixes
[
i
];
if
(
c
==
prefix
->
character
&&
if
(
!
prefix
->
is_empty
()
&&
prefix
->
character
!=
space_id
)
{
prefix
->
log_prob_b_prev
>
-
NUM_FLT_INF
)
{
float
score
=
0.0
;
log_p
=
log_prob_c
+
prefix
->
log_prob_b_prev
;
std
::
vector
<
std
::
string
>
ngram
=
ext_scorer
->
make_ngram
(
prefix
);
}
else
if
(
c
!=
prefix
->
character
)
{
score
=
log_p
=
log_prob_c
+
prefix
->
score
;
ext_scorer
->
get_log_cond_prob
(
ngram
)
*
ext_scorer
->
alpha
;
}
score
+=
ext_scorer
->
beta
;
prefix
->
score
+=
score
;
// language model scoring
if
(
ext_scorer
!=
nullptr
&&
(
c
==
space_id
||
ext_scorer
->
is_character_based
()))
{
PathTrie
*
prefix_to_score
=
nullptr
;
// skip scoring the space
if
(
ext_scorer
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
}
else
{
prefix_to_score
=
prefix
;
}
}
float
score
=
0.0
;
std
::
vector
<
std
::
string
>
ngram
;
ngram
=
ext_scorer
->
make_ngram
(
prefix_to_score
);
score
=
ext_scorer
->
get_log_cond_prob
(
ngram
)
*
ext_scorer
->
alpha
;
log_p
+=
score
;
log_p
+=
ext_scorer
->
beta
;
}
prefix_new
->
log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
}
}
}
// end of loop over prefix
}
// end of loop over vocabulary
prefixes
.
clear
();
// update log probs
root
.
iterate_to_vec
(
prefixes
);
// only preserve top beam_size prefixes
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
end
(),
prefix_compare
);
for
(
size_t
i
=
beam_size
;
i
<
prefixes
.
size
();
++
i
)
{
prefixes
[
i
]
->
remove
();
}
}
}
}
// end of loop over time
// score the last word of each prefix that doesn't end with space
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
auto
prefix
=
prefixes
[
i
];
double
approx_ctc
=
prefixes
[
i
]
->
score
;
if
(
!
prefix
->
is_empty
()
&&
prefix
->
character
!=
space_id
)
{
if
(
ext_scorer
!=
nullptr
)
{
float
score
=
0.0
;
std
::
vector
<
int
>
output
;
std
::
vector
<
std
::
string
>
ngram
=
ext_scorer
->
make_ngram
(
prefix
);
prefixes
[
i
]
->
get_path_vec
(
output
);
score
=
ext_scorer
->
get_log_cond_prob
(
ngram
)
*
ext_scorer
->
alpha
;
auto
prefix_length
=
output
.
size
();
score
+=
ext_scorer
->
beta
;
auto
words
=
ext_scorer
->
split_labels
(
output
);
prefix
->
score
+=
score
;
// remove word insert
}
approx_ctc
=
approx_ctc
-
prefix_length
*
ext_scorer
->
beta
;
}
// remove language model weight:
}
approx_ctc
-=
(
ext_scorer
->
get_sent_log_prob
(
words
))
*
ext_scorer
->
alpha
;
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
}
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
double
approx_ctc
=
prefixes
[
i
]
->
score
;
if
(
ext_scorer
!=
nullptr
)
{
std
::
vector
<
int
>
output
;
prefixes
[
i
]
->
get_path_vec
(
output
);
auto
prefix_length
=
output
.
size
();
auto
words
=
ext_scorer
->
split_labels
(
output
);
// remove word insert
approx_ctc
=
approx_ctc
-
prefix_length
*
ext_scorer
->
beta
;
// remove language model weight:
approx_ctc
-=
(
ext_scorer
->
get_sent_log_prob
(
words
))
*
ext_scorer
->
alpha
;
}
}
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
}
return
get_beam_search_result
(
prefixes
,
vocabulary
,
beam_size
);
return
get_beam_search_result
(
prefixes
,
vocabulary
,
beam_size
);
}
}
...
@@ -211,28 +219,28 @@ ctc_beam_search_decoder_batch(
...
@@ -211,28 +219,28 @@ ctc_beam_search_decoder_batch(
double
cutoff_prob
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
Scorer
*
ext_scorer
)
{
VALID_CHECK_GT
(
num_processes
,
0
,
"num_processes must be nonnegative!"
);
VALID_CHECK_GT
(
num_processes
,
0
,
"num_processes must be nonnegative!"
);
// thread pool
// thread pool
ThreadPool
pool
(
num_processes
);
ThreadPool
pool
(
num_processes
);
// number of samples
// number of samples
size_t
batch_size
=
probs_split
.
size
();
size_t
batch_size
=
probs_split
.
size
();
// enqueue the tasks of decoding
// enqueue the tasks of decoding
std
::
vector
<
std
::
future
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>>
res
;
std
::
vector
<
std
::
future
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>>
res
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
probs_split
[
i
],
vocabulary
,
vocabulary
,
beam_size
,
beam_size
,
cutoff_prob
,
cutoff_prob
,
cutoff_top_n
,
cutoff_top_n
,
ext_scorer
));
ext_scorer
));
}
}
// get decoding results
// get decoding results
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
batch_results
;
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
batch_results
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
batch_results
.
emplace_back
(
res
[
i
].
get
());
batch_results
.
emplace_back
(
res
[
i
].
get
());
}
}
return
batch_results
;
return
batch_results
;
}
}
deepspeech/decoders/swig/ctc_greedy_decoder.cpp
浏览文件 @
f842c79a
...
@@ -18,42 +18,42 @@
...
@@ -18,42 +18,42 @@
std
::
string
ctc_greedy_decoder
(
std
::
string
ctc_greedy_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
)
{
const
std
::
vector
<
std
::
string
>
&
vocabulary
)
{
// dimension check
// dimension check
size_t
num_time_steps
=
probs_seq
.
size
();
size_t
num_time_steps
=
probs_seq
.
size
();
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
VALID_CHECK_EQ
(
probs_seq
[
i
].
size
(),
VALID_CHECK_EQ
(
probs_seq
[
i
].
size
(),
vocabulary
.
size
()
+
1
,
vocabulary
.
size
()
+
1
,
"The shape of probs_seq does not match with "
"The shape of probs_seq does not match with "
"the shape of the vocabulary"
);
"the shape of the vocabulary"
);
}
}
size_t
blank_id
=
vocabulary
.
size
();
size_t
blank_id
=
vocabulary
.
size
();
std
::
vector
<
size_t
>
max_idx_vec
(
num_time_steps
,
0
);
std
::
vector
<
size_t
>
max_idx_vec
(
num_time_steps
,
0
);
std
::
vector
<
size_t
>
idx_vec
;
std
::
vector
<
size_t
>
idx_vec
;
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
double
max_prob
=
0.0
;
double
max_prob
=
0.0
;
size_t
max_idx
=
0
;
size_t
max_idx
=
0
;
const
std
::
vector
<
double
>
&
probs_step
=
probs_seq
[
i
];
const
std
::
vector
<
double
>
&
probs_step
=
probs_seq
[
i
];
for
(
size_t
j
=
0
;
j
<
probs_step
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
probs_step
.
size
();
++
j
)
{
if
(
max_prob
<
probs_step
[
j
])
{
if
(
max_prob
<
probs_step
[
j
])
{
max_idx
=
j
;
max_idx
=
j
;
max_prob
=
probs_step
[
j
];
max_prob
=
probs_step
[
j
];
}
}
}
}
// id with maximum probability in current time step
// id with maximum probability in current time step
max_idx_vec
[
i
]
=
max_idx
;
max_idx_vec
[
i
]
=
max_idx
;
// deduplicate
// deduplicate
if
((
i
==
0
)
||
((
i
>
0
)
&&
max_idx_vec
[
i
]
!=
max_idx_vec
[
i
-
1
]))
{
if
((
i
==
0
)
||
((
i
>
0
)
&&
max_idx_vec
[
i
]
!=
max_idx_vec
[
i
-
1
]))
{
idx_vec
.
push_back
(
max_idx_vec
[
i
]);
idx_vec
.
push_back
(
max_idx_vec
[
i
]);
}
}
}
}
std
::
string
best_path_result
;
std
::
string
best_path_result
;
for
(
size_t
i
=
0
;
i
<
idx_vec
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
idx_vec
.
size
();
++
i
)
{
if
(
idx_vec
[
i
]
!=
blank_id
)
{
if
(
idx_vec
[
i
]
!=
blank_id
)
{
best_path_result
+=
vocabulary
[
idx_vec
[
i
]];
best_path_result
+=
vocabulary
[
idx_vec
[
i
]];
}
}
}
}
return
best_path_result
;
return
best_path_result
;
}
}
deepspeech/decoders/swig/decoder_utils.cpp
浏览文件 @
f842c79a
...
@@ -22,33 +22,35 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
...
@@ -22,33 +22,35 @@ std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const
std
::
vector
<
double
>
&
prob_step
,
const
std
::
vector
<
double
>
&
prob_step
,
double
cutoff_prob
,
double
cutoff_prob
,
size_t
cutoff_top_n
)
{
size_t
cutoff_top_n
)
{
std
::
vector
<
std
::
pair
<
int
,
double
>>
prob_idx
;
std
::
vector
<
std
::
pair
<
int
,
double
>>
prob_idx
;
for
(
size_t
i
=
0
;
i
<
prob_step
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
prob_step
.
size
();
++
i
)
{
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob_step
[
i
]));
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob_step
[
i
]));
}
// pruning of vacobulary
size_t
cutoff_len
=
prob_step
.
size
();
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
cutoff_len
)
{
std
::
sort
(
prob_idx
.
begin
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
if
(
cutoff_prob
<
1.0
)
{
double
cum_prob
=
0.0
;
cutoff_len
=
0
;
for
(
size_t
i
=
0
;
i
<
prob_idx
.
size
();
++
i
)
{
cum_prob
+=
prob_idx
[
i
].
second
;
cutoff_len
+=
1
;
if
(
cum_prob
>=
cutoff_prob
||
cutoff_len
>=
cutoff_top_n
)
break
;
}
}
}
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
// pruning of vacobulary
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
size_t
cutoff_len
=
prob_step
.
size
();
}
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
cutoff_len
)
{
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
log_prob_idx
;
std
::
sort
(
prob_idx
.
begin
(),
for
(
size_t
i
=
0
;
i
<
cutoff_len
;
++
i
)
{
prob_idx
.
end
(),
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
float
>
(
pair_comp_second_rev
<
int
,
double
>
);
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_FLT_MIN
)));
if
(
cutoff_prob
<
1.0
)
{
}
double
cum_prob
=
0.0
;
return
log_prob_idx
;
cutoff_len
=
0
;
for
(
size_t
i
=
0
;
i
<
prob_idx
.
size
();
++
i
)
{
cum_prob
+=
prob_idx
[
i
].
second
;
cutoff_len
+=
1
;
if
(
cum_prob
>=
cutoff_prob
||
cutoff_len
>=
cutoff_top_n
)
break
;
}
}
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
}
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
log_prob_idx
;
for
(
size_t
i
=
0
;
i
<
cutoff_len
;
++
i
)
{
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
float
>
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_FLT_MIN
)));
}
return
log_prob_idx
;
}
}
...
@@ -56,106 +58,106 @@ std::vector<std::pair<double, std::string>> get_beam_search_result(
...
@@ -56,106 +58,106 @@ std::vector<std::pair<double, std::string>> get_beam_search_result(
const
std
::
vector
<
PathTrie
*>
&
prefixes
,
const
std
::
vector
<
PathTrie
*>
&
prefixes
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
)
{
size_t
beam_size
)
{
// allow for the post processing
// allow for the post processing
std
::
vector
<
PathTrie
*>
space_prefixes
;
std
::
vector
<
PathTrie
*>
space_prefixes
;
if
(
space_prefixes
.
empty
())
{
if
(
space_prefixes
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
space_prefixes
.
push_back
(
prefixes
[
i
]);
space_prefixes
.
push_back
(
prefixes
[
i
]);
}
}
}
}
std
::
sort
(
space_prefixes
.
begin
(),
space_prefixes
.
end
(),
prefix_compare
);
std
::
sort
(
space_prefixes
.
begin
(),
space_prefixes
.
end
(),
prefix_compare
);
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
output_vecs
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
output_vecs
;
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
space_prefixes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
space_prefixes
.
size
();
++
i
)
{
std
::
vector
<
int
>
output
;
std
::
vector
<
int
>
output
;
space_prefixes
[
i
]
->
get_path_vec
(
output
);
space_prefixes
[
i
]
->
get_path_vec
(
output
);
// convert index to string
// convert index to string
std
::
string
output_str
;
std
::
string
output_str
;
for
(
size_t
j
=
0
;
j
<
output
.
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
output
.
size
();
j
++
)
{
output_str
+=
vocabulary
[
output
[
j
]];
output_str
+=
vocabulary
[
output
[
j
]];
}
std
::
pair
<
double
,
std
::
string
>
output_pair
(
-
space_prefixes
[
i
]
->
approx_ctc
,
output_str
);
output_vecs
.
emplace_back
(
output_pair
);
}
}
std
::
pair
<
double
,
std
::
string
>
output_pair
(
-
space_prefixes
[
i
]
->
approx_ctc
,
output_str
);
output_vecs
.
emplace_back
(
output_pair
);
}
return
output_vecs
;
return
output_vecs
;
}
}
size_t
get_utf8_str_len
(
const
std
::
string
&
str
)
{
size_t
get_utf8_str_len
(
const
std
::
string
&
str
)
{
size_t
str_len
=
0
;
size_t
str_len
=
0
;
for
(
char
c
:
str
)
{
for
(
char
c
:
str
)
{
str_len
+=
((
c
&
0xc0
)
!=
0x80
);
str_len
+=
((
c
&
0xc0
)
!=
0x80
);
}
}
return
str_len
;
return
str_len
;
}
}
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
)
{
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
vector
<
std
::
string
>
result
;
std
::
string
out_str
;
std
::
string
out_str
;
for
(
char
c
:
str
)
{
for
(
char
c
:
str
)
{
if
((
c
&
0xc0
)
!=
0x80
)
// new UTF-8 character
if
((
c
&
0xc0
)
!=
0x80
)
// new UTF-8 character
{
{
if
(
!
out_str
.
empty
())
{
if
(
!
out_str
.
empty
())
{
result
.
push_back
(
out_str
);
result
.
push_back
(
out_str
);
out_str
.
clear
();
out_str
.
clear
();
}
}
}
out_str
.
append
(
1
,
c
);
}
}
result
.
push_back
(
out_str
);
out_str
.
append
(
1
,
c
);
return
result
;
}
result
.
push_back
(
out_str
);
return
result
;
}
}
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
const
std
::
string
&
delim
)
{
const
std
::
string
&
delim
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
vector
<
std
::
string
>
result
;
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
while
(
true
)
{
while
(
true
)
{
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
if
(
end
==
std
::
string
::
npos
)
{
if
(
end
==
std
::
string
::
npos
)
{
if
(
start
<
s
.
size
())
{
if
(
start
<
s
.
size
())
{
result
.
push_back
(
s
.
substr
(
start
));
result
.
push_back
(
s
.
substr
(
start
));
}
}
break
;
break
;
}
}
if
(
end
>
start
)
{
if
(
end
>
start
)
{
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
}
start
=
end
+
delim_len
;
}
}
start
=
end
+
delim_len
;
return
result
;
}
return
result
;
}
}
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
if
(
x
->
score
==
y
->
score
)
{
if
(
x
->
score
==
y
->
score
)
{
if
(
x
->
character
==
y
->
character
)
{
if
(
x
->
character
==
y
->
character
)
{
return
false
;
return
false
;
}
else
{
return
(
x
->
character
<
y
->
character
);
}
}
else
{
}
else
{
return
(
x
->
character
<
y
->
character
)
;
return
x
->
score
>
y
->
score
;
}
}
}
else
{
return
x
->
score
>
y
->
score
;
}
}
}
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
fst
::
StdVectorFst
*
dictionary
)
{
fst
::
StdVectorFst
*
dictionary
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
assert
(
start
==
0
);
assert
(
start
==
0
);
dictionary
->
SetStart
(
start
);
dictionary
->
SetStart
(
start
);
}
}
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
dst
;
fst
::
StdVectorFst
::
StateId
dst
;
for
(
auto
c
:
word
)
{
for
(
auto
c
:
word
)
{
dst
=
dictionary
->
AddState
();
dst
=
dictionary
->
AddState
();
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
src
=
dst
;
src
=
dst
;
}
}
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
}
}
bool
add_word_to_dictionary
(
bool
add_word_to_dictionary
(
...
@@ -164,27 +166,27 @@ bool add_word_to_dictionary(
...
@@ -164,27 +166,27 @@ bool add_word_to_dictionary(
bool
add_space
,
bool
add_space
,
int
SPACE_ID
,
int
SPACE_ID
,
fst
::
StdVectorFst
*
dictionary
)
{
fst
::
StdVectorFst
*
dictionary
)
{
auto
characters
=
split_utf8_str
(
word
);
auto
characters
=
split_utf8_str
(
word
);
std
::
vector
<
int
>
int_word
;
std
::
vector
<
int
>
int_word
;
for
(
auto
&
c
:
characters
)
{
for
(
auto
&
c
:
characters
)
{
if
(
c
==
" "
)
{
if
(
c
==
" "
)
{
int_word
.
push_back
(
SPACE_ID
);
int_word
.
push_back
(
SPACE_ID
);
}
else
{
}
else
{
auto
int_c
=
char_map
.
find
(
c
);
auto
int_c
=
char_map
.
find
(
c
);
if
(
int_c
!=
char_map
.
end
())
{
if
(
int_c
!=
char_map
.
end
())
{
int_word
.
push_back
(
int_c
->
second
);
int_word
.
push_back
(
int_c
->
second
);
}
else
{
}
else
{
return
false
;
// return without adding
return
false
;
// return without adding
}
}
}
}
}
}
if
(
add_space
)
{
if
(
add_space
)
{
int_word
.
push_back
(
SPACE_ID
);
int_word
.
push_back
(
SPACE_ID
);
}
}
add_word_to_fst
(
int_word
,
dictionary
);
add_word_to_fst
(
int_word
,
dictionary
);
return
true
;
// return with successful adding
return
true
;
// return with successful adding
}
}
deepspeech/decoders/swig/decoder_utils.h
浏览文件 @
f842c79a
...
@@ -25,14 +25,14 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
...
@@ -25,14 +25,14 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// inline function for validation check
// inline function for validation check
inline
void
check
(
inline
void
check
(
bool
x
,
const
char
*
expr
,
const
char
*
file
,
int
line
,
const
char
*
err
)
{
bool
x
,
const
char
*
expr
,
const
char
*
file
,
int
line
,
const
char
*
err
)
{
if
(
!
x
)
{
if
(
!
x
)
{
std
::
cout
<<
"["
<<
file
<<
":"
<<
line
<<
"] "
;
std
::
cout
<<
"["
<<
file
<<
":"
<<
line
<<
"] "
;
LOG
(
FATAL
)
<<
"
\"
"
<<
expr
<<
"
\"
check failed. "
<<
err
;
LOG
(
FATAL
)
<<
"
\"
"
<<
expr
<<
"
\"
check failed. "
<<
err
;
}
}
}
}
#define VALID_CHECK(x, info) \
#define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
...
@@ -42,24 +42,24 @@ inline void check(
...
@@ -42,24 +42,24 @@ inline void check(
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
return
a
.
first
>
b
.
first
;
return
a
.
first
>
b
.
first
;
}
}
// Function template for comparing two pairs
// Function template for comparing two pairs
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
return
a
.
second
>
b
.
second
;
return
a
.
second
>
b
.
second
;
}
}
// Return the sum of two probabilities in log scale
// Return the sum of two probabilities in log scale
template
<
typename
T
>
template
<
typename
T
>
T
log_sum_exp
(
const
T
&
x
,
const
T
&
y
)
{
T
log_sum_exp
(
const
T
&
x
,
const
T
&
y
)
{
static
T
num_min
=
-
std
::
numeric_limits
<
T
>::
max
();
static
T
num_min
=
-
std
::
numeric_limits
<
T
>::
max
();
if
(
x
<=
num_min
)
return
y
;
if
(
x
<=
num_min
)
return
y
;
if
(
y
<=
num_min
)
return
x
;
if
(
y
<=
num_min
)
return
x
;
T
xmax
=
std
::
max
(
x
,
y
);
T
xmax
=
std
::
max
(
x
,
y
);
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
}
}
// Get pruned probability vector for each time step's beam search
// Get pruned probability vector for each time step's beam search
...
...
deepspeech/decoders/swig/path_trie.cpp
浏览文件 @
f842c79a
...
@@ -23,140 +23,141 @@
...
@@ -23,140 +23,141 @@
#include "decoder_utils.h"
#include "decoder_utils.h"
PathTrie
::
PathTrie
()
{
PathTrie
::
PathTrie
()
{
log_prob_b_prev
=
-
NUM_FLT_INF
;
log_prob_b_prev
=
-
NUM_FLT_INF
;
log_prob_nb_prev
=
-
NUM_FLT_INF
;
log_prob_nb_prev
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
ROOT_
=
-
1
;
ROOT_
=
-
1
;
character
=
ROOT_
;
character
=
ROOT_
;
exists_
=
true
;
exists_
=
true
;
parent
=
nullptr
;
parent
=
nullptr
;
dictionary_
=
nullptr
;
dictionary_
=
nullptr
;
dictionary_state_
=
0
;
dictionary_state_
=
0
;
has_dictionary_
=
false
;
has_dictionary_
=
false
;
matcher_
=
nullptr
;
matcher_
=
nullptr
;
}
}
PathTrie
::~
PathTrie
()
{
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
children_
)
{
for
(
auto
child
:
children_
)
{
delete
child
.
second
;
delete
child
.
second
;
}
}
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
children_
.
begin
();
auto
child
=
children_
.
begin
();
for
(
child
=
children_
.
begin
();
child
!=
children_
.
end
();
++
child
)
{
for
(
child
=
children_
.
begin
();
child
!=
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
if
(
child
->
first
==
new_char
)
{
break
;
break
;
}
}
}
if
(
child
!=
children_
.
end
())
{
if
(
!
child
->
second
->
exists_
)
{
child
->
second
->
exists_
=
true
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_cur
=
-
NUM_FLT_INF
;
}
}
return
(
child
->
second
);
if
(
child
!=
children_
.
end
())
{
}
else
{
if
(
!
child
->
second
->
exists_
)
{
if
(
has_dictionary_
)
{
child
->
second
->
exists_
=
true
;
matcher_
->
SetState
(
dictionary_state_
);
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
bool
found
=
matcher_
->
Find
(
new_char
+
1
);
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
if
(
!
found
)
{
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
// Adding this character causes word outside dictionary
child
->
second
->
log_prob_nb_cur
=
-
NUM_FLT_INF
;
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
dictionary_
->
Final
(
dictionary_state_
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
dictionary_state_
=
dictionary_
->
Start
();
}
}
return
nullptr
;
return
(
child
->
second
);
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
dictionary_
=
dictionary_
;
new_path
->
dictionary_state_
=
matcher_
->
Value
().
nextstate
;
new_path
->
has_dictionary_
=
true
;
new_path
->
matcher_
=
matcher_
;
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
if
(
has_dictionary_
)
{
new_path
->
character
=
new_char
;
matcher_
->
SetState
(
dictionary_state_
);
new_path
->
parent
=
this
;
bool
found
=
matcher_
->
Find
(
new_char
+
1
);
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
if
(
!
found
)
{
return
new_path
;
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
dictionary_
->
Final
(
dictionary_state_
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
dictionary_state_
=
dictionary_
->
Start
();
}
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
dictionary_
=
dictionary_
;
new_path
->
dictionary_state_
=
matcher_
->
Value
().
nextstate
;
new_path
->
has_dictionary_
=
true
;
new_path
->
matcher_
=
matcher_
;
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
}
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
ROOT_
);
return
get_path_vec
(
output
,
ROOT_
);
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
int
stop
,
size_t
max_steps
)
{
size_t
max_steps
)
{
if
(
character
==
stop
||
character
==
ROOT_
||
output
.
size
()
==
max_steps
)
{
if
(
character
==
stop
||
character
==
ROOT_
||
output
.
size
()
==
max_steps
)
{
std
::
reverse
(
output
.
begin
(),
output
.
end
());
std
::
reverse
(
output
.
begin
(),
output
.
end
());
return
this
;
return
this
;
}
else
{
}
else
{
output
.
push_back
(
character
);
output
.
push_back
(
character
);
return
parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
return
parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
}
}
}
}
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
exists_
)
{
if
(
exists_
)
{
log_prob_b_prev
=
log_prob_b_cur
;
log_prob_b_prev
=
log_prob_b_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
output
.
push_back
(
this
);
output
.
push_back
(
this
);
}
}
for
(
auto
child
:
children_
)
{
for
(
auto
child
:
children_
)
{
child
.
second
->
iterate_to_vec
(
output
);
child
.
second
->
iterate_to_vec
(
output
);
}
}
}
}
void
PathTrie
::
remove
()
{
void
PathTrie
::
remove
()
{
exists_
=
false
;
exists_
=
false
;
if
(
children_
.
size
()
==
0
)
{
if
(
children_
.
size
()
==
0
)
{
auto
child
=
parent
->
children_
.
begin
();
auto
child
=
parent
->
children_
.
begin
();
for
(
child
=
parent
->
children_
.
begin
();
child
!=
parent
->
children_
.
end
();
for
(
child
=
parent
->
children_
.
begin
();
++
child
)
{
child
!=
parent
->
children_
.
end
();
if
(
child
->
first
==
character
)
{
++
child
)
{
parent
->
children_
.
erase
(
child
);
if
(
child
->
first
==
character
)
{
break
;
parent
->
children_
.
erase
(
child
);
}
break
;
}
}
}
if
(
parent
->
children_
.
size
()
==
0
&&
!
parent
->
exists_
)
{
if
(
parent
->
children_
.
size
()
==
0
&&
!
parent
->
exists_
)
{
parent
->
remove
();
parent
->
remove
();
}
}
delete
this
;
delete
this
;
}
}
}
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
dictionary_
=
dictionary
;
dictionary_
=
dictionary
;
dictionary_state_
=
dictionary
->
Start
();
dictionary_state_
=
dictionary
->
Start
();
has_dictionary_
=
true
;
has_dictionary_
=
true
;
}
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
matcher_
=
matcher
;
matcher_
=
matcher
;
}
}
deepspeech/decoders/swig/path_trie.h
浏览文件 @
f842c79a
...
@@ -27,55 +27,56 @@
...
@@ -27,55 +27,56 @@
* finite-state transducer for spelling correction.
* finite-state transducer for spelling correction.
*/
*/
class
PathTrie
{
class
PathTrie
{
public:
public:
PathTrie
();
PathTrie
();
~
PathTrie
();
~
PathTrie
();
// get new prefix after appending new char
// get new prefix after appending new char
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
// get the prefix in index from root to current node
// get the prefix in index from root to current node
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
);
// get the prefix in index from some stop node to current nodel
// get the prefix in index from some stop node to current nodel
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
,
PathTrie
*
get_path_vec
(
int
stop
,
std
::
vector
<
int
>&
output
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
int
stop
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
// update log probs
// update log probs
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
);
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
);
// set dictionary for FST
// set dictionary for FST
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
void
set_matcher
(
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
);
void
set_matcher
(
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
);
bool
is_empty
()
{
return
ROOT_
==
character
;
}
bool
is_empty
()
{
return
ROOT_
==
character
;
}
// remove current path from root
// remove current path from root
void
remove
();
void
remove
();
float
log_prob_b_prev
;
float
log_prob_b_prev
;
float
log_prob_nb_prev
;
float
log_prob_nb_prev
;
float
log_prob_b_cur
;
float
log_prob_b_cur
;
float
log_prob_nb_cur
;
float
log_prob_nb_cur
;
float
score
;
float
score
;
float
approx_ctc
;
float
approx_ctc
;
int
character
;
int
character
;
PathTrie
*
parent
;
PathTrie
*
parent
;
private:
private:
int
ROOT_
;
int
ROOT_
;
bool
exists_
;
bool
exists_
;
bool
has_dictionary_
;
bool
has_dictionary_
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
children_
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
children_
;
// pointer to dictionary of FST
// pointer to dictionary of FST
fst
::
StdVectorFst
*
dictionary_
;
fst
::
StdVectorFst
*
dictionary_
;
fst
::
StdVectorFst
::
StateId
dictionary_state_
;
fst
::
StdVectorFst
::
StateId
dictionary_state_
;
// true if finding ars in FST
// true if finding ars in FST
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
matcher_
;
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
matcher_
;
};
};
#endif // PATH_TRIE_H
#endif // PATH_TRIE_H
deepspeech/decoders/swig/scorer.cpp
浏览文件 @
f842c79a
...
@@ -31,214 +31,214 @@ Scorer::Scorer(double alpha,
...
@@ -31,214 +31,214 @@ Scorer::Scorer(double alpha,
double
beta
,
double
beta
,
const
std
::
string
&
lm_path
,
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
this
->
alpha
=
alpha
;
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
this
->
beta
=
beta
;
dictionary
=
nullptr
;
dictionary
=
nullptr
;
is_character_based_
=
true
;
is_character_based_
=
true
;
language_model_
=
nullptr
;
language_model_
=
nullptr
;
max_order_
=
0
;
max_order_
=
0
;
dict_size_
=
0
;
dict_size_
=
0
;
SPACE_ID_
=
-
1
;
SPACE_ID_
=
-
1
;
setup
(
lm_path
,
vocab_list
);
setup
(
lm_path
,
vocab_list
);
}
}
Scorer
::~
Scorer
()
{
Scorer
::~
Scorer
()
{
if
(
language_model_
!=
nullptr
)
{
if
(
language_model_
!=
nullptr
)
{
delete
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
delete
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
}
}
if
(
dictionary
!=
nullptr
)
{
if
(
dictionary
!=
nullptr
)
{
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
}
}
}
}
void
Scorer
::
setup
(
const
std
::
string
&
lm_path
,
void
Scorer
::
setup
(
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
// load language model
// load language model
load_lm
(
lm_path
);
load_lm
(
lm_path
);
// set char map for scorer
// set char map for scorer
set_char_map
(
vocab_list
);
set_char_map
(
vocab_list
);
// fill the dictionary for FST
// fill the dictionary for FST
if
(
!
is_character_based
())
{
if
(
!
is_character_based
())
{
fill_dictionary
(
true
);
fill_dictionary
(
true
);
}
}
}
}
void
Scorer
::
load_lm
(
const
std
::
string
&
lm_path
)
{
void
Scorer
::
load_lm
(
const
std
::
string
&
lm_path
)
{
const
char
*
filename
=
lm_path
.
c_str
();
const
char
*
filename
=
lm_path
.
c_str
();
VALID_CHECK_EQ
(
access
(
filename
,
F_OK
),
0
,
"Invalid language model path"
);
VALID_CHECK_EQ
(
access
(
filename
,
F_OK
),
0
,
"Invalid language model path"
);
RetriveStrEnumerateVocab
enumerate
;
RetriveStrEnumerateVocab
enumerate
;
lm
::
ngram
::
Config
config
;
lm
::
ngram
::
Config
config
;
config
.
enumerate_vocab
=
&
enumerate
;
config
.
enumerate_vocab
=
&
enumerate
;
language_model_
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
language_model_
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
max_order_
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
)
->
Order
();
max_order_
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
)
->
Order
();
vocabulary_
=
enumerate
.
vocabulary
;
vocabulary_
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
vocabulary_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vocabulary_
.
size
();
++
i
)
{
if
(
is_character_based_
&&
vocabulary_
[
i
]
!=
UNK_TOKEN
&&
if
(
is_character_based_
&&
vocabulary_
[
i
]
!=
UNK_TOKEN
&&
vocabulary_
[
i
]
!=
START_TOKEN
&&
vocabulary_
[
i
]
!=
END_TOKEN
&&
vocabulary_
[
i
]
!=
START_TOKEN
&&
vocabulary_
[
i
]
!=
END_TOKEN
&&
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
is_character_based_
=
false
;
is_character_based_
=
false
;
}
}
}
}
}
}
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
double
cond_prob
;
double
cond_prob
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
// avoid to inserting <s> in begin
// avoid to inserting <s> in begin
model
->
NullContextWrite
(
&
state
);
model
->
NullContextWrite
(
&
state
);
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
lm
::
WordIndex
word_index
=
model
->
BaseVocabulary
().
Index
(
words
[
i
]);
lm
::
WordIndex
word_index
=
model
->
BaseVocabulary
().
Index
(
words
[
i
]);
// encounter OOV
// encounter OOV
if
(
word_index
==
0
)
{
if
(
word_index
==
0
)
{
return
OOV_SCORE
;
return
OOV_SCORE
;
}
cond_prob
=
model
->
BaseScore
(
&
state
,
word_index
,
&
out_state
);
tmp_state
=
state
;
state
=
out_state
;
out_state
=
tmp_state
;
}
}
cond_prob
=
model
->
BaseScore
(
&
state
,
word_index
,
&
out_state
);
// return log10 prob
tmp_state
=
state
;
return
cond_prob
;
state
=
out_state
;
out_state
=
tmp_state
;
}
// return log10 prob
return
cond_prob
;
}
}
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
std
::
vector
<
std
::
string
>
sentence
;
std
::
vector
<
std
::
string
>
sentence
;
if
(
words
.
size
()
==
0
)
{
if
(
words
.
size
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
sentence
.
push_back
(
START_TOKEN
);
}
}
}
else
{
}
else
{
for
(
size_t
i
=
0
;
i
<
max_order_
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
}
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
sentence
.
push_back
(
END_TOKEN
);
}
return
get_log_prob
(
sentence
);
sentence
.
push_back
(
END_TOKEN
);
return
get_log_prob
(
sentence
);
}
}
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
assert
(
words
.
size
()
>
max_order_
);
assert
(
words
.
size
()
>
max_order_
);
double
score
=
0.0
;
double
score
=
0.0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
max_order_
+
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
max_order_
+
1
;
++
i
)
{
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
words
.
begin
()
+
i
+
max_order_
);
words
.
begin
()
+
i
+
max_order_
);
score
+=
get_log_cond_prob
(
ngram
);
score
+=
get_log_cond_prob
(
ngram
);
}
}
return
score
;
return
score
;
}
}
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
this
->
alpha
=
alpha
;
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
this
->
beta
=
beta
;
}
}
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
word
;
std
::
string
word
;
for
(
auto
ind
:
input
)
{
for
(
auto
ind
:
input
)
{
word
+=
char_list_
[
ind
];
word
+=
char_list_
[
ind
];
}
}
return
word
;
return
word
;
}
}
std
::
vector
<
std
::
string
>
Scorer
::
split_labels
(
const
std
::
vector
<
int
>&
labels
)
{
std
::
vector
<
std
::
string
>
Scorer
::
split_labels
(
const
std
::
vector
<
int
>&
labels
)
{
if
(
labels
.
empty
())
return
{};
if
(
labels
.
empty
())
return
{};
std
::
string
s
=
vec2str
(
labels
);
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
std
::
vector
<
std
::
string
>
words
;
if
(
is_character_based_
)
{
if
(
is_character_based_
)
{
words
=
split_utf8_str
(
s
);
words
=
split_utf8_str
(
s
);
}
else
{
}
else
{
words
=
split_str
(
s
,
" "
);
words
=
split_str
(
s
,
" "
);
}
}
return
words
;
return
words
;
}
}
void
Scorer
::
set_char_map
(
const
std
::
vector
<
std
::
string
>&
char_list
)
{
void
Scorer
::
set_char_map
(
const
std
::
vector
<
std
::
string
>&
char_list
)
{
char_list_
=
char_list
;
char_list_
=
char_list
;
char_map_
.
clear
();
char_map_
.
clear
();
// Set the char map for the FST for spelling correction
// Set the char map for the FST for spelling correction
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
if
(
char_list_
[
i
]
==
" "
)
{
if
(
char_list_
[
i
]
==
" "
)
{
SPACE_ID_
=
i
;
SPACE_ID_
=
i
;
}
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_
[
char_list_
[
i
]]
=
i
+
1
;
}
}
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_
[
char_list_
[
i
]]
=
i
+
1
;
}
}
}
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
ngram
;
std
::
vector
<
std
::
string
>
ngram
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
PathTrie
*
new_node
=
nullptr
;
for
(
int
order
=
0
;
order
<
max_order_
;
order
++
)
{
for
(
int
order
=
0
;
order
<
max_order_
;
order
++
)
{
std
::
vector
<
int
>
prefix_vec
;
std
::
vector
<
int
>
prefix_vec
;
if
(
is_character_based_
)
{
if
(
is_character_based_
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
,
1
);
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
,
1
);
current_node
=
new_node
;
current_node
=
new_node
;
}
else
{
}
else
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
);
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
);
current_node
=
new_node
->
parent
;
// Skipping spaces
current_node
=
new_node
->
parent
;
// Skipping spaces
}
// reconstruct word
std
::
string
word
=
vec2str
(
prefix_vec
);
ngram
.
push_back
(
word
);
if
(
new_node
->
character
==
-
1
)
{
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
max_order_
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
}
break
;
}
}
}
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
// reconstruct word
return
ngram
;
std
::
string
word
=
vec2str
(
prefix_vec
);
ngram
.
push_back
(
word
);
if
(
new_node
->
character
==
-
1
)
{
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
max_order_
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
}
break
;
}
}
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
return
ngram
;
}
}
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
fst
::
StdVectorFst
dictionary
;
fst
::
StdVectorFst
dictionary
;
// For each unigram convert to ints and put in trie
// For each unigram convert to ints and put in trie
int
dict_size
=
0
;
int
dict_size
=
0
;
for
(
const
auto
&
word
:
vocabulary_
)
{
for
(
const
auto
&
word
:
vocabulary_
)
{
bool
added
=
add_word_to_dictionary
(
bool
added
=
add_word_to_dictionary
(
word
,
char_map_
,
add_space
,
SPACE_ID_
+
1
,
&
dictionary
);
word
,
char_map_
,
add_space
,
SPACE_ID_
+
1
,
&
dictionary
);
dict_size
+=
added
?
1
:
0
;
dict_size
+=
added
?
1
:
0
;
}
}
dict_size_
=
dict_size
;
dict_size_
=
dict_size
;
/* Simplify FST
/* Simplify FST
* This gets rid of "epsilon" transitions in the FST.
* This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken.
* These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST determinisitc, but
* Getting rid of them is necessary to make the FST determinisitc, but
* can greatly increase the size of the FST
* can greatly increase the size of the FST
*/
*/
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
/* This makes the FST deterministic, meaning for any string input there's
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
* (lest we'd have to check for multiple transitions at each state)
*/
*/
fst
::
Determinize
(
dictionary
,
new_dict
);
fst
::
Determinize
(
dictionary
,
new_dict
);
/* Finds the simplest equivalent fst. This is unnecessary but decreases
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
* memory usage of the dictionary
*/
*/
fst
::
Minimize
(
new_dict
);
fst
::
Minimize
(
new_dict
);
this
->
dictionary
=
new_dict
;
this
->
dictionary
=
new_dict
;
}
}
deepspeech/decoders/swig/scorer.h
浏览文件 @
f842c79a
...
@@ -34,14 +34,14 @@ const std::string END_TOKEN = "</s>";
...
@@ -34,14 +34,14 @@ const std::string END_TOKEN = "</s>";
// Implement a callback to retrive the dictionary of language model.
// Implement a callback to retrive the dictionary of language model.
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
public:
public:
RetriveStrEnumerateVocab
()
{}
RetriveStrEnumerateVocab
()
{}
void
Add
(
lm
::
WordIndex
index
,
const
StringPiece
&
str
)
{
void
Add
(
lm
::
WordIndex
index
,
const
StringPiece
&
str
)
{
vocabulary
.
push_back
(
std
::
string
(
str
.
data
(),
str
.
length
()));
vocabulary
.
push_back
(
std
::
string
(
str
.
data
(),
str
.
length
()));
}
}
std
::
vector
<
std
::
string
>
vocabulary
;
std
::
vector
<
std
::
string
>
vocabulary
;
};
};
/* External scorer to query score for n-gram or sentence, including language
/* External scorer to query score for n-gram or sentence, including language
...
@@ -53,74 +53,74 @@ public:
...
@@ -53,74 +53,74 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
*/
class
Scorer
{
class
Scorer
{
public:
public:
Scorer
(
double
alpha
,
Scorer
(
double
alpha
,
double
beta
,
double
beta
,
const
std
::
string
&
lm_path
,
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
);
const
std
::
vector
<
std
::
string
>
&
vocabulary
);
~
Scorer
();
~
Scorer
();
double
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
double
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
// return the max order
// return the max order
size_t
get_max_order
()
const
{
return
max_order_
;
}
size_t
get_max_order
()
const
{
return
max_order_
;
}
// return the dictionary size of language model
// return the dictionary size of language model
size_t
get_dict_size
()
const
{
return
dict_size_
;
}
size_t
get_dict_size
()
const
{
return
dict_size_
;
}
// retrun true if the language model is character based
// retrun true if the language model is character based
bool
is_character_based
()
const
{
return
is_character_based_
;
}
bool
is_character_based
()
const
{
return
is_character_based_
;
}
// reset params alpha & beta
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
void
reset_params
(
float
alpha
,
float
beta
);
// make ngram for a given prefix
// make ngram for a given prefix
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
// trransform the labels in index to the vector of words (word based lm) or
// trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm)
// the vector of characters (character based lm)
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
// language model weight
// language model weight
double
alpha
;
double
alpha
;
// word insertion weight
// word insertion weight
double
beta
;
double
beta
;
// pointer to the dictionary of FST
// pointer to the dictionary of FST
void
*
dictionary
;
void
*
dictionary
;
protected:
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
// necessary setup: load language model, set char map, fill FST's dictionary
void
setup
(
const
std
::
string
&
lm_path
,
void
setup
(
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>
&
vocab_list
);
const
std
::
vector
<
std
::
string
>
&
vocab_list
);
// load language model from given path
// load language model from given path
void
load_lm
(
const
std
::
string
&
lm_path
);
void
load_lm
(
const
std
::
string
&
lm_path
);
// fill dictionary for FST
// fill dictionary for FST
void
fill_dictionary
(
bool
add_space
);
void
fill_dictionary
(
bool
add_space
);
// set char map
// set char map
void
set_char_map
(
const
std
::
vector
<
std
::
string
>
&
char_list
);
void
set_char_map
(
const
std
::
vector
<
std
::
string
>
&
char_list
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
// translate the vector in index to string
// translate the vector in index to string
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
private:
private:
void
*
language_model_
;
void
*
language_model_
;
bool
is_character_based_
;
bool
is_character_based_
;
size_t
max_order_
;
size_t
max_order_
;
size_t
dict_size_
;
size_t
dict_size_
;
int
SPACE_ID_
;
int
SPACE_ID_
;
std
::
vector
<
std
::
string
>
char_list_
;
std
::
vector
<
std
::
string
>
char_list_
;
std
::
unordered_map
<
std
::
string
,
int
>
char_map_
;
std
::
unordered_map
<
std
::
string
,
int
>
char_map_
;
std
::
vector
<
std
::
string
>
vocabulary_
;
std
::
vector
<
std
::
string
>
vocabulary_
;
};
};
#endif // SCORER_H_
#endif // SCORER_H_
third_party/pymmseg-cpp/bin/pymmseg
浏览文件 @
f842c79a
#!/usr/bin/env python3
#!/usr/bin/env python3
import
sys
import
pstats
import
cProfile
import
cProfile
from
io
import
StringIO
import
getopt
import
getopt
import
os
import
os
from
os.path
import
dirname
,
join
import
pstats
import
sys
from
io
import
StringIO
from
os.path
import
dirname
from
os.path
import
join
import
mmseg
import
mmseg
...
...
third_party/python-pinyin/pinyin-data/CHANGELOG.md
浏览文件 @
f842c79a
...
@@ -94,7 +94,7 @@
...
@@ -94,7 +94,7 @@
*
Update to the latest version of
[
Unihan Database
](
http://www.unicode.org/charts/unihan.html
)
:
*
Update to the latest version of
[
Unihan Database
](
http://www.unicode.org/charts/unihan.html
)
:
> Date: 2016-06-01 07:01:48 GMT [JHJ]
> Date: 2016-06-01 07:01:48 GMT [JHJ]
> Unicode version: 9.0.0
> Unicode version: 9.0.0
...
...
third_party/python-pinyin/pinyin-data/README.md
浏览文件 @
f842c79a
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
[
Unihan Database
][
unihan
]
数据版本:
[
Unihan Database
][
unihan
]
数据版本:
> Date: 2020-02-18 18:27:33 GMT [JHJ]
> Date: 2020-02-18 18:27:33 GMT [JHJ]
> Unicode version: 13.0.0
> Unicode version: 13.0.0
*
`kTGHZ2013.txt`
:
[
Unihan Database
][
unihan
]
中
[
kTGHZ2013
](
http://www.unicode.org/reports/tr38/#kTGHZ2013
)
部分的拼音数据(来源于《通用规范汉字字典》的拼音数据)
*
`kTGHZ2013.txt`
:
[
Unihan Database
][
unihan
]
中
[
kTGHZ2013
](
http://www.unicode.org/reports/tr38/#kTGHZ2013
)
部分的拼音数据(来源于《通用规范汉字字典》的拼音数据)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录