Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
5208b8e4
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5208b8e4
编写于
9月 06, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format C++ source code
上级
a2ddfe8d
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
749 addition
and
779 deletion
+749
-779
deploy/ctc_decoders.cpp
deploy/ctc_decoders.cpp
+292
-300
deploy/ctc_decoders.h
deploy/ctc_decoders.h
+21
-23
deploy/decoder_utils.cpp
deploy/decoder_utils.cpp
+79
-81
deploy/decoder_utils.h
deploy/decoder_utils.h
+21
-23
deploy/path_trie.cpp
deploy/path_trie.cpp
+103
-106
deploy/path_trie.h
deploy/path_trie.h
+30
-32
deploy/scorer.cpp
deploy/scorer.cpp
+160
-171
deploy/scorer.h
deploy/scorer.h
+43
-43
未找到文件。
deploy/ctc_decoders.cpp
浏览文件 @
5208b8e4
此差异已折叠。
点击以展开。
deploy/ctc_decoders.h
浏览文件 @
5208b8e4
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <vector>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include <vector>
#include "scorer.h"
#include "scorer.h"
/* CTC Best Path Decoder
/* CTC Best Path Decoder
...
@@ -16,8 +16,8 @@
...
@@ -16,8 +16,8 @@
* A vector that each element is a pair of score and decoding result,
* A vector that each element is a pair of score and decoding result,
* in desending order.
* in desending order.
*/
*/
std
::
string
ctc_best_path_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
std
::
string
ctc_best_path_decoder
(
std
::
vector
<
std
::
vector
<
double
>>
probs_seq
,
std
::
vector
<
std
::
string
>
vocabulary
);
std
::
vector
<
std
::
string
>
vocabulary
);
/* CTC Beam Search Decoder
/* CTC Beam Search Decoder
...
@@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
...
@@ -34,15 +34,14 @@ std::string ctc_best_path_decoder(std::vector<std::vector<double> > probs_seq,
* A vector that each element is a pair of score and decoding result,
* A vector that each element is a pair of score and decoding result,
* in desending order.
* in desending order.
*/
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
ctc_beam_search_decoder
(
std
::
vector
<
std
::
vector
<
double
>
>
probs_seq
,
std
::
vector
<
std
::
vector
<
double
>>
probs_seq
,
int
beam_size
,
int
beam_size
,
std
::
vector
<
std
::
string
>
vocabulary
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
blank_id
,
int
blank_id
,
double
cutoff_prob
=
1
.
0
,
double
cutoff_prob
=
1
.
0
,
int
cutoff_top_n
=
40
,
int
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
NULL
Scorer
*
ext_scorer
=
NULL
);
);
/* CTC Beam Search Decoder for batch data, the interface is consistent with the
/* CTC Beam Search Decoder for batch data, the interface is consistent with the
* original decoder in Python version.
* original decoder in Python version.
...
@@ -63,15 +62,14 @@ std::vector<std::pair<double, std::string> >
...
@@ -63,15 +62,14 @@ std::vector<std::pair<double, std::string> >
* sample.
* sample.
*/
*/
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
probs_split
,
ctc_beam_search_decoder_batch
(
int
beam_size
,
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
probs_split
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
beam_size
,
int
blank_id
,
std
::
vector
<
std
::
string
>
vocabulary
,
int
num_processes
,
int
blank_id
,
double
cutoff_prob
=
1
.
0
,
int
num_processes
,
int
cutoff_top_n
=
40
,
double
cutoff_prob
=
1
.
0
,
Scorer
*
ext_scorer
=
NULL
int
cutoff_top_n
=
40
,
);
Scorer
*
ext_scorer
=
NULL
);
#endif // CTC_BEAM_SEARCH_DECODER_H_
#endif
// CTC_BEAM_SEARCH_DECODER_H_
deploy/decoder_utils.cpp
浏览文件 @
5208b8e4
#include
<limits>
#include
"decoder_utils.h"
#include <algorithm>
#include <algorithm>
#include <cmath>
#include <cmath>
#include
"decoder_utils.h"
#include
<limits>
size_t
get_utf8_str_len
(
const
std
::
string
&
str
)
{
size_t
get_utf8_str_len
(
const
std
::
string
&
str
)
{
size_t
str_len
=
0
;
size_t
str_len
=
0
;
for
(
char
c
:
str
)
{
for
(
char
c
:
str
)
{
str_len
+=
((
c
&
0xc0
)
!=
0x80
);
str_len
+=
((
c
&
0xc0
)
!=
0x80
);
}
}
return
str_len
;
return
str_len
;
}
}
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
)
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
)
{
{
std
::
vector
<
std
::
string
>
result
;
std
::
vector
<
std
::
string
>
result
;
std
::
string
out_str
;
std
::
string
out_str
;
for
(
char
c
:
str
)
for
(
char
c
:
str
)
{
if
((
c
&
0xc0
)
!=
0x80
)
// new UTF-8 character
{
{
if
((
c
&
0xc0
)
!=
0x80
)
//new UTF-8 character
if
(
!
out_str
.
empty
())
{
{
result
.
push_back
(
out_str
);
if
(
!
out_str
.
empty
())
out_str
.
clear
();
{
}
result
.
push_back
(
out_str
);
out_str
.
clear
();
}
}
out_str
.
append
(
1
,
c
);
}
}
out_str
.
append
(
1
,
c
);
}
result
.
push_back
(
out_str
);
result
.
push_back
(
out_str
);
return
result
;
return
result
;
}
}
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
const
std
::
string
&
delim
)
{
const
std
::
string
&
delim
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
vector
<
std
::
string
>
result
;
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
while
(
true
)
{
while
(
true
)
{
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
if
(
end
==
std
::
string
::
npos
)
{
if
(
end
==
std
::
string
::
npos
)
{
if
(
start
<
s
.
size
())
{
if
(
start
<
s
.
size
())
{
result
.
push_back
(
s
.
substr
(
start
));
result
.
push_back
(
s
.
substr
(
start
));
}
}
break
;
break
;
}
}
if
(
end
>
start
)
{
if
(
end
>
start
)
{
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
}
start
=
end
+
delim_len
;
}
}
return
result
;
start
=
end
+
delim_len
;
}
return
result
;
}
}
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
if
(
x
->
_score
==
y
->
_score
)
{
if
(
x
->
score
==
y
->
score
)
{
if
(
x
->
_character
==
y
->
_character
)
{
if
(
x
->
character
==
y
->
character
)
{
return
false
;
return
false
;
}
else
{
return
(
x
->
_character
<
y
->
_character
);
}
}
else
{
}
else
{
return
x
->
_score
>
y
->
_score
;
return
(
x
->
character
<
y
->
character
)
;
}
}
}
else
{
return
x
->
score
>
y
->
score
;
}
}
}
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
void
add_word_to_fst
(
const
std
::
vector
<
int
>&
word
,
fst
::
StdVectorFst
*
dictionary
)
{
fst
::
StdVectorFst
*
dictionary
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
assert
(
start
==
0
);
assert
(
start
==
0
);
dictionary
->
SetStart
(
start
);
dictionary
->
SetStart
(
start
);
}
}
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
dst
;
fst
::
StdVectorFst
::
StateId
dst
;
for
(
auto
c
:
word
)
{
for
(
auto
c
:
word
)
{
dst
=
dictionary
->
AddState
();
dst
=
dictionary
->
AddState
();
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
src
=
dst
;
src
=
dst
;
}
}
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
}
}
bool
add_word_to_dictionary
(
const
std
::
string
&
word
,
bool
add_word_to_dictionary
(
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
const
std
::
string
&
word
,
bool
add_space
,
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
int
SPACE_ID
,
bool
add_space
,
fst
::
StdVectorFst
*
dictionary
)
{
int
SPACE_ID
,
auto
characters
=
split_utf8_str
(
word
);
fst
::
StdVectorFst
*
dictionary
)
{
auto
characters
=
split_utf8_str
(
word
);
std
::
vector
<
int
>
int_word
;
std
::
vector
<
int
>
int_word
;
for
(
auto
&
c
:
characters
)
{
for
(
auto
&
c
:
characters
)
{
if
(
c
==
" "
)
{
if
(
c
==
" "
)
{
int_word
.
push_back
(
SPACE_ID
);
int_word
.
push_back
(
SPACE_ID
);
}
else
{
}
else
{
auto
int_c
=
char_map
.
find
(
c
);
auto
int_c
=
char_map
.
find
(
c
);
if
(
int_c
!=
char_map
.
end
())
{
if
(
int_c
!=
char_map
.
end
())
{
int_word
.
push_back
(
int_c
->
second
);
int_word
.
push_back
(
int_c
->
second
);
}
else
{
}
else
{
return
false
;
// return without adding
return
false
;
// return without adding
}
}
}
}
}
}
if
(
add_space
)
{
if
(
add_space
)
{
int_word
.
push_back
(
SPACE_ID
);
int_word
.
push_back
(
SPACE_ID
);
}
}
add_word_to_fst
(
int_word
,
dictionary
);
add_word_to_fst
(
int_word
,
dictionary
);
return
true
;
return
true
;
}
}
deploy/decoder_utils.h
浏览文件 @
5208b8e4
...
@@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
...
@@ -10,34 +10,31 @@ const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// Function template for comparing two pairs
// Function template for comparing two pairs
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
{
return
a
.
first
>
b
.
first
;
return
a
.
first
>
b
.
first
;
}
}
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
{
return
a
.
second
>
b
.
second
;
return
a
.
second
>
b
.
second
;
}
}
template
<
typename
T
>
template
<
typename
T
>
T
log_sum_exp
(
const
T
&
x
,
const
T
&
y
)
T
log_sum_exp
(
const
T
&
x
,
const
T
&
y
)
{
{
static
T
num_min
=
-
std
::
numeric_limits
<
T
>::
max
();
static
T
num_min
=
-
std
::
numeric_limits
<
T
>::
max
();
if
(
x
<=
num_min
)
return
y
;
if
(
x
<=
num_min
)
return
y
;
if
(
y
<=
num_min
)
return
x
;
if
(
y
<=
num_min
)
return
x
;
T
xmax
=
std
::
max
(
x
,
y
);
T
xmax
=
std
::
max
(
x
,
y
);
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
}
}
// Functor for prefix comparsion
// Functor for prefix comparsion
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
);
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
);
// Get length of utf8 encoding string
// Get length of utf8 encoding string
// See: http://stackoverflow.com/a/4063229
// See: http://stackoverflow.com/a/4063229
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
// Split a string into a list of strings on a given string
// Split a string into a list of strings on a given string
// delimiter. NB: delimiters on beginning / end of string are
// delimiter. NB: delimiters on beginning / end of string are
...
@@ -50,13 +47,14 @@ std::vector<std::string> split_str(const std::string &s,
...
@@ -50,13 +47,14 @@ std::vector<std::string> split_str(const std::string &s,
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
);
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
);
// Add a word in index to the dicionary of fst
// Add a word in index to the dicionary of fst
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
fst
::
StdVectorFst
*
dictionary
);
fst
::
StdVectorFst
*
dictionary
);
// Add a word in string to dictionary
// Add a word in string to dictionary
bool
add_word_to_dictionary
(
const
std
::
string
&
word
,
bool
add_word_to_dictionary
(
const
std
::
unordered_map
<
std
::
string
,
int
>&
char_map
,
const
std
::
string
&
word
,
bool
add_space
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
char_map
,
int
SPACE_ID
,
bool
add_space
,
fst
::
StdVectorFst
*
dictionary
);
int
SPACE_ID
,
#endif // DECODER_UTILS_H
fst
::
StdVectorFst
*
dictionary
);
#endif // DECODER_UTILS_H
deploy/path_trie.cpp
浏览文件 @
5208b8e4
...
@@ -4,145 +4,142 @@
...
@@ -4,145 +4,142 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "path_trie.h"
#include "decoder_utils.h"
#include "decoder_utils.h"
#include "path_trie.h"
PathTrie
::
PathTrie
()
{
PathTrie
::
PathTrie
()
{
_
log_prob_b_prev
=
-
NUM_FLT_INF
;
log_prob_b_prev
=
-
NUM_FLT_INF
;
_
log_prob_nb_prev
=
-
NUM_FLT_INF
;
log_prob_nb_prev
=
-
NUM_FLT_INF
;
_
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
_
log_prob_nb_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
_
score
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
_ROOT
=
-
1
;
_ROOT
=
-
1
;
_
character
=
_ROOT
;
character
=
_ROOT
;
_exists
=
true
;
_exists
=
true
;
_
parent
=
nullptr
;
parent
=
nullptr
;
_dictionary
=
nullptr
;
_dictionary
=
nullptr
;
_dictionary_state
=
0
;
_dictionary_state
=
0
;
_has_dictionary
=
false
;
_has_dictionary
=
false
;
_matcher
=
nullptr
;
// finds arcs in FST
_matcher
=
nullptr
;
// finds arcs in FST
}
}
PathTrie
::~
PathTrie
()
{
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
_children
)
{
delete
child
.
second
;
delete
child
.
second
;
}
}
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
_children
.
begin
();
auto
child
=
_children
.
begin
();
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
for
(
child
=
_children
.
begin
();
child
!=
_children
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
if
(
child
->
first
==
new_char
)
{
break
;
break
;
}
}
}
if
(
child
!=
_children
.
end
()
)
{
}
if
(
!
child
->
second
->
_exists
)
{
if
(
child
!=
_children
.
end
())
{
child
->
second
->
_exists
=
true
;
if
(
!
child
->
second
->
_exists
)
{
child
->
second
->
_log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
_exists
=
true
;
child
->
second
->
_log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
_log_prob_b_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
_log_prob_nb_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_cur
=
-
NUM_FLT_INF
;
}
return
(
child
->
second
);
}
else
{
if
(
_has_dictionary
)
{
_matcher
->
SetState
(
_dictionary_state
);
bool
found
=
_matcher
->
Find
(
new_char
);
if
(
!
found
)
{
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
}
}
return
(
child
->
second
);
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
_matcher
=
_matcher
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
}
else
{
if
(
_has_dictionary
)
{
PathTrie
*
new_path
=
new
PathTrie
;
_matcher
->
SetState
(
_dictionary_state
);
new_path
->
character
=
new_char
;
bool
found
=
_matcher
->
Find
(
new_char
);
new_path
->
parent
=
this
;
if
(
!
found
)
{
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
// Adding this character causes word outside dictionary
return
new_path
;
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
_dictionary
->
Final
(
_dictionary_state
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
_dictionary_state
=
_dictionary
->
Start
();
}
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
_character
=
new_char
;
new_path
->
_parent
=
this
;
new_path
->
_dictionary
=
_dictionary
;
new_path
->
_dictionary_state
=
_matcher
->
Value
().
nextstate
;
new_path
->
_has_dictionary
=
true
;
new_path
->
_matcher
=
_matcher
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
_character
=
new_char
;
new_path
->
_parent
=
this
;
_children
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
}
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
_ROOT
);
return
get_path_vec
(
output
,
_ROOT
);
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
int
stop
,
size_t
max_steps
)
{
size_t
max_steps
)
{
if
(
_character
==
stop
||
if
(
character
==
stop
||
character
==
_ROOT
||
output
.
size
()
==
max_steps
)
{
_character
==
_ROOT
||
std
::
reverse
(
output
.
begin
(),
output
.
end
());
output
.
size
()
==
max_steps
)
{
return
this
;
std
::
reverse
(
output
.
begin
(),
output
.
end
());
}
else
{
return
this
;
output
.
push_back
(
character
);
}
else
{
return
parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
output
.
push_back
(
_character
);
}
return
_parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
}
}
}
void
PathTrie
::
iterate_to_vec
(
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
_exists
)
{
if
(
_exists
)
{
log_prob_b_prev
=
log_prob_b_cur
;
_log_prob_b_prev
=
_log_prob_b_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
_log_prob_nb_prev
=
_log_prob_nb_cur
;
_
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
_
log_prob_nb_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
_score
=
log_sum_exp
(
_log_prob_b_prev
,
_
log_prob_nb_prev
);
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
output
.
push_back
(
this
);
output
.
push_back
(
this
);
}
}
for
(
auto
child
:
_children
)
{
for
(
auto
child
:
_children
)
{
child
.
second
->
iterate_to_vec
(
output
);
child
.
second
->
iterate_to_vec
(
output
);
}
}
}
}
void
PathTrie
::
remove
()
{
void
PathTrie
::
remove
()
{
_exists
=
false
;
_exists
=
false
;
if
(
_children
.
size
()
==
0
)
{
if
(
_children
.
size
()
==
0
)
{
auto
child
=
_parent
->
_children
.
begin
();
auto
child
=
parent
->
_children
.
begin
();
for
(
child
=
_parent
->
_children
.
begin
();
for
(
child
=
parent
->
_children
.
begin
();
child
!=
parent
->
_children
.
end
();
child
!=
_parent
->
_children
.
end
();
++
child
)
{
++
child
)
{
if
(
child
->
first
==
_character
)
{
if
(
child
->
first
==
character
)
{
_parent
->
_children
.
erase
(
child
);
parent
->
_children
.
erase
(
child
);
break
;
break
;
}
}
}
}
if
(
_parent
->
_children
.
size
()
==
0
&&
!
_parent
->
_exists
)
{
_parent
->
remove
();
}
delete
this
;
if
(
parent
->
_children
.
size
()
==
0
&&
!
parent
->
_exists
)
{
parent
->
remove
();
}
}
delete
this
;
}
}
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
_dictionary
=
dictionary
;
_dictionary
=
dictionary
;
_dictionary_state
=
dictionary
->
Start
();
_dictionary_state
=
dictionary
->
Start
();
_has_dictionary
=
true
;
_has_dictionary
=
true
;
}
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
_matcher
=
matcher
;
_matcher
=
matcher
;
}
}
deploy/path_trie.h
浏览文件 @
5208b8e4
#ifndef PATH_TRIE_H
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#define PATH_TRIE_H
#pragma once
#pragma once
#include <fst/fstlib.h>
#include <algorithm>
#include <algorithm>
#include <limits>
#include <limits>
#include <memory>
#include <memory>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <fst/fstlib.h>
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
class
PathTrie
{
class
PathTrie
{
public:
public:
PathTrie
();
PathTrie
();
~
PathTrie
();
~
PathTrie
();
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>
&
output
);
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
,
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
);
int
stop
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>
&
output
);
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
);
void
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
);
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
bool
is_empty
()
{
void
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
);
return
_ROOT
==
_character
;
}
void
remove
();
bool
is_empty
()
{
return
_ROOT
==
character
;
}
float
_log_prob_b_prev
;
void
remove
();
float
_log_prob_nb_prev
;
float
_log_prob_b_cur
;
float
_log_prob_nb_cur
;
float
_score
;
float
_approx_ctc
;
float
log_prob_b_prev
;
float
log_prob_nb_prev
;
float
log_prob_b_cur
;
float
log_prob_nb_cur
;
float
score
;
float
approx_ctc
;
int
character
;
PathTrie
*
parent
;
int
_ROOT
;
private:
int
_character
;
int
_ROOT
;
bool
_exists
;
bool
_exists
;
PathTrie
*
_parent
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
_children
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>
>
_children
;
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
*
_dictionary
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
fst
::
StdVectorFst
::
StateId
_dictionary_state
;
bool
_has_dictionary
;
bool
_has_dictionary
;
std
::
shared_ptr
<
FSTMATCH
>
_matcher
;
std
::
shared_ptr
<
FSTMATCH
>
_matcher
;
};
};
#endif // PATH_TRIE_H
#endif
// PATH_TRIE_H
deploy/scorer.cpp
浏览文件 @
5208b8e4
#include
<iostream>
#include
"scorer.h"
#include <unistd.h>
#include <unistd.h>
#include <iostream>
#include "decoder_utils.h"
#include "lm/config.hh"
#include "lm/config.hh"
#include "lm/state.hh"
#include "lm/model.hh"
#include "lm/model.hh"
#include "
util/tokenize_piec
e.hh"
#include "
lm/stat
e.hh"
#include "util/string_piece.hh"
#include "util/string_piece.hh"
#include "scorer.h"
#include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using
namespace
lm
::
ngram
;
using
namespace
lm
::
ngram
;
Scorer
::
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
)
{
Scorer
::
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
)
{
this
->
alpha
=
alpha
;
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
this
->
beta
=
beta
;
_is_character_based
=
true
;
_is_character_based
=
true
;
_language_model
=
nullptr
;
_language_model
=
nullptr
;
dictionary
=
nullptr
;
dictionary
=
nullptr
;
_max_order
=
0
;
_max_order
=
0
;
_SPACE_ID
=
-
1
;
_SPACE_ID
=
-
1
;
// load language model
// load language model
load_LM
(
lm_path
.
c_str
());
load_LM
(
lm_path
.
c_str
());
}
}
Scorer
::~
Scorer
()
{
Scorer
::~
Scorer
()
{
if
(
_language_model
!=
nullptr
)
if
(
_language_model
!=
nullptr
)
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
delete
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
if
(
dictionary
!=
nullptr
)
if
(
dictionary
!=
nullptr
)
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
}
}
void
Scorer
::
load_LM
(
const
char
*
filename
)
{
void
Scorer
::
load_LM
(
const
char
*
filename
)
{
if
(
access
(
filename
,
F_OK
)
!=
0
)
{
if
(
access
(
filename
,
F_OK
)
!=
0
)
{
std
::
cerr
<<
"Invalid language model file !!!"
<<
std
::
endl
;
std
::
cerr
<<
"Invalid language model file !!!"
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
}
}
RetriveStrEnumerateVocab
enumerate
;
RetriveStrEnumerateVocab
enumerate
;
lm
::
ngram
::
Config
config
;
lm
::
ngram
::
Config
config
;
config
.
enumerate_vocab
=
&
enumerate
;
config
.
enumerate_vocab
=
&
enumerate
;
_language_model
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
_language_model
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
_max_order
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
)
->
Order
();
_max_order
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
)
->
Order
();
_vocabulary
=
enumerate
.
vocabulary
;
_vocabulary
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
_vocabulary
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_vocabulary
.
size
();
++
i
)
{
if
(
_is_character_based
if
(
_is_character_based
&&
_vocabulary
[
i
]
!=
UNK_TOKEN
&&
&&
_vocabulary
[
i
]
!=
UNK_TOKEN
_vocabulary
[
i
]
!=
START_TOKEN
&&
_vocabulary
[
i
]
!=
END_TOKEN
&&
&&
_vocabulary
[
i
]
!=
START_TOKEN
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
&&
_vocabulary
[
i
]
!=
END_TOKEN
_is_character_based
=
false
;
&&
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
_is_character_based
=
false
;
}
}
}
}
}
}
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
_language_model
);
double
cond_prob
;
double
cond_prob
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
// avoid to inserting <s> in begin
// avoid to inserting <s> in begin
model
->
NullContextWrite
(
&
state
);
model
->
NullContextWrite
(
&
state
);
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
lm
::
WordIndex
word_index
=
model
->
BaseVocabulary
().
Index
(
words
[
i
]);
lm
::
WordIndex
word_index
=
model
->
BaseVocabulary
().
Index
(
words
[
i
]);
// encounter OOV
// encounter OOV
if
(
word_index
==
0
)
{
if
(
word_index
==
0
)
{
return
OOV_SCORE
;
return
OOV_SCORE
;
}
cond_prob
=
model
->
BaseScore
(
&
state
,
word_index
,
&
out_state
);
tmp_state
=
state
;
state
=
out_state
;
out_state
=
tmp_state
;
}
}
// log10 prob
cond_prob
=
model
->
BaseScore
(
&
state
,
word_index
,
&
out_state
);
return
cond_prob
;
tmp_state
=
state
;
state
=
out_state
;
out_state
=
tmp_state
;
}
// log10 prob
return
cond_prob
;
}
}
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
std
::
vector
<
std
::
string
>
sentence
;
std
::
vector
<
std
::
string
>
sentence
;
if
(
words
.
size
()
==
0
)
{
if
(
words
.
size
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
_max_order
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_max_order
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
sentence
.
push_back
(
START_TOKEN
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
_max_order
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
}
}
sentence
.
push_back
(
END_TOKEN
);
}
else
{
return
get_log_prob
(
sentence
);
for
(
size_t
i
=
0
;
i
<
_max_order
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
}
sentence
.
push_back
(
END_TOKEN
);
return
get_log_prob
(
sentence
);
}
}
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
assert
(
words
.
size
()
>
_max_order
);
assert
(
words
.
size
()
>
_max_order
);
double
score
=
0.0
;
double
score
=
0.0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
_max_order
+
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
_max_order
+
1
;
++
i
)
{
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
words
.
begin
()
+
i
+
_max_order
);
words
.
begin
()
+
i
+
_max_order
);
score
+=
get_log_cond_prob
(
ngram
);
score
+=
get_log_cond_prob
(
ngram
);
}
}
return
score
;
return
score
;
}
}
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
this
->
alpha
=
alpha
;
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
this
->
beta
=
beta
;
}
}
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
word
;
std
::
string
word
;
for
(
auto
ind
:
input
)
{
for
(
auto
ind
:
input
)
{
word
+=
_char_list
[
ind
];
word
+=
_char_list
[
ind
];
}
}
return
word
;
return
word
;
}
}
std
::
vector
<
std
::
string
>
std
::
vector
<
std
::
string
>
Scorer
::
split_labels
(
const
std
::
vector
<
int
>&
labels
)
{
Scorer
::
split_labels
(
const
std
::
vector
<
int
>
&
labels
)
{
if
(
labels
.
empty
())
return
{};
if
(
labels
.
empty
())
return
{};
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
std
::
string
s
=
vec2str
(
labels
);
if
(
_is_character_based
)
{
std
::
vector
<
std
::
string
>
words
;
words
=
split_utf8_str
(
s
);
if
(
_is_character_based
)
{
}
else
{
words
=
split_utf8_str
(
s
);
words
=
split_str
(
s
,
" "
);
}
else
{
}
words
=
split_str
(
s
,
" "
);
return
words
;
}
return
words
;
}
}
void
Scorer
::
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
)
{
void
Scorer
::
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
)
{
_char_list
=
char_list
;
_char_list
=
char_list
;
_char_map
.
clear
();
_char_map
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
{
if
(
_char_list
[
i
]
==
" "
)
{
if
(
_char_list
[
i
]
==
" "
)
{
_SPACE_ID
=
i
;
_SPACE_ID
=
i
;
_char_map
[
' '
]
=
i
;
_char_map
[
' '
]
=
i
;
}
else
if
(
_char_list
[
i
].
size
()
==
1
)
{
}
else
if
(
_char_list
[
i
].
size
()
==
1
){
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
_char_map
[
_char_list
[
i
][
0
]]
=
i
;
}
}
}
}
}
}
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
ngram
;
std
::
vector
<
std
::
string
>
ngram
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
PathTrie
*
new_node
=
nullptr
;
for
(
int
order
=
0
;
order
<
_max_order
;
order
++
)
{
std
::
vector
<
int
>
prefix_vec
;
if
(
_is_character_based
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
,
1
);
current_node
=
new_node
;
}
else
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
);
current_node
=
new_node
->
_parent
;
// Skipping spaces
}
// reconstruct word
std
::
string
word
=
vec2str
(
prefix_vec
);
ngram
.
push_back
(
word
);
if
(
new_node
->
_character
==
-
1
)
{
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
_max_order
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
}
break
;
}
}
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
return
ngram
;
}
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
fst
::
StdVectorFst
dictionary
;
for
(
int
order
=
0
;
order
<
_max_order
;
order
++
)
{
// First reverse char_list so ints can be accessed by chars
std
::
vector
<
int
>
prefix_vec
;
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
if
(
_is_character_based
)
{
int
vocab_size
=
0
;
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
,
1
);
for
(
const
auto
&
word
:
_vocabulary
)
{
current_node
=
new_node
;
bool
added
=
add_word_to_dictionary
(
word
,
}
else
{
char_map
,
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
_SPACE_ID
);
add_space
,
current_node
=
new_node
->
parent
;
// Skipping spaces
_SPACE_ID
,
&
dictionary
);
vocab_size
+=
added
?
1
:
0
;
}
}
std
::
cerr
<<
"Vocab Size "
<<
vocab_size
<<
std
::
endl
;
// reconstruct word
std
::
string
word
=
vec2str
(
prefix_vec
);
// Simplify FST
ngram
.
push_back
(
word
);
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
// This makes the FST deterministic, meaning for any string input there's
if
(
new_node
->
character
==
-
1
)
{
// only one possible state the FST could be in. It is assumed our
// No more spaces, but still need order
// dictionary is deterministic when using it.
for
(
int
i
=
0
;
i
<
_max_order
-
order
-
1
;
i
++
)
{
// (lest we'd have to check for multiple transitions at each state)
ngram
.
push_back
(
START_TOKEN
);
fst
::
Determinize
(
dictionary
,
new_dict
);
}
break
;
// Finds the simplest equivalent fst. This is unnecessary but decreases
}
// memory usage of the dictionary
}
fst
::
Minimize
(
new_dict
);
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
this
->
dictionary
=
new_dict
;
return
ngram
;
}
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
fst
::
StdVectorFst
dictionary
;
// First reverse char_list so ints can be accessed by chars
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
unsigned
int
i
=
0
;
i
<
_char_list
.
size
();
i
++
)
{
char_map
[
_char_list
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
int
vocab_size
=
0
;
for
(
const
auto
&
word
:
_vocabulary
)
{
bool
added
=
add_word_to_dictionary
(
word
,
char_map
,
add_space
,
_SPACE_ID
,
&
dictionary
);
vocab_size
+=
added
?
1
:
0
;
}
std
::
cerr
<<
"Vocab Size "
<<
vocab_size
<<
std
::
endl
;
// Simplify FST
// This gets rid of "epsilon" transitions in the FST.
// These are transitions that don't require a string input to be taken.
// Getting rid of them is necessary to make the FST determinisitc, but
// can greatly increase the size of the FST
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
// This makes the FST deterministic, meaning for any string input there's
// only one possible state the FST could be in. It is assumed our
// dictionary is deterministic when using it.
// (lest we'd have to check for multiple transitions at each state)
fst
::
Determinize
(
dictionary
,
new_dict
);
// Finds the simplest equivalent fst. This is unnecessary but decreases
// memory usage of the dictionary
fst
::
Minimize
(
new_dict
);
this
->
dictionary
=
new_dict
;
}
}
deploy/scorer.h
浏览文件 @
5208b8e4
#ifndef SCORER_H_
#ifndef SCORER_H_
#define SCORER_H_
#define SCORER_H_
#include <string>
#include <memory>
#include <memory>
#include <
vector
>
#include <
string
>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/enumerate_vocab.hh"
#include "lm/word_index.hh"
#include "lm/virtual_interface.hh"
#include "lm/virtual_interface.hh"
#include "
util/string_piece
.hh"
#include "
lm/word_index
.hh"
#include "path_trie.h"
#include "path_trie.h"
#include "util/string_piece.hh"
const
double
OOV_SCORE
=
-
1000.0
;
const
double
OOV_SCORE
=
-
1000.0
;
const
std
::
string
START_TOKEN
=
"<s>"
;
const
std
::
string
START_TOKEN
=
"<s>"
;
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
// Implement a callback to retrive string vocabulary.
// Implement a callback to retrive string vocabulary.
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
public:
public:
RetriveStrEnumerateVocab
()
{}
RetriveStrEnumerateVocab
()
{}
void
Add
(
lm
::
WordIndex
index
,
const
StringPiece
&
str
)
{
void
Add
(
lm
::
WordIndex
index
,
const
StringPiece
&
str
)
{
vocabulary
.
push_back
(
std
::
string
(
str
.
data
(),
str
.
length
()));
vocabulary
.
push_back
(
std
::
string
(
str
.
data
(),
str
.
length
()));
}
}
std
::
vector
<
std
::
string
>
vocabulary
;
std
::
vector
<
std
::
string
>
vocabulary
;
};
};
// External scorer to query languange score for n-gram or sentence.
// External scorer to query languange score for n-gram or sentence.
...
@@ -33,59 +33,59 @@ public:
...
@@ -33,59 +33,59 @@ public:
// Scorer scorer(alpha, beta, "path_of_language_model");
// Scorer scorer(alpha, beta, "path_of_language_model");
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
// scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
class
Scorer
{
class
Scorer
{
public:
public:
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
);
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
);
~
Scorer
();
~
Scorer
();
double
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
size_t
get_max_order
()
{
return
_max_order
;
}
size_t
get_max_order
()
{
return
_max_order
;
}
bool
is_char_map_empty
()
{
return
_char_map
.
size
()
==
0
;
}
bool
is_char_map_empty
()
{
return
_char_map
.
size
()
==
0
;
}
bool
is_character_based
()
{
return
_is_character_based
;
}
bool
is_character_based
()
{
return
_is_character_based
;
}
// reset params alpha & beta
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
void
reset_params
(
float
alpha
,
float
beta
);
// make ngram
// make ngram
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
// fill dictionary for fst
// fill dictionary for fst
void
fill_dictionary
(
bool
add_space
);
void
fill_dictionary
(
bool
add_space
);
// set char map
// set char map
void
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
);
void
set_char_map
(
std
::
vector
<
std
::
string
>
char_list
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>&
labels
);
// expose to decoder
// expose to decoder
double
alpha
;
double
alpha
;
double
beta
;
double
beta
;
// fst dictionary
// fst dictionary
void
*
dictionary
;
void
*
dictionary
;
protected:
protected:
void
load_LM
(
const
char
*
filename
);
void
load_LM
(
const
char
*
filename
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
std
::
string
vec2str
(
const
std
::
vector
<
int
>&
input
);
private:
private:
void
*
_language_model
;
void
*
_language_model
;
bool
_is_character_based
;
bool
_is_character_based
;
size_t
_max_order
;
size_t
_max_order
;
int
_SPACE_ID
;
int
_SPACE_ID
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
vector
<
std
::
string
>
_char_list
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
std
::
unordered_map
<
char
,
int
>
_char_map
;
std
::
vector
<
std
::
string
>
_vocabulary
;
std
::
vector
<
std
::
string
>
_vocabulary
;
};
};
#endif // SCORER_H_
#endif
// SCORER_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录