Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
63a72c1e
M
models
项目概览
PaddlePaddle
/
models
接近 2 年 前同步成功
通知
230
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
63a72c1e
编写于
6月 27, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine ctc_beam_search_decoder
上级
accaf924
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
187 addition
and
229 deletion
+187
-229
deep_speech_2/decoder.py
deep_speech_2/decoder.py
+68
-60
deep_speech_2/evaluate.py
deep_speech_2/evaluate.py
+38
-51
deep_speech_2/infer.py
deep_speech_2/infer.py
+30
-49
deep_speech_2/lm/__init__.py
deep_speech_2/lm/__init__.py
+0
-0
deep_speech_2/lm/lm_scorer.py
deep_speech_2/lm/lm_scorer.py
+12
-9
deep_speech_2/lm/run.sh
deep_speech_2/lm/run.sh
+3
-0
deep_speech_2/requirements.txt
deep_speech_2/requirements.txt
+1
-0
deep_speech_2/tests/test_decoders.py
deep_speech_2/tests/test_decoders.py
+3
-3
deep_speech_2/tune.py
deep_speech_2/tune.py
+32
-57
未找到文件。
deep_speech_2/decoder.py
浏览文件 @
63a72c1e
...
@@ -8,8 +8,8 @@ import numpy as np
...
@@ -8,8 +8,8 @@ import numpy as np
import
multiprocessing
import
multiprocessing
def
ctc_best_path_decode
(
probs_seq
,
vocabulary
):
def
ctc_best_path_decode
r
(
probs_seq
,
vocabulary
):
"""Best path decod
ing, also called argmax decoding or greedy decoding
.
"""Best path decod
er, also called argmax decoder or greedy decoder
.
Path consisting of the most probable tokens are further post-processed to
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
remove consecutive repetitions and all blanks.
...
@@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary):
...
@@ -40,73 +40,84 @@ def ctc_best_path_decode(probs_seq, vocabulary):
def
ctc_beam_search_decoder
(
probs_seq
,
def
ctc_beam_search_decoder
(
probs_seq
,
beam_size
,
beam_size
,
vocabulary
,
vocabulary
,
blank_id
=
0
,
blank_id
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
ext_scoring_func
=
None
,
ext_scoring_func
=
None
,
nproc
=
False
):
nproc
=
False
):
'''Beam search decoder for CTC-trained network, using beam search with width
"""Beam search decoder for CTC-trained network. It utilizes beam search
beam_size to find many paths to one label, return beam_size labels in
to approximately select top best decoding labels and returning results
the descending order of probabilities. The implementation is based on Prefix
in the descending order. The implementation is based on Prefix
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
Beam Search (https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned.
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
:param probs_seq: 2-D list with length num_time_steps, each element
one prefix may comes from different paths; 2) the if condition "if l^+ not
is a list of normalized probabilities over vocabulary
in A_prev then" after probabilities' computation is deprecated for it is
and blank for one time step.
hard to understand and seems unnecessary.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:param beam_size: Width for beam search.
:type beam_size: int
:type beam_size: int
:param vocabulary: Vocabulary list.
:param vocabulary: Vocabulary list.
:type vocabulary: list
:type vocabulary: list
:param blank_id: ID of blank
, default 0
.
:param blank_id: ID of blank.
:type blank_id: int
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
default 1.0, no pruning.
:type cutoff_prob: float
:type cutoff_prob: float
:param ext_scoring_func: External
defined
scoring function for
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
partially decoded sentence, e.g. word count
and
language model.
or
language model.
:type external_scoring_func
tion: function
:type external_scoring_func
: callable
:param nproc: Whether the decoder used in multiprocesses.
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:type nproc: bool
:return: Decoding log probabilities and result sentences in descending order.
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
:rtype: list
'''
"""
# dimension check
# dimension check
for
prob_list
in
probs_seq
:
for
prob_list
in
probs_seq
:
if
not
len
(
prob_list
)
==
len
(
vocabulary
)
+
1
:
if
not
len
(
prob_list
)
==
len
(
vocabulary
)
+
1
:
raise
ValueError
(
"
probs dimension mismatched with vocabulary"
)
raise
ValueError
(
"
The shape of prob_seq does not match with the "
num_time_steps
=
len
(
probs_seq
)
"shape of the vocabulary."
)
# blank_id check
# blank_id check
probs_dim
=
len
(
probs_seq
[
0
])
if
not
blank_id
<
len
(
probs_seq
[
0
]):
if
not
blank_id
<
probs_dim
:
raise
ValueError
(
"blank_id shouldn't be greater than probs dimension"
)
raise
ValueError
(
"blank_id shouldn't be greater than probs dimension"
)
# If the decoder called in the multiprocesses, then use the global scorer
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_
nproc
().
# instantiated in ctc_beam_search_decoder_
batch
().
if
nproc
is
True
:
if
nproc
is
True
:
global
ext_nproc_scorer
global
ext_nproc_scorer
ext_scoring_func
=
ext_nproc_scorer
ext_scoring_func
=
ext_nproc_scorer
## initialize
## initialize
# the set containing selected prefixes
# prefix_set_prev: the set containing selected prefixes
prefix_set_prev
=
{
'
\t
'
:
1.0
}
# probs_b_prev: prefixes' probability ending with blank in previous step
probs_b_prev
,
probs_nb_prev
=
{
'
\t
'
:
1.0
},
{
'
\t
'
:
0.0
}
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev
,
probs_b_prev
,
probs_nb_prev
=
{
'
\t
'
:
1.0
},
{
'
\t
'
:
1.0
},
{
'
\t
'
:
0.0
}
## extend prefix in loop
## extend prefix in loop
for
time_step
in
xrange
(
num_time_steps
):
for
time_step
in
xrange
(
len
(
probs_seq
)):
# the set containing candidate prefixes
# prefix_set_next: the set containing candidate prefixes
prefix_set_next
=
{}
# probs_b_cur: prefixes' probability ending with blank in current step
probs_b_cur
,
probs_nb_cur
=
{},
{}
# probs_nb_cur: prefixes' probability ending with non-blank in current step
prob
=
probs_seq
[
time_step
]
prefix_set_next
,
probs_b_cur
,
probs_nb_cur
=
{},
{},
{}
prob_idx
=
[[
i
,
prob
[
i
]]
for
i
in
xrange
(
len
(
prob
))]
prob_idx
=
list
(
enumerate
(
probs_seq
[
time_step
]))
cutoff_len
=
len
(
prob_idx
)
cutoff_len
=
len
(
prob_idx
)
#If pruning is enabled
#If pruning is enabled
if
(
cutoff_prob
<
1.0
)
:
if
cutoff_prob
<
1.0
:
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
cutoff_len
=
0
cutoff_len
,
cum_prob
=
0
,
0.0
cum_prob
=
0.0
for
i
in
xrange
(
len
(
prob_idx
)):
for
i
in
xrange
(
len
(
prob_idx
)):
cum_prob
+=
prob_idx
[
i
][
1
]
cum_prob
+=
prob_idx
[
i
][
1
]
cutoff_len
+=
1
cutoff_len
+=
1
...
@@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq,
...
@@ -162,54 +173,53 @@ def ctc_beam_search_decoder(probs_seq,
prefix_set_prev
=
dict
(
prefix_set_prev
)
prefix_set_prev
=
dict
(
prefix_set_prev
)
beam_result
=
[]
beam_result
=
[]
for
(
seq
,
prob
)
in
prefix_set_prev
.
items
():
for
seq
,
prob
in
prefix_set_prev
.
items
():
if
prob
>
0.0
and
len
(
seq
)
>
1
:
if
prob
>
0.0
and
len
(
seq
)
>
1
:
result
=
seq
[
1
:]
result
=
seq
[
1
:]
# score last word by external scorer
# score last word by external scorer
if
(
ext_scoring_func
is
not
None
)
and
(
result
[
-
1
]
!=
' '
):
if
(
ext_scoring_func
is
not
None
)
and
(
result
[
-
1
]
!=
' '
):
prob
=
prob
*
ext_scoring_func
(
result
)
prob
=
prob
*
ext_scoring_func
(
result
)
log_prob
=
np
.
log
(
prob
)
log_prob
=
np
.
log
(
prob
)
beam_result
.
append
(
[
log_prob
,
result
]
)
beam_result
.
append
(
(
log_prob
,
result
)
)
## output top beam_size decoding results
## output top beam_size decoding results
beam_result
=
sorted
(
beam_result
,
key
=
lambda
asd
:
asd
[
0
],
reverse
=
True
)
beam_result
=
sorted
(
beam_result
,
key
=
lambda
asd
:
asd
[
0
],
reverse
=
True
)
return
beam_result
return
beam_result
def
ctc_beam_search_decoder_
nproc
(
probs_split
,
def
ctc_beam_search_decoder_
batch
(
probs_split
,
beam_size
,
beam_size
,
vocabulary
,
vocabulary
,
blank_id
=
0
,
blank_id
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_prob
=
1.0
,
ext_scoring_func
=
None
,
ext_scoring_func
=
None
):
num_processes
=
None
):
"""CTC beam search decoder using multiple processes.
'''Beam search decoder using multiple processes.
:param probs_seq: 3-D list with length batch_size, each element
:param probs_seq: 3-D list with each element as an instance of 2-D list
is a 2-D list of probabilities can be used by
of probabilities used by ctc_beam_search_decoder().
ctc_beam_search_decoder.
:type probs_seq: 3-D list
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:param beam_size: Width for beam search.
:type beam_size: int
:type beam_size: int
:param vocabulary: Vocabulary list.
:param vocabulary: Vocabulary list.
:type vocabulary: list
:type vocabulary: list
:param blank_id: ID of blank
, default 0
.
:param blank_id: ID of blank.
:type blank_id: int
:type blank_id: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
:param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning.
default 1.0, no pruning.
:param num_processes: Number of parallel processes.
:type num_processes: int
:type cutoff_prob: float
:type cutoff_prob: float
:param ext_scoring_func: External
defined
scoring function for
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
partially decoded sentence, e.g. word count
and language model.
or language model.
:type external_scoring_function: function
:type external_scoring_function: callable
:param num_processes: Number of processes, default None, equal to the
:return: List of tuples of log probability and sentence as decoding
number of CPUs.
results, in descending order of the probability.
:type num_processes: int
:return: Decoding log probabilities and result sentences in descending order.
:rtype: list
:rtype: list
'''
"""
if
num_processes
is
None
:
num_processes
=
multiprocessing
.
cpu_count
()
if
not
num_processes
>
0
:
if
not
num_processes
>
0
:
raise
ValueError
(
"Number of processes must be positive!"
)
raise
ValueError
(
"Number of processes must be positive!"
)
...
@@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split,
...
@@ -227,7 +237,5 @@ def ctc_beam_search_decoder_nproc(probs_split,
pool
.
close
()
pool
.
close
()
pool
.
join
()
pool
.
join
()
beam_search_results
=
[]
beam_search_results
=
[
result
.
get
()
for
result
in
results
]
for
result
in
results
:
beam_search_results
.
append
(
result
.
get
())
return
beam_search_results
return
beam_search_results
deep_speech_2/evaluate.py
浏览文件 @
63a72c1e
...
@@ -3,22 +3,22 @@ from __future__ import absolute_import
...
@@ -3,22 +3,22 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
paddle.v2
as
paddle
import
distutils.util
import
distutils.util
import
argparse
import
argparse
import
gzip
import
gzip
import
paddle.v2
as
paddle
from
data_utils.data
import
DataGenerator
from
data_utils.data
import
DataGenerator
from
model
import
deep_speech2
from
model
import
deep_speech2
from
decoder
import
*
from
decoder
import
*
from
scorer
import
Scorer
from
lm.lm_scorer
import
Lm
Scorer
from
error_rate
import
wer
from
error_rate
import
wer
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
num_samples
"
,
"--
batch_size
"
,
default
=
100
,
default
=
100
,
type
=
int
,
type
=
int
,
help
=
"
Number of samples
for evaluation. (default: %(default)s)"
)
help
=
"
Minibatch size
for evaluation. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_conv_layers"
,
"--num_conv_layers"
,
default
=
2
,
default
=
2
,
...
@@ -39,6 +39,16 @@ parser.add_argument(
...
@@ -39,6 +39,16 @@ parser.add_argument(
default
=
True
,
default
=
True
,
type
=
distutils
.
util
.
strtobool
,
type
=
distutils
.
util
.
strtobool
,
help
=
"Use gpu or not. (default: %(default)s)"
)
help
=
"Use gpu or not. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_threads_data"
,
default
=
multiprocessing
.
cpu_count
(),
type
=
int
,
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_processes_beam_search"
,
default
=
multiprocessing
.
cpu_count
(),
type
=
int
,
help
=
"Number of cpu processes for beam search. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mean_std_filepath"
,
"--mean_std_filepath"
,
default
=
'mean_std.npz'
,
default
=
'mean_std.npz'
,
...
@@ -46,10 +56,10 @@ parser.add_argument(
...
@@ -46,10 +56,10 @@ parser.add_argument(
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode_method"
,
"--decode_method"
,
default
=
'beam_search
_nproc
'
,
default
=
'beam_search'
,
type
=
str
,
type
=
str
,
help
=
"Method for ctc decoding, best_path
,
"
help
=
"Method for ctc decoding, best_path
or beam_search. (default: %(default)s)
"
"beam_search or beam_search_nproc. (default: %(default)s)"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--language_model_path"
,
"--language_model_path"
,
default
=
"data/en.00.UNKNOWN.klm"
,
default
=
"data/en.00.UNKNOWN.klm"
,
...
@@ -76,11 +86,6 @@ parser.add_argument(
...
@@ -76,11 +86,6 @@ parser.add_argument(
default
=
500
,
default
=
500
,
type
=
int
,
type
=
int
,
help
=
"Width for beam search decoding. (default: %(default)d)"
)
help
=
"Width for beam search decoding. (default: %(default)d)"
)
parser
.
add_argument
(
"--normalizer_manifest_path"
,
default
=
'data/manifest.libri.train-clean-100'
,
type
=
str
,
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode_manifest_path"
,
"--decode_manifest_path"
,
default
=
'data/manifest.libri.test-clean'
,
default
=
'data/manifest.libri.test-clean'
,
...
@@ -88,7 +93,7 @@ parser.add_argument(
...
@@ -88,7 +93,7 @@ parser.add_argument(
help
=
"Manifest path for decoding. (default: %(default)s)"
)
help
=
"Manifest path for decoding. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_filepath"
,
"--model_filepath"
,
default
=
'
./params
.tar.gz'
,
default
=
'
checkpoints/params.latest
.tar.gz'
,
type
=
str
,
type
=
str
,
help
=
"Model filepath. (default: %(default)s)"
)
help
=
"Model filepath. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -101,12 +106,12 @@ args = parser.parse_args()
...
@@ -101,12 +106,12 @@ args = parser.parse_args()
def
evaluate
():
def
evaluate
():
"""Evaluate on whole test data for DeepSpeech2."""
"""Evaluate on whole test data for DeepSpeech2."""
# initialize data generator
# initialize data generator
data_generator
=
DataGenerator
(
data_generator
=
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
vocab_filepath
=
args
.
vocab_filepath
,
mean_std_filepath
=
args
.
mean_std_filepath
,
mean_std_filepath
=
args
.
mean_std_filepath
,
augmentation_config
=
'{}'
)
augmentation_config
=
'{}'
,
num_threads
=
args
.
num_threads_data
)
# create network config
# create network config
# paddle.data_type.dense_array is used for variable batch input.
# paddle.data_type.dense_array is used for variable batch input.
...
@@ -133,7 +138,7 @@ def evaluate():
...
@@ -133,7 +138,7 @@ def evaluate():
# prepare infer data
# prepare infer data
batch_reader
=
data_generator
.
batch_reader_creator
(
batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
args
.
decode_manifest_path
,
manifest_path
=
args
.
decode_manifest_path
,
batch_size
=
args
.
num_samples
,
batch_size
=
args
.
batch_size
,
sortagrad
=
False
,
sortagrad
=
False
,
shuffle_method
=
None
)
shuffle_method
=
None
)
...
@@ -142,9 +147,8 @@ def evaluate():
...
@@ -142,9 +147,8 @@ def evaluate():
output_layer
=
output_probs
,
parameters
=
parameters
)
output_layer
=
output_probs
,
parameters
=
parameters
)
# initialize external scorer for beam search decoding
# initialize external scorer for beam search decoding
if
args
.
decode_method
==
'beam_search'
or
\
if
args
.
decode_method
==
'beam_search'
:
args
.
decode_method
==
'beam_search_nproc'
:
ext_scorer
=
LmScorer
(
args
.
alpha
,
args
.
beta
,
args
.
language_model_path
)
ext_scorer
=
Scorer
(
args
.
alpha
,
args
.
beta
,
args
.
language_model_path
)
wer_counter
,
wer_sum
=
0
,
0.0
wer_counter
,
wer_sum
=
0
,
0.0
for
infer_data
in
batch_reader
():
for
infer_data
in
batch_reader
():
...
@@ -155,56 +159,39 @@ def evaluate():
...
@@ -155,56 +159,39 @@ def evaluate():
infer_results
[
i
*
num_steps
:(
i
+
1
)
*
num_steps
]
infer_results
[
i
*
num_steps
:(
i
+
1
)
*
num_steps
]
for
i
in
xrange
(
0
,
len
(
infer_data
))
for
i
in
xrange
(
0
,
len
(
infer_data
))
]
]
# target transcription
target_transcription
=
[
''
.
join
([
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
])
for
i
,
probs
in
enumerate
(
probs_split
)
]
# decode and print
# decode and print
# best path decode
# best path decode
if
args
.
decode_method
==
"best_path"
:
if
args
.
decode_method
==
"best_path"
:
for
i
,
probs
in
enumerate
(
probs_split
):
for
i
,
probs
in
enumerate
(
probs_split
):
output_transcription
=
ctc_best_path_decode
(
output_transcription
=
ctc_best_path_decode
r
(
probs_seq
=
probs
,
vocabulary
=
data_generator
.
vocab_list
)
probs_seq
=
probs
,
vocabulary
=
data_generator
.
vocab_list
)
target_transcription
=
''
.
join
([
wer_sum
+=
wer
(
target_transcription
[
i
],
output_transcription
)
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
])
wer_sum
+=
wer
(
target_transcription
,
output_transcription
)
wer_counter
+=
1
wer_counter
+=
1
# beam search decode
in single process
# beam search decode
elif
args
.
decode_method
==
"beam_search"
:
elif
args
.
decode_method
==
"beam_search"
:
for
i
,
probs
in
enumerate
(
probs_split
):
target_transcription
=
''
.
join
([
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
])
beam_search_result
=
ctc_beam_search_decoder
(
probs_seq
=
probs
,
vocabulary
=
data_generator
.
vocab_list
,
beam_size
=
args
.
beam_size
,
blank_id
=
len
(
data_generator
.
vocab_list
),
ext_scoring_func
=
ext_scorer
,
cutoff_prob
=
args
.
cutoff_prob
,
)
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_counter
+=
1
# beam search using multiple processes
# beam search using multiple processes
elif
args
.
decode_method
==
"beam_search_nproc"
:
beam_search_results
=
ctc_beam_search_decoder_batch
(
beam_search_nproc_results
=
ctc_beam_search_decoder_nproc
(
probs_split
=
probs_split
,
probs_split
=
probs_split
,
vocabulary
=
data_generator
.
vocab_list
,
vocabulary
=
data_generator
.
vocab_list
,
beam_size
=
args
.
beam_size
,
beam_size
=
args
.
beam_size
,
blank_id
=
len
(
data_generator
.
vocab_list
),
blank_id
=
len
(
data_generator
.
vocab_list
),
num_processes
=
args
.
num_processes_beam_search
,
ext_scoring_func
=
ext_scorer
,
ext_scoring_func
=
ext_scorer
,
cutoff_prob
=
args
.
cutoff_prob
,
)
cutoff_prob
=
args
.
cutoff_prob
,
)
for
i
,
beam_search_result
in
enumerate
(
beam_search_nproc_results
):
for
i
,
beam_search_result
in
enumerate
(
beam_search_results
):
target_transcription
=
''
.
join
([
wer_sum
+=
wer
(
target_transcription
[
i
],
data_generator
.
vocab_list
[
index
]
beam_search_result
[
0
][
1
])
for
index
in
infer_data
[
i
][
1
]
])
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_counter
+=
1
wer_counter
+=
1
else
:
else
:
raise
ValueError
(
"Decoding method [%s] is not supported."
%
raise
ValueError
(
"Decoding method [%s] is not supported."
%
decode_method
)
decode_method
)
print
(
"Cur WER = %f"
%
(
wer_sum
/
wer_counter
))
print
(
"Final WER = %f"
%
(
wer_sum
/
wer_counter
))
print
(
"Final WER = %f"
%
(
wer_sum
/
wer_counter
))
...
...
deep_speech_2/infer.py
浏览文件 @
63a72c1e
...
@@ -11,14 +11,14 @@ import paddle.v2 as paddle
...
@@ -11,14 +11,14 @@ import paddle.v2 as paddle
from
data_utils.data
import
DataGenerator
from
data_utils.data
import
DataGenerator
from
model
import
deep_speech2
from
model
import
deep_speech2
from
decoder
import
*
from
decoder
import
*
from
scorer
import
Scorer
from
lm.lm_scorer
import
Lm
Scorer
from
error_rate
import
wer
from
error_rate
import
wer
import
utils
import
utils
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_samples"
,
"--num_samples"
,
default
=
10
0
,
default
=
10
,
type
=
int
,
type
=
int
,
help
=
"Number of samples for inference. (default: %(default)s)"
)
help
=
"Number of samples for inference. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -46,6 +46,11 @@ parser.add_argument(
...
@@ -46,6 +46,11 @@ parser.add_argument(
default
=
multiprocessing
.
cpu_count
(),
default
=
multiprocessing
.
cpu_count
(),
type
=
int
,
type
=
int
,
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_processes_beam_search"
,
default
=
multiprocessing
.
cpu_count
(),
type
=
int
,
help
=
"Number of cpu processes for beam search. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mean_std_filepath"
,
"--mean_std_filepath"
,
default
=
'mean_std.npz'
,
default
=
'mean_std.npz'
,
...
@@ -53,12 +58,12 @@ parser.add_argument(
...
@@ -53,12 +58,12 @@ parser.add_argument(
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode_manifest_path"
,
"--decode_manifest_path"
,
default
=
'data
/manifest.libri.test-100sample
'
,
default
=
'data
sets/manifest.test
'
,
type
=
str
,
type
=
str
,
help
=
"Manifest path for decoding. (default: %(default)s)"
)
help
=
"Manifest path for decoding. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_filepath"
,
"--model_filepath"
,
default
=
'checkpoints/params.
latest.tar.gz
'
,
default
=
'checkpoints/params.
tar.gz.41
'
,
type
=
str
,
type
=
str
,
help
=
"Model filepath. (default: %(default)s)"
)
help
=
"Model filepath. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -68,12 +73,10 @@ parser.add_argument(
...
@@ -68,12 +73,10 @@ parser.add_argument(
help
=
"Vocabulary filepath. (default: %(default)s)"
)
help
=
"Vocabulary filepath. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode_method"
,
"--decode_method"
,
default
=
'beam_search
_nproc
'
,
default
=
'beam_search'
,
type
=
str
,
type
=
str
,
help
=
"Method for ctc decoding:"
help
=
"Method for ctc decoding: best_path or beam_search. (default: %(default)s)"
" best_path,"
)
" beam_search, "
" or beam_search_nproc. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--beam_size"
,
"--beam_size"
,
default
=
500
,
default
=
500
,
...
@@ -86,7 +89,7 @@ parser.add_argument(
...
@@ -86,7 +89,7 @@ parser.add_argument(
help
=
"Number of output per sample in beam search. (default: %(default)d)"
)
help
=
"Number of output per sample in beam search. (default: %(default)d)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--language_model_path"
,
"--language_model_path"
,
default
=
"data/en.00.UNKNOWN.klm"
,
default
=
"
lm/
data/en.00.UNKNOWN.klm"
,
type
=
str
,
type
=
str
,
help
=
"Path for language model. (default: %(default)s)"
)
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -143,6 +146,7 @@ def infer():
...
@@ -143,6 +146,7 @@ def infer():
batch_reader
=
data_generator
.
batch_reader_creator
(
batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
args
.
decode_manifest_path
,
manifest_path
=
args
.
decode_manifest_path
,
batch_size
=
args
.
num_samples
,
batch_size
=
args
.
num_samples
,
min_batch_size
=
1
,
sortagrad
=
False
,
sortagrad
=
False
,
shuffle_method
=
None
)
shuffle_method
=
None
)
infer_data
=
batch_reader
().
next
()
infer_data
=
batch_reader
().
next
()
...
@@ -156,68 +160,45 @@ def infer():
...
@@ -156,68 +160,45 @@ def infer():
for
i
in
xrange
(
len
(
infer_data
))
for
i
in
xrange
(
len
(
infer_data
))
]
]
# targe transcription
target_transcription
=
[
''
.
join
(
[
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]])
for
i
,
probs
in
enumerate
(
probs_split
)
]
## decode and print
## decode and print
# best path decode
# best path decode
wer_sum
,
wer_counter
=
0
,
0
wer_sum
,
wer_counter
=
0
,
0
if
args
.
decode_method
==
"best_path"
:
if
args
.
decode_method
==
"best_path"
:
for
i
,
probs
in
enumerate
(
probs_split
):
for
i
,
probs
in
enumerate
(
probs_split
):
target_transcription
=
''
.
join
([
best_path_transcription
=
ctc_best_path_decoder
(
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
])
best_path_transcription
=
ctc_best_path_decode
(
probs_seq
=
probs
,
vocabulary
=
data_generator
.
vocab_list
)
probs_seq
=
probs
,
vocabulary
=
data_generator
.
vocab_list
)
print
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
print
(
"
\n
Target Transcription: %s
\n
Output Transcription: %s"
%
(
target_transcription
,
best_path_transcription
))
(
target_transcription
[
i
]
,
best_path_transcription
))
wer_cur
=
wer
(
target_transcription
,
best_path_transcription
)
wer_cur
=
wer
(
target_transcription
[
i
]
,
best_path_transcription
)
wer_sum
+=
wer_cur
wer_sum
+=
wer_cur
wer_counter
+=
1
wer_counter
+=
1
print
(
"cur wer = %f, average wer = %f"
%
print
(
"cur wer = %f, average wer = %f"
%
(
wer_cur
,
wer_sum
/
wer_counter
))
(
wer_cur
,
wer_sum
/
wer_counter
))
# beam search decode
# beam search decode
elif
args
.
decode_method
==
"beam_search"
:
elif
args
.
decode_method
==
"beam_search"
:
ext_scorer
=
Scorer
(
args
.
alpha
,
args
.
beta
,
args
.
language_model_path
)
ext_scorer
=
LmScorer
(
args
.
alpha
,
args
.
beta
,
args
.
language_model_path
)
for
i
,
probs
in
enumerate
(
probs_split
):
beam_search_batch_results
=
ctc_beam_search_decoder_batch
(
target_transcription
=
''
.
join
([
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
])
beam_search_result
=
ctc_beam_search_decoder
(
probs_seq
=
probs
,
vocabulary
=
data_generator
.
vocab_list
,
beam_size
=
args
.
beam_size
,
blank_id
=
len
(
data_generator
.
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
ext_scoring_func
=
ext_scorer
,
)
print
(
"
\n
Target Transcription:
\t
%s"
%
target_transcription
)
for
index
in
xrange
(
args
.
num_results_per_sample
):
result
=
beam_search_result
[
index
]
#output: index, log prob, beam result
print
(
"Beam %d: %f
\t
%s"
%
(
index
,
result
[
0
],
result
[
1
]))
wer_cur
=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_sum
+=
wer_cur
wer_counter
+=
1
print
(
"cur wer = %f , average wer = %f"
%
(
wer_cur
,
wer_sum
/
wer_counter
))
elif
args
.
decode_method
==
"beam_search_nproc"
:
ext_scorer
=
Scorer
(
args
.
alpha
,
args
.
beta
,
args
.
language_model_path
)
beam_search_nproc_results
=
ctc_beam_search_decoder_nproc
(
probs_split
=
probs_split
,
probs_split
=
probs_split
,
vocabulary
=
data_generator
.
vocab_list
,
vocabulary
=
data_generator
.
vocab_list
,
beam_size
=
args
.
beam_size
,
beam_size
=
args
.
beam_size
,
blank_id
=
len
(
data_generator
.
vocab_list
),
blank_id
=
len
(
data_generator
.
vocab_list
),
num_processes
=
args
.
num_processes_beam_search
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_prob
=
args
.
cutoff_prob
,
ext_scoring_func
=
ext_scorer
,
)
ext_scoring_func
=
ext_scorer
,
)
for
i
,
beam_search_result
in
enumerate
(
beam_search_nproc_results
):
for
i
,
beam_search_result
in
enumerate
(
beam_search_batch_results
):
target_transcription
=
''
.
join
([
print
(
"
\n
Target Transcription:
\t
%s"
%
target_transcription
[
i
])
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
])
print
(
"
\n
Target Transcription:
\t
%s"
%
target_transcription
)
for
index
in
xrange
(
args
.
num_results_per_sample
):
for
index
in
xrange
(
args
.
num_results_per_sample
):
result
=
beam_search_result
[
index
]
result
=
beam_search_result
[
index
]
#output: index, log prob, beam result
#output: index, log prob, beam result
print
(
"Beam %d: %f
\t
%s"
%
(
index
,
result
[
0
],
result
[
1
]))
print
(
"Beam %d: %f
\t
%s"
%
(
index
,
result
[
0
],
result
[
1
]))
wer_cur
=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_cur
=
wer
(
target_transcription
[
i
]
,
beam_search_result
[
0
][
1
])
wer_sum
+=
wer_cur
wer_sum
+=
wer_cur
wer_counter
+=
1
wer_counter
+=
1
print
(
"cur wer = %f , average wer = %f"
%
print
(
"cur wer = %f , average wer = %f"
%
...
...
deep_speech_2/lm/__init__.py
0 → 100644
浏览文件 @
63a72c1e
deep_speech_2/scorer.py
→
deep_speech_2/
lm/lm_
scorer.py
浏览文件 @
63a72c1e
...
@@ -8,13 +8,16 @@ import kenlm
...
@@ -8,13 +8,16 @@ import kenlm
import
numpy
as
np
import
numpy
as
np
class
Scorer
(
object
):
class
LmScorer
(
object
):
"""External defined scorer to evaluate a sentence in beam search
"""External scorer to evaluate a prefix or whole sentence in
decoding, consisting of language model and word count.
beam search decoding, including the score from n-gram language
model and word count.
:param alpha: Parameter associated with language model.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:type alpha: float
:param beta: Parameter associated with word count.
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:type beta: float
:model_path: Path to load language model.
:model_path: Path to load language model.
:type model_path: basestring
:type model_path: basestring
...
@@ -28,14 +31,14 @@ class Scorer(object):
...
@@ -28,14 +31,14 @@ class Scorer(object):
self
.
_language_model
=
kenlm
.
LanguageModel
(
model_path
)
self
.
_language_model
=
kenlm
.
LanguageModel
(
model_path
)
# n-gram language model scoring
# n-gram language model scoring
def
language_model_score
(
self
,
sentence
):
def
_
language_model_score
(
self
,
sentence
):
#log10 prob of last word
#log10 prob of last word
log_cond_prob
=
list
(
log_cond_prob
=
list
(
self
.
_language_model
.
full_scores
(
sentence
,
eos
=
False
))[
-
1
][
0
]
self
.
_language_model
.
full_scores
(
sentence
,
eos
=
False
))[
-
1
][
0
]
return
np
.
power
(
10
,
log_cond_prob
)
return
np
.
power
(
10
,
log_cond_prob
)
# word insertion term
# word insertion term
def
word_count
(
self
,
sentence
):
def
_
word_count
(
self
,
sentence
):
words
=
sentence
.
strip
().
split
(
' '
)
words
=
sentence
.
strip
().
split
(
' '
)
return
len
(
words
)
return
len
(
words
)
...
@@ -51,8 +54,8 @@ class Scorer(object):
...
@@ -51,8 +54,8 @@ class Scorer(object):
:return: Evaluation score, in the decimal or log.
:return: Evaluation score, in the decimal or log.
:rtype: float
:rtype: float
"""
"""
lm
=
self
.
language_model_score
(
sentence
)
lm
=
self
.
_
language_model_score
(
sentence
)
word_cnt
=
self
.
word_count
(
sentence
)
word_cnt
=
self
.
_
word_count
(
sentence
)
if
log
==
False
:
if
log
==
False
:
score
=
np
.
power
(
lm
,
self
.
_alpha
)
\
score
=
np
.
power
(
lm
,
self
.
_alpha
)
\
*
np
.
power
(
word_cnt
,
self
.
_beta
)
*
np
.
power
(
word_cnt
,
self
.
_beta
)
...
...
deep_speech_2/lm/run.sh
0 → 100644
浏览文件 @
63a72c1e
echo
"Downloading language model."
wget
-c
ftp://xxx/xxx/en.00.UNKNOWN.klm
-P
./data
deep_speech_2/requirements.txt
浏览文件 @
63a72c1e
SoundFile==0.9.0.post1
SoundFile==0.9.0.post1
wget==3.2
wget==3.2
scipy==0.13.1
scipy==0.13.1
https://github.com/kpu/kenlm/archive/master.zip
deep_speech_2/tests/test_decoders.py
浏览文件 @
63a72c1e
...
@@ -53,11 +53,11 @@ class TestDecoders(unittest.TestCase):
...
@@ -53,11 +53,11 @@ class TestDecoders(unittest.TestCase):
self
.
beam_search_result
=
[
'acdc'
,
"b'a"
]
self
.
beam_search_result
=
[
'acdc'
,
"b'a"
]
def
test_best_path_decoder_1
(
self
):
def
test_best_path_decoder_1
(
self
):
bst_result
=
ctc_best_path_decode
(
self
.
probs_seq1
,
self
.
vocab_list
)
bst_result
=
ctc_best_path_decode
r
(
self
.
probs_seq1
,
self
.
vocab_list
)
self
.
assertEqual
(
bst_result
,
self
.
best_path_result
[
0
])
self
.
assertEqual
(
bst_result
,
self
.
best_path_result
[
0
])
def
test_best_path_decoder_2
(
self
):
def
test_best_path_decoder_2
(
self
):
bst_result
=
ctc_best_path_decode
(
self
.
probs_seq2
,
self
.
vocab_list
)
bst_result
=
ctc_best_path_decode
r
(
self
.
probs_seq2
,
self
.
vocab_list
)
self
.
assertEqual
(
bst_result
,
self
.
best_path_result
[
1
])
self
.
assertEqual
(
bst_result
,
self
.
best_path_result
[
1
])
def
test_beam_search_decoder_1
(
self
):
def
test_beam_search_decoder_1
(
self
):
...
@@ -77,7 +77,7 @@ class TestDecoders(unittest.TestCase):
...
@@ -77,7 +77,7 @@ class TestDecoders(unittest.TestCase):
self
.
assertEqual
(
beam_result
[
0
][
1
],
self
.
beam_search_result
[
1
])
self
.
assertEqual
(
beam_result
[
0
][
1
],
self
.
beam_search_result
[
1
])
def
test_beam_search_nproc_decoder
(
self
):
def
test_beam_search_nproc_decoder
(
self
):
beam_results
=
ctc_beam_search_decoder_
nproc
(
beam_results
=
ctc_beam_search_decoder_
batch
(
probs_split
=
[
self
.
probs_seq1
,
self
.
probs_seq2
],
probs_split
=
[
self
.
probs_seq1
,
self
.
probs_seq2
],
beam_size
=
self
.
beam_size
,
beam_size
=
self
.
beam_size
,
vocabulary
=
self
.
vocab_list
,
vocabulary
=
self
.
vocab_list
,
...
...
deep_speech_2/tune.py
浏览文件 @
63a72c1e
...
@@ -3,14 +3,14 @@ from __future__ import absolute_import
...
@@ -3,14 +3,14 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
paddle.v2
as
paddle
import
distutils.util
import
distutils.util
import
argparse
import
argparse
import
gzip
import
gzip
import
paddle.v2
as
paddle
from
data_utils.data
import
DataGenerator
from
data_utils.data
import
DataGenerator
from
model
import
deep_speech2
from
model
import
deep_speech2
from
decoder
import
*
from
decoder
import
*
from
scorer
import
Scorer
from
lm.lm_scorer
import
Lm
Scorer
from
error_rate
import
wer
from
error_rate
import
wer
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
@@ -39,24 +39,29 @@ parser.add_argument(
...
@@ -39,24 +39,29 @@ parser.add_argument(
default
=
True
,
default
=
True
,
type
=
distutils
.
util
.
strtobool
,
type
=
distutils
.
util
.
strtobool
,
help
=
"Use gpu or not. (default: %(default)s)"
)
help
=
"Use gpu or not. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_threads_data"
,
default
=
multiprocessing
.
cpu_count
(),
type
=
int
,
help
=
"Number of cpu threads for preprocessing data. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_processes_beam_search"
,
default
=
multiprocessing
.
cpu_count
(),
type
=
int
,
help
=
"Number of cpu processes for beam search. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--mean_std_filepath"
,
"--mean_std_filepath"
,
default
=
'mean_std.npz'
,
default
=
'mean_std.npz'
,
type
=
str
,
type
=
str
,
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
"--normalizer_manifest_path"
,
default
=
'data/manifest.libri.train-clean-100'
,
type
=
str
,
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode_manifest_path"
,
"--decode_manifest_path"
,
default
=
'data
/manifest.libri.test-100sample
'
,
default
=
'data
sets/manifest.test
'
,
type
=
str
,
type
=
str
,
help
=
"Manifest path for decoding. (default: %(default)s)"
)
help
=
"Manifest path for decoding. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_filepath"
,
"--model_filepath"
,
default
=
'
./params
.tar.gz'
,
default
=
'
checkpoints/params.latest
.tar.gz'
,
type
=
str
,
type
=
str
,
help
=
"Model filepath. (default: %(default)s)"
)
help
=
"Model filepath. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -64,25 +69,14 @@ parser.add_argument(
...
@@ -64,25 +69,14 @@ parser.add_argument(
default
=
'datasets/vocab/eng_vocab.txt'
,
default
=
'datasets/vocab/eng_vocab.txt'
,
type
=
str
,
type
=
str
,
help
=
"Vocabulary filepath. (default: %(default)s)"
)
help
=
"Vocabulary filepath. (default: %(default)s)"
)
parser
.
add_argument
(
"--decode_method"
,
default
=
'beam_search_nproc'
,
type
=
str
,
help
=
"Method for decoding, beam_search or beam_search_nproc. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--beam_size"
,
"--beam_size"
,
default
=
500
,
default
=
500
,
type
=
int
,
type
=
int
,
help
=
"Width for beam search decoding. (default: %(default)d)"
)
help
=
"Width for beam search decoding. (default: %(default)d)"
)
parser
.
add_argument
(
"--num_results_per_sample"
,
default
=
1
,
type
=
int
,
help
=
"Number of outputs per sample in beam search. (default: %(default)d)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--language_model_path"
,
"--language_model_path"
,
default
=
"data/en.00.UNKNOWN.klm"
,
default
=
"
lm/
data/en.00.UNKNOWN.klm"
,
type
=
str
,
type
=
str
,
help
=
"Path for language model. (default: %(default)s)"
)
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -137,7 +131,8 @@ def tune():
...
@@ -137,7 +131,8 @@ def tune():
data_generator
=
DataGenerator
(
data_generator
=
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
vocab_filepath
=
args
.
vocab_filepath
,
mean_std_filepath
=
args
.
mean_std_filepath
,
mean_std_filepath
=
args
.
mean_std_filepath
,
augmentation_config
=
'{}'
)
augmentation_config
=
'{}'
,
num_threads
=
args
.
num_threads_data
)
# create network config
# create network config
# paddle.data_type.dense_array is used for variable batch input.
# paddle.data_type.dense_array is used for variable batch input.
...
@@ -188,42 +183,22 @@ def tune():
...
@@ -188,42 +183,22 @@ def tune():
## tune parameters in loop
## tune parameters in loop
for
(
alpha
,
beta
)
in
params_grid
:
for
(
alpha
,
beta
)
in
params_grid
:
wer_sum
,
wer_counter
=
0
,
0
wer_sum
,
wer_counter
=
0
,
0
ext_scorer
=
Scorer
(
alpha
,
beta
,
args
.
language_model_path
)
ext_scorer
=
LmScorer
(
alpha
,
beta
,
args
.
language_model_path
)
# beam search decode
if
args
.
decode_method
==
"beam_search"
:
for
i
,
probs
in
enumerate
(
probs_split
):
target_transcription
=
''
.
join
([
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
])
beam_search_result
=
ctc_beam_search_decoder
(
probs_seq
=
probs
,
vocabulary
=
data_generator
.
vocab_list
,
beam_size
=
args
.
beam_size
,
blank_id
=
len
(
data_generator
.
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
ext_scoring_func
=
ext_scorer
,
)
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_counter
+=
1
# beam search using multiple processes
# beam search using multiple processes
elif
args
.
decode_method
==
"beam_search_nproc"
:
beam_search_results
=
ctc_beam_search_decoder_batch
(
beam_search_nproc_results
=
ctc_beam_search_decoder_nproc
(
probs_split
=
probs_split
,
probs_split
=
probs_split
,
vocabulary
=
data_generator
.
vocab_list
,
vocabulary
=
data_generator
.
vocab_list
,
beam_size
=
args
.
beam_size
,
beam_size
=
args
.
beam_size
,
cutoff_prob
=
args
.
cutoff_prob
,
cutoff_prob
=
args
.
cutoff_prob
,
blank_id
=
len
(
data_generator
.
vocab_list
),
blank_id
=
len
(
data_generator
.
vocab_list
),
num_processes
=
args
.
num_processes_beam_search
,
ext_scoring_func
=
ext_scorer
,
)
ext_scoring_func
=
ext_scorer
,
)
for
i
,
beam_search_result
in
enumerate
(
beam_search_nproc
_results
):
for
i
,
beam_search_result
in
enumerate
(
beam_search
_results
):
target_transcription
=
''
.
join
([
target_transcription
=
''
.
join
([
data_generator
.
vocab_list
[
index
]
data_generator
.
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]
for
index
in
infer_data
[
i
][
1
]
])
])
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_counter
+=
1
wer_counter
+=
1
else
:
raise
ValueError
(
"Decoding method [%s] is not supported."
%
decode_method
)
print
(
"alpha = %f
\t
beta = %f
\t
WER = %f"
%
print
(
"alpha = %f
\t
beta = %f
\t
WER = %f"
%
(
alpha
,
beta
,
wer_sum
/
wer_counter
))
(
alpha
,
beta
,
wer_sum
/
wer_counter
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录