Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
17ebb40a
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
17ebb40a
编写于
9月 18, 2017
作者:
Y
Yibing Liu
提交者:
GitHub
9月 18, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #139 from kuke/ctc_decoder_deploy
Add optimized decoders for DS2
上级
23b53d80
9db0d25f
变更
34
隐藏空白更改
内联
并排
Showing
34 changed file
with
1555 addition
and
64 deletion
+1555
-64
.clang_format.hook
.clang_format.hook
+1
-1
deep_speech_2/decoders/__init__.py
deep_speech_2/decoders/__init__.py
+0
-0
deep_speech_2/decoders/decoders_deprecated.py
deep_speech_2/decoders/decoders_deprecated.py
+8
-12
deep_speech_2/decoders/scorer_deprecated.py
deep_speech_2/decoders/scorer_deprecated.py
+1
-1
deep_speech_2/decoders/swig/__init__.py
deep_speech_2/decoders/swig/__init__.py
+0
-0
deep_speech_2/decoders/swig/_init_paths.py
deep_speech_2/decoders/swig/_init_paths.py
+0
-0
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
+204
-0
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
+61
-0
deep_speech_2/decoders/swig/ctc_greedy_decoder.cpp
deep_speech_2/decoders/swig/ctc_greedy_decoder.cpp
+45
-0
deep_speech_2/decoders/swig/ctc_greedy_decoder.h
deep_speech_2/decoders/swig/ctc_greedy_decoder.h
+20
-0
deep_speech_2/decoders/swig/decoder_utils.cpp
deep_speech_2/decoders/swig/decoder_utils.cpp
+176
-0
deep_speech_2/decoders/swig/decoder_utils.h
deep_speech_2/decoders/swig/decoder_utils.h
+94
-0
deep_speech_2/decoders/swig/decoders.i
deep_speech_2/decoders/swig/decoders.i
+33
-0
deep_speech_2/decoders/swig/path_trie.cpp
deep_speech_2/decoders/swig/path_trie.cpp
+148
-0
deep_speech_2/decoders/swig/path_trie.h
deep_speech_2/decoders/swig/path_trie.h
+67
-0
deep_speech_2/decoders/swig/scorer.cpp
deep_speech_2/decoders/swig/scorer.cpp
+234
-0
deep_speech_2/decoders/swig/scorer.h
deep_speech_2/decoders/swig/scorer.h
+112
-0
deep_speech_2/decoders/swig/setup.py
deep_speech_2/decoders/swig/setup.py
+121
-0
deep_speech_2/decoders/swig/setup.sh
deep_speech_2/decoders/swig/setup.sh
+21
-0
deep_speech_2/decoders/swig_wrapper.py
deep_speech_2/decoders/swig_wrapper.py
+116
-0
deep_speech_2/decoders/tests/test_decoders.py
deep_speech_2/decoders/tests/test_decoders.py
+3
-6
deep_speech_2/examples/librispeech/run_infer.sh
deep_speech_2/examples/librispeech/run_infer.sh
+4
-3
deep_speech_2/examples/librispeech/run_infer_golden.sh
deep_speech_2/examples/librispeech/run_infer_golden.sh
+4
-3
deep_speech_2/examples/librispeech/run_test.sh
deep_speech_2/examples/librispeech/run_test.sh
+3
-3
deep_speech_2/examples/librispeech/run_test_golden.sh
deep_speech_2/examples/librispeech/run_test_golden.sh
+4
-3
deep_speech_2/examples/tiny/run_infer.sh
deep_speech_2/examples/tiny/run_infer.sh
+3
-3
deep_speech_2/examples/tiny/run_infer_golden.sh
deep_speech_2/examples/tiny/run_infer_golden.sh
+3
-3
deep_speech_2/examples/tiny/run_test.sh
deep_speech_2/examples/tiny/run_test.sh
+3
-3
deep_speech_2/examples/tiny/run_test_golden.sh
deep_speech_2/examples/tiny/run_test_golden.sh
+3
-3
deep_speech_2/infer.py
deep_speech_2/infer.py
+11
-4
deep_speech_2/model_utils/model.py
deep_speech_2/model_utils/model.py
+31
-9
deep_speech_2/requirements.txt
deep_speech_2/requirements.txt
+0
-1
deep_speech_2/setup.sh
deep_speech_2/setup.sh
+11
-2
deep_speech_2/test.py
deep_speech_2/test.py
+10
-4
未找到文件。
.clang_format.hook
浏览文件 @
17ebb40a
#!/usr/bin/env bash
#!/usr/bin/env bash
set
-e
set
-e
readonly
VERSION
=
"3.
8
"
readonly
VERSION
=
"3.
9
"
version
=
$(
clang-format
-version
)
version
=
$(
clang-format
-version
)
...
...
deep_speech_2/decoders/__init__.py
0 → 100644
浏览文件 @
17ebb40a
deep_speech_2/
model_utils/decoder
.py
→
deep_speech_2/
decoders/decoders_deprecated
.py
浏览文件 @
17ebb40a
...
@@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
...
@@ -42,8 +42,8 @@ def ctc_greedy_decoder(probs_seq, vocabulary):
def
ctc_beam_search_decoder
(
probs_seq
,
def
ctc_beam_search_decoder
(
probs_seq
,
beam_size
,
beam_size
,
vocabulary
,
vocabulary
,
blank_id
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
,
ext_scoring_func
=
None
,
nproc
=
False
):
nproc
=
False
):
"""CTC Beam search decoder.
"""CTC Beam search decoder.
...
@@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -66,8 +66,6 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int
:type beam_size: int
:param vocabulary: Vocabulary list.
:param vocabulary: Vocabulary list.
:type vocabulary: list
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
default 1.0, no pruning.
:type cutoff_prob: float
:type cutoff_prob: float
...
@@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -87,9 +85,8 @@ def ctc_beam_search_decoder(probs_seq,
raise
ValueError
(
"The shape of prob_seq does not match with the "
raise
ValueError
(
"The shape of prob_seq does not match with the "
"shape of the vocabulary."
)
"shape of the vocabulary."
)
# blank_id check
# blank_id assign
if
not
blank_id
<
len
(
probs_seq
[
0
]):
blank_id
=
len
(
vocabulary
)
raise
ValueError
(
"blank_id shouldn't be greater than probs dimension"
)
# If the decoder called in the multiprocesses, then use the global scorer
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
# instantiated in ctc_beam_search_decoder_batch().
...
@@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -114,7 +111,7 @@ def ctc_beam_search_decoder(probs_seq,
prob_idx
=
list
(
enumerate
(
probs_seq
[
time_step
]))
prob_idx
=
list
(
enumerate
(
probs_seq
[
time_step
]))
cutoff_len
=
len
(
prob_idx
)
cutoff_len
=
len
(
prob_idx
)
#If pruning is enabled
#If pruning is enabled
if
cutoff_prob
<
1.0
:
if
cutoff_prob
<
1.0
or
cutoff_top_n
<
cutoff_len
:
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
cutoff_len
,
cum_prob
=
0
,
0.0
cutoff_len
,
cum_prob
=
0
,
0.0
for
i
in
xrange
(
len
(
prob_idx
)):
for
i
in
xrange
(
len
(
prob_idx
)):
...
@@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -122,6 +119,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_len
+=
1
cutoff_len
+=
1
if
cum_prob
>=
cutoff_prob
:
if
cum_prob
>=
cutoff_prob
:
break
break
cutoff_len
=
min
(
cutoff_len
,
cutoff_top_n
)
prob_idx
=
prob_idx
[
0
:
cutoff_len
]
prob_idx
=
prob_idx
[
0
:
cutoff_len
]
for
l
in
prefix_set_prev
:
for
l
in
prefix_set_prev
:
...
@@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -191,9 +189,9 @@ def ctc_beam_search_decoder(probs_seq,
def
ctc_beam_search_decoder_batch
(
probs_split
,
def
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
beam_size
,
vocabulary
,
vocabulary
,
blank_id
,
num_processes
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
ext_scoring_func
=
None
):
"""CTC beam search decoder using multiple processes.
"""CTC beam search decoder using multiple processes.
...
@@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -204,8 +202,6 @@ def ctc_beam_search_decoder_batch(probs_split,
:type beam_size: int
:type beam_size: int
:param vocabulary: Vocabulary list.
:param vocabulary: Vocabulary list.
:type vocabulary: list
:type vocabulary: list
:param blank_id: ID of blank.
:type blank_id: int
:param num_processes: Number of parallel processes.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
...
@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
...
@@ -232,8 +228,8 @@ def ctc_beam_search_decoder_batch(probs_split,
pool
=
multiprocessing
.
Pool
(
processes
=
num_processes
)
pool
=
multiprocessing
.
Pool
(
processes
=
num_processes
)
results
=
[]
results
=
[]
for
i
,
probs_list
in
enumerate
(
probs_split
):
for
i
,
probs_list
in
enumerate
(
probs_split
):
args
=
(
probs_list
,
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
None
,
args
=
(
probs_list
,
beam_size
,
vocabulary
,
cutoff_prob
,
cutoff_top_n
,
nproc
)
None
,
nproc
)
results
.
append
(
pool
.
apply_async
(
ctc_beam_search_decoder
,
args
))
results
.
append
(
pool
.
apply_async
(
ctc_beam_search_decoder
,
args
))
pool
.
close
()
pool
.
close
()
...
...
deep_speech_2/
model_utils/lm_scorer
.py
→
deep_speech_2/
decoders/scorer_deprecated
.py
浏览文件 @
17ebb40a
...
@@ -8,7 +8,7 @@ import kenlm
...
@@ -8,7 +8,7 @@ import kenlm
import
numpy
as
np
import
numpy
as
np
class
Lm
Scorer
(
object
):
class
Scorer
(
object
):
"""External scorer to evaluate a prefix or whole sentence in
"""External scorer to evaluate a prefix or whole sentence in
beam search decoding, including the score from n-gram language
beam search decoding, including the score from n-gram language
model and word count.
model and word count.
...
...
deep_speech_2/decoders/swig/__init__.py
0 → 100644
浏览文件 @
17ebb40a
deep_speech_2/de
ploy
/_init_paths.py
→
deep_speech_2/de
coders/swig
/_init_paths.py
浏览文件 @
17ebb40a
文件已移动
deep_speech_2/decoders/swig/ctc_beam_search_decoder.cpp
0 → 100644
浏览文件 @
17ebb40a
#include "ctc_beam_search_decoder.h"
#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <utility>
#include "ThreadPool.h"
#include "fst/fstlib.h"
#include "decoder_utils.h"
#include "path_trie.h"
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
// dimension check
size_t
num_time_steps
=
probs_seq
.
size
();
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
VALID_CHECK_EQ
(
probs_seq
[
i
].
size
(),
vocabulary
.
size
()
+
1
,
"The shape of probs_seq does not match with "
"the shape of the vocabulary"
);
}
// assign blank id
size_t
blank_id
=
vocabulary
.
size
();
// assign space id
auto
it
=
std
::
find
(
vocabulary
.
begin
(),
vocabulary
.
end
(),
" "
);
int
space_id
=
it
-
vocabulary
.
begin
();
// if no space in vocabulary
if
((
size_t
)
space_id
>=
vocabulary
.
size
())
{
space_id
=
-
2
;
}
// init prefixes' root
PathTrie
root
;
root
.
score
=
root
.
log_prob_b_prev
=
0.0
;
std
::
vector
<
PathTrie
*>
prefixes
;
prefixes
.
push_back
(
&
root
);
if
(
ext_scorer
!=
nullptr
&&
!
ext_scorer
->
is_character_based
())
{
auto
fst_dict
=
static_cast
<
fst
::
StdVectorFst
*>
(
ext_scorer
->
dictionary
);
fst
::
StdVectorFst
*
dict_ptr
=
fst_dict
->
Copy
(
true
);
root
.
set_dictionary
(
dict_ptr
);
auto
matcher
=
std
::
make_shared
<
FSTMATCH
>
(
*
dict_ptr
,
fst
::
MATCH_INPUT
);
root
.
set_matcher
(
matcher
);
}
// prefix search over time
for
(
size_t
time_step
=
0
;
time_step
<
num_time_steps
;
++
time_step
)
{
auto
&
prob
=
probs_seq
[
time_step
];
float
min_cutoff
=
-
NUM_FLT_INF
;
bool
full_beam
=
false
;
if
(
ext_scorer
!=
nullptr
)
{
size_t
num_prefixes
=
std
::
min
(
prefixes
.
size
(),
beam_size
);
std
::
sort
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
num_prefixes
,
prefix_compare
);
min_cutoff
=
prefixes
[
num_prefixes
-
1
]
->
score
+
std
::
log
(
prob
[
blank_id
])
-
std
::
max
(
0.0
,
ext_scorer
->
beta
);
full_beam
=
(
num_prefixes
==
beam_size
);
}
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
log_prob_idx
=
get_pruned_log_probs
(
prob
,
cutoff_prob
,
cutoff_top_n
);
// loop over chars
for
(
size_t
index
=
0
;
index
<
log_prob_idx
.
size
();
index
++
)
{
auto
c
=
log_prob_idx
[
index
].
first
;
auto
log_prob_c
=
log_prob_idx
[
index
].
second
;
for
(
size_t
i
=
0
;
i
<
prefixes
.
size
()
&&
i
<
beam_size
;
++
i
)
{
auto
prefix
=
prefixes
[
i
];
if
(
full_beam
&&
log_prob_c
+
prefix
->
score
<
min_cutoff
)
{
break
;
}
// blank
if
(
c
==
blank_id
)
{
prefix
->
log_prob_b_cur
=
log_sum_exp
(
prefix
->
log_prob_b_cur
,
log_prob_c
+
prefix
->
score
);
continue
;
}
// repeated character
if
(
c
==
prefix
->
character
)
{
prefix
->
log_prob_nb_cur
=
log_sum_exp
(
prefix
->
log_prob_nb_cur
,
log_prob_c
+
prefix
->
log_prob_nb_prev
);
}
// get new prefix
auto
prefix_new
=
prefix
->
get_path_trie
(
c
);
if
(
prefix_new
!=
nullptr
)
{
float
log_p
=
-
NUM_FLT_INF
;
if
(
c
==
prefix
->
character
&&
prefix
->
log_prob_b_prev
>
-
NUM_FLT_INF
)
{
log_p
=
log_prob_c
+
prefix
->
log_prob_b_prev
;
}
else
if
(
c
!=
prefix
->
character
)
{
log_p
=
log_prob_c
+
prefix
->
score
;
}
// language model scoring
if
(
ext_scorer
!=
nullptr
&&
(
c
==
space_id
||
ext_scorer
->
is_character_based
()))
{
PathTrie
*
prefix_toscore
=
nullptr
;
// skip scoring the space
if
(
ext_scorer
->
is_character_based
())
{
prefix_toscore
=
prefix_new
;
}
else
{
prefix_toscore
=
prefix
;
}
double
score
=
0.0
;
std
::
vector
<
std
::
string
>
ngram
;
ngram
=
ext_scorer
->
make_ngram
(
prefix_toscore
);
score
=
ext_scorer
->
get_log_cond_prob
(
ngram
)
*
ext_scorer
->
alpha
;
log_p
+=
score
;
log_p
+=
ext_scorer
->
beta
;
}
prefix_new
->
log_prob_nb_cur
=
log_sum_exp
(
prefix_new
->
log_prob_nb_cur
,
log_p
);
}
}
// end of loop over prefix
}
// end of loop over vocabulary
prefixes
.
clear
();
// update log probs
root
.
iterate_to_vec
(
prefixes
);
// only preserve top beam_size prefixes
if
(
prefixes
.
size
()
>=
beam_size
)
{
std
::
nth_element
(
prefixes
.
begin
(),
prefixes
.
begin
()
+
beam_size
,
prefixes
.
end
(),
prefix_compare
);
for
(
size_t
i
=
beam_size
;
i
<
prefixes
.
size
();
++
i
)
{
prefixes
[
i
]
->
remove
();
}
}
}
// end of loop over time
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
double
approx_ctc
=
prefixes
[
i
]
->
score
;
if
(
ext_scorer
!=
nullptr
)
{
std
::
vector
<
int
>
output
;
prefixes
[
i
]
->
get_path_vec
(
output
);
auto
prefix_length
=
output
.
size
();
auto
words
=
ext_scorer
->
split_labels
(
output
);
// remove word insert
approx_ctc
=
approx_ctc
-
prefix_length
*
ext_scorer
->
beta
;
// remove language model weight:
approx_ctc
-=
(
ext_scorer
->
get_sent_log_prob
(
words
))
*
ext_scorer
->
alpha
;
}
prefixes
[
i
]
->
approx_ctc
=
approx_ctc
;
}
return
get_beam_search_result
(
prefixes
,
vocabulary
,
beam_size
);
}
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
size_t
num_processes
,
double
cutoff_prob
,
size_t
cutoff_top_n
,
Scorer
*
ext_scorer
)
{
VALID_CHECK_GT
(
num_processes
,
0
,
"num_processes must be nonnegative!"
);
// thread pool
ThreadPool
pool
(
num_processes
);
// number of samples
size_t
batch_size
=
probs_split
.
size
();
// enqueue the tasks of decoding
std
::
vector
<
std
::
future
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>>
res
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
res
.
emplace_back
(
pool
.
enqueue
(
ctc_beam_search_decoder
,
probs_split
[
i
],
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scorer
));
}
// get decoding results
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
batch_results
;
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
batch_results
.
emplace_back
(
res
[
i
].
get
());
}
return
batch_results
;
}
deep_speech_2/decoders/swig/ctc_beam_search_decoder.h
0 → 100644
浏览文件 @
17ebb40a
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "scorer.h"
/* CTC Beam Search Decoder
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
ctc_beam_search_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
nullptr
);
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* vocabulary: A vector of vocabulary.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
*/
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>>
ctc_beam_search_decoder_batch
(
const
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>>>
&
probs_split
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
,
size_t
num_processes
,
double
cutoff_prob
=
1
.
0
,
size_t
cutoff_top_n
=
40
,
Scorer
*
ext_scorer
=
nullptr
);
#endif // CTC_BEAM_SEARCH_DECODER_H_
deep_speech_2/decoders/swig/ctc_greedy_decoder.cpp
0 → 100644
浏览文件 @
17ebb40a
#include "ctc_greedy_decoder.h"
#include "decoder_utils.h"
std
::
string
ctc_greedy_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>
&
probs_seq
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
)
{
// dimension check
size_t
num_time_steps
=
probs_seq
.
size
();
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
VALID_CHECK_EQ
(
probs_seq
[
i
].
size
(),
vocabulary
.
size
()
+
1
,
"The shape of probs_seq does not match with "
"the shape of the vocabulary"
);
}
size_t
blank_id
=
vocabulary
.
size
();
std
::
vector
<
size_t
>
max_idx_vec
(
num_time_steps
,
0
);
std
::
vector
<
size_t
>
idx_vec
;
for
(
size_t
i
=
0
;
i
<
num_time_steps
;
++
i
)
{
double
max_prob
=
0.0
;
size_t
max_idx
=
0
;
const
std
::
vector
<
double
>
&
probs_step
=
probs_seq
[
i
];
for
(
size_t
j
=
0
;
j
<
probs_step
.
size
();
++
j
)
{
if
(
max_prob
<
probs_step
[
j
])
{
max_idx
=
j
;
max_prob
=
probs_step
[
j
];
}
}
// id with maximum probability in current time step
max_idx_vec
[
i
]
=
max_idx
;
// deduplicate
if
((
i
==
0
)
||
((
i
>
0
)
&&
max_idx_vec
[
i
]
!=
max_idx_vec
[
i
-
1
]))
{
idx_vec
.
push_back
(
max_idx_vec
[
i
]);
}
}
std
::
string
best_path_result
;
for
(
size_t
i
=
0
;
i
<
idx_vec
.
size
();
++
i
)
{
if
(
idx_vec
[
i
]
!=
blank_id
)
{
best_path_result
+=
vocabulary
[
idx_vec
[
i
]];
}
}
return
best_path_result
;
}
deep_speech_2/decoders/swig/ctc_greedy_decoder.h
0 → 100644
浏览文件 @
17ebb40a
#ifndef CTC_GREEDY_DECODER_H
#define CTC_GREEDY_DECODER_H
#include <string>
#include <vector>
/* CTC Greedy (Best Path) Decoder
*
* Parameters:
* probs_seq: 2-D vector that each element is a vector of probabilities
* over vocabulary of one time step.
* vocabulary: A vector of vocabulary.
* Return:
* The decoding result in string
*/
std
::
string
ctc_greedy_decoder
(
const
std
::
vector
<
std
::
vector
<
double
>>&
probs_seq
,
const
std
::
vector
<
std
::
string
>&
vocabulary
);
#endif // CTC_GREEDY_DECODER_H
deep_speech_2/decoders/swig/decoder_utils.cpp
0 → 100644
浏览文件 @
17ebb40a
#include "decoder_utils.h"
#include <algorithm>
#include <cmath>
#include <limits>
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
get_pruned_log_probs
(
const
std
::
vector
<
double
>
&
prob_step
,
double
cutoff_prob
,
size_t
cutoff_top_n
)
{
std
::
vector
<
std
::
pair
<
int
,
double
>>
prob_idx
;
for
(
size_t
i
=
0
;
i
<
prob_step
.
size
();
++
i
)
{
prob_idx
.
push_back
(
std
::
pair
<
int
,
double
>
(
i
,
prob_step
[
i
]));
}
// pruning of vacobulary
size_t
cutoff_len
=
prob_step
.
size
();
if
(
cutoff_prob
<
1.0
||
cutoff_top_n
<
cutoff_len
)
{
std
::
sort
(
prob_idx
.
begin
(),
prob_idx
.
end
(),
pair_comp_second_rev
<
int
,
double
>
);
if
(
cutoff_prob
<
1.0
)
{
double
cum_prob
=
0.0
;
cutoff_len
=
0
;
for
(
size_t
i
=
0
;
i
<
prob_idx
.
size
();
++
i
)
{
cum_prob
+=
prob_idx
[
i
].
second
;
cutoff_len
+=
1
;
if
(
cum_prob
>=
cutoff_prob
||
cutoff_len
>=
cutoff_top_n
)
break
;
}
}
prob_idx
=
std
::
vector
<
std
::
pair
<
int
,
double
>>
(
prob_idx
.
begin
(),
prob_idx
.
begin
()
+
cutoff_len
);
}
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
log_prob_idx
;
for
(
size_t
i
=
0
;
i
<
cutoff_len
;
++
i
)
{
log_prob_idx
.
push_back
(
std
::
pair
<
int
,
float
>
(
prob_idx
[
i
].
first
,
log
(
prob_idx
[
i
].
second
+
NUM_FLT_MIN
)));
}
return
log_prob_idx
;
}
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
get_beam_search_result
(
const
std
::
vector
<
PathTrie
*>
&
prefixes
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
)
{
// allow for the post processing
std
::
vector
<
PathTrie
*>
space_prefixes
;
if
(
space_prefixes
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
prefixes
.
size
();
++
i
)
{
space_prefixes
.
push_back
(
prefixes
[
i
]);
}
}
std
::
sort
(
space_prefixes
.
begin
(),
space_prefixes
.
end
(),
prefix_compare
);
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
output_vecs
;
for
(
size_t
i
=
0
;
i
<
beam_size
&&
i
<
space_prefixes
.
size
();
++
i
)
{
std
::
vector
<
int
>
output
;
space_prefixes
[
i
]
->
get_path_vec
(
output
);
// convert index to string
std
::
string
output_str
;
for
(
size_t
j
=
0
;
j
<
output
.
size
();
j
++
)
{
output_str
+=
vocabulary
[
output
[
j
]];
}
std
::
pair
<
double
,
std
::
string
>
output_pair
(
-
space_prefixes
[
i
]
->
approx_ctc
,
output_str
);
output_vecs
.
emplace_back
(
output_pair
);
}
return
output_vecs
;
}
size_t
get_utf8_str_len
(
const
std
::
string
&
str
)
{
size_t
str_len
=
0
;
for
(
char
c
:
str
)
{
str_len
+=
((
c
&
0xc0
)
!=
0x80
);
}
return
str_len
;
}
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
string
out_str
;
for
(
char
c
:
str
)
{
if
((
c
&
0xc0
)
!=
0x80
)
// new UTF-8 character
{
if
(
!
out_str
.
empty
())
{
result
.
push_back
(
out_str
);
out_str
.
clear
();
}
}
out_str
.
append
(
1
,
c
);
}
result
.
push_back
(
out_str
);
return
result
;
}
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
const
std
::
string
&
delim
)
{
std
::
vector
<
std
::
string
>
result
;
std
::
size_t
start
=
0
,
delim_len
=
delim
.
size
();
while
(
true
)
{
std
::
size_t
end
=
s
.
find
(
delim
,
start
);
if
(
end
==
std
::
string
::
npos
)
{
if
(
start
<
s
.
size
())
{
result
.
push_back
(
s
.
substr
(
start
));
}
break
;
}
if
(
end
>
start
)
{
result
.
push_back
(
s
.
substr
(
start
,
end
-
start
));
}
start
=
end
+
delim_len
;
}
return
result
;
}
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
)
{
if
(
x
->
score
==
y
->
score
)
{
if
(
x
->
character
==
y
->
character
)
{
return
false
;
}
else
{
return
(
x
->
character
<
y
->
character
);
}
}
else
{
return
x
->
score
>
y
->
score
;
}
}
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
fst
::
StdVectorFst
*
dictionary
)
{
if
(
dictionary
->
NumStates
()
==
0
)
{
fst
::
StdVectorFst
::
StateId
start
=
dictionary
->
AddState
();
assert
(
start
==
0
);
dictionary
->
SetStart
(
start
);
}
fst
::
StdVectorFst
::
StateId
src
=
dictionary
->
Start
();
fst
::
StdVectorFst
::
StateId
dst
;
for
(
auto
c
:
word
)
{
dst
=
dictionary
->
AddState
();
dictionary
->
AddArc
(
src
,
fst
::
StdArc
(
c
,
c
,
0
,
dst
));
src
=
dst
;
}
dictionary
->
SetFinal
(
dst
,
fst
::
StdArc
::
Weight
::
One
());
}
bool
add_word_to_dictionary
(
const
std
::
string
&
word
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
char_map
,
bool
add_space
,
int
SPACE_ID
,
fst
::
StdVectorFst
*
dictionary
)
{
auto
characters
=
split_utf8_str
(
word
);
std
::
vector
<
int
>
int_word
;
for
(
auto
&
c
:
characters
)
{
if
(
c
==
" "
)
{
int_word
.
push_back
(
SPACE_ID
);
}
else
{
auto
int_c
=
char_map
.
find
(
c
);
if
(
int_c
!=
char_map
.
end
())
{
int_word
.
push_back
(
int_c
->
second
);
}
else
{
return
false
;
// return without adding
}
}
}
if
(
add_space
)
{
int_word
.
push_back
(
SPACE_ID
);
}
add_word_to_fst
(
int_word
,
dictionary
);
return
true
;
// return with successful adding
}
deep_speech_2/decoders/swig/decoder_utils.h
0 → 100644
浏览文件 @
17ebb40a
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <utility>
#include "fst/log.h"
#include "path_trie.h"
const
float
NUM_FLT_INF
=
std
::
numeric_limits
<
float
>::
max
();
const
float
NUM_FLT_MIN
=
std
::
numeric_limits
<
float
>::
min
();
// inline function for validation check
inline
void
check
(
bool
x
,
const
char
*
expr
,
const
char
*
file
,
int
line
,
const
char
*
err
)
{
if
(
!
x
)
{
std
::
cout
<<
"["
<<
file
<<
":"
<<
line
<<
"] "
;
LOG
(
FATAL
)
<<
"
\"
"
<<
expr
<<
"
\"
check failed. "
<<
err
;
}
}
#define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
// Function template for comparing two pairs
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_first_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
return
a
.
first
>
b
.
first
;
}
// Function template for comparing two pairs
template
<
typename
T1
,
typename
T2
>
bool
pair_comp_second_rev
(
const
std
::
pair
<
T1
,
T2
>
&
a
,
const
std
::
pair
<
T1
,
T2
>
&
b
)
{
return
a
.
second
>
b
.
second
;
}
// Return the sum of two probabilities in log scale
template
<
typename
T
>
T
log_sum_exp
(
const
T
&
x
,
const
T
&
y
)
{
static
T
num_min
=
-
std
::
numeric_limits
<
T
>::
max
();
if
(
x
<=
num_min
)
return
y
;
if
(
y
<=
num_min
)
return
x
;
T
xmax
=
std
::
max
(
x
,
y
);
return
std
::
log
(
std
::
exp
(
x
-
xmax
)
+
std
::
exp
(
y
-
xmax
))
+
xmax
;
}
// Get pruned probability vector for each time step's beam search
std
::
vector
<
std
::
pair
<
size_t
,
float
>>
get_pruned_log_probs
(
const
std
::
vector
<
double
>
&
prob_step
,
double
cutoff_prob
,
size_t
cutoff_top_n
);
// Get beam search result from prefixes in trie tree
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>>
get_beam_search_result
(
const
std
::
vector
<
PathTrie
*>
&
prefixes
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
,
size_t
beam_size
);
// Functor for prefix comparsion
bool
prefix_compare
(
const
PathTrie
*
x
,
const
PathTrie
*
y
);
/* Get length of utf8 encoding string
* See: http://stackoverflow.com/a/4063229
*/
size_t
get_utf8_str_len
(
const
std
::
string
&
str
);
/* Split a string into a list of strings on a given string
* delimiter. NB: delimiters on beginning / end of string are
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
*/
std
::
vector
<
std
::
string
>
split_str
(
const
std
::
string
&
s
,
const
std
::
string
&
delim
);
/* Splits string into vector of strings representing
* UTF-8 characters (not same as chars)
*/
std
::
vector
<
std
::
string
>
split_utf8_str
(
const
std
::
string
&
str
);
// Add a word in index to the dicionary of fst
void
add_word_to_fst
(
const
std
::
vector
<
int
>
&
word
,
fst
::
StdVectorFst
*
dictionary
);
// Add a word in string to dictionary
bool
add_word_to_dictionary
(
const
std
::
string
&
word
,
const
std
::
unordered_map
<
std
::
string
,
int
>
&
char_map
,
bool
add_space
,
int
SPACE_ID
,
fst
::
StdVectorFst
*
dictionary
);
#endif // DECODER_UTILS_H
deep_speech_2/decoders/swig/decoders.i
0 → 100644
浏览文件 @
17ebb40a
%
module
swig_decoders
%
{
#
include
"scorer.h"
#
include
"ctc_greedy_decoder.h"
#
include
"ctc_beam_search_decoder.h"
#
include
"decoder_utils.h"
%
}
%
include
"std_vector.i"
%
include
"std_pair.i"
%
include
"std_string.i"
%
import
"decoder_utils.h"
namespace
std
{
%
template
(
DoubleVector
)
std
::
vector
<
double
>
;
%
template
(
IntVector
)
std
::
vector
<
int
>
;
%
template
(
StringVector
)
std
::
vector
<
std
::
string
>
;
%
template
(
VectorOfStructVector
)
std
::
vector
<
std
::
vector
<
double
>
>
;
%
template
(
FloatVector
)
std
::
vector
<
float
>
;
%
template
(
Pair
)
std
::
pair
<
float
,
std
::
string
>
;
%
template
(
PairFloatStringVector
)
std
::
vector
<
std
::
pair
<
float
,
std
::
string
>
>
;
%
template
(
PairDoubleStringVector
)
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
;
%
template
(
PairDoubleStringVector2
)
std
::
vector
<
std
::
vector
<
std
::
pair
<
double
,
std
::
string
>
>
>
;
%
template
(
DoubleVector3
)
std
::
vector
<
std
::
vector
<
std
::
vector
<
double
>
>
>
;
}
%
template
(
IntDoublePairCompSecondRev
)
pair_comp_second_rev
<
int
,
double
>
;
%
template
(
StringDoublePairCompSecondRev
)
pair_comp_second_rev
<
std
::
string
,
double
>
;
%
template
(
DoubleStringPairCompFirstRev
)
pair_comp_first_rev
<
double
,
std
::
string
>
;
%
include
"scorer.h"
%
include
"ctc_greedy_decoder.h"
%
include
"ctc_beam_search_decoder.h"
deep_speech_2/decoders/swig/path_trie.cpp
0 → 100644
浏览文件 @
17ebb40a
#include "path_trie.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "decoder_utils.h"
PathTrie
::
PathTrie
()
{
log_prob_b_prev
=
-
NUM_FLT_INF
;
log_prob_nb_prev
=
-
NUM_FLT_INF
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
score
=
-
NUM_FLT_INF
;
ROOT_
=
-
1
;
character
=
ROOT_
;
exists_
=
true
;
parent
=
nullptr
;
dictionary_
=
nullptr
;
dictionary_state_
=
0
;
has_dictionary_
=
false
;
matcher_
=
nullptr
;
}
PathTrie
::~
PathTrie
()
{
for
(
auto
child
:
children_
)
{
delete
child
.
second
;
}
}
PathTrie
*
PathTrie
::
get_path_trie
(
int
new_char
,
bool
reset
)
{
auto
child
=
children_
.
begin
();
for
(
child
=
children_
.
begin
();
child
!=
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
new_char
)
{
break
;
}
}
if
(
child
!=
children_
.
end
())
{
if
(
!
child
->
second
->
exists_
)
{
child
->
second
->
exists_
=
true
;
child
->
second
->
log_prob_b_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_prev
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_b_cur
=
-
NUM_FLT_INF
;
child
->
second
->
log_prob_nb_cur
=
-
NUM_FLT_INF
;
}
return
(
child
->
second
);
}
else
{
if
(
has_dictionary_
)
{
matcher_
->
SetState
(
dictionary_state_
);
bool
found
=
matcher_
->
Find
(
new_char
);
if
(
!
found
)
{
// Adding this character causes word outside dictionary
auto
FSTZERO
=
fst
::
TropicalWeight
::
Zero
();
auto
final_weight
=
dictionary_
->
Final
(
dictionary_state_
);
bool
is_final
=
(
final_weight
!=
FSTZERO
);
if
(
is_final
&&
reset
)
{
dictionary_state_
=
dictionary_
->
Start
();
}
return
nullptr
;
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
new_path
->
dictionary_
=
dictionary_
;
new_path
->
dictionary_state_
=
matcher_
->
Value
().
nextstate
;
new_path
->
has_dictionary_
=
true
;
new_path
->
matcher_
=
matcher_
;
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
else
{
PathTrie
*
new_path
=
new
PathTrie
;
new_path
->
character
=
new_char
;
new_path
->
parent
=
this
;
children_
.
push_back
(
std
::
make_pair
(
new_char
,
new_path
));
return
new_path
;
}
}
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
)
{
return
get_path_vec
(
output
,
ROOT_
);
}
PathTrie
*
PathTrie
::
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
)
{
if
(
character
==
stop
||
character
==
ROOT_
||
output
.
size
()
==
max_steps
)
{
std
::
reverse
(
output
.
begin
(),
output
.
end
());
return
this
;
}
else
{
output
.
push_back
(
character
);
return
parent
->
get_path_vec
(
output
,
stop
,
max_steps
);
}
}
void
PathTrie
::
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
)
{
if
(
exists_
)
{
log_prob_b_prev
=
log_prob_b_cur
;
log_prob_nb_prev
=
log_prob_nb_cur
;
log_prob_b_cur
=
-
NUM_FLT_INF
;
log_prob_nb_cur
=
-
NUM_FLT_INF
;
score
=
log_sum_exp
(
log_prob_b_prev
,
log_prob_nb_prev
);
output
.
push_back
(
this
);
}
for
(
auto
child
:
children_
)
{
child
.
second
->
iterate_to_vec
(
output
);
}
}
void
PathTrie
::
remove
()
{
exists_
=
false
;
if
(
children_
.
size
()
==
0
)
{
auto
child
=
parent
->
children_
.
begin
();
for
(
child
=
parent
->
children_
.
begin
();
child
!=
parent
->
children_
.
end
();
++
child
)
{
if
(
child
->
first
==
character
)
{
parent
->
children_
.
erase
(
child
);
break
;
}
}
if
(
parent
->
children_
.
size
()
==
0
&&
!
parent
->
exists_
)
{
parent
->
remove
();
}
delete
this
;
}
}
void
PathTrie
::
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
)
{
dictionary_
=
dictionary
;
dictionary_state_
=
dictionary
->
Start
();
has_dictionary_
=
true
;
}
using
FSTMATCH
=
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>
;
void
PathTrie
::
set_matcher
(
std
::
shared_ptr
<
FSTMATCH
>
matcher
)
{
matcher_
=
matcher
;
}
deep_speech_2/decoders/swig/path_trie.h
0 → 100644
浏览文件 @
17ebb40a
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
*/
class
PathTrie
{
public:
PathTrie
();
~
PathTrie
();
// get new prefix after appending new char
PathTrie
*
get_path_trie
(
int
new_char
,
bool
reset
=
true
);
// get the prefix in index from root to current node
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
);
// get the prefix in index from some stop node to current nodel
PathTrie
*
get_path_vec
(
std
::
vector
<
int
>&
output
,
int
stop
,
size_t
max_steps
=
std
::
numeric_limits
<
size_t
>::
max
());
// update log probs
void
iterate_to_vec
(
std
::
vector
<
PathTrie
*>&
output
);
// set dictionary for FST
void
set_dictionary
(
fst
::
StdVectorFst
*
dictionary
);
void
set_matcher
(
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
);
bool
is_empty
()
{
return
ROOT_
==
character
;
}
// remove current path from root
void
remove
();
float
log_prob_b_prev
;
float
log_prob_nb_prev
;
float
log_prob_b_cur
;
float
log_prob_nb_cur
;
float
score
;
float
approx_ctc
;
int
character
;
PathTrie
*
parent
;
private:
int
ROOT_
;
bool
exists_
;
bool
has_dictionary_
;
std
::
vector
<
std
::
pair
<
int
,
PathTrie
*>>
children_
;
// pointer to dictionary of FST
fst
::
StdVectorFst
*
dictionary_
;
fst
::
StdVectorFst
::
StateId
dictionary_state_
;
// true if finding ars in FST
std
::
shared_ptr
<
fst
::
SortedMatcher
<
fst
::
StdVectorFst
>>
matcher_
;
};
#endif // PATH_TRIE_H
deep_speech_2/decoders/swig/scorer.cpp
0 → 100644
浏览文件 @
17ebb40a
#include "scorer.h"
#include <unistd.h>
#include <iostream>
#include "lm/config.hh"
#include "lm/model.hh"
#include "lm/state.hh"
#include "util/string_piece.hh"
#include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using
namespace
lm
::
ngram
;
Scorer
::
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
dictionary
=
nullptr
;
is_character_based_
=
true
;
language_model_
=
nullptr
;
max_order_
=
0
;
dict_size_
=
0
;
SPACE_ID_
=
-
1
;
setup
(
lm_path
,
vocab_list
);
}
Scorer
::~
Scorer
()
{
if
(
language_model_
!=
nullptr
)
{
delete
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
}
if
(
dictionary
!=
nullptr
)
{
delete
static_cast
<
fst
::
StdVectorFst
*>
(
dictionary
);
}
}
void
Scorer
::
setup
(
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>&
vocab_list
)
{
// load language model
load_lm
(
lm_path
);
// set char map for scorer
set_char_map
(
vocab_list
);
// fill the dictionary for FST
if
(
!
is_character_based
())
{
fill_dictionary
(
true
);
}
}
void
Scorer
::
load_lm
(
const
std
::
string
&
lm_path
)
{
const
char
*
filename
=
lm_path
.
c_str
();
VALID_CHECK_EQ
(
access
(
filename
,
F_OK
),
0
,
"Invalid language model path"
);
RetriveStrEnumerateVocab
enumerate
;
lm
::
ngram
::
Config
config
;
config
.
enumerate_vocab
=
&
enumerate
;
language_model_
=
lm
::
ngram
::
LoadVirtual
(
filename
,
config
);
max_order_
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
)
->
Order
();
vocabulary_
=
enumerate
.
vocabulary
;
for
(
size_t
i
=
0
;
i
<
vocabulary_
.
size
();
++
i
)
{
if
(
is_character_based_
&&
vocabulary_
[
i
]
!=
UNK_TOKEN
&&
vocabulary_
[
i
]
!=
START_TOKEN
&&
vocabulary_
[
i
]
!=
END_TOKEN
&&
get_utf8_str_len
(
enumerate
.
vocabulary
[
i
])
>
1
)
{
is_character_based_
=
false
;
}
}
}
double
Scorer
::
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
lm
::
base
::
Model
*
model
=
static_cast
<
lm
::
base
::
Model
*>
(
language_model_
);
double
cond_prob
;
lm
::
ngram
::
State
state
,
tmp_state
,
out_state
;
// avoid to inserting <s> in begin
model
->
NullContextWrite
(
&
state
);
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
lm
::
WordIndex
word_index
=
model
->
BaseVocabulary
().
Index
(
words
[
i
]);
// encounter OOV
if
(
word_index
==
0
)
{
return
OOV_SCORE
;
}
cond_prob
=
model
->
BaseScore
(
&
state
,
word_index
,
&
out_state
);
tmp_state
=
state
;
state
=
out_state
;
out_state
=
tmp_state
;
}
// return log10 prob
return
cond_prob
;
}
double
Scorer
::
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
std
::
vector
<
std
::
string
>
sentence
;
if
(
words
.
size
()
==
0
)
{
for
(
size_t
i
=
0
;
i
<
max_order_
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
max_order_
-
1
;
++
i
)
{
sentence
.
push_back
(
START_TOKEN
);
}
sentence
.
insert
(
sentence
.
end
(),
words
.
begin
(),
words
.
end
());
}
sentence
.
push_back
(
END_TOKEN
);
return
get_log_prob
(
sentence
);
}
double
Scorer
::
get_log_prob
(
const
std
::
vector
<
std
::
string
>&
words
)
{
assert
(
words
.
size
()
>
max_order_
);
double
score
=
0.0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
()
-
max_order_
+
1
;
++
i
)
{
std
::
vector
<
std
::
string
>
ngram
(
words
.
begin
()
+
i
,
words
.
begin
()
+
i
+
max_order_
);
score
+=
get_log_cond_prob
(
ngram
);
}
return
score
;
}
void
Scorer
::
reset_params
(
float
alpha
,
float
beta
)
{
this
->
alpha
=
alpha
;
this
->
beta
=
beta
;
}
std
::
string
Scorer
::
vec2str
(
const
std
::
vector
<
int
>&
input
)
{
std
::
string
word
;
for
(
auto
ind
:
input
)
{
word
+=
char_list_
[
ind
];
}
return
word
;
}
std
::
vector
<
std
::
string
>
Scorer
::
split_labels
(
const
std
::
vector
<
int
>&
labels
)
{
if
(
labels
.
empty
())
return
{};
std
::
string
s
=
vec2str
(
labels
);
std
::
vector
<
std
::
string
>
words
;
if
(
is_character_based_
)
{
words
=
split_utf8_str
(
s
);
}
else
{
words
=
split_str
(
s
,
" "
);
}
return
words
;
}
void
Scorer
::
set_char_map
(
const
std
::
vector
<
std
::
string
>&
char_list
)
{
char_list_
=
char_list
;
char_map_
.
clear
();
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
if
(
char_list_
[
i
]
==
" "
)
{
SPACE_ID_
=
i
;
char_map_
[
' '
]
=
i
;
}
else
if
(
char_list_
[
i
].
size
()
==
1
)
{
char_map_
[
char_list_
[
i
][
0
]]
=
i
;
}
}
}
std
::
vector
<
std
::
string
>
Scorer
::
make_ngram
(
PathTrie
*
prefix
)
{
std
::
vector
<
std
::
string
>
ngram
;
PathTrie
*
current_node
=
prefix
;
PathTrie
*
new_node
=
nullptr
;
for
(
int
order
=
0
;
order
<
max_order_
;
order
++
)
{
std
::
vector
<
int
>
prefix_vec
;
if
(
is_character_based_
)
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
,
1
);
current_node
=
new_node
;
}
else
{
new_node
=
current_node
->
get_path_vec
(
prefix_vec
,
SPACE_ID_
);
current_node
=
new_node
->
parent
;
// Skipping spaces
}
// reconstruct word
std
::
string
word
=
vec2str
(
prefix_vec
);
ngram
.
push_back
(
word
);
if
(
new_node
->
character
==
-
1
)
{
// No more spaces, but still need order
for
(
int
i
=
0
;
i
<
max_order_
-
order
-
1
;
i
++
)
{
ngram
.
push_back
(
START_TOKEN
);
}
break
;
}
}
std
::
reverse
(
ngram
.
begin
(),
ngram
.
end
());
return
ngram
;
}
void
Scorer
::
fill_dictionary
(
bool
add_space
)
{
fst
::
StdVectorFst
dictionary
;
// First reverse char_list so ints can be accessed by chars
std
::
unordered_map
<
std
::
string
,
int
>
char_map
;
for
(
size_t
i
=
0
;
i
<
char_list_
.
size
();
i
++
)
{
char_map
[
char_list_
[
i
]]
=
i
;
}
// For each unigram convert to ints and put in trie
int
dict_size
=
0
;
for
(
const
auto
&
word
:
vocabulary_
)
{
bool
added
=
add_word_to_dictionary
(
word
,
char_map
,
add_space
,
SPACE_ID_
,
&
dictionary
);
dict_size
+=
added
?
1
:
0
;
}
dict_size_
=
dict_size
;
/* Simplify FST
* This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST determinisitc, but
* can greatly increase the size of the FST
*/
fst
::
RmEpsilon
(
&
dictionary
);
fst
::
StdVectorFst
*
new_dict
=
new
fst
::
StdVectorFst
;
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
*/
fst
::
Determinize
(
dictionary
,
new_dict
);
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
*/
fst
::
Minimize
(
new_dict
);
this
->
dictionary
=
new_dict
;
}
deep_speech_2/decoders/swig/scorer.h
0 → 100644
浏览文件 @
17ebb40a
#ifndef SCORER_H_
#define SCORER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "util/string_piece.hh"
#include "path_trie.h"
const
double
OOV_SCORE
=
-
1000.0
;
const
std
::
string
START_TOKEN
=
"<s>"
;
const
std
::
string
UNK_TOKEN
=
"<unk>"
;
const
std
::
string
END_TOKEN
=
"</s>"
;
// Implement a callback to retrive the dictionary of language model.
class
RetriveStrEnumerateVocab
:
public
lm
::
EnumerateVocab
{
public:
RetriveStrEnumerateVocab
()
{}
void
Add
(
lm
::
WordIndex
index
,
const
StringPiece
&
str
)
{
vocabulary
.
push_back
(
std
::
string
(
str
.
data
(),
str
.
length
()));
}
std
::
vector
<
std
::
string
>
vocabulary
;
};
/* External scorer to query score for n-gram or sentence, including language
* model scoring and word insertion.
*
* Example:
* Scorer scorer(alpha, beta, "path_of_language_model");
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class
Scorer
{
public:
Scorer
(
double
alpha
,
double
beta
,
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>
&
vocabulary
);
~
Scorer
();
double
get_log_cond_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
double
get_sent_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
// return the max order
size_t
get_max_order
()
const
{
return
max_order_
;
}
// return the dictionary size of language model
size_t
get_dict_size
()
const
{
return
dict_size_
;
}
// retrun true if the language model is character based
bool
is_character_based
()
const
{
return
is_character_based_
;
}
// reset params alpha & beta
void
reset_params
(
float
alpha
,
float
beta
);
// make ngram for a given prefix
std
::
vector
<
std
::
string
>
make_ngram
(
PathTrie
*
prefix
);
// trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm)
std
::
vector
<
std
::
string
>
split_labels
(
const
std
::
vector
<
int
>
&
labels
);
// language model weight
double
alpha
;
// word insertion weight
double
beta
;
// pointer to the dictionary of FST
void
*
dictionary
;
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void
setup
(
const
std
::
string
&
lm_path
,
const
std
::
vector
<
std
::
string
>
&
vocab_list
);
// load language model from given path
void
load_lm
(
const
std
::
string
&
lm_path
);
// fill dictionary for FST
void
fill_dictionary
(
bool
add_space
);
// set char map
void
set_char_map
(
const
std
::
vector
<
std
::
string
>
&
char_list
);
double
get_log_prob
(
const
std
::
vector
<
std
::
string
>
&
words
);
// translate the vector in index to string
std
::
string
vec2str
(
const
std
::
vector
<
int
>
&
input
);
private:
void
*
language_model_
;
bool
is_character_based_
;
size_t
max_order_
;
size_t
dict_size_
;
int
SPACE_ID_
;
std
::
vector
<
std
::
string
>
char_list_
;
std
::
unordered_map
<
char
,
int
>
char_map_
;
std
::
vector
<
std
::
string
>
vocabulary_
;
};
#endif // SCORER_H_
deep_speech_2/decoders/swig/setup.py
0 → 100644
浏览文件 @
17ebb40a
"""Script to build and install decoder package."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
setuptools
import
setup
,
Extension
,
distutils
import
glob
import
platform
import
os
,
sys
import
multiprocessing.pool
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--num_processes"
,
default
=
1
,
type
=
int
,
help
=
"Number of cpu processes to build package. (default: %(default)d)"
)
args
=
parser
.
parse_known_args
()
# reconstruct sys.argv to pass to setup below
sys
.
argv
=
[
sys
.
argv
[
0
]]
+
args
[
1
]
# monkey-patch for parallel compilation
# See: https://stackoverflow.com/a/13176803
def
parallelCCompile
(
self
,
sources
,
output_dir
=
None
,
macros
=
None
,
include_dirs
=
None
,
debug
=
0
,
extra_preargs
=
None
,
extra_postargs
=
None
,
depends
=
None
):
# those lines are copied from distutils.ccompiler.CCompiler directly
macros
,
objects
,
extra_postargs
,
pp_opts
,
build
=
self
.
_setup_compile
(
output_dir
,
macros
,
include_dirs
,
sources
,
depends
,
extra_postargs
)
cc_args
=
self
.
_get_cc_args
(
pp_opts
,
debug
,
extra_preargs
)
# parallel code
def
_single_compile
(
obj
):
try
:
src
,
ext
=
build
[
obj
]
except
KeyError
:
return
self
.
_compile
(
obj
,
src
,
ext
,
cc_args
,
extra_postargs
,
pp_opts
)
# convert to list, imap is evaluated on-demand
thread_pool
=
multiprocessing
.
pool
.
ThreadPool
(
args
[
0
].
num_processes
)
list
(
thread_pool
.
imap
(
_single_compile
,
objects
))
return
objects
def
compile_test
(
header
,
library
):
dummy_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"dummy"
)
command
=
"bash -c
\"
g++ -include "
+
header
\
+
" -l"
+
library
+
" -x c++ - <<<'int main() {}' -o "
\
+
dummy_path
+
" >/dev/null 2>/dev/null && rm "
\
+
dummy_path
+
" 2>/dev/null
\"
"
return
os
.
system
(
command
)
==
0
# hack compile to support parallel compiling
distutils
.
ccompiler
.
CCompiler
.
compile
=
parallelCCompile
FILES
=
glob
.
glob
(
'kenlm/util/*.cc'
)
\
+
glob
.
glob
(
'kenlm/lm/*.cc'
)
\
+
glob
.
glob
(
'kenlm/util/double-conversion/*.cc'
)
FILES
+=
glob
.
glob
(
'openfst-1.6.3/src/lib/*.cc'
)
# FILES + glob.glob('glog/src/*.cc')
FILES
=
[
fn
for
fn
in
FILES
if
not
(
fn
.
endswith
(
'main.cc'
)
or
fn
.
endswith
(
'test.cc'
)
or
fn
.
endswith
(
'unittest.cc'
))
]
LIBS
=
[
'stdc++'
]
if
platform
.
system
()
!=
'Darwin'
:
LIBS
.
append
(
'rt'
)
ARGS
=
[
'-O3'
,
'-DNDEBUG'
,
'-DKENLM_MAX_ORDER=6'
,
'-std=c++11'
]
if
compile_test
(
'zlib.h'
,
'z'
):
ARGS
.
append
(
'-DHAVE_ZLIB'
)
LIBS
.
append
(
'z'
)
if
compile_test
(
'bzlib.h'
,
'bz2'
):
ARGS
.
append
(
'-DHAVE_BZLIB'
)
LIBS
.
append
(
'bz2'
)
if
compile_test
(
'lzma.h'
,
'lzma'
):
ARGS
.
append
(
'-DHAVE_XZLIB'
)
LIBS
.
append
(
'lzma'
)
os
.
system
(
'swig -python -c++ ./decoders.i'
)
decoders_module
=
[
Extension
(
name
=
'_swig_decoders'
,
sources
=
FILES
+
glob
.
glob
(
'*.cxx'
)
+
glob
.
glob
(
'*.cpp'
),
language
=
'c++'
,
include_dirs
=
[
'.'
,
'kenlm'
,
'openfst-1.6.3/src/include'
,
'ThreadPool'
,
#'glog/src'
],
libraries
=
LIBS
,
extra_compile_args
=
ARGS
)
]
setup
(
name
=
'swig_decoders'
,
version
=
'0.1'
,
description
=
"""CTC decoders"""
,
ext_modules
=
decoders_module
,
py_modules
=
[
'swig_decoders'
],
)
deep_speech_2/decoders/swig/setup.sh
0 → 100644
浏览文件 @
17ebb40a
#!/usr/bin/env bash
if
[
!
-d
kenlm
]
;
then
git clone https://github.com/luotao1/kenlm.git
echo
-e
"
\n
"
fi
if
[
!
-d
openfst-1.6.3
]
;
then
echo
"Download and extract openfst ..."
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
tar
-xzvf
openfst-1.6.3.tar.gz
echo
-e
"
\n
"
fi
if
[
!
-d
ThreadPool
]
;
then
git clone https://github.com/progschj/ThreadPool.git
echo
-e
"
\n
"
fi
echo
"Install decoders ..."
python setup.py
install
--num_processes
4
deep_speech_2/decoders/swig_wrapper.py
0 → 100644
浏览文件 @
17ebb40a
"""Wrapper for various CTC decoders in SWIG."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
swig_decoders
class
Scorer
(
swig_decoders
.
Scorer
):
"""Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: basestring
"""
def
__init__
(
self
,
alpha
,
beta
,
model_path
,
vocabulary
):
swig_decoders
.
Scorer
.
__init__
(
self
,
alpha
,
beta
,
model_path
,
vocabulary
)
def
ctc_greedy_decoder
(
probs_seq
,
vocabulary
):
"""Wrapper for ctc best path decoder in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: basestring
"""
return
swig_decoders
.
ctc_greedy_decoder
(
probs_seq
.
tolist
(),
vocabulary
)
def
ctc_beam_search_decoder
(
probs_seq
,
vocabulary
,
beam_size
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
return
swig_decoders
.
ctc_beam_search_decoder
(
probs_seq
.
tolist
(),
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
def
ctc_beam_search_decoder_batch
(
probs_split
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
"""Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
probs_split
=
[
probs_seq
.
tolist
()
for
probs_seq
in
probs_split
]
return
swig_decoders
.
ctc_beam_search_decoder_batch
(
probs_split
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
)
deep_speech_2/
model_util
s/tests/test_decoders.py
→
deep_speech_2/
decoder
s/tests/test_decoders.py
浏览文件 @
17ebb40a
...
@@ -4,7 +4,7 @@ from __future__ import division
...
@@ -4,7 +4,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
from
model_utils
import
decoder
from
decoders
import
decoders_deprecated
as
decoder
class
TestDecoders
(
unittest
.
TestCase
):
class
TestDecoders
(
unittest
.
TestCase
):
...
@@ -66,16 +66,14 @@ class TestDecoders(unittest.TestCase):
...
@@ -66,16 +66,14 @@ class TestDecoders(unittest.TestCase):
beam_result
=
decoder
.
ctc_beam_search_decoder
(
beam_result
=
decoder
.
ctc_beam_search_decoder
(
probs_seq
=
self
.
probs_seq1
,
probs_seq
=
self
.
probs_seq1
,
beam_size
=
self
.
beam_size
,
beam_size
=
self
.
beam_size
,
vocabulary
=
self
.
vocab_list
,
vocabulary
=
self
.
vocab_list
)
blank_id
=
len
(
self
.
vocab_list
))
self
.
assertEqual
(
beam_result
[
0
][
1
],
self
.
beam_search_result
[
0
])
self
.
assertEqual
(
beam_result
[
0
][
1
],
self
.
beam_search_result
[
0
])
def
test_beam_search_decoder_2
(
self
):
def
test_beam_search_decoder_2
(
self
):
beam_result
=
decoder
.
ctc_beam_search_decoder
(
beam_result
=
decoder
.
ctc_beam_search_decoder
(
probs_seq
=
self
.
probs_seq2
,
probs_seq
=
self
.
probs_seq2
,
beam_size
=
self
.
beam_size
,
beam_size
=
self
.
beam_size
,
vocabulary
=
self
.
vocab_list
,
vocabulary
=
self
.
vocab_list
)
blank_id
=
len
(
self
.
vocab_list
))
self
.
assertEqual
(
beam_result
[
0
][
1
],
self
.
beam_search_result
[
1
])
self
.
assertEqual
(
beam_result
[
0
][
1
],
self
.
beam_search_result
[
1
])
def
test_beam_search_decoder_batch
(
self
):
def
test_beam_search_decoder_batch
(
self
):
...
@@ -83,7 +81,6 @@ class TestDecoders(unittest.TestCase):
...
@@ -83,7 +81,6 @@ class TestDecoders(unittest.TestCase):
probs_split
=
[
self
.
probs_seq1
,
self
.
probs_seq2
],
probs_split
=
[
self
.
probs_seq1
,
self
.
probs_seq2
],
beam_size
=
self
.
beam_size
,
beam_size
=
self
.
beam_size
,
vocabulary
=
self
.
vocab_list
,
vocabulary
=
self
.
vocab_list
,
blank_id
=
len
(
self
.
vocab_list
),
num_processes
=
24
)
num_processes
=
24
)
self
.
assertEqual
(
beam_results
[
0
][
0
][
1
],
self
.
beam_search_result
[
0
])
self
.
assertEqual
(
beam_results
[
0
][
0
][
1
],
self
.
beam_search_result
[
0
])
self
.
assertEqual
(
beam_results
[
1
][
0
][
1
],
self
.
beam_search_result
[
1
])
self
.
assertEqual
(
beam_results
[
1
][
0
][
1
],
self
.
beam_search_result
[
1
])
...
...
deep_speech_2/examples/librispeech/run_infer.sh
浏览文件 @
17ebb40a
...
@@ -21,9 +21,10 @@ python -u infer.py \
...
@@ -21,9 +21,10 @@ python -u infer.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.25
\
--beta
=
0.35
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--cutoff_top_n
=
40
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/examples/librispeech/run_infer_golden.sh
浏览文件 @
17ebb40a
...
@@ -30,9 +30,10 @@ python -u infer.py \
...
@@ -30,9 +30,10 @@ python -u infer.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.25
\
--beta
=
0.35
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--cutoff_top_n
=
40
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/examples/librispeech/run_test.sh
浏览文件 @
17ebb40a
...
@@ -22,9 +22,9 @@ python -u test.py \
...
@@ -22,9 +22,9 @@ python -u test.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.
2
5
\
--beta
=
0.
3
5
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/examples/librispeech/run_test_golden.sh
浏览文件 @
17ebb40a
...
@@ -31,9 +31,10 @@ python -u test.py \
...
@@ -31,9 +31,10 @@ python -u test.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.25
\
--beta
=
0.35
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--cutoff_top_n
=
40
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/examples/tiny/run_infer.sh
浏览文件 @
17ebb40a
...
@@ -21,9 +21,9 @@ python -u infer.py \
...
@@ -21,9 +21,9 @@ python -u infer.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.
2
5
\
--beta
=
0.
3
5
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/examples/tiny/run_infer_golden.sh
浏览文件 @
17ebb40a
...
@@ -30,9 +30,9 @@ python -u infer.py \
...
@@ -30,9 +30,9 @@ python -u infer.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.
2
5
\
--beta
=
0.
3
5
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/examples/tiny/run_test.sh
浏览文件 @
17ebb40a
...
@@ -22,9 +22,9 @@ python -u test.py \
...
@@ -22,9 +22,9 @@ python -u test.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.
2
5
\
--beta
=
0.
3
5
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/examples/tiny/run_test_golden.sh
浏览文件 @
17ebb40a
...
@@ -31,9 +31,9 @@ python -u test.py \
...
@@ -31,9 +31,9 @@ python -u test.py \
--num_conv_layers
=
2
\
--num_conv_layers
=
2
\
--num_rnn_layers
=
3
\
--num_rnn_layers
=
3
\
--rnn_layer_size
=
2048
\
--rnn_layer_size
=
2048
\
--alpha
=
0.36
\
--alpha
=
2.15
\
--beta
=
0.
2
5
\
--beta
=
0.
3
5
\
--cutoff_prob
=
0.99
\
--cutoff_prob
=
1.0
\
--use_gru
=
False
\
--use_gru
=
False
\
--use_gpu
=
True
\
--use_gpu
=
True
\
--share_rnn_weights
=
True
\
--share_rnn_weights
=
True
\
...
...
deep_speech_2/infer.py
浏览文件 @
17ebb40a
...
@@ -21,9 +21,10 @@ add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
...
@@ -21,9 +21,10 @@ add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
add_arg
(
'num_conv_layers'
,
int
,
2
,
"# of convolution layers."
)
add_arg
(
'num_conv_layers'
,
int
,
2
,
"# of convolution layers."
)
add_arg
(
'num_rnn_layers'
,
int
,
3
,
"# of recurrent layers."
)
add_arg
(
'num_rnn_layers'
,
int
,
3
,
"# of recurrent layers."
)
add_arg
(
'rnn_layer_size'
,
int
,
2048
,
"# of recurrent cells per layer."
)
add_arg
(
'rnn_layer_size'
,
int
,
2048
,
"# of recurrent cells per layer."
)
add_arg
(
'alpha'
,
float
,
0.36
,
"Coef of LM for beam search."
)
add_arg
(
'alpha'
,
float
,
2.15
,
"Coef of LM for beam search."
)
add_arg
(
'beta'
,
float
,
0.25
,
"Coef of WC for beam search."
)
add_arg
(
'beta'
,
float
,
0.35
,
"Coef of WC for beam search."
)
add_arg
(
'cutoff_prob'
,
float
,
0.99
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_prob'
,
float
,
1.0
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_top_n'
,
int
,
40
,
"Cutoff number for pruning."
)
add_arg
(
'use_gru'
,
bool
,
False
,
"Use GRUs instead of simple RNNs."
)
add_arg
(
'use_gru'
,
bool
,
False
,
"Use GRUs instead of simple RNNs."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Use GPU or not."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Use GPU or not."
)
add_arg
(
'share_rnn_weights'
,
bool
,
True
,
"Share input-hidden weights across "
add_arg
(
'share_rnn_weights'
,
bool
,
True
,
"Share input-hidden weights across "
...
@@ -84,6 +85,10 @@ def infer():
...
@@ -84,6 +85,10 @@ def infer():
use_gru
=
args
.
use_gru
,
use_gru
=
args
.
use_gru
,
pretrained_model_path
=
args
.
model_path
,
pretrained_model_path
=
args
.
model_path
,
share_rnn_weights
=
args
.
share_rnn_weights
)
share_rnn_weights
=
args
.
share_rnn_weights
)
# decoders only accept string encoded in utf-8
vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
data_generator
.
vocab_list
]
result_transcripts
=
ds2_model
.
infer_batch
(
result_transcripts
=
ds2_model
.
infer_batch
(
infer_data
=
infer_data
,
infer_data
=
infer_data
,
decoding_method
=
args
.
decoding_method
,
decoding_method
=
args
.
decoding_method
,
...
@@ -91,7 +96,8 @@ def infer():
...
@@ -91,7 +96,8 @@ def infer():
beam_beta
=
args
.
beta
,
beam_beta
=
args
.
beta
,
beam_size
=
args
.
beam_size
,
beam_size
=
args
.
beam_size
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_prob
=
args
.
cutoff_prob
,
vocab_list
=
data_generator
.
vocab_list
,
cutoff_top_n
=
args
.
cutoff_top_n
,
vocab_list
=
vocab_list
,
language_model_path
=
args
.
lang_model_path
,
language_model_path
=
args
.
lang_model_path
,
num_processes
=
args
.
num_proc_bsearch
)
num_processes
=
args
.
num_proc_bsearch
)
...
@@ -106,6 +112,7 @@ def infer():
...
@@ -106,6 +112,7 @@ def infer():
print
(
"Current error rate [%s] = %f"
%
print
(
"Current error rate [%s] = %f"
%
(
args
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
(
args
.
error_rate_type
,
error_rate_func
(
target
,
result
)))
ds2_model
.
logger
.
info
(
"finish inference"
)
def
main
():
def
main
():
print_arguments
(
args
)
print_arguments
(
args
)
...
...
deep_speech_2/model_utils/model.py
浏览文件 @
17ebb40a
...
@@ -6,14 +6,18 @@ from __future__ import print_function
...
@@ -6,14 +6,18 @@ from __future__ import print_function
import
sys
import
sys
import
os
import
os
import
time
import
time
import
logging
import
gzip
import
gzip
from
distutils.dir_util
import
mkpath
from
distutils.dir_util
import
mkpath
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
from
model_utils.lm_scorer
import
Lm
Scorer
from
decoders.swig_wrapper
import
Scorer
from
model_utils.decoder
import
ctc_greedy_decoder
,
ctc_beam_search
_decoder
from
decoders.swig_wrapper
import
ctc_greedy
_decoder
from
model_utils.decod
er
import
ctc_beam_search_decoder_batch
from
decoders.swig_wrapp
er
import
ctc_beam_search_decoder_batch
from
model_utils.network
import
deep_speech_v2_network
from
model_utils.network
import
deep_speech_v2_network
logging
.
basicConfig
(
format
=
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
class
DeepSpeech2Model
(
object
):
class
DeepSpeech2Model
(
object
):
"""DeepSpeech2Model class.
"""DeepSpeech2Model class.
...
@@ -44,6 +48,8 @@ class DeepSpeech2Model(object):
...
@@ -44,6 +48,8 @@ class DeepSpeech2Model(object):
self
.
_inferer
=
None
self
.
_inferer
=
None
self
.
_loss_inferer
=
None
self
.
_loss_inferer
=
None
self
.
_ext_scorer
=
None
self
.
_ext_scorer
=
None
self
.
logger
=
logging
.
getLogger
(
""
)
self
.
logger
.
setLevel
(
level
=
logging
.
INFO
)
def
train
(
self
,
def
train
(
self
,
train_batch_reader
,
train_batch_reader
,
...
@@ -157,8 +163,8 @@ class DeepSpeech2Model(object):
...
@@ -157,8 +163,8 @@ class DeepSpeech2Model(object):
return
self
.
_loss_inferer
.
infer
(
input
=
infer_data
)
return
self
.
_loss_inferer
.
infer
(
input
=
infer_data
)
def
infer_batch
(
self
,
infer_data
,
decoding_method
,
beam_alpha
,
beam_beta
,
def
infer_batch
(
self
,
infer_data
,
decoding_method
,
beam_alpha
,
beam_beta
,
beam_size
,
cutoff_prob
,
vocab_list
,
language_model_path
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
vocab_list
,
num_processes
):
language_model_path
,
num_processes
):
"""Model inference. Infer the transcription for a batch of speech
"""Model inference. Infer the transcription for a batch of speech
utterances.
utterances.
...
@@ -178,6 +184,10 @@ class DeepSpeech2Model(object):
...
@@ -178,6 +184,10 @@ class DeepSpeech2Model(object):
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
default 1.0, no pruning.
:type cutoff_prob: float
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:type vocab_list: list
:param language_model_path: Filepath for language model.
:param language_model_path: Filepath for language model.
...
@@ -209,21 +219,33 @@ class DeepSpeech2Model(object):
...
@@ -209,21 +219,33 @@ class DeepSpeech2Model(object):
elif
decoding_method
==
"ctc_beam_search"
:
elif
decoding_method
==
"ctc_beam_search"
:
# initialize external scorer
# initialize external scorer
if
self
.
_ext_scorer
==
None
:
if
self
.
_ext_scorer
==
None
:
self
.
_ext_scorer
=
LmScorer
(
beam_alpha
,
beam_beta
,
language_model_path
)
self
.
_loaded_lm_path
=
language_model_path
self
.
_loaded_lm_path
=
language_model_path
self
.
logger
.
info
(
"begin to initialize the external scorer "
"for decoding"
)
self
.
_ext_scorer
=
Scorer
(
beam_alpha
,
beam_beta
,
language_model_path
,
vocab_list
)
lm_char_based
=
self
.
_ext_scorer
.
is_character_based
()
lm_max_order
=
self
.
_ext_scorer
.
get_max_order
()
lm_dict_size
=
self
.
_ext_scorer
.
get_dict_size
()
self
.
logger
.
info
(
"language model: "
"is_character_based = %d,"
%
lm_char_based
+
" max_order = %d,"
%
lm_max_order
+
" dict_size = %d"
%
lm_dict_size
)
self
.
logger
.
info
(
"end initializing scorer. Start decoding ..."
)
else
:
else
:
self
.
_ext_scorer
.
reset_params
(
beam_alpha
,
beam_beta
)
self
.
_ext_scorer
.
reset_params
(
beam_alpha
,
beam_beta
)
assert
self
.
_loaded_lm_path
==
language_model_path
assert
self
.
_loaded_lm_path
==
language_model_path
# beam search decode
# beam search decode
num_processes
=
min
(
num_processes
,
len
(
probs_split
))
beam_search_results
=
ctc_beam_search_decoder_batch
(
beam_search_results
=
ctc_beam_search_decoder_batch
(
probs_split
=
probs_split
,
probs_split
=
probs_split
,
vocabulary
=
vocab_list
,
vocabulary
=
vocab_list
,
beam_size
=
beam_size
,
beam_size
=
beam_size
,
blank_id
=
len
(
vocab_list
),
num_processes
=
num_processes
,
num_processes
=
num_processes
,
ext_scoring_func
=
self
.
_ext_scorer
,
ext_scoring_func
=
self
.
_ext_scorer
,
cutoff_prob
=
cutoff_prob
)
cutoff_prob
=
cutoff_prob
,
cutoff_top_n
=
cutoff_top_n
)
results
=
[
result
[
0
][
1
]
for
result
in
beam_search_results
]
results
=
[
result
[
0
][
1
]
for
result
in
beam_search_results
]
else
:
else
:
...
...
deep_speech_2/requirements.txt
浏览文件 @
17ebb40a
...
@@ -2,4 +2,3 @@ scipy==0.13.1
...
@@ -2,4 +2,3 @@ scipy==0.13.1
resampy==0.1.5
resampy==0.1.5
SoundFile==0.9.0.post1
SoundFile==0.9.0.post1
python_speech_features
python_speech_features
https://github.com/luotao1/kenlm/archive/master.zip
deep_speech_2/setup.sh
浏览文件 @
17ebb40a
#! /usr/bin/env bash
#! /usr/bin/env
bash
# install python dependencies
# install python dependencies
if
[
-f
"requirements.txt"
]
;
then
if
[
-f
"requirements.txt"
]
;
then
...
@@ -20,10 +20,19 @@ if [ $? != 0 ]; then
...
@@ -20,10 +20,19 @@ if [ $? != 0 ]; then
fi
fi
tar
-zxvf
libsndfile-1.0.28.tar.gz
tar
-zxvf
libsndfile-1.0.28.tar.gz
cd
libsndfile-1.0.28
cd
libsndfile-1.0.28
./configure
&&
make
&&
make
insta
ll
./configure
>
/dev/null
&&
make
>
/dev/null
&&
make
install
>
/dev/nu
ll
cd
..
cd
..
rm
-rf
libsndfile-1.0.28
rm
-rf
libsndfile-1.0.28
rm
libsndfile-1.0.28.tar.gz
rm
libsndfile-1.0.28.tar.gz
fi
fi
# install decoders
python
-c
"import swig_decoders"
if
[
$?
!=
0
]
;
then
cd
decoders/swig
>
/dev/null
sh setup.sh
cd
-
>
/dev/null
fi
echo
"Install all dependencies successfully."
echo
"Install all dependencies successfully."
deep_speech_2/test.py
浏览文件 @
17ebb40a
...
@@ -22,9 +22,10 @@ add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.")
...
@@ -22,9 +22,10 @@ add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.")
add_arg
(
'num_conv_layers'
,
int
,
2
,
"# of convolution layers."
)
add_arg
(
'num_conv_layers'
,
int
,
2
,
"# of convolution layers."
)
add_arg
(
'num_rnn_layers'
,
int
,
3
,
"# of recurrent layers."
)
add_arg
(
'num_rnn_layers'
,
int
,
3
,
"# of recurrent layers."
)
add_arg
(
'rnn_layer_size'
,
int
,
2048
,
"# of recurrent cells per layer."
)
add_arg
(
'rnn_layer_size'
,
int
,
2048
,
"# of recurrent cells per layer."
)
add_arg
(
'alpha'
,
float
,
0.36
,
"Coef of LM for beam search."
)
add_arg
(
'alpha'
,
float
,
2.15
,
"Coef of LM for beam search."
)
add_arg
(
'beta'
,
float
,
0.25
,
"Coef of WC for beam search."
)
add_arg
(
'beta'
,
float
,
0.35
,
"Coef of WC for beam search."
)
add_arg
(
'cutoff_prob'
,
float
,
0.99
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_prob'
,
float
,
1.0
,
"Cutoff probability for pruning."
)
add_arg
(
'cutoff_top_n'
,
int
,
40
,
"Cutoff number for pruning."
)
add_arg
(
'use_gru'
,
bool
,
False
,
"Use GRUs instead of simple RNNs."
)
add_arg
(
'use_gru'
,
bool
,
False
,
"Use GRUs instead of simple RNNs."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Use GPU or not."
)
add_arg
(
'use_gpu'
,
bool
,
True
,
"Use GPU or not."
)
add_arg
(
'share_rnn_weights'
,
bool
,
True
,
"Share input-hidden weights across "
add_arg
(
'share_rnn_weights'
,
bool
,
True
,
"Share input-hidden weights across "
...
@@ -85,6 +86,9 @@ def evaluate():
...
@@ -85,6 +86,9 @@ def evaluate():
pretrained_model_path
=
args
.
model_path
,
pretrained_model_path
=
args
.
model_path
,
share_rnn_weights
=
args
.
share_rnn_weights
)
share_rnn_weights
=
args
.
share_rnn_weights
)
# decoders only accept string encoded in utf-8
vocab_list
=
[
chars
.
encode
(
"utf-8"
)
for
chars
in
data_generator
.
vocab_list
]
error_rate_func
=
cer
if
args
.
error_rate_type
==
'cer'
else
wer
error_rate_func
=
cer
if
args
.
error_rate_type
==
'cer'
else
wer
error_sum
,
num_ins
=
0.0
,
0
error_sum
,
num_ins
=
0.0
,
0
for
infer_data
in
batch_reader
():
for
infer_data
in
batch_reader
():
...
@@ -95,7 +99,8 @@ def evaluate():
...
@@ -95,7 +99,8 @@ def evaluate():
beam_beta
=
args
.
beta
,
beam_beta
=
args
.
beta
,
beam_size
=
args
.
beam_size
,
beam_size
=
args
.
beam_size
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_prob
=
args
.
cutoff_prob
,
vocab_list
=
data_generator
.
vocab_list
,
cutoff_top_n
=
args
.
cutoff_top_n
,
vocab_list
=
vocab_list
,
language_model_path
=
args
.
lang_model_path
,
language_model_path
=
args
.
lang_model_path
,
num_processes
=
args
.
num_proc_bsearch
)
num_processes
=
args
.
num_proc_bsearch
)
target_transcripts
=
[
target_transcripts
=
[
...
@@ -110,6 +115,7 @@ def evaluate():
...
@@ -110,6 +115,7 @@ def evaluate():
print
(
"Final error rate [%s] (%d/%d) = %f"
%
print
(
"Final error rate [%s] (%d/%d) = %f"
%
(
args
.
error_rate_type
,
num_ins
,
num_ins
,
error_sum
/
num_ins
))
(
args
.
error_rate_type
,
num_ins
,
num_ins
,
error_sum
/
num_ins
))
ds2_model
.
logger
.
info
(
"finish evaluation"
)
def
main
():
def
main
():
print_arguments
(
args
)
print_arguments
(
args
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录