Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
ff01d048
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ff01d048
编写于
6月 18, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
final refining on old data provider: enable pruning & add evaluation & code cleanup
上级
a633eb9c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
339 addition
and
72 deletion
+339
-72
decoder.py
decoder.py
+61
-23
evaluate.py
evaluate.py
+214
-0
infer.py
infer.py
+25
-15
tune.py
tune.py
+39
-34
未找到文件。
decoder.py
浏览文件 @
ff01d048
...
...
@@ -5,7 +5,6 @@
import
os
from
itertools
import
groupby
import
numpy
as
np
import
copy
import
kenlm
import
multiprocessing
...
...
@@ -73,11 +72,25 @@ class Scorer(object):
return
len
(
words
)
# execute evaluation
def
__call__
(
self
,
sentence
):
def
__call__
(
self
,
sentence
,
log
=
False
):
"""
Evaluation function
:param sentence: The input sentence for evalutation
:type sentence: basestring
:param log: Whether return the score in log representation.
:type log: bool
:return: Evaluation score, in the decimal or log.
:rtype: float
"""
lm
=
self
.
language_model_score
(
sentence
)
word_cnt
=
self
.
word_count
(
sentence
)
score
=
np
.
power
(
lm
,
self
.
_alpha
)
\
*
np
.
power
(
word_cnt
,
self
.
_beta
)
if
log
==
False
:
score
=
np
.
power
(
lm
,
self
.
_alpha
)
\
*
np
.
power
(
word_cnt
,
self
.
_beta
)
else
:
score
=
self
.
_alpha
*
np
.
log
(
lm
)
\
+
self
.
_beta
*
np
.
log
(
word_cnt
)
return
score
...
...
@@ -85,13 +98,14 @@ def ctc_beam_search_decoder(probs_seq,
beam_size
,
vocabulary
,
blank_id
=
0
,
cutoff_prob
=
1.0
,
ext_scoring_func
=
None
,
nproc
=
False
):
'''
Beam search decoder for CTC-trained network, using beam search with width
beam_size to find many paths to one label, return beam_size labels in
the
order of probabilities. The implementation is based on Prefix Beam
Search(https://arxiv.org/abs/1408.2873), and the unclear part is
the
descending order of probabilities. The implementation is based on Prefix
Beam
Search(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned, need to be verified.
:param probs_seq: 2-D list with length num_time_steps, each element
...
...
@@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: Decoding log probabilit
y and result string
.
:return: Decoding log probabilit
ies and result sentences in descending order
.
:rtype: list
'''
# dimension check
for
prob_list
in
probs_seq
:
if
not
len
(
prob_list
)
==
len
(
vocabulary
)
+
1
:
raise
ValueError
(
"probs dimension mismatched
d
with vocabulary"
)
raise
ValueError
(
"probs dimension mismatched with vocabulary"
)
num_time_steps
=
len
(
probs_seq
)
# blank_id check
...
...
@@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_prev
,
probs_nb_prev
=
{
'
\t
'
:
1.0
},
{
'
\t
'
:
0.0
}
## extend prefix in loop
for
time_step
in
range
(
num_time_steps
):
for
time_step
in
x
range
(
num_time_steps
):
# the set containing candidate prefixes
prefix_set_next
=
{}
probs_b_cur
,
probs_nb_cur
=
{},
{}
prob
=
probs_seq
[
time_step
]
prob_idx
=
[[
i
,
prob
[
i
]]
for
i
in
xrange
(
len
(
prob
))]
cutoff_len
=
len
(
prob_idx
)
#If pruning is enabled
if
(
cutoff_prob
<
1.0
):
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
cutoff_len
=
0
cum_prob
=
0.0
for
i
in
xrange
(
len
(
prob_idx
)):
cum_prob
+=
prob_idx
[
i
][
1
]
cutoff_len
+=
1
if
cum_prob
>=
cutoff_prob
:
break
prob_idx
=
prob_idx
[
0
:
cutoff_len
]
for
l
in
prefix_set_prev
:
prob
=
probs_seq
[
time_step
]
if
not
prefix_set_next
.
has_key
(
l
):
probs_b_cur
[
l
],
probs_nb_cur
[
l
]
=
0.0
,
0.0
# extend prefix by travering vocabulary
for
c
in
range
(
0
,
probs_dim
):
# extend prefix by travering prob_idx
for
index
in
xrange
(
cutoff_len
):
c
,
prob_c
=
prob_idx
[
index
][
0
],
prob_idx
[
index
][
1
]
if
c
==
blank_id
:
probs_b_cur
[
l
]
+=
prob
[
c
]
*
(
probs_b_cur
[
l
]
+=
prob
_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
])
else
:
last_char
=
l
[
-
1
]
...
...
@@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq,
probs_b_cur
[
l_plus
],
probs_nb_cur
[
l_plus
]
=
0.0
,
0.0
if
new_char
==
last_char
:
probs_nb_cur
[
l_plus
]
+=
prob
[
c
]
*
probs_b_prev
[
l
]
probs_nb_cur
[
l
]
+=
prob
[
c
]
*
probs_nb_prev
[
l
]
probs_nb_cur
[
l_plus
]
+=
prob
_c
*
probs_b_prev
[
l
]
probs_nb_cur
[
l
]
+=
prob
_c
*
probs_nb_prev
[
l
]
elif
new_char
==
' '
:
if
(
ext_scoring_func
is
None
)
or
(
len
(
l
)
==
1
):
score
=
1.0
else
:
prefix
=
l
[
1
:]
score
=
ext_scoring_func
(
prefix
)
probs_nb_cur
[
l_plus
]
+=
score
*
prob
[
c
]
*
(
probs_nb_cur
[
l_plus
]
+=
score
*
prob
_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
])
else
:
probs_nb_cur
[
l_plus
]
+=
prob
[
c
]
*
(
probs_nb_cur
[
l_plus
]
+=
prob
_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
])
# add l_plus into prefix_set_next
prefix_set_next
[
l_plus
]
=
probs_nb_cur
[
...
...
@@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split,
beam_size
,
vocabulary
,
blank_id
=
0
,
cutoff_prob
=
1.0
,
ext_scoring_func
=
None
,
num_processes
=
None
):
'''
...
...
@@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split,
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param blank_id: ID of blank, default 0.
:type blank_id: int
:param cutoff_prob: Cutoff probability in pruning,
default 0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External defined scoring function for
partially decoded sentence, e.g. word count
and language model.
:type external_scoring_function: function
:param blank_id: id of blank, default 0.
:type blank_id: int
:param num_processes: Number of processes, default None, equal to the
number of CPUs.
:type num_processes: int
:return: Decoding log probabilit
y and result string
.
:return: Decoding log probabilit
ies and result sentences in descending order
.
:rtype: list
'''
...
...
@@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split,
pool
=
multiprocessing
.
Pool
(
processes
=
num_processes
)
results
=
[]
for
i
,
probs_list
in
enumerate
(
probs_split
):
args
=
(
probs_list
,
beam_size
,
vocabulary
,
blank_id
,
None
,
nproc
)
args
=
(
probs_list
,
beam_size
,
vocabulary
,
blank_id
,
cutoff_prob
,
None
,
nproc
)
results
.
append
(
pool
.
apply_async
(
ctc_beam_search_decoder
,
args
))
pool
.
close
()
...
...
evaluate.py
0 → 100644
浏览文件 @
ff01d048
"""
Evaluation for a simplifed version of Baidu DeepSpeech2 model.
"""
import
paddle.v2
as
paddle
import
distutils.util
import
argparse
import
gzip
from
audio_data_utils
import
DataGenerator
from
model
import
deep_speech2
from
decoder
import
*
from
error_rate
import
wer
parser
=
argparse
.
ArgumentParser
(
description
=
'Simplified version of DeepSpeech2 evaluation.'
)
parser
.
add_argument
(
"--num_samples"
,
default
=
100
,
type
=
int
,
help
=
"Number of samples for evaluation. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_conv_layers"
,
default
=
2
,
type
=
int
,
help
=
"Convolution layer number. (default: %(default)s)"
)
parser
.
add_argument
(
"--num_rnn_layers"
,
default
=
3
,
type
=
int
,
help
=
"RNN layer number. (default: %(default)s)"
)
parser
.
add_argument
(
"--rnn_layer_size"
,
default
=
512
,
type
=
int
,
help
=
"RNN layer cell number. (default: %(default)s)"
)
parser
.
add_argument
(
"--use_gpu"
,
default
=
True
,
type
=
distutils
.
util
.
strtobool
,
help
=
"Use gpu or not. (default: %(default)s)"
)
parser
.
add_argument
(
"--decode_method"
,
default
=
'beam_search_nproc'
,
type
=
str
,
help
=
"Method for ctc decoding, best_path, "
"beam_search or beam_search_nproc. (default: %(default)s)"
)
parser
.
add_argument
(
"--language_model_path"
,
default
=
"./data/1Billion.klm"
,
type
=
str
,
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
"--alpha"
,
default
=
0.26
,
type
=
float
,
help
=
"Parameter associated with language model. (default: %(default)f)"
)
parser
.
add_argument
(
"--beta"
,
default
=
0.1
,
type
=
float
,
help
=
"Parameter associated with word count. (default: %(default)f)"
)
parser
.
add_argument
(
"--cutoff_prob"
,
default
=
0.99
,
type
=
float
,
help
=
"The cutoff probability of pruning"
"in beam search. (default: %(default)f)"
)
parser
.
add_argument
(
"--beam_size"
,
default
=
500
,
type
=
int
,
help
=
"Width for beam search decoding. (default: %(default)d)"
)
parser
.
add_argument
(
"--normalizer_manifest_path"
,
default
=
'data/manifest.libri.train-clean-100'
,
type
=
str
,
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
"--decode_manifest_path"
,
default
=
'data/manifest.libri.test-clean'
,
type
=
str
,
help
=
"Manifest path for decoding. (default: %(default)s)"
)
parser
.
add_argument
(
"--model_filepath"
,
default
=
'./params.tar.gz'
,
type
=
str
,
help
=
"Model filepath. (default: %(default)s)"
)
parser
.
add_argument
(
"--vocab_filepath"
,
default
=
'data/eng_vocab.txt'
,
type
=
str
,
help
=
"Vocabulary filepath. (default: %(default)s)"
)
args
=
parser
.
parse_args
()
def
evaluate
():
"""
Evaluate on whole test data for DeepSpeech2.
"""
# initialize data generator
data_generator
=
DataGenerator
(
vocab_filepath
=
args
.
vocab_filepath
,
normalizer_manifest_path
=
args
.
normalizer_manifest_path
,
normalizer_num_samples
=
200
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10
,
window_ms
=
20
)
# create network config
dict_size
=
data_generator
.
vocabulary_size
()
vocab_list
=
data_generator
.
vocabulary_list
()
audio_data
=
paddle
.
layer
.
data
(
name
=
"audio_spectrogram"
,
height
=
161
,
width
=
2000
,
type
=
paddle
.
data_type
.
dense_vector
(
322000
))
text_data
=
paddle
.
layer
.
data
(
name
=
"transcript_text"
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
output_probs
=
deep_speech2
(
audio_data
=
audio_data
,
text_data
=
text_data
,
dict_size
=
dict_size
,
num_conv_layers
=
args
.
num_conv_layers
,
num_rnn_layers
=
args
.
num_rnn_layers
,
rnn_size
=
args
.
rnn_layer_size
,
is_inference
=
True
)
# load parameters
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
gzip
.
open
(
args
.
model_filepath
))
# prepare infer data
feeding
=
data_generator
.
data_name_feeding
()
test_batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
args
.
decode_manifest_path
,
batch_size
=
args
.
num_samples
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
False
)
# define inferer
inferer
=
paddle
.
inference
.
Inference
(
output_layer
=
output_probs
,
parameters
=
parameters
)
# initialize external scorer for beam search decoding
if
args
.
decode_method
==
'beam_search'
or
\
args
.
decode_method
==
'beam_search_nproc'
:
ext_scorer
=
Scorer
(
args
.
alpha
,
args
.
beta
,
args
.
language_model_path
)
wer_counter
,
wer_sum
=
0
,
0.0
for
infer_data
in
test_batch_reader
():
# run inference
infer_results
=
inferer
.
infer
(
input
=
infer_data
)
num_steps
=
len
(
infer_results
)
/
len
(
infer_data
)
probs_split
=
[
infer_results
[
i
*
num_steps
:(
i
+
1
)
*
num_steps
]
for
i
in
xrange
(
0
,
len
(
infer_data
))
]
# decode and print
# best path decode
if
args
.
decode_method
==
"best_path"
:
for
i
,
probs
in
enumerate
(
probs_split
):
output_transcription
=
ctc_decode
(
probs_seq
=
probs
,
vocabulary
=
vocab_list
,
method
=
"best_path"
)
target_transcription
=
''
.
join
(
[
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]])
wer_sum
+=
wer
(
target_transcription
,
output_transcription
)
wer_counter
+=
1
# beam search decode in single process
elif
args
.
decode_method
==
"beam_search"
:
for
i
,
probs
in
enumerate
(
probs_split
):
target_transcription
=
''
.
join
(
[
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]])
beam_search_result
=
ctc_beam_search_decoder
(
probs_seq
=
probs
,
vocabulary
=
vocab_list
,
beam_size
=
args
.
beam_size
,
blank_id
=
len
(
vocab_list
),
ext_scoring_func
=
ext_scorer
,
cutoff_prob
=
args
.
cutoff_prob
,
)
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_counter
+=
1
# beam search using multiple processes
elif
args
.
decode_method
==
"beam_search_nproc"
:
beam_search_nproc_results
=
ctc_beam_search_decoder_nproc
(
probs_split
=
probs_split
,
vocabulary
=
vocab_list
,
beam_size
=
args
.
beam_size
,
blank_id
=
len
(
vocab_list
),
ext_scoring_func
=
ext_scorer
,
cutoff_prob
=
args
.
cutoff_prob
,
)
for
i
,
beam_search_result
in
enumerate
(
beam_search_nproc_results
):
target_transcription
=
''
.
join
(
[
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]])
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_counter
+=
1
else
:
raise
ValueError
(
"Decoding method [%s] is not supported."
%
method
)
print
(
"Cur WER = %f"
%
(
wer_sum
/
wer_counter
))
print
(
"Final WER = %f"
%
(
wer_sum
/
wer_counter
))
def
main
():
paddle
.
init
(
use_gpu
=
args
.
use_gpu
,
trainer_count
=
1
)
evaluate
()
if
__name__
==
'__main__'
:
main
()
infer.py
浏览文件 @
ff01d048
...
...
@@ -9,14 +9,14 @@ import gzip
from
audio_data_utils
import
DataGenerator
from
model
import
deep_speech2
from
decoder
import
*
import
kenlm
from
error_rate
import
wer
import
time
parser
=
argparse
.
ArgumentParser
(
description
=
'Simplified version of DeepSpeech2 inference.'
)
parser
.
add_argument
(
"--num_samples"
,
default
=
10
,
default
=
10
0
,
type
=
int
,
help
=
"Number of samples for inference. (default: %(default)s)"
)
parser
.
add_argument
(
...
...
@@ -46,7 +46,7 @@ parser.add_argument(
help
=
"Manifest path for normalizer. (default: %(default)s)"
)
parser
.
add_argument
(
"--decode_manifest_path"
,
default
=
'data/manifest.libri.test-
clean
'
,
default
=
'data/manifest.libri.test-
100sample
'
,
type
=
str
,
help
=
"Manifest path for decoding. (default: %(default)s)"
)
parser
.
add_argument
(
...
...
@@ -63,11 +63,13 @@ parser.add_argument(
"--decode_method"
,
default
=
'beam_search_nproc'
,
type
=
str
,
help
=
"Method for ctc decoding, best_path, beam_search or beam_search_nproc. (default: %(default)s)"
)
help
=
"Method for ctc decoding:"
" best_path,"
" beam_search, "
" or beam_search_nproc. (default: %(default)s)"
)
parser
.
add_argument
(
"--beam_size"
,
default
=
50
,
default
=
50
0
,
type
=
int
,
help
=
"Width for beam search decoding. (default: %(default)d)"
)
parser
.
add_argument
(
...
...
@@ -82,14 +84,20 @@ parser.add_argument(
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
"--alpha"
,
default
=
0.
0
,
default
=
0.
26
,
type
=
float
,
help
=
"Parameter associated with language model. (default: %(default)f)"
)
parser
.
add_argument
(
"--beta"
,
default
=
0.
0
,
default
=
0.
1
,
type
=
float
,
help
=
"Parameter associated with word count. (default: %(default)f)"
)
parser
.
add_argument
(
"--cutoff_prob"
,
default
=
0.99
,
type
=
float
,
help
=
"The cutoff probability of pruning"
"in beam search. (default: %(default)f)"
)
args
=
parser
.
parse_args
()
...
...
@@ -154,6 +162,7 @@ def infer():
## decode and print
# best path decode
wer_sum
,
wer_counter
=
0
,
0
total_time
=
0.0
if
args
.
decode_method
==
"best_path"
:
for
i
,
probs
in
enumerate
(
probs_split
):
target_transcription
=
''
.
join
(
...
...
@@ -177,11 +186,12 @@ def infer():
probs_seq
=
probs
,
vocabulary
=
vocab_list
,
beam_size
=
args
.
beam_size
,
ext_scoring_func
=
ext_scorer
,
blank_id
=
len
(
vocab_list
))
blank_id
=
len
(
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
ext_scoring_func
=
ext_scorer
,
)
print
(
"
\n
Target Transcription:
\t
%s"
%
target_transcription
)
for
index
in
range
(
args
.
num_results_per_sample
):
for
index
in
x
range
(
args
.
num_results_per_sample
):
result
=
beam_search_result
[
index
]
#output: index, log prob, beam result
print
(
"Beam %d: %f
\t
%s"
%
(
index
,
result
[
0
],
result
[
1
]))
...
...
@@ -190,21 +200,21 @@ def infer():
wer_counter
+=
1
print
(
"cur wer = %f , average wer = %f"
%
(
wer_cur
,
wer_sum
/
wer_counter
))
# beam search using multiple processes
elif
args
.
decode_method
==
"beam_search_nproc"
:
ext_scorer
=
Scorer
(
args
.
alpha
,
args
.
beta
,
args
.
language_model_path
)
beam_search_nproc_results
=
ctc_beam_search_decoder_nproc
(
probs_split
=
probs_split
,
vocabulary
=
vocab_list
,
beam_size
=
args
.
beam_size
,
ext_scoring_func
=
ext_scorer
,
blank_id
=
len
(
vocab_list
))
blank_id
=
len
(
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
ext_scoring_func
=
ext_scorer
,
)
for
i
,
beam_search_result
in
enumerate
(
beam_search_nproc_results
):
target_transcription
=
''
.
join
(
[
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]])
print
(
"
\n
Target Transcription:
\t
%s"
%
target_transcription
)
for
index
in
range
(
args
.
num_results_per_sample
):
for
index
in
x
range
(
args
.
num_results_per_sample
):
result
=
beam_search_result
[
index
]
#output: index, log prob, beam result
print
(
"Beam %d: %f
\t
%s"
%
(
index
,
result
[
0
],
result
[
1
]))
...
...
tune.py
浏览文件 @
ff01d048
"""
Tune parameters
for beam search decoder in Deep Speech 2.
Parameters tuning
for beam search decoder in Deep Speech 2.
"""
import
paddle.v2
as
paddle
...
...
@@ -12,7 +12,7 @@ from decoder import *
from
error_rate
import
wer
parser
=
argparse
.
ArgumentParser
(
description
=
'Parameters tuning
script
for ctc beam search decoder in Deep Speech 2.'
description
=
'Parameters tuning for ctc beam search decoder in Deep Speech 2.'
)
parser
.
add_argument
(
"--num_samples"
,
...
...
@@ -82,34 +82,40 @@ parser.add_argument(
help
=
"Path for language model. (default: %(default)s)"
)
parser
.
add_argument
(
"--alpha_from"
,
default
=
0.
0
,
default
=
0.
1
,
type
=
float
,
help
=
"Where alpha starts from
, <= alpha_to
. (default: %(default)f)"
)
help
=
"Where alpha starts from. (default: %(default)f)"
)
parser
.
add_argument
(
"--
alpha_stride
"
,
default
=
0.001
,
type
=
floa
t
,
help
=
"
Step length for varying alpha. (default: %(default)f
)"
)
"--
num_alphas
"
,
default
=
14
,
type
=
in
t
,
help
=
"
Number of candidate alphas. (default: %(default)d
)"
)
parser
.
add_argument
(
"--alpha_to"
,
default
=
0.
01
,
default
=
0.
36
,
type
=
float
,
help
=
"Where alpha ends with
, >= alpha_from
. (default: %(default)f)"
)
help
=
"Where alpha ends with. (default: %(default)f)"
)
parser
.
add_argument
(
"--beta_from"
,
default
=
0.0
,
default
=
0.0
5
,
type
=
float
,
help
=
"Where beta starts from
, <= beta_to
. (default: %(default)f)"
)
help
=
"Where beta starts from. (default: %(default)f)"
)
parser
.
add_argument
(
"--
beta_stride
"
,
default
=
0.01
,
"--
num_betas
"
,
default
=
20
,
type
=
float
,
help
=
"
Step length for varying beta. (default: %(default)f
)"
)
help
=
"
Number of candidate betas. (default: %(default)d
)"
)
parser
.
add_argument
(
"--beta_to"
,
default
=
0
.0
,
default
=
1
.0
,
type
=
float
,
help
=
"Where beta ends with, >= beta_from. (default: %(default)f)"
)
help
=
"Where beta ends with. (default: %(default)f)"
)
parser
.
add_argument
(
"--cutoff_prob"
,
default
=
0.99
,
type
=
float
,
help
=
"The cutoff probability of pruning"
"in beam search. (default: %(default)f)"
)
args
=
parser
.
parse_args
()
...
...
@@ -118,15 +124,11 @@ def tune():
Tune parameters alpha and beta on one minibatch.
"""
if
not
args
.
alpha_from
<=
args
.
alpha_to
:
raise
ValueError
(
"alpha_from <= alpha_to doesn't satisfy!"
)
if
not
args
.
alpha_stride
>
0
:
raise
ValueError
(
"alpha_stride shouldn't be negative!"
)
if
not
args
.
num_alphas
>=
0
:
raise
ValueError
(
"num_alphas must be non-negative!"
)
if
not
args
.
beta_from
<=
args
.
beta_to
:
raise
ValueError
(
"beta_from <= beta_to doesn't satisfy!"
)
if
not
args
.
beta_stride
>
0
:
raise
ValueError
(
"beta_stride shouldn't be negative!"
)
if
not
args
.
num_betas
>=
0
:
raise
ValueError
(
"num_betas must be non-negative!"
)
# initialize data generator
data_generator
=
DataGenerator
(
...
...
@@ -171,6 +173,7 @@ def tune():
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
False
)
# get one batch data for tuning
infer_data
=
test_batch_reader
().
next
()
# run inference
...
...
@@ -182,11 +185,12 @@ def tune():
for
i
in
xrange
(
0
,
len
(
infer_data
))
]
cand_alpha
=
np
.
arange
(
args
.
alpha_from
,
args
.
alpha_to
+
args
.
alpha_stride
,
args
.
alpha_stride
)
cand_beta
=
np
.
arange
(
args
.
beta_from
,
args
.
beta_to
+
args
.
beta_stride
,
args
.
beta_stride
)
params_grid
=
[(
alpha
,
beta
)
for
alpha
in
cand_alpha
for
beta
in
cand_beta
]
# create grid for search
cand_alphas
=
np
.
linspace
(
args
.
alpha_from
,
args
.
alpha_to
,
args
.
num_alphas
)
cand_betas
=
np
.
linspace
(
args
.
beta_from
,
args
.
beta_to
,
args
.
num_betas
)
params_grid
=
[(
alpha
,
beta
)
for
alpha
in
cand_alphas
for
beta
in
cand_betas
]
## tune parameters in loop
for
(
alpha
,
beta
)
in
params_grid
:
wer_sum
,
wer_counter
=
0
,
0
...
...
@@ -200,8 +204,9 @@ def tune():
probs_seq
=
probs
,
vocabulary
=
vocab_list
,
beam_size
=
args
.
beam_size
,
ext_scoring_func
=
ext_scorer
,
blank_id
=
len
(
vocab_list
))
blank_id
=
len
(
vocab_list
),
cutoff_prob
=
args
.
cutoff_prob
,
ext_scoring_func
=
ext_scorer
,
)
wer_sum
+=
wer
(
target_transcription
,
beam_search_result
[
0
][
1
])
wer_counter
+=
1
# beam search using multiple processes
...
...
@@ -210,9 +215,9 @@ def tune():
probs_split
=
probs_split
,
vocabulary
=
vocab_list
,
beam_size
=
args
.
beam_size
,
ext_scoring_func
=
ext_scorer
,
cutoff_prob
=
args
.
cutoff_prob
,
blank_id
=
len
(
vocab_list
),
num_processes
=
1
)
ext_scoring_func
=
ext_scorer
,
)
for
i
,
beam_search_result
in
enumerate
(
beam_search_nproc_results
):
target_transcription
=
''
.
join
(
[
vocab_list
[
index
]
for
index
in
infer_data
[
i
][
1
]])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录