Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5208b8e4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5208b8e4
编写于
9月 06, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format C++ source code
上级
a2ddfe8d
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
749 addition
and
779 deletion
+749
-779
deploy/ctc_decoders.cpp
deploy/ctc_decoders.cpp
+292
-300
deploy/ctc_decoders.h
deploy/ctc_decoders.h
+21
-23
deploy/decoder_utils.cpp
deploy/decoder_utils.cpp
+79
-81
deploy/decoder_utils.h
deploy/decoder_utils.h
+21
-23
deploy/path_trie.cpp
deploy/path_trie.cpp
+103
-106
deploy/path_trie.h
deploy/path_trie.h
+30
-32
deploy/scorer.cpp
deploy/scorer.cpp
+160
-171
deploy/scorer.h
deploy/scorer.h
+43
-43
未找到文件。
deploy/ctc_decoders.cpp
浏览文件 @
5208b8e4
#include <iostream>
#include "ctc_decoders.h"
#include <map>
#include <algorithm>
#include <algorithm>
#include <utility>
#include <cmath>
#include <cmath>
#include <iostream>
#include <limits>
#include <limits>
#include "fst/fstlib.h"
#include <map>
#include "ctc_decoders.h"
#include <utility>
#include "ThreadPool.h"
#include "decoder_utils.h"
#include "decoder_utils.h"
#include "fst/fstlib.h"
#include "path_trie.h"
#include "path_trie.h"
#include "ThreadPool.h"
std
::
string
ctc_best_path_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
std
::
string
ctc_best_path_decoder
(
std
::
vector
<
std
::
vector
<
double
>>
probs_seq
,
std
::
vector
<
std
::
string
>
vocabulary
)
std
::
vector
<
std
::
string
>
vocabulary
)
{
{
// dimension check
// dimension check
int
num_time_steps
=
probs_seq
.
size
();
int
num_time_steps
=
probs_seq
.
size
();
for
(
int
i
=
0
;
i
<
num_time_steps
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_time_steps
;
i
++
)
{
if
(
probs_seq
[
i
].
size
()
!=
vocabulary
.
size
()
+
1
)
{
if
(
probs_seq
[
i
].
size
()
!=
vocabulary
.
size
()
+
1
)
{
std
::
cout
<<
"The shape of probs_seq does not match"
std
::
cout
<<
"The shape of probs_seq does not match"
<<
" with the shape of the vocabulary!"
<<
std
::
endl
;
<<
" with the shape of the vocabulary!"
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
}
}
}
}
int
blank_id
=
vocabulary
.
size
();
int
blank_id
=
vocabulary
.
size
();
std
::
vector
<
int
>
max_idx_vec
;
double
max_prob
=
0.0
;
std
::
vector
<
int
>
max_idx_vec
;
int
max_idx
=
0
;
double
max_prob
=
0.0
;
for
(
int
i
=
0
;
i
<
num_time_steps
;
i
++
)
{
int
max_idx
=
0
;
for
(
int
j
=
0
;
j
<
probs_seq
[
i
].
size
();
j
++
)
{
for
(
int
i
=
0
;
i
<
num_time_steps
;
i
++
)
{
if
(
max_prob
<
probs_seq
[
i
][
j
])
{
for
(
int
j
=
0
;
j
<
probs_seq
[
i
].
size
();
j
++
)
{
max_idx
=
j
;
if
(
max_prob
<
probs_seq
[
i
][
j
])
{
max_prob
=
probs_seq
[
i
][
j
];
max_idx
=
j
;
}
max_prob
=
probs_seq
[
i
][
j
];
}
}
max_idx_vec
.
push_back
(
max_idx
);
max_prob
=
0.0
;
max_idx
=
0
;
}
}
max_idx_vec
.
push_back
(
max_idx
);
std
::
vector
<
int
>
idx_vec
;
max_prob
=
0.0
;
for
(
int
i
=
0
;
i
<
max_idx_vec
.
size
();
i
++
)
{
max_idx
=
0
;
if
((
i
==
0
)
||
((
i
>
0
)
&&
max_idx_vec
[
i
]
!=
max_idx_vec
[
i
-
1
]))
{
}
idx_vec
.
push_back
(
max_idx_vec
[
i
]);
}
std
::
vector
<
int
>
idx_vec
;
for
(
int
i
=
0
;
i
<
max_idx_vec
.
size
();
i
++
)
{
if
((
i
==
0
)
||
((
i
>
0
)
&&
max_idx_vec
[
i
]
!=
max_idx_vec
[
i
-
1
]))
{
idx_vec
.
push_back
(
max_idx_vec
[
i
]);
}
}
}
std
::
string
best_path_result
;
std
::
string
best_path_result
;
for
(
int
i
=
0
;
i
<
idx_vec
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
idx_vec
.
size
();
i
++
)
{
if
(
idx_vec
[
i
]
!=
blank_id
)
{
if
(
idx_vec
[
i
]
!=
blank_id
)
{
best_path_result
+=
vocabulary
[
idx_vec
[
i
]];
best_path_result
+=
vocabulary
[
idx_vec
[
i
]];
}
}
}
return
best_path_result
;
}
return
best_path_result
;
}
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
ctc_beam_search_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
std
::
vector
<
std
::
vector
<
double
>>
probs_seq
,
int
beam_size
,
int
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
blank_id
,
double
cutoff_prob
,
double
cutoff_prob
,
int
cutoff_top_n
,
int
cutoff_top_n
,
Scorer
*
ext_scorer
)
Scorer
*
extscorer
)
{
{
// dimension check
// dimension check
int
num_time_steps
=
probs_seq
.
size
();
int
num_time_steps
=
probs_seq
.
size
();
for
(
int
i
=
0
;
i
<
num_time_steps
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_time_steps
;
i
++
)
{
if
(
probs_seq
[
i
].
size
()
!=
vocabulary
.
size
()
+
1
)
{
if
(
probs_seq
[
i
].
size
()
!=
vocabulary
.
size
()
+
1
)
{
std
::
cout
<<
" The shape of probs_seq does not match"
std
::
cout
<<
" The shape of probs_seq does not match"
<<
" with the shape of the vocabulary!"
<<
std
::
endl
;
<<
" with the shape of the vocabulary!"
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
}
}
}
}
// blank_id check
if
(
blank_id
>
vocabulary
.
size
())
{
// blank_id check
std
::
cout
<<
" Invalid blank_id! "
<<
std
::
endl
;
if
(
blank_id
>
vocabulary
.
size
())
{
exit
(
1
);
std
::
cout
<<
" Invalid blank_id! "
<<
std
::
endl
;
exit
(
1
);
}
// assign space ID
std
::
vector
<
std
::
string
>::
iterator
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
// if no space in vocabulary
if
(
space_id
>=
vocabulary
.
size
())
{
space_id
=
-
2
;
}
// init prefixes' root
PathTrie
root
;
root
.
score
=
root
.
log_prob_b_prev
=
0.0
;
std
::
vector
<
PathTrie
*>
prefixes
;
prefixes
.
push_back
(
&
root
);
if
(
extscorer
!=
nullptr
)
{
if
(
extscorer
->
is_char_map_empty
())
{
extscorer
->
set_char_map
(
vocabulary
);
}
}
if
(
!
extscorer
->
is_character_based
())
{
// assign space ID
if
(
extscorer
->
dictionary
==
nullptr
)
{
std
::
vector
<
std
::
string
>::
iterator
it
=
std
::
find
(
vocabulary
.
begin
(),
// fill dictionary for fst
vocabulary
.
end
(),
" "
);
extscorer
->
fill_dictionary
(
true
);
int
space_id
=
it
-
vocabulary
.
begin
();
}
// if no space in vocabulary
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
extscorer
->
dictionary
);
if
(
space_id
>=
vocabulary
.
size
())
{
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
space_id
=
-
2
;
root
.
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
root
.
set_matcher
(
matcher
);
}
}
// prefix search over time
for
(
int
time_step
=
0
;
time_step
<
num_time_steps
;
time_step
++
)
{
std
::
vector
<
double
>
prob
=
probs_seq
[
time_step
];
std
::
vector
<
std
::
pair
<
int
,
double
>>
prob_idx
;
for
(
int
i
=
0
;
i
<
prob
.
size
();
i
++
)
{
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob
[
i
]));
}
}
// init prefixes' root
float
min_cutoff
=
-
NUM_FLT_INF
;
PathTrie
root
;
bool
full_beam
=
false
;
root
.
_score
=
root
.
_log_prob_b_prev
=
0.0
;
if
(
extscorer
!=
nullptr
)
{
std
::
vector
<
PathTrie
*>
prefixes
;
int
num_prefixes
=
std
::
min
((
int
)
prefixes
.
size
(),
beam_size
);
prefixes
.
push_back
(
&
root
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
min_cutoff
=
prefixes
[
num_prefixes
-
1
]
->
score
+
log
(
prob
[
blank_id
])
-
std
::
max
(
0.0
,
extscorer
->
beta
);
full_beam
=
(
num_prefixes
==
beam_size
);
}
if
(
ext_scorer
!=
nullptr
)
{
// pruning of vacobulary
if
(
ext_scorer
->
is_char_map_empty
())
{
int
cutoff_len
=
prob
.
size
();
ext_scorer
->
set_char_map
(
vocabulary
);
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
prob
.
size
())
{
}
std
::
sort
(
if
(
!
ext_scorer
->
is_character_based
())
{
prob_idx
.
begin
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
if
(
ext_scorer
->
dictionary
==
nullptr
)
{
if
(
cutoff_prob
<
1.0
)
{
// fill dictionary for fst
double
cum_prob
=
0.0
;
ext_scorer
->
fill_dictionary
(
true
);
cutoff_len
=
0
;
}
for
(
int
i
=
0
;
i
<
prob_idx
.
size
();
i
++
)
{
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
cum_prob
+=
prob_idx
[
i
].
second
;
(
ext_scorer
->
dictionary
);
cutoff_len
+=
1
;
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
if
(
cum_prob
>=
cutoff_prob
)
break
;
root
.
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
root
.
set_matcher
(
matcher
);
}
}
}
cutoff_len
=
std
::
min
(
cutoff_len
,
cutoff_top_n
);
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
}
std
::
vector
<
std
::
pair
<
int
,
float
>>
log_prob_idx
;
for
(
int
i
=
0
;
i
<
cutoff_len
;
i
++
)
{
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
float
>
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_FLT_MIN
)));
}
}
// prefix search over time
// loop over chars
for
(
int
time_step
=
0
;
time_step
<
num_time_steps
;
time_step
++
)
{
for
(
int
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
std
::
vector
<
double
>
prob
=
probs_seq
[
time_step
];
auto
c
=
log_prob_idx
[
index
].
first
;
std
::
vector
<
std
::
pair
<
int
,
double
>
>
prob_idx
;
float
log_prob_c
=
log_prob_idx
[
index
].
second
;
for
(
int
i
=
0
;
i
<
prob
.
size
();
i
++
)
{
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob
[
i
]));
}
float
min_cutoff
=
-
NUM_FLT_INF
;
for
(
int
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
i
++
)
{
bool
full_beam
=
false
;
auto
prefix
=
prefixes
[
i
];
if
(
ext_scorer
!=
nullptr
)
{
int
num_prefixes
=
std
::
min
((
int
)
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
min_cutoff
=
prefixes
[
num_prefixes
-
1
]
->
_score
+
log
(
prob
[
blank_id
])
-
std
::
max
(
0.0
,
ext_scorer
->
beta
);
full_beam
=
(
num_prefixes
==
beam_size
);
}
// pruning of vacobulary
int
cutoff_len
=
prob
.
size
();
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
prob
.
size
())
{
std
::
sort
(
prob_idx
.
begin
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
if
(
cutoff_prob
<
1.0
)
{
double
cum_prob
=
0.0
;
cutoff_len
=
0
;
for
(
int
i
=
0
;
i
<
prob_idx
.
size
();
i
++
)
{
cum_prob
+=
prob_idx
[
i
].
second
;
cutoff_len
+=
1
;
if
(
cum_prob
>=
cutoff_prob
)
break
;
}
}
cutoff_len
=
std
::
min
(
cutoff_len
,
cutoff_top_n
);
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>
>
(
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
}
std
::
vector
<
std
::
pair
<
int
,
float
>
>
log_prob_idx
;
for
(
int
i
=
0
;
i
<
cutoff_len
;
i
++
)
{
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
float
>
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_FLT_MIN
)));
}
// loop over chars
if
(
full_beam
&&
log_prob_c
+
prefix
->
score
<
min_cutoff
)
{
for
(
int
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
break
;
auto
c
=
log_prob_idx
[
index
].
first
;
float
log_prob_c
=
log_prob_idx
[
index
].
second
;
for
(
int
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
i
++
)
{
auto
prefix
=
prefixes
[
i
];
if
(
full_beam
&&
log_prob_c
+
prefix
->
_score
<
min_cutoff
)
{
break
;
}
// blank
if
(
c
==
blank_id
)
{
prefix
->
_log_prob_b_cur
=
log_sum_exp
(
prefix
->
_log_prob_b_cur
,
log_prob_c
+
prefix
->
_score
);
continue
;
}
// repeated character
if
(
c
==
prefix
->
_character
)
{
prefix
->
_log_prob_nb_cur
=
log_sum_exp
(
prefix
->
_log_prob_nb_cur
,
log_prob_c
+
prefix
->
_log_prob_nb_prev
);
}
// get new prefix
auto
prefix_new
=
prefix
->
get_path_trie
(
c
);
if
(
prefix_new
!=
nullptr
)
{
float
log_p
=
-
NUM_FLT_INF
;
if
(
c
==
prefix
->
_character
&&
prefix
->
_log_prob_b_prev
>
-
NUM_FLT_INF
)
{
log_p
=
log_prob_c
+
prefix
->
_log_prob_b_prev
;
}
else
if
(
c
!=
prefix
->
_character
)
{
log_p
=
log_prob_c
+
prefix
->
_score
;
}
// language model scoring
if
(
ext_scorer
!=
nullptr
&&
(
c
==
space_id
||
ext_scorer
->
is_character_based
())
)
{
PathTrie
*
prefix_to_score
=
nullptr
;
// skip scoring the space
if
(
ext_scorer
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
}
else
{
prefix_to_score
=
prefix
;
}
double
score
=
0.0
;
std
::
vector
<
std
::
string
>
ngram
;
ngram
=
ext_scorer
->
make_ngram
(
prefix_to_score
);
score
=
ext_scorer
->
get_log_cond_prob
(
ngram
)
*
ext_scorer
->
alpha
;
log_p
+=
score
;
log_p
+=
ext_scorer
->
beta
;
}
prefix_new
->
_log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
_log_prob_nb_cur
,
log_p
);
}
}
// end of loop over prefix
}
// end of loop over chars
prefixes
.
clear
();
// update log probs
root
.
iterate_to_vec
(
prefixes
);
// only preserve top beam_size prefixes
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
end
(),
prefix_compare
);
for
(
size_t
i
=
beam_size
;
i
<
prefixes
.
size
();
i
++
)
{
prefixes
[
i
]
->
remove
();
}
}
}
}
// end of loop over time
// blank
if
(
c
==
blank_id
)
{
// compute aproximate ctc score as the return score
prefix
->
log_prob_b_cur
=
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
double
approx_ctc
=
prefixes
[
i
]
->
_score
;
continue
;
}
if
(
ext_scorer
!=
nullptr
)
{
// repeated character
std
::
vector
<
int
>
output
;
if
(
c
==
prefix
->
character
)
{
prefixes
[
i
]
->
get_path_vec
(
output
);
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
size_t
prefix_length
=
output
.
size
();
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
auto
words
=
ext_scorer
->
split_labels
(
output
);
// remove word insert
approx_ctc
=
approx_ctc
-
prefix_length
*
ext_scorer
->
beta
;
// remove language model weight:
approx_ctc
-=
(
ext_scorer
->
get_sent_log_prob
(
words
))
*
ext_scorer
->
alpha
;
}
}
// get new prefix
auto
prefix_new
=
prefix
->
get_path_trie
(
c
);
if
(
prefix_new
!=
nullptr
)
{
float
log_p
=
-
NUM_FLT_INF
;
if
(
c
==
prefix
->
character
&&
prefix
->
log_prob_b_prev
>
-
NUM_FLT_INF
)
{
log_p
=
log_prob_c
+
prefix
->
log_prob_b_prev
;
}
else
if
(
c
!=
prefix
->
character
)
{
log_p
=
log_prob_c
+
prefix
->
score
;
}
// language model scoring
if
(
extscorer
!=
nullptr
&&
(
c
==
space_id
||
extscorer
->
is_character_based
()))
{
PathTrie
*
prefix_toscore
=
nullptr
;
// skip scoring the space
if
(
extscorer
->
is_character_based
())
{
prefix_toscore
=
prefix_new
;
}
else
{
prefix_toscore
=
prefix
;
}
prefixes
[
i
]
->
_approx_ctc
=
approx_ctc
;
double
score
=
0.0
;
}
std
::
vector
<
std
::
string
>
ngram
;
ngram
=
extscorer
->
make_ngram
(
prefix_toscore
);
score
=
extscorer
->
get_log_cond_prob
(
ngram
)
*
extscorer
->
alpha
;
// allow for the post processing
log_p
+=
score
;
std
::
vector
<
PathTrie
*>
space_prefixes
;
log_p
+=
extscorer
->
beta
;
if
(
space_prefixes
.
empty
())
{
}
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
prefix_new
->
log_prob_nb_cur
=
space_prefixes
.
push_back
(
prefixes
[
i
]
);
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
}
}
}
// end of loop over prefix
}
// end of loop over chars
prefixes
.
clear
();
// update log probs
root
.
iterate_to_vec
(
prefixes
);
// only preserve top beam_size prefixes
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
end
(),
prefix_compare
);
for
(
size_t
i
=
beam_size
;
i
<
prefixes
.
size
();
i
++
)
{
prefixes
[
i
]
->
remove
();
}
}
}
}
// end of loop over time
std
::
sort
(
space_prefixes
.
begin
(),
space_prefixes
.
end
(),
prefix_compare
);
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
output_vecs
;
// compute aproximate ctc score as the return score
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
space_prefixes
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
std
::
vector
<
int
>
output
;
double
approx_ctc
=
prefixes
[
i
]
->
score
;
space_prefixes
[
i
]
->
get_path_vec
(
output
);
// convert index to string
if
(
extscorer
!=
nullptr
)
{
std
::
string
output_str
;
std
::
vector
<
int
>
output
;
for
(
int
j
=
0
;
j
<
output
.
size
();
j
++
)
{
prefixes
[
i
]
->
get_path_vec
(
output
);
output_str
+=
vocabulary
[
output
[
j
]];
size_t
prefix_length
=
output
.
size
();
}
auto
words
=
extscorer
->
split_labels
(
output
);
std
::
pair
<
double
,
std
::
string
>
// remove word insert
output_pair
(
-
space_prefixes
[
i
]
->
_approx_ctc
,
output_str
);
approx_ctc
=
approx_ctc
-
prefix_length
*
extscorer
->
beta
;
output_vecs
.
emplace_back
(
output_pair
);
// remove language model weight:
approx_ctc
-=
(
extscorer
->
get_sent_log_prob
(
words
))
*
extscorer
->
alpha
;
}
}
return
output_vecs
;
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
}
}
// allow for the post processing
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
>
std
::
vector
<
PathTrie
*>
space_prefixes
;
ctc_beam_search_decoder_batch
(
if
(
space_prefixes
.
empty
())
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
probs_split
,
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
int
beam_size
,
space_prefixes
.
push_back
(
prefixes
[
i
]);
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
num_processes
,
double
cutoff_prob
,
int
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
if
(
num_processes
<=
0
)
{
std
::
cout
<<
"num_processes must be nonnegative!"
<<
std
::
endl
;
exit
(
1
);
}
}
// thread pool
}
ThreadPool
pool
(
num_processes
);
// number of samples
std
::
sort
(
space_prefixes
.
begin
(),
space_prefixes
.
end
(),
prefix_compare
);
int
batch_size
=
probs_split
.
size
();
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
output_vecs
;
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
space_prefixes
.
size
();
i
++
)
{
// scorer filling up
std
::
vector
<
int
>
output
;
if
(
ext_scorer
!=
nullptr
)
{
space_prefixes
[
i
]
->
get_path_vec
(
output
);
if
(
ext_scorer
->
is_char_map_empty
())
{
// convert index to string
ext_scorer
->
set_char_map
(
vocabulary
);
std
::
string
output_str
;
}
for
(
int
j
=
0
;
j
<
output
.
size
();
j
++
)
{
if
(
!
ext_scorer
->
is_character_based
()
output_str
+=
vocabulary
[
output
[
j
]];
&&
ext_scorer
->
dictionary
==
nullptr
)
{
// init dictionary
ext_scorer
->
fill_dictionary
(
true
);
}
}
}
std
::
pair
<
double
,
std
::
string
>
output_pair
(
-
space_prefixes
[
i
]
->
approx_ctc
,
output_str
);
output_vecs
.
emplace_back
(
output_pair
);
}
// enqueue the tasks of decoding
return
output_vecs
;
std
::
vector
<
std
::
future
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>>
res
;
}
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
cutoff_top_n
,
ext_scorer
)
);
}
// get decoding results
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
>
batch_results
;
ctc_beam_search_decoder_batch
(
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
probs_split
,
batch_results
.
emplace_back
(
res
[
i
].
get
());
int
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
num_processes
,
double
cutoff_prob
,
int
cutoff_top_n
,
Scorer
*
extscorer
)
{
if
(
num_processes
<=
0
)
{
std
::
cout
<<
"num_processes must be nonnegative!"
<<
std
::
endl
;
exit
(
1
);
}
// thread pool
ThreadPool
pool
(
num_processes
);
// number of samples
int
batch_size
=
probs_split
.
size
();
// scorer filling up
if
(
extscorer
!=
nullptr
)
{
if
(
extscorer
->
is_char_map_empty
())
{
extscorer
->
set_char_map
(
vocabulary
);
}
if
(
!
extscorer
->
is_character_based
()
&&
extscorer
->
dictionary
==
nullptr
)
{
// init dictionary
extscorer
->
fill_dictionary
(
true
);
}
}
return
batch_results
;
}
// enqueue the tasks of decoding
std
::
vector
<
std
::
future
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>>
res
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
cutoff_top_n
,
extscorer
));
}
// get decoding results
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
batch_results
;
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
batch_results
.
emplace_back
(
res
[
i
].
get
());
}
return
batch_results
;
}
}
deploy/ctc_decoders.h
浏览文件 @
5208b8e4
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <vector>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include "scorer.h"
#include "scorer.h"
/* CTC Best Path Decoder
/* CTC Best Path Decoder
...
@@ -16,8 +16,8 @@
...
@@ -16,8 +16,8 @@
* A vector that each element is a pair of score and decoding result,
* A vector that each element is a pair of score and decoding result,
* in desending order.
* in desending order.
*/
*/
std
::
string
ctc_best_path_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
std
::
string
ctc_best_path_decoder
(
std
::
vector
<
std
::
vector
<
double
>>
probs_seq
,
std
::
vector
<
std
::
string
>
vocabulary
);
std
::
vector
<
std
::
string
>
vocabulary
);
/* CTC Beam Search Decoder
/* CTC Beam Search Decoder
...
@@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
...
@@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
* A vector that each element is a pair of score and decoding result,
* A vector that each element is a pair of score and decoding result,
* in desending order.
* in desending order.
*/
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
ctc_beam_search_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
std
::
vector
<
std
::
vector
<
double
>>
probs_seq
,
int
beam_size
,
int
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
blank_id
,
double
cutoff_prob
=
1
.
0
,
double
cutoff_prob
=
1
.
0
,
int
cutoff_top_n
=
40
,
int
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
NULL
Scorer
*
ext_scorer
=
NULL
);
);
/* CTC Beam Search Decoder for batch data, the interface is consistent with the
/* CTC Beam Search Decoder for batch data, the interface is consistent with the
* original decoder in Python version.
* original decoder in Python version.
...
@@ -63,15 +62,14 @@ std::vector<std::pair<double, std::string> >
...
@@ -63,15 +62,14 @@ std::vector<std::pair<double, std::string> >
* sample.
* sample.
*/
*/
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
probs_split
,
ctc_beam_search_decoder_batch
(
int
beam_size
,
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
probs_split
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
beam_size
,
int
blank_id
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
num_processes
,
int
blank_id
,
double
cutoff_prob
=
1
.
0
,
int
num_processes
,
int
cutoff_top_n
=
40
,
double
cutoff_prob
=
1
.
0
,
Scorer
*
ext_scorer
=
NULL
int
cutoff_top_n
=
40
,
);
Scorer
*
ext_scorer
=
NULL
);
#endif // CTC_BEAM_SEARCH_DECODER_H_
#endif
// CTC_BEAM_SEARCH_DECODER_H_
deploy/decoder_utils.cpp
浏览文件 @
5208b8e4
#include
<limits>
#include
"decoder_utils.h"
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include
"decoder_utils.h"
#include
<limits>
size_t
get_utf8_str_len
(
const
std
::
string
&
str
)
{
size_t
get_utf8_str_len
(
const
std
::
string
&
str
)
{
size_t
str_len
=
0
;
size_t
str_len
=
0
;
for
(
char
c
:
str
)
{
for
(
char
c
:
str
)
{
str_len
+=
((
c
&
0xc0
)
!=
0x80
);
str_len
+=
((
c
&
0xc0
)
!=
0x80
);
}
}
return
str_len
;
return
str_len
;
}
}
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
)
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
)
{
{
std
::
vector
<
std
::
string
>
result
;
std
::
vector
<
std
::
string
>
result
;
std
::
string
out_str
;
std
::
string
out_str
;
for
(
char
c
:
str
)
for
(
char
c
:
str
)
{
if
((
c
&
0xc0
)
!=
0x80
)
// new UTF-8 character
{
{
if
((
c
&
0xc0
)
!=
0x80
)
//new UTF-8 character
if
(
!
out_str
.
empty
())
{
{
result
.
push_back
(
out_str
);
if
(
!
out_str
.
empty
())
out_str
.
clear
();
{
}
result
.
push_back
(
out_str
);
out_str
.
clear
();
}
}
out_str
.
append
(
1
,
c
);
}
}
out_str
.
append
(
1
,
c
);
}
result
.
push_back
(
out_str
);
result
.
push_back
(
out_str
);
return
result
;
return
result
;
}
}
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
const
std
::
string
&
delim
)
{
const
std
::
string
&
delim
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
vector
<
std
::
string
>
result
;
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
while
(
true
)
{
while
(
true
)
{
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
if
(
end
==
std
::
string
::
npos
)
{
if
(
end
==
std
::
string
::
npos
)
{
if
(
start
<
s
.
size
())
{
if
(
start
<
s
.
size
())
{
result
.
push_back
(
s
.
substr
(
start
));
result
.
push_back
(
s
.
substr
(
start
));
}
}
break
;
break
;
}
}
if
(
end
>
start
)
{
if
(
end
>
start
)
{
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
}
start
=
end
+
delim_len
;
}
}
return
result
;
start
=
end
+
delim_len
;
}
return
result
;
}
}
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
if
(
x
->
_score
==
y
->
_score
)
{
if
(
x
->
score
==
y
->
score
)
{
if
(
x
->
_character
==
y
->
_character
)
{
if
(
x
->
character
==
y
->
character
)
{
return
false
;
return
false
;
}
else
{
return
(
x
->
_character
<
y
->
_character
);
}
}
else
{
}
else
{
return
x
->
_score
>
y
->
_score
;
return
(
x
->
character
<
y
->
character
)
;
}
}
}
else
{
return
x
->
score
>
y
->
score
;
}
}
}
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
fst
::
StdVectorFst
*
dictionary
)
{
fst
::
StdVectorFst
*
dictionary
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
assert
(
start
==
0
);
assert
(
start
==
0
);
dictionary
->
SetStart
(
start
);
dictionary
->
SetStart
(
start
);
}
}
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
dst
;
fst
::
StdVectorFst
::
StateId
dst
;
for
(
auto
c
:
word
)
{
for
(
auto
c
:
word
)
{
dst
=
dictionary
->
AddState
();
dst
=
dictionary
->
AddState
();
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
src
=
dst
;
src
=
dst
;
}
}
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
}
}
bool
add_word_to_dictionary
(
const
std
::
string
&
word
,
bool
add_word_to_dictionary
(
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
const
std
::
string
&
word
,
bool
add_space
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
int
SPACE_ID
,
bool
add_space
,
fst
::
StdVectorFst
*
dictionary
)
{
int
SPACE_ID
,
auto
characters
=
split_utf8_str
(
word
);
fst
::
StdVectorFst
*
dictionary
)
{
auto
characters
=
split_utf8_str
(
word
);
std
::
vector
<
int
>
int_word
;
std
::
vector
<
int
>
int_word
;
for
(
auto
&
c
:
characters
)
{
for
(
auto
&
c
:
characters
)
{
if
(
c
==
" "
)
{
if
(
c
==
" "
)
{
int_word
.
push_back
(
SPACE_ID
);
int_word
.
push_back
(
SPACE_ID
);
}
else
{
}
else
{
auto
int_c
=
char_map
.
find
(
c
);
auto
int_c
=
char_map
.
find
(
c
);
if
(
int_c
!=
char_map
.
end
())
{
if
(
int_c
!=
char_map
.
end
())
{
int_word
.
push_back
(
int_c
->
second
);
int_word
.
push_back
(
int_c
->
second
);
}
else
{
}
else
{
return
false
;
// return without adding
return
false
;
// return without adding
}
}
}
}
}
}
if
(
add_space
)
{
if
(
add_space
)
{
int_word
.
push_back
(
SPACE_ID
);
int_word
.
push_back
(
SPACE_ID
);
}
}
add_word_to_fst
(
int_word
,
dictionary
);
add_word_to_fst
(
int_word
,
dictionary
);
return
true
;
return
true
;
}
}
deploy/decoder_utils.h
浏览文件 @
5208b8e4
...
@@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
...
@@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// Function template for comparing two pairs
// Function template for comparing two pairs
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
{
return
a
.
first
>
b
.
first
;
return
a
.
first
>
b
.
first
;
}
}
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
{
return
a
.
second
>
b
.
second
;
return
a
.
second
>
b
.
second
;
}
}
template
<
typename
T
>
template
<
typename
T
>
T
log_sum_exp
(
const
T
&
x
,
const
T
&
y
)
T
log_sum_exp
(
const
T
&
x
,
const
T
&
y
)
{
{
static
T
num_min
=
-
std
::
numeric_limits
<
T
>::
max
();
static
T
num_min
=
-
std
::
numeric_limits
<
T
>::
max
();
if
(
x
<=
num_min
)
return
y
;
if
(
x
<=
num_min
)
return
y
;
if
(
y
<=
num_min
)
return
x
;
if
(
y
<=
num_min
)
return
x
;
T
xmax
=
std
::
max
(
x
,
y
);
T
xmax
=
std
::
max
(
x
,
y
);
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
}
}
// Functor for prefix comparsion
// Functor for prefix comparsion
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
);
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
);
// Get length of utf8 encoding string
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
// See: http://stackoverflow.com/a/4063229
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
// Split a string into a list of strings on a given string
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// delimiter. NB: delimiters on beginning / end of string are
...
@@ -50,13 +47,14 @@ std::vector<std::string> split_str(const std::string &s,
...
@@ -50,13 +47,14 @@ std::vector<std::string> split_str(const std::string &s,
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
);
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
);
// Add a word in index to the dicionary of fst
// Add a word in index to the dicionary of fst
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
fst
::
StdVectorFst
*
dictionary
);
fst
::
StdVectorFst
*
dictionary
);
// Add a word in string to dictionary
// Add a word in string to dictionary
bool
add_word_to_dictionary
(
const
std
::
string
&
word
,
bool
add_word_to_dictionary
(
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
const
std
::
string
&
word
,
bool
add_space
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
char_map
,
int
SPACE_ID
,
bool
add_space
,
fst
::
StdVectorFst
*
dictionary
);
int
SPACE_ID
,
#endif // DECODER_UTILS_H
fst
::
StdVectorFst
*
dictionary
);
#endif // DECODER_UTILS_H
deploy/path_trie.cpp
浏览文件 @
5208b8e4
...
@@ -4,145 +4,142 @@
...
@@ -4,145 +4,142 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "path_trie.h"
#include "decoder_utils.h"
#include "decoder_utils.h"
#include "path_trie.h"
PathTrie
::
PathTrie
()
{
PathTrie
::
PathTrie
()
{
_
log_prob_b_prev
=
-
NUM_FLT_INF
;
log_prob_b_prev
=
-
NUM_FLT_INF
;
_
log_prob_nb_prev
=
-
NUM_FLT_INF
;
log_prob_nb_prev
=
-
NUM_FLT_INF
;
_
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
_
log_prob_nb_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
_
score
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
_ROOT
=
-
1
;
_ROOT
=
-
1
;
_
character
=
_ROOT
;
character
=
_ROOT
;
_exists
=
true
;
_exists
=
true
;
_
parent
=
nullptr
;
parent
=
nullptr
;
_dictionary
=
nullptr
;
_dictionary
=
nullptr
;
_dictionary_state
=
0
;
_dictionary_state
=
0
;
_has_dictionary
=
false
;
_has_dictionary
=
false
;
_matcher
=
nullptr
;
// finds arcs in FST
_matcher
=
nullptr
;
// finds arcs in FST
}
}
PathTrie
::~
PathTrie
()
{
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
_children
)
{
delete
child
.
second
;
delete
child
.
second
;
}
}
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
_children
.
begin
();
auto
child
=
_children
.
begin
();
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
if
(
child
->
first
==
new_char
)
{
break
;
break
;
}
}
}
if
(
child
!=
_children
.
end
()
)
{
}
if
(
!
child
->
second
->
_exists
)
{
if
(
child
!=
_children
.
end
())
{
child
->
second
->
_exists
=
true
;
if
(
!
child
->
second
->
_exists
)
{
child
->
second
->
_log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
_exists
=
true
;
child
->
second
->
_log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
_log_prob_b_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
_log_prob_nb_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_cur
=
-
NUM_FLT_INF
;
}
return
(
child
->
second
);
}
else
{
if
(
_has_dictionary
)
{
_matcher
->
SetState
(
_dictionary_state
);
bool
found
=
_matcher
->
Find
(
new_char
);
if
(
!
found
)
{
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
}
}
return
(
child
->
second
);
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
_matcher
=
_matcher
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
}
else
{
if
(
_has_dictionary
)
{
PathTrie
*
new_path
=
new
PathTrie
;
_matcher
->
SetState
(
_dictionary_state
);
new_path
->
character
=
new_char
;
bool
found
=
_matcher
->
Find
(
new_char
);
new_path
->
parent
=
this
;
if
(
!
found
)
{
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
// Adding this character causes word outside dictionary
return
new_path
;
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
}
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
_character
=
new_char
;
new_path
->
_parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
_matcher
=
_matcher
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
_character
=
new_char
;
new_path
->
_parent
=
this
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
}
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
_ROOT
);
return
get_path_vec
(
output
,
_ROOT
);
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
int
stop
,
size_t
max_steps
)
{
size_t
max_steps
)
{
if
(
_character
==
stop
||
if
(
character
==
stop
||
character
==
_ROOT
||
output
.
size
()
==
max_steps
)
{
_character
==
_ROOT
||
std
::
reverse
(
output
.
begin
(),
output
.
end
());
output
.
size
()
==
max_steps
)
{
return
this
;
std
::
reverse
(
output
.
begin
(),
output
.
end
());
}
else
{
return
this
;
output
.
push_back
(
character
);
}
else
{
return
parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
output
.
push_back
(
_character
);
}
return
_parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
}
}
}
void
PathTrie
::
iterate_to_vec
(
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
_exists
)
{
if
(
_exists
)
{
log_prob_b_prev
=
log_prob_b_cur
;
_log_prob_b_prev
=
_log_prob_b_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
_log_prob_nb_prev
=
_log_prob_nb_cur
;
_
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
_
log_prob_nb_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
_score
=
log_sum_exp
(
_log_prob_b_prev
,
_
log_prob_nb_prev
);
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
output
.
push_back
(
this
);
output
.
push_back
(
this
);
}
}
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
_children
)
{
child
.
second
->
iterate_to_vec
(
output
);
child
.
second
->
iterate_to_vec
(
output
);
}
}
}
}
void
PathTrie
::
remove
()
{
void
PathTrie
::
remove
()
{
_exists
=
false
;
_exists
=
false
;
if
(
_children
.
size
()
==
0
)
{
if
(
_children
.
size
()
==
0
)
{
auto
child
=
_parent
->
_children
.
begin
();
auto
child
=
parent
->
_children
.
begin
();
for
(
child
=
_parent
->
_children
.
begin
();
for
(
child
=
parent
->
_children
.
begin
();
child
!=
parent
->
_children
.
end
();
child
!=
_parent
->
_children
.
end
();
++
child
)
{
++
child
)
{
if
(
child
->
first
==
_character
)
{
if
(
child
->
first
==
character
)
{
_parent
->
_children
.
erase
(
child
);
parent
->
_children
.
erase
(
child
);
break
;
break
;
}
}
}
}
if
(
_parent
->
_children
.
size
()
==
0
&&
!
_parent
->
_exists
)
{
_parent
->
remove
();
}
delete
this
;
if
(
parent
->
_children
.
size
()
==
0
&&
!
parent
->
_exists
)
{
parent
->
remove
();
}
}
delete
this
;
}
}
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
_dictionary
=
dictionary
;
_dictionary
=
dictionary
;
_dictionary_state
=
dictionary
->
Start
();
_dictionary_state
=
dictionary
->
Start
();
_has_dictionary
=
true
;
_has_dictionary
=
true
;
}
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
_matcher
=
matcher
;
_matcher
=
matcher
;
}
}
deploy/path_trie.h
浏览文件 @
5208b8e4
#ifndef PATH_TRIE_H
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#define PATH_TRIE_H
#pragma once
#pragma once
#include <fst/fstlib.h>
#include <algorithm>
#include <algorithm>
#include <limits>
#include <limits>
#include <memory>
#include <memory>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <fst/fstlib.h>
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
class
PathTrie
{
class
PathTrie
{
public:
public:
PathTrie
();
PathTrie
();
~
PathTrie
();
~
PathTrie
();
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>
&
output
);
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
,
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
);
int
stop
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>
&
output
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
);
void
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
);
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
bool
is_empty
()
{
void
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
);
return
_ROOT
==
_character
;
}
void
remove
();
bool
is_empty
()
{
return
_ROOT
==
character
;
}
float
_log_prob_b_prev
;
void
remove
();
float
_log_prob_nb_prev
;
float
_log_prob_b_cur
;
float
_log_prob_nb_cur
;
float
_score
;
float
_approx_ctc
;
float
log_prob_b_prev
;
float
log_prob_nb_prev
;
float
log_prob_b_cur
;
float
log_prob_nb_cur
;
float
score
;
float
approx_ctc
;
int
character
;
PathTrie
*
parent
;
int
_ROOT
;
private:
int
_character
;
int
_ROOT
;
bool
_exists
;
bool
_exists
;
PathTrie
*
_parent
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
_children
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>
>
_children
;
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
bool
_has_dictionary
;
bool
_has_dictionary
;
std
::
shared_ptr
<
FSTMATCH
>
_matcher
;
std
::
shared_ptr
<
FSTMATCH
>
_matcher
;
};
};
#endif // PATH_TRIE_H
#endif
// PATH_TRIE_H
deploy/scorer.cpp
浏览文件 @
5208b8e4
#include
<iostream>
#include
"scorer.h"
#include <unistd.h>
#include <unistd.h>
#include <iostream>
#include "decoder_utils.h"
#include "lm/config.hh"
#include "lm/config.hh"
#include "lm/state.hh"
#include "lm/model.hh"
#include "lm/model.hh"
#include "
util/tokenize_piec
e.hh"
#include "
lm/stat
e.hh"
#include "util/string_piece.hh"
#include "util/string_piece.hh"
#include "scorer.h"
#include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using
namespace
lm
::
ngram
;
using
namespace
lm
::
ngram
;
Scorer
::
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
)
{
Scorer
::
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
)
{
this
->
alpha
=
alpha
;
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
this
->
beta
=
beta
;
_is_character_based
=
true
;
_is_character_based
=
true
;
_language_model
=
nullptr
;
_language_model
=
nullptr
;
dictionary
=
nullptr
;
dictionary
=
nullptr
;
_max_order
=
0
;
_max_order
=
0
;
_SPACE_ID
=
-
1
;
_SPACE_ID
=
-
1
;
// load language model
// load language model
load_LM
(
lm_path
.
c_str
());
load_LM
(
lm_path
.
c_str
());
}
}
Scorer
::~
Scorer
()
{
Scorer
::~
Scorer
()
{
if
(
_language_model
!=
nullptr
)
if
(
_language_model
!=
nullptr
)
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
if
(
dictionary
!=
nullptr
)
if
(
dictionary
!=
nullptr
)
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
}
}
void
Scorer
::
load_LM
(
const
char
*
filename
)
{
void
Scorer
::
load_LM
(
const
char
*
filename
)
{
if
(
access
(
filename
,
F_OK
)
!=
0
)
{
if
(
access
(
filename
,
F_OK
)
!=
0
)
{
std
::
cerr
<<
"Invalid language model file !!!"
<<
std
::
endl
;
std
::
cerr
<<
"Invalid language model file !!!"
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
}
}
RetriveStrEnumerateVocab
enumerate
;
RetriveStrEnumerateVocab
enumerate
;
lm
::
ngram
::
Config
config
;
lm
::
ngram
::
Config
config
;
config
.
enumerate_vocab
=
&
enumerate
;
config
.
enumerate_vocab
=
&
enumerate
;
_language_model
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
_language_model
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
_max_order
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
)
->
Order
();
_max_order
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
)
->
Order
();
_vocabulary
=
enumerate
.
vocabulary
;
_vocabulary
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
_vocabulary
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_vocabulary
.
size
();
++
i
)
{
if
(
_is_character_based
if
(
_is_character_based
&&
_vocabulary
[
i
]
!=
UNK_TOKEN
&&
&&
_vocabulary
[
i
]
!=
UNK_TOKEN
_vocabulary
[
i
]
!=
START_TOKEN
&&
_vocabulary
[
i
]
!=
END_TOKEN
&&
&&
_vocabulary
[
i
]
!=
START_TOKEN
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
&&
_vocabulary
[
i
]
!=
END_TOKEN
_is_character_based
=
false
;
&&
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
_is_character_based
=
false
;
}
}
}
}
}
}
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
double
cond_prob
;
double
cond_prob
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
// avoid to inserting <s> in begin
// avoid to inserting <s> in begin
model
->
NullContextWrite
(
&
state
);
model
->
NullContextWrite
(
&
state
);
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
lm
::
WordIndex
word_index
=
model
->
BaseVocabulary
().
Index
(
words
[
i
]);
lm
::
WordIndex
word_index
=
model
->
BaseVocabulary
().
Index
(
words
[
i
]);
// encounter OOV
// encounter OOV
if
(
word_index
==
0
)
{
if
(
word_index
==
0
)
{
return
OOV_SCORE
;
return
OOV_SCORE
;
}
cond_prob
=
model
->
BaseScore
(
&
state
,
word_index
,
&
out_state
);
tmp_state
=
state
;
state
=
out_state
;
out_state
=
tmp_state
;
}
}
// log10 prob
cond_prob
=
model
->
BaseScore
(
&
state
,
word_index
,
&
out_state
);
return
cond_prob
;
tmp_state
=
state
;
state
=
out_state
;
out_state
=
tmp_state
;
}
// log10 prob
return
cond_prob
;
}
}
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
std
::
vector
<
std
::
string
>
sentence
;
std
::
vector
<
std
::
string
>
sentence
;
if
(
words
.
size
()
==
0
)
{
if
(
words
.
size
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
_max_order
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_max_order
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
sentence
.
push_back
(
START_TOKEN
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
_max_order
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
}
}
sentence
.
push_back
(
END_TOKEN
);
}
else
{
return
get_log_prob
(
sentence
);
for
(
size_t
i
=
0
;
i
<
_max_order
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
}
sentence
.
push_back
(
END_TOKEN
);
return
get_log_prob
(
sentence
);
}
}
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
assert
(
words
.
size
()
>
_max_order
);
assert
(
words
.
size
()
>
_max_order
);
double
score
=
0.0
;
double
score
=
0.0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
_max_order
+
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
_max_order
+
1
;
++
i
)
{
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
words
.
begin
()
+
i
+
_max_order
);
words
.
begin
()
+
i
+
_max_order
);
score
+=
get_log_cond_prob
(
ngram
);
score
+=
get_log_cond_prob
(
ngram
);
}
}
return
score
;
return
score
;
}
}
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
this
->
alpha
=
alpha
;
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
this
->
beta
=
beta
;
}
}
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
word
;
std
::
string
word
;
for
(
auto
ind
:
input
)
{
for
(
auto
ind
:
input
)
{
word
+=
_char_list
[
ind
];
word
+=
_char_list
[
ind
];
}
}
return
word
;
return
word
;
}
}
std
::
vector
<
std
::
string
>
std
::
vector
<
std
::
string
>
Scorer
::
split_labels
(
const
std
::
vector
<
int
>&
labels
)
{
Scorer
::
split_labels
(
const
std
::
vector
<
int
>
&
labels
)
{
if
(
labels
.
empty
())
return
{};
if
(
labels
.
empty
())
return
{};
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
std
::
string
s
=
vec2str
(
labels
);
if
(
_is_character_based
)
{
std
::
vector
<
std
::
string
>
words
;
words
=
split_utf8_str
(
s
);
if
(
_is_character_based
)
{
}
else
{
words
=
split_utf8_str
(
s
);
words
=
split_str
(
s
,
" "
);
}
else
{
}
words
=
split_str
(
s
,
" "
);
return
words
;
}
return
words
;
}
}
void
Scorer
::
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
)
{
void
Scorer
::
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
)
{
_char_list
=
char_list
;
_char_list
=
char_list
;
_char_map
.
clear
();
_char_map
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
{
if
(
_char_list
[
i
]
==
" "
)
{
if
(
_char_list
[
i
]
==
" "
)
{
_SPACE_ID
=
i
;
_SPACE_ID
=
i
;
_char_map
[
' '
]
=
i
;
_char_map
[
' '
]
=
i
;
}
else
if
(
_char_list
[
i
].
size
()
==
1
)
{
}
else
if
(
_char_list
[
i
].
size
()
==
1
){
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
}
}
}
}
}
}
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
ngram
;
std
::
vector
<
std
::
string
>
ngram
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
PathTrie
*
new_node
=
nullptr
;
for
(
int
order
=
0
;
order
<
_max_order
;
order
++
)
{
std
::
vector
<
int
>
prefix_vec
;
if
(
_is_character_based
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
,
1
);
current_node
=
new_node
;
}
else
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
);
current_node
=
new_node
->
_parent
;
// Skipping spaces
}
// reconstruct word
std
::
string
word
=
vec2str
(
prefix_vec
);
ngram
.
push_back
(
word
);
if
(
new_node
->
_character
==
-
1
)
{
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
_max_order
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
}
break
;
}
}
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
return
ngram
;
}
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
fst
::
StdVectorFst
dictionary
;
for
(
int
order
=
0
;
order
<
_max_order
;
order
++
)
{
// First reverse char_list so ints can be accessed by chars
std
::
vector
<
int
>
prefix_vec
;
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
if
(
_is_character_based
)
{
int
vocab_size
=
0
;
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
,
1
);
for
(
const
auto
&
word
:
_vocabulary
)
{
current_node
=
new_node
;
bool
added
=
add_word_to_dictionary
(
word
,
}
else
{
char_map
,
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
);
add_space
,
current_node
=
new_node
->
parent
;
// Skipping spaces
_SPACE_ID
,
&
dictionary
);
vocab_size
+=
added
?
1
:
0
;
}
}
std
::
cerr
<<
"Vocab Size "
<<
vocab_size
<<
std
::
endl
;
// reconstruct word
std
::
string
word
=
vec2str
(
prefix_vec
);
// Simplify FST
ngram
.
push_back
(
word
);
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
// This makes the FST deterministic, meaning for any string input there's
if
(
new_node
->
character
==
-
1
)
{
// only one possible state the FST could be in. It is assumed our
// No more spaces, but still need order
// dictionary is deterministic when using it.
for
(
int
i
=
0
;
i
<
_max_order
-
order
-
1
;
i
++
)
{
// (lest we'd have to check for multiple transitions at each state)
ngram
.
push_back
(
START_TOKEN
);
fst
::
Determinize
(
dictionary
,
new_dict
);
}
break
;
// Finds the simplest equivalent fst. This is unnecessary but decreases
}
// memory usage of the dictionary
}
fst
::
Minimize
(
new_dict
);
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
this
->
dictionary
=
new_dict
;
return
ngram
;
}
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
fst
::
StdVectorFst
dictionary
;
// First reverse char_list so ints can be accessed by chars
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
int
vocab_size
=
0
;
for
(
const
auto
&
word
:
_vocabulary
)
{
bool
added
=
add_word_to_dictionary
(
word
,
char_map
,
add_space
,
_SPACE_ID
,
&
dictionary
);
vocab_size
+=
added
?
1
:
0
;
}
std
::
cerr
<<
"Vocab Size "
<<
vocab_size
<<
std
::
endl
;
// Simplify FST
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
// This makes the FST deterministic, meaning for any string input there's
// only one possible state the FST could be in. It is assumed our
// dictionary is deterministic when using it.
// (lest we'd have to check for multiple transitions at each state)
fst
::
Determinize
(
dictionary
,
new_dict
);
// Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary
fst
::
Minimize
(
new_dict
);
this
->
dictionary
=
new_dict
;
}
}
deploy/scorer.h
浏览文件 @
5208b8e4
#ifndef SCORER_H_
#ifndef SCORER_H_
#define SCORER_H_
#define SCORER_H_
#include <string>
#include <memory>
#include <memory>
#include <
vector
>
#include <
string
>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh"
#include "lm/virtual_interface.hh"
#include "
util/string_piece
.hh"
#include "
lm/word_index
.hh"
#include "path_trie.h"
#include "path_trie.h"
#include "util/string_piece.hh"
const
double
OOV_SCORE
=
-
1000.0
;
const
double
OOV_SCORE
=
-
1000.0
;
const
std
::
string
START_TOKEN
=
"<s>"
;
const
std
::
string
START_TOKEN
=
"<s>"
;
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
// Implement a callback to retrive string vocabulary.
// Implement a callback to retrive string vocabulary.
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
public:
public:
RetriveStrEnumerateVocab
()
{}
RetriveStrEnumerateVocab
()
{}
void
Add
(
lm
::
WordIndex
index
,
const
StringPiece
&
str
)
{
void
Add
(
lm
::
WordIndex
index
,
const
StringPiece
&
str
)
{
vocabulary
.
push_back
(
std
::
string
(
str
.
data
(),
str
.
length
()));
vocabulary
.
push_back
(
std
::
string
(
str
.
data
(),
str
.
length
()));
}
}
std
::
vector
<
std
::
string
>
vocabulary
;
std
::
vector
<
std
::
string
>
vocabulary
;
};
};
// External scorer to query languange score for n-gram or sentence.
// External scorer to query languange score for n-gram or sentence.
...
@@ -33,59 +33,59 @@ public:
...
@@ -33,59 +33,59 @@ public:
// Scorer scorer(alpha, beta, "path_of_language_model");
// Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class
Scorer
{
class
Scorer
{
public:
public:
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
);
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
);
~
Scorer
();
~
Scorer
();
double
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
size_t
get_max_order
()
{
return
_max_order
;
}
size_t
get_max_order
()
{
return
_max_order
;
}
bool
is_char_map_empty
()
{
return
_char_map
.
size
()
==
0
;
}
bool
is_char_map_empty
()
{
return
_char_map
.
size
()
==
0
;
}
bool
is_character_based
()
{
return
_is_character_based
;
}
bool
is_character_based
()
{
return
_is_character_based
;
}
// reset params alpha & beta
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
void
reset_params
(
float
alpha
,
float
beta
);
// make ngram
// make ngram
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
// fill dictionary for fst
// fill dictionary for fst
void
fill_dictionary
(
bool
add_space
);
void
fill_dictionary
(
bool
add_space
);
// set char map
// set char map
void
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
);
void
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>&
labels
);
// expose to decoder
// expose to decoder
double
alpha
;
double
alpha
;
double
beta
;
double
beta
;
// fst dictionary
// fst dictionary
void
*
dictionary
;
void
*
dictionary
;
protected:
protected:
void
load_LM
(
const
char
*
filename
);
void
load_LM
(
const
char
*
filename
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>&
input
);
private:
private:
void
*
_language_model
;
void
*
_language_model
;
bool
_is_character_based
;
bool
_is_character_based
;
size_t
_max_order
;
size_t
_max_order
;
int
_SPACE_ID
;
int
_SPACE_ID
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
std
::
vector
<
std
::
string
>
_vocabulary
;
std
::
vector
<
std
::
string
>
_vocabulary
;
};
};
#endif // SCORER_H_
#endif
// SCORER_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录