Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
b5602054
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b5602054
编写于
8月 24, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert data structure for prefix from map to trie tree
上级
eef364d1
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
492 addition
and
115 deletion
+492
-115
deploy.py
deploy.py
+5
-4
deploy/ctc_decoders.cpp
deploy/ctc_decoders.cpp
+139
-111
deploy/decoder_utils.cpp
deploy/decoder_utils.cpp
+70
-0
deploy/decoder_utils.h
deploy/decoder_utils.h
+14
-0
deploy/path_trie.cpp
deploy/path_trie.cpp
+153
-0
deploy/path_trie.h
deploy/path_trie.h
+59
-0
deploy/scorer.cpp
deploy/scorer.cpp
+39
-0
deploy/scorer.h
deploy/scorer.h
+13
-0
未找到文件。
deploy.py
浏览文件 @
b5602054
...
...
@@ -18,7 +18,7 @@ import time
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--num_samples"
,
default
=
32
,
default
=
5
,
type
=
int
,
help
=
"Number of samples for inference. (default: %(default)s)"
)
parser
.
add_argument
(
...
...
@@ -79,7 +79,7 @@ parser.add_argument(
"(default: %(default)s)"
)
parser
.
add_argument
(
"--beam_size"
,
default
=
20
0
,
default
=
20
,
type
=
int
,
help
=
"Width for beam search decoding. (default: %(default)d)"
)
parser
.
add_argument
(
...
...
@@ -104,7 +104,7 @@ parser.add_argument(
help
=
"Parameter associated with word count. (default: %(default)f)"
)
parser
.
add_argument
(
"--cutoff_prob"
,
default
=
0.99
,
default
=
1.0
,
type
=
float
,
help
=
"The cutoff probability of pruning"
"in beam search. (default: %(default)f)"
)
...
...
@@ -183,7 +183,8 @@ def infer():
vocabulary
=
data_generator
.
vocab_list
,
blank_id
=
len
(
data_generator
.
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
ext_scoring_func
=
ext_scorer
,
)
# ext_scoring_func=ext_scorer,
)
batch_beam_results
+=
[
beam_result
]
else
:
batch_beam_results
=
ctc_beam_search_decoder_batch
(
...
...
deploy/ctc_decoders.cpp
浏览文件 @
b5602054
...
...
@@ -4,11 +4,13 @@
#include <utility>
#include <cmath>
#include <limits>
#include "fst/fstlib.h"
#include "ctc_decoders.h"
#include "decoder_utils.h"
#include "path_trie.h"
#include "ThreadPool.h"
typedef
double
log_prob_type
;
typedef
float
log_prob_type
;
std
::
string
ctc_best_path_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
std
::
vector
<
std
::
string
>
vocabulary
)
...
...
@@ -89,24 +91,30 @@ std::vector<std::pair<double, std::string> >
exit
(
1
);
}
// initialize
// two sets containing selected and candidate prefixes respectively
std
::
map
<
std
::
string
,
log_prob_type
>
prefix_set_prev
,
prefix_set_next
;
// probability of prefixes ending with blank and non-blank
std
::
map
<
std
::
string
,
log_prob_type
>
log_probs_b_prev
,
log_probs_nb_prev
;
std
::
map
<
std
::
string
,
log_prob_type
>
log_probs_b_cur
,
log_probs_nb_cur
;
static
log_prob_type
POS_INF
=
std
::
numeric_limits
<
log_prob_type
>::
max
();
static
log_prob_type
NEG_INF
=
-
POS_INF
;
static
log_prob_type
NUM_MIN
=
std
::
numeric_limits
<
log_prob_type
>::
min
();
static
log_prob_type
NUM_MAX
=
std
::
numeric_limits
<
log_prob_type
>::
max
();
prefix_set_prev
[
"
\t
"
]
=
0.0
;
log_probs_b_prev
[
"
\t
"
]
=
0.0
;
log_probs_nb_prev
[
"
\t
"
]
=
-
NUM_MAX
;
// init
PathTrie
root
;
root
.
_log_prob_b_prev
=
0.0
;
root
.
_score
=
0.0
;
std
::
vector
<
PathTrie
*>
prefixes
;
prefixes
.
push_back
(
&
root
);
for
(
int
time_step
=
0
;
time_step
<
num_time_steps
;
time_step
++
)
{
prefix_set_next
.
clear
();
log_probs_b_cur
.
clear
();
log_probs_nb_cur
.
clear
();
std
::
vector
<
double
>
prob
=
probs_seq
[
time_step
];
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
if
(
ext_scorer
->
dictionary
==
nullptr
)
{
// TODO: init dictionary
}
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
ext_scorer
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
.
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
root
.
set_matcher
(
matcher
);
}
for
(
int
time_step
=
0
;
time_step
<
num_time_steps
;
time_step
++
)
{
std
::
vector
<
double
>
prob
=
probs_seq
[
time_step
];
std
::
vector
<
std
::
pair
<
int
,
double
>
>
prob_idx
;
for
(
int
i
=
0
;
i
<
prob
.
size
();
i
++
)
{
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob
[
i
]));
...
...
@@ -132,113 +140,134 @@ std::vector<std::pair<double, std::string> >
std
::
vector
<
std
::
pair
<
int
,
log_prob_type
>
>
log_prob_idx
;
for
(
int
i
=
0
;
i
<
cutoff_len
;
i
++
)
{
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
log_prob_type
>
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
)));
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_MIN
)));
}
// extend prefix
for
(
std
::
map
<
std
::
string
,
log_prob_type
>::
iterator
it
=
prefix_set_prev
.
begin
();
it
!=
prefix_set_prev
.
end
();
it
++
)
{
std
::
string
l
=
it
->
first
;
if
(
prefix_set_next
.
find
(
l
)
==
prefix_set_next
.
end
())
{
log_probs_b_cur
[
l
]
=
log_probs_nb_cur
[
l
]
=
-
NUM_MAX
;
}
for
(
int
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
int
c
=
log_prob_idx
[
index
].
first
;
// loop over chars
for
(
int
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
auto
c
=
log_prob_idx
[
index
].
first
;
log_prob_type
log_prob_c
=
log_prob_idx
[
index
].
second
;
log_prob_type
log_probs_prev
;
//log_prob_type log_probs_prev;
for
(
int
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
i
++
)
{
auto
prefix
=
prefixes
[
i
];
// blank
if
(
c
==
blank_id
)
{
log_probs_prev
=
log_sum_exp
(
log_probs_b_prev
[
l
],
log_probs_nb_prev
[
l
]);
log_probs_b_cur
[
l
]
=
log_sum_exp
(
log_probs_b_cur
[
l
],
log_prob_c
+
log_probs_prev
);
}
else
{
std
::
string
last_char
=
l
.
substr
(
l
.
size
()
-
1
,
1
);
std
::
string
new_char
=
vocabulary
[
c
];
std
::
string
l_plus
=
l
+
new_char
;
if
(
prefix_set_next
.
find
(
l_plus
)
==
prefix_set_next
.
end
())
{
log_probs_b_cur
[
l_plus
]
=
-
NUM_MAX
;
log_probs_nb_cur
[
l_plus
]
=
-
NUM_MAX
;
}
if
(
last_char
==
new_char
)
{
log_probs_nb_cur
[
l_plus
]
=
log_sum_exp
(
log_probs_nb_cur
[
l_plus
],
log_prob_c
+
log_probs_b_prev
[
l
]
);
log_probs_nb_cur
[
l
]
=
log_sum_exp
(
log_probs_nb_cur
[
l
],
log_prob_c
+
log_probs_nb_prev
[
l
]
);
}
else
if
(
new_char
==
" "
)
{
float
score
=
0.0
;
if
(
ext_scorer
!=
NULL
&&
l
.
size
()
>
1
)
{
score
=
ext_scorer
->
get_score
(
l
.
substr
(
1
),
true
);
}
log_probs_prev
=
log_sum_exp
(
log_probs_b_prev
[
l
],
log_probs_nb_prev
[
l
]);
log_probs_nb_cur
[
l_plus
]
=
log_sum_exp
(
log_probs_nb_cur
[
l_plus
],
score
+
log_prob_c
+
log_probs_prev
prefix
->
_log_prob_b_cur
=
log_sum_exp
(
prefix
->
_log_prob_b_cur
,
log_prob_c
+
prefix
->
_score
);
continue
;
}
// repeated character
if
(
c
==
prefix
->
_character
)
{
prefix
->
_log_prob_nb_cur
=
log_sum_exp
(
prefix
->
_log_prob_nb_cur
,
log_prob_c
+
prefix
->
_log_prob_nb_prev
);
}
// get new prefix
auto
prefix_new
=
prefix
->
get_path_trie
(
c
);
if
(
prefix_new
!=
nullptr
)
{
float
log_p
=
NEG_INF
;
if
(
c
==
prefix
->
_character
&&
prefix
->
_log_prob_b_prev
>
NEG_INF
)
{
log_p
=
log_prob_c
+
prefix
->
_log_prob_b_prev
;
}
else
if
(
c
!=
prefix
->
_character
)
{
log_p
=
log_prob_c
+
prefix
->
_score
;
}
// language model scoring
if
(
ext_scorer
!=
nullptr
&&
(
c
==
space_id
||
ext_scorer
->
is_character_based
())
)
{
PathTrie
*
prefix_to_score
=
nullptr
;
// don't score the space
if
(
ext_scorer
->
is_character_based
())
{
prefix_to_score
=
prefix_new
;
}
else
{
log_probs_prev
=
log_sum_exp
(
log_probs_b_prev
[
l
],
log_probs_nb_prev
[
l
]);
log_probs_nb_cur
[
l_plus
]
=
log_sum_exp
(
log_probs_nb_cur
[
l_plus
],
log_prob_c
+
log_probs_prev
);
prefix_to_score
=
prefix
;
}
double
score
=
0.0
;
std
::
vector
<
std
::
string
>
ngram
;
ngram
=
ext_scorer
->
make_ngram
(
prefix_to_score
);
score
=
ext_scorer
->
get_log_cond_prob
(
ngram
)
*
ext_scorer
->
alpha
;
log_p
+=
score
;
log_p
+=
ext_scorer
->
beta
;
}
prefix_new
->
_log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
_log_prob_nb_cur
,
log_p
);
}
}
}
// end of loop over chars
prefixes
.
clear
();
// update log probabilities
root
.
iterate_to_vec
(
prefixes
);
// sort prefixes by score
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
end
(),
prefix_compare
);
for
(
size_t
i
=
beam_size
;
i
<
prefixes
.
size
();
i
++
)
{
prefixes
[
i
]
->
remove
();
}
prefix_set_next
[
l_plus
]
=
log_sum_exp
(
log_probs_nb_cur
[
l_plus
],
log_probs_b_cur
[
l_plus
]
);
}
}
prefix_set_next
[
l
]
=
log_sum_exp
(
log_probs_b_cur
[
l
],
log_probs_nb_cur
[
l
]);
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
double
approx_ctc
=
prefixes
[
i
]
->
_score
;
// remove word insert:
std
::
vector
<
int
>
output
;
prefixes
[
i
]
->
get_path_vec
(
output
);
size_t
prefix_length
=
output
.
size
();
// remove language model weight:
if
(
ext_scorer
!=
nullptr
)
{
// auto words = split_labels(output);
// approx_ctc = approx_ctc - path_length * ext_scorer->beta;
// approx_ctc -= (_lm->get_sent_log_prob(words)) * ext_scorer->alpha;
}
log_probs_b_prev
=
log_probs_b_cur
;
log_probs_nb_prev
=
log_probs_nb_cur
;
std
::
vector
<
std
::
pair
<
std
::
string
,
log_prob_type
>
>
prefix_vec_next
(
prefix_set_next
.
begin
(),
prefix_set_next
.
end
());
std
::
sort
(
prefix_vec_next
.
begin
(),
prefix_vec_next
.
end
(),
pair_comp_second_rev
<
std
::
string
,
log_prob_type
>
);
int
num_prefixes_next
=
prefix_vec_next
.
size
();
int
k
=
beam_size
<
num_prefixes_next
?
beam_size
:
num_prefixes_next
;
prefix_set_prev
=
std
::
map
<
std
::
string
,
log_prob_type
>
(
prefix_vec_next
.
begin
(),
prefix_vec_next
.
begin
()
+
k
);
prefixes
[
i
]
->
_approx_ctc
=
approx_ctc
;
}
// post processing
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
beam_result
;
for
(
std
::
map
<
std
::
string
,
log_prob_type
>::
iterator
it
=
prefix_set_prev
.
begin
();
it
!=
prefix_set_prev
.
end
();
it
++
)
{
if
(
it
->
second
>
-
NUM_MAX
&&
it
->
first
.
size
()
>
1
)
{
log_prob_type
log_prob
=
it
->
second
;
std
::
string
sentence
=
it
->
first
.
substr
(
1
);
// scoring the last word
if
(
ext_scorer
!=
NULL
&&
sentence
[
sentence
.
size
()
-
1
]
!=
' '
)
{
log_prob
=
log_prob
+
ext_scorer
->
get_score
(
sentence
,
true
);
// allow for the post processing
std
::
vector
<
PathTrie
*>
space_prefixes
;
if
(
space_prefixes
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
i
++
)
{
space_prefixes
.
push_back
(
prefixes
[
i
]);
}
if
(
log_prob
>
-
NUM_MAX
)
{
std
::
pair
<
double
,
std
::
string
>
cur_result
(
log_prob
,
sentence
);
beam_result
.
push_back
(
cur_result
);
}
std
::
sort
(
space_prefixes
.
begin
(),
space_prefixes
.
end
(),
prefix_compare
);
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
output_vecs
;
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
space_prefixes
.
size
();
i
++
)
{
std
::
vector
<
int
>
output
;
space_prefixes
[
i
]
->
get_path_vec
(
output
);
// convert index to string
std
::
string
output_str
;
for
(
int
j
=
0
;
j
<
output
.
size
();
j
++
)
{
output_str
+=
vocabulary
[
output
[
j
]];
}
std
::
pair
<
double
,
std
::
string
>
output_pair
(
space_prefixes
[
i
]
->
_score
,
output_str
);
output_vecs
.
emplace_back
(
output_pair
);
}
return
output_vecs
;
}
// sort the result and return
std
::
sort
(
beam_result
.
begin
(),
beam_result
.
end
(),
pair_comp_first_rev
<
double
,
std
::
string
>
);
return
beam_result
;
}
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
...
...
@@ -250,8 +279,7 @@ std::vector<std::vector<std::pair<double, std::string>>>
int
num_processes
,
double
cutoff_prob
,
Scorer
*
ext_scorer
)
{
)
{
if
(
num_processes
<=
0
)
{
std
::
cout
<<
"num_processes must be nonnegative!"
<<
std
::
endl
;
exit
(
1
);
...
...
deploy/decoder_utils.cpp
浏览文件 @
b5602054
...
...
@@ -10,3 +10,73 @@ size_t get_utf8_str_len(const std::string& str) {
}
return
str_len
;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
if
(
x
->
_score
==
y
->
_score
)
{
if
(
x
->
_character
==
y
->
_character
)
{
return
false
;
}
else
{
return
(
x
->
_character
<
y
->
_character
);
}
}
else
{
return
x
->
_score
>
y
->
_score
;
}
}
//---------- End path_compare ---------------------------
// --------------------------------------------------------------
// Adds word to fst without copying entire dictionary
// --------------------------------------------------------------
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
fst
::
StdVectorFst
*
dictionary
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
assert
(
start
==
0
);
dictionary
->
SetStart
(
start
);
}
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
dst
;
for
(
auto
c
:
word
)
{
dst
=
dictionary
->
AddState
();
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
src
=
dst
;
}
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
}
// ------------ End of add_word_to_fst -----------------------
// ---------------------------------------------------------
// Adds a word to the dictionary FST based on char_map
// ---------------------------------------------------------
bool
addWordToDictionary
(
const
std
::
string
&
word
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
bool
add_space
,
int
SPACE
,
fst
::
StdVectorFst
*
dictionary
)
{
/*
auto characters = UTF8_split(word);
std::vector<int> int_word;
for (auto& c : characters) {
if (c == " ") {
int_word.push_back(SPACE);
} else {
auto int_c = char_map.find(c);
if (int_c != char_map.end()) {
int_word.push_back(int_c->second);
} else {
return false; // return without adding
}
}
}
if (add_space) {
int_word.push_back(SPACE);
}
add_word_to_fst(int_word, dictionary);
*/
return
true
;
}
// -------------- End of addWordToDictionary ------------
deploy/decoder_utils.h
浏览文件 @
b5602054
...
...
@@ -2,6 +2,7 @@
#define DECODER_UTILS_H_
#include <utility>
#include "path_trie.h"
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
...
...
@@ -25,8 +26,21 @@ T log_sum_exp(const T &x, const T &y)
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
}
//-------------------------------------------------------
// Overriding less than operator for sorting
//-------------------------------------------------------
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
);
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
fst
::
StdVectorFst
*
dictionary
);
bool
addWordToDictionary
(
const
std
::
string
&
word
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
bool
add_space
,
int
SPACE
,
fst
::
StdVectorFst
*
dictionary
);
#endif // DECODER_UTILS_H
deploy/path_trie.cpp
0 → 100644
浏览文件 @
b5602054
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "path_trie.h"
#include "decoder_utils.h"
PathTrie
::
PathTrie
()
{
float
lowest
=
-
1.0
*
std
::
numeric_limits
<
float
>::
max
();
_log_prob_b_prev
=
lowest
;
_log_prob_nb_prev
=
lowest
;
_log_prob_b_cur
=
lowest
;
_log_prob_nb_cur
=
lowest
;
_score
=
lowest
;
_ROOT
=
-
1
;
_character
=
_ROOT
;
_exists
=
true
;
_parent
=
nullptr
;
_dictionary
=
nullptr
;
_dictionary_state
=
0
;
_has_dictionary
=
false
;
_matcher
=
nullptr
;
// finds arcs in FST
}
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
_children
)
{
delete
child
.
second
;
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
_children
.
begin
();
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
break
;
}
}
if
(
child
!=
_children
.
end
()
)
{
if
(
!
child
->
second
->
_exists
)
{
child
->
second
->
_exists
=
true
;
float
lowest
=
-
1.0
*
std
::
numeric_limits
<
float
>::
max
();
child
->
second
->
_log_prob_b_prev
=
lowest
;
child
->
second
->
_log_prob_nb_prev
=
lowest
;
child
->
second
->
_log_prob_b_cur
=
lowest
;
child
->
second
->
_log_prob_nb_cur
=
lowest
;
}
return
(
child
->
second
);
}
else
{
if
(
_has_dictionary
)
{
_matcher
->
SetState
(
_dictionary_state
);
bool
found
=
_matcher
->
Find
(
new_char
);
if
(
!
found
)
{
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
}
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
_character
=
new_char
;
new_path
->
_parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
_matcher
=
_matcher
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
_character
=
new_char
;
new_path
->
_parent
=
this
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
_ROOT
);
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
/*= std::numeric_limits<size_t>::max() */
)
{
if
(
_character
==
stop
||
_character
==
_ROOT
||
output
.
size
()
==
max_steps
)
{
std
::
reverse
(
output
.
begin
(),
output
.
end
());
return
this
;
}
else
{
output
.
push_back
(
_character
);
return
_parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
}
}
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
_exists
)
{
_log_prob_b_prev
=
_log_prob_b_cur
;
_log_prob_nb_prev
=
_log_prob_nb_cur
;
_log_prob_b_cur
=
-
1.0
*
std
::
numeric_limits
<
float
>::
max
();
_log_prob_nb_cur
=
-
1.0
*
std
::
numeric_limits
<
float
>::
max
();
_score
=
log_sum_exp
(
_log_prob_b_prev
,
_log_prob_nb_prev
);
output
.
push_back
(
this
);
}
for
(
auto
child
:
_children
)
{
child
.
second
->
iterate_to_vec
(
output
);
}
}
//-------------------------------------------------------
// Effectively removes node
//-------------------------------------------------------
void
PathTrie
::
remove
()
{
_exists
=
false
;
if
(
_children
.
size
()
==
0
)
{
auto
child
=
_parent
->
_children
.
begin
();
for
(
child
=
_parent
->
_children
.
begin
();
child
!=
_parent
->
_children
.
end
();
++
child
)
{
if
(
child
->
first
==
_character
)
{
_parent
->
_children
.
erase
(
child
);
break
;
}
}
if
(
_parent
->
_children
.
size
()
==
0
&&
!
_parent
->
_exists
)
{
_parent
->
remove
();
}
delete
this
;
}
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
_dictionary
=
dictionary
;
_dictionary_state
=
dictionary
->
Start
();
_has_dictionary
=
true
;
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
_matcher
=
matcher
;
}
deploy/path_trie.h
0 → 100644
浏览文件 @
b5602054
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#pragma once
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include <fst/fstlib.h>
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
class
PathTrie
{
public:
PathTrie
();
~
PathTrie
();
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>
&
output
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>
&
output
);
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
void
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
);
bool
is_empty
()
{
return
_ROOT
==
_character
;
}
void
remove
();
float
_log_prob_b_prev
;
float
_log_prob_nb_prev
;
float
_log_prob_b_cur
;
float
_log_prob_nb_cur
;
float
_score
;
float
_approx_ctc
;
int
_ROOT
;
int
_character
;
bool
_exists
;
PathTrie
*
_parent
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>
>
_children
;
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
bool
_has_dictionary
;
std
::
shared_ptr
<
FSTMATCH
>
_matcher
;
};
#endif // PATH_TRIE_H
deploy/scorer.cpp
浏览文件 @
b5602054
...
...
@@ -175,3 +175,42 @@ double Scorer::get_score(std::string sentence, bool log) {
}
return
final_score
;
}
//--------------------------------------------------
// Turn indices back into strings of chars
//--------------------------------------------------
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
/*
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
for (int order = 0; order < _max_order; order++) {
std::vector<int> prefix_vec;
if (_is_character_based) {
new_node = current_node->get_path_vec(prefix_vec, ' ', 1);
current_node = new_node;
} else {
new_node = current_node->getPathVec(prefix_vec, ' ');
current_node = new_node->_parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->_character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order - order - 1; i++) {
ngram.push_back("<s>");
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
*/
std
::
vector
<
std
::
string
>
ngram
;
ngram
.
push_back
(
"this"
);
return
ngram
;
}
//---------------- End makeNgrams ------------------
deploy/scorer.h
浏览文件 @
b5602054
...
...
@@ -4,10 +4,12 @@
#include <string>
#include <memory>
#include <vector>
#include <unordered_map>
#include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh"
#include "util/string_piece.hh"
#include "path_trie.h"
const
double
OOV_SCOER
=
-
1000.0
;
const
std
::
string
START_TOKEN
=
"<s>"
;
...
...
@@ -49,18 +51,29 @@ public:
void
reset_params
(
float
alpha
,
float
beta
);
// get the final score
double
get_score
(
std
::
string
,
bool
log
=
false
);
// make ngram
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
// expose to decoder
double
alpha
;
double
beta
;
// fst dictionary
void
*
dictionary
;
protected:
void
load_LM
(
const
char
*
filename
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
private:
void
_init_char_list
();
void
_init_char_map
();
void
*
_language_model
;
bool
_is_character_based
;
size_t
_max_order
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
std
::
vector
<
std
::
string
>
_vocabulary
;
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录