Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
f6d820ed
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f6d820ed
编写于
5月 30, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor data utils into a class and add feature normalization.
上级
f33f7420
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
389 addition
and
208 deletion
+389
-208
deep_speech_2/audio_data_utils.py
deep_speech_2/audio_data_utils.py
+340
-172
deep_speech_2/train.py
deep_speech_2/train.py
+49
-36
未找到文件。
deep_speech_2/audio_data_utils.py
浏览文件 @
f6d820ed
"""
Audio data preprocessing tools and reader creators.
Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
"""
import
paddle.v2
as
paddle
import
logging
...
...
@@ -9,20 +10,127 @@ import soundfile
import
numpy
as
np
import
os
# TODO: add z-score normalization.
RANDOM_SEED
=
0
logger
=
logging
.
getLogger
(
__name__
)
ENGLISH_CHAR_VOCAB_FILEPATH
=
"eng_vocab.txt"
logger
=
logging
.
getLogger
(
__name__
)
class
DataGenerator
(
object
):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offer
both instance-level and batch-level data reader interfaces.
Normalized FFT are used as audio features here.
:param vocab_filepath: Vocabulary file path for indexing tokenized
transcriptions.
:type vocab_filepath: basestring
:param normalizer_manifest_path: Manifest filepath for collecting feature
normalization statistics, e.g. mean, std.
:type normalizer_manifest_path: basestring
:param normalizer_num_samples: Number of instances sampled for collecting
feature normalization statistics.
Default is 100.
:type normalizer_num_samples: int
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded. Default is 20.0.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded. Default is 0.0.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for frames. Default is 20.0.
:type window_ms: float
:param max_frequency: Maximun frequency for FFT features. FFT features of
frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
"""
def
__init__
(
self
,
vocab_filepath
,
normalizer_manifest_path
,
normalizer_num_samples
=
100
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_frequency
=
None
):
self
.
__max_duration__
=
max_duration
self
.
__min_duration__
=
min_duration
self
.
__stride_ms__
=
stride_ms
self
.
__window_ms__
=
window_ms
self
.
__max_frequency__
=
max_frequency
self
.
__random__
=
random
.
Random
(
RANDOM_SEED
)
# load vocabulary (dictionary)
self
.
__vocab_dict__
,
self
.
__vocab_list__
=
\
self
.
__load_vocabulary_from_file__
(
vocab_filepath
)
# collect normalizer statistics
self
.
__mean__
,
self
.
__std__
=
self
.
__collect_normalizer_statistics__
(
manifest_path
=
normalizer_manifest_path
,
num_samples
=
normalizer_num_samples
)
def
__audio_featurize__
(
self
,
audio_filename
):
"""
Preprocess audio data, including feature extraction, normalization etc..
"""
features
=
self
.
__audio_basic_featurize__
(
audio_filename
)
return
self
.
__normalize__
(
features
)
def
__text_featurize__
(
self
,
text
):
"""
Preprocess text data, including tokenizing and token indexing etc..
"""
return
self
.
__convert_text_to_char_index__
(
text
=
text
,
vocabulary
=
self
.
__vocab_dict__
)
def
__audio_basic_featurize__
(
self
,
audio_filename
):
"""
Compute basic (without normalization etc.) features for audio data.
"""
return
self
.
__spectrogram_from_file__
(
filename
=
audio_filename
,
stride_ms
=
self
.
__stride_ms__
,
window_ms
=
self
.
__window_ms__
,
max_freq
=
self
.
__max_frequency__
)
def
__collect_normalizer_statistics__
(
self
,
manifest_path
,
num_samples
=
100
):
"""
Compute feature normalization statistics, i.e. mean and stddev.
"""
# read manifest
manifest
=
self
.
__read_manifest__
(
manifest_path
=
manifest_path
,
max_duration
=
self
.
__max_duration__
,
min_duration
=
self
.
__min_duration__
)
# sample for statistics
sampled_manifest
=
self
.
__random__
.
sample
(
manifest
,
num_samples
)
# extract spectrogram feature
features
=
[]
for
instance
in
sampled_manifest
:
spectrogram
=
self
.
__audio_basic_featurize__
(
instance
[
"audio_filepath"
])
features
.
append
(
spectrogram
)
features
=
np
.
hstack
(
features
)
mean
=
np
.
mean
(
features
,
axis
=
1
).
reshape
([
-
1
,
1
])
std
=
np
.
std
(
features
,
axis
=
1
).
reshape
([
-
1
,
1
])
return
mean
,
std
def
spectrogram_from_file
(
filename
,
stride_ms
=
10
,
window_ms
=
20
,
def
__normalize__
(
self
,
features
,
eps
=
1e-14
):
"""
Normalize features to be of zero mean and unit stddev.
"""
return
(
features
-
self
.
__mean__
)
/
(
self
.
__std__
+
eps
)
def
__spectrogram_from_file__
(
self
,
filename
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""
Calculate the log of linear spectrogram from FFT energy
Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
audio
,
sample_rate
=
soundfile
.
read
(
filename
)
...
...
@@ -34,10 +142,11 @@ def spectrogram_from_file(filename,
raise
ValueError
(
"max_freq must be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than window size."
)
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
spectrogram
,
freqs
=
extract_spectrogram
(
spectrogram
,
freqs
=
self
.
__extract_spectrogram__
(
audio
,
window_size
=
window_size
,
stride_size
=
stride_size
,
...
...
@@ -45,10 +154,10 @@ def spectrogram_from_file(filename,
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
return
np
.
log
(
spectrogram
[:
ind
,
:]
+
eps
)
def
extract_spectrogram
(
samples
,
window_size
,
stride_size
,
sample_rate
):
def
__extract_spectrogram__
(
self
,
samples
,
window_size
,
stride_size
,
sample_rate
):
"""
Compute the spectrogram for a real discrete
signal.
Compute the spectrogram by FFT for a discrete real
signal.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
# extract strided windows
...
...
@@ -60,7 +169,7 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
samples
,
shape
=
nshape
,
strides
=
nstrides
)
assert
np
.
all
(
windows
[:,
1
]
==
samples
[
stride_size
:(
stride_size
+
window_size
)])
# window weighting, compute
squared Fast Fourier Transform (fft), scaling
# window weighting,
squared Fast Fourier Transform (fft), scaling
weighting
=
np
.
hanning
(
window_size
)[:,
None
]
fft
=
np
.
fft
.
rfft
(
windows
*
weighting
,
axis
=
0
)
fft
=
np
.
absolute
(
fft
)
**
2
...
...
@@ -71,12 +180,12 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
return
fft
,
freqs
def
vocabulary_from_file
(
vocabulary_path
):
def
__load_vocabulary_from_file__
(
self
,
vocabulary_path
):
"""
Load vocabulary from file.
"""
if
os
.
path
.
exists
(
vocabulary_path
):
if
not
os
.
path
.
exists
(
vocabulary_path
):
raise
ValueError
(
"Vocabulary file %s not found."
,
vocabulary_path
)
vocab_lines
=
[]
with
open
(
vocabulary_path
,
'r'
)
as
file
:
vocab_lines
.
extend
(
file
.
readlines
())
...
...
@@ -84,56 +193,76 @@ def vocabulary_from_file(vocabulary_path):
vocab_dict
=
dict
(
[(
token
,
id
)
for
(
id
,
token
)
in
enumerate
(
vocab_list
)])
return
vocab_dict
,
vocab_list
else
:
raise
ValueError
(
"Vocabulary file %s not found."
,
vocabulary_path
)
def
get_vocabulary_size
():
def
__convert_text_to_char_index__
(
self
,
text
,
vocabulary
):
"""
Get vocabulary size
.
Convert text string to a list of character index integers
.
"""
vocab_dict
,
_
=
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
return
len
(
vocab_dict
)
return
[
vocabulary
[
w
]
for
w
in
text
]
def
get_vocabulary
(
):
def
__read_manifest__
(
self
,
manifest_path
,
max_duration
,
min_duration
):
"""
Get vocabulary
.
Load and parse manifest file
.
"""
return
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
manifest
=
[]
for
json_line
in
open
(
manifest_path
):
try
:
json_data
=
json
.
loads
(
json_line
)
except
Exception
as
e
:
raise
ValueError
(
"Error reading manifest: %s"
%
str
(
e
))
if
(
json_data
[
"duration"
]
<=
max_duration
and
json_data
[
"duration"
]
>=
min_duration
):
manifest
.
append
(
json_data
)
return
manifest
def
parse_transcript
(
text
,
vocabulary
):
def
__padding_batch__
(
self
,
batch
,
padding_to
=-
1
,
flatten
=
False
):
"""
Convert the transcript text string to list of token index integers.
"""
return
[
vocabulary
[
w
]
for
w
in
text
]
Padding audio part of features (only in the time axis -- column axis)
with zeros, to make each instance in the batch share the same
audio feature shape.
If `padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1.
If `flatten` is set True, audio data will be flatten to be a 1-dim
ndarray. Default is False.
"""
new_batch
=
[]
# get target shape
max_length
=
max
([
audio
.
shape
[
1
]
for
audio
,
text
in
batch
])
if
padding_to
!=
-
1
:
if
padding_to
<
max_length
:
raise
ValueError
(
"If padding_to is not -1, it should be greater"
" or equal to the original instance length."
)
max_length
=
padding_to
# padding
for
audio
,
text
in
batch
:
padded_audio
=
np
.
zeros
([
audio
.
shape
[
0
],
max_length
])
padded_audio
[:,
:
audio
.
shape
[
1
]]
=
audio
if
flatten
:
padded_audio
=
padded_audio
.
flatten
()
new_batch
.
append
((
padded_audio
,
text
))
return
new_batch
def
reader_creator
(
manifest_path
,
def
instance_reader_creator
(
self
,
manifest_path
,
sort_by_duration
=
True
,
shuffle
=
False
,
max_duration
=
10.0
,
min_duration
=
0.0
):
shuffle
=
False
):
"""
Audio data reader creator.
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokeniz
ed transcription text.
tokenized and index
ed transcription text.
:param manifest_path: Filepath for Manifest of
audio clip files.
:param manifest_path: Filepath of manifest for
audio clip files.
:type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True.
For SortaGrad
.
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad)
.
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded.
:type min_duration: float
:return: Data reader function.
:rtype: callable
"""
...
...
@@ -141,75 +270,114 @@ def reader_creator(manifest_path,
sort_by_duration
=
False
logger
.
warn
(
"When shuffle set to true, "
"sort_by_duration is forced to set False."
)
vocab_dict
,
_
=
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
def
reader
():
# read manifest
manifest_data
=
[]
for
json_line
in
open
(
manifest_path
):
try
:
json_data
=
json
.
loads
(
json_line
)
except
Exception
as
e
:
raise
ValueError
(
"Error reading manifest: %s"
%
str
(
e
))
if
(
json_data
[
"duration"
]
<=
max_duration
and
json_data
[
"duration"
]
>=
min_duration
):
manifest_data
.
append
(
json_data
)
manifest
=
self
.
__read_manifest__
(
manifest_path
=
manifest_path
,
max_duration
=
self
.
__max_duration__
,
min_duration
=
self
.
__min_duration__
)
# sort (by duration) or shuffle manifest
if
sort_by_duration
:
manifest_data
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
if
shuffle
:
random
.
shuffle
(
manifest_data
)
self
.
__random__
.
shuffle
(
manifest
)
# extract spectrogram feature
for
instance
in
manifest_data
:
spectrogram
=
spectrogram_from_file
(
instance
[
"audio_filepath"
])
text
=
parse_transcript
(
instance
[
"text"
],
vocab_dict
)
yield
(
spectrogram
,
text
)
for
instance
in
manifest
:
spectrogram
=
self
.
__audio_featurize__
(
instance
[
"audio_filepath"
])
transcript
=
self
.
__text_featurize__
(
instance
[
"text"
])
yield
(
spectrogram
,
transcript
)
return
reader
def
padding_batch_reader
(
batch_reader
,
padding
=
[
-
1
,
-
1
],
flatten
=
True
):
def
batch_reader_creator
(
self
,
manifest_path
,
batch_size
,
padding_to
=-
1
,
flatten
=
False
,
sort_by_duration
=
True
,
shuffle
=
False
):
"""
Padding for batches. Return a batch reader.
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Each instance in a batch will be padded to be of a same target shape.
The target shape is the largest shape among all the batch instances and
'padding' argument. Therefore, if padding is set [-1, -1], instance will be
padded to have the same shape just within each batch and the shape will
be different across batches; if padding is set
[VERY_LARGE_NUM, VERY_LARGE_NUM], instances in all batches will be padded to
have the same shape of [VERY_LARGE_NUM, VERY_LARGE_NUM].
Audio features will be padded with zeros to make each instance in the
batch to share the same audio feature shape.
:param batch_reader: Input batch reader.
:type batch_reader: callable
:param padding: Padding pattern. Details please refer to the above.
:type padding: list
:param flatten: Flatten the tensor to be one dimension.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param batch_size: Instance number in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:return: Batch reader function.
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def
padding_batch
(
batch
):
new_batch
=
[]
# get target shape within batch
nshape_list
=
[
padding
]
for
audio
,
text
in
batch
:
nshape_list
.
append
(
audio
.
shape
)
target_shape
=
np
.
array
(
nshape_list
).
max
(
axis
=
0
)
# padding
for
audio
,
text
in
batch
:
pad_shape
=
target_shape
-
audio
.
shape
assert
np
.
all
(
pad_shape
>=
0
)
padded_audio
=
np
.
pad
(
audio
,
[(
0
,
pad_shape
[
0
]),
(
0
,
pad_shape
[
1
])],
mode
=
"constant"
)
if
flatten
:
padded_audio
=
padded_audio
.
flatten
()
new_batch
.
append
((
padded_audio
,
text
))
return
new_batch
def
batch_reader
():
instance_reader
=
self
.
instance_reader_creator
(
manifest_path
=
manifest_path
,
sort_by_duration
=
sort_by_duration
,
shuffle
=
shuffle
)
batch
=
[]
for
instance
in
instance_reader
():
batch
.
append
(
instance
)
if
len
(
batch
)
==
batch_size
:
yield
self
.
__padding_batch__
(
batch
,
padding_to
,
flatten
)
batch
=
[]
if
len
(
batch
)
>
0
:
yield
self
.
__padding_batch__
(
batch
,
padding_to
,
flatten
)
return
batch_reader
def
vocabulary_size
(
self
):
"""
Get vocabulary size.
def
new_batch_reader
():
for
batch
in
batch_reader
():
yield
padding_batch
(
batch
)
:return: Vocabulary size.
:rtype: int
"""
return
len
(
self
.
__vocab_list__
)
return
new_batch_reader
def
vocabulary_dict
(
self
):
"""
Get vocabulary in dict.
:return: Vocabulary in dict.
:rtype: dict
"""
return
self
.
__vocab_dict__
def
vocabulary_list
(
self
):
"""
Get vocabulary in list.
:return: Vocabulary in list
:rtype: list
"""
return
self
.
__vocab_list__
def
data_name_feeding
(
self
):
"""
Get feeddings (data field name and corresponding field id).
:return: Feeding dict.
:rtype: dict
"""
feeding
=
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
,
}
return
feeding
deep_speech_2/train.py
浏览文件 @
f6d820ed
...
...
@@ -5,16 +5,18 @@
import
paddle.v2
as
paddle
import
argparse
import
gzip
import
time
import
sys
from
model
import
deep_speech2
import
audio_data_utils
from
audio_data_utils
import
DataGenerator
import
numpy
as
np
#TODO: add WER metric
parser
=
argparse
.
ArgumentParser
(
description
=
'Simplified version of DeepSpeech2 trainer.'
)
parser
.
add_argument
(
"--batch_size"
,
default
=
51
2
,
type
=
int
,
help
=
"Minibatch size."
)
"--batch_size"
,
default
=
3
2
,
type
=
int
,
help
=
"Minibatch size."
)
parser
.
add_argument
(
"--trainer"
,
default
=
1
,
type
=
int
,
help
=
"Trainer number."
)
parser
.
add_argument
(
"--num_passes"
,
default
=
20
,
type
=
int
,
help
=
"Training pass number."
)
...
...
@@ -23,7 +25,7 @@ parser.add_argument(
parser
.
add_argument
(
"--num_rnn_layers"
,
default
=
5
,
type
=
int
,
help
=
"RNN layer number."
)
parser
.
add_argument
(
"--rnn_layer_size"
,
default
=
256
,
type
=
int
,
help
=
"RNN layer cell number."
)
"--rnn_layer_size"
,
default
=
512
,
type
=
int
,
help
=
"RNN layer cell number."
)
parser
.
add_argument
(
"--use_gpu"
,
default
=
True
,
type
=
bool
,
help
=
"Use gpu or not."
)
parser
.
add_argument
(
...
...
@@ -37,13 +39,45 @@ def train():
"""
DeepSpeech2 training.
"""
# create data readers
data_generator
=
DataGenerator
(
vocab_filepath
=
'eng_vocab.txt'
,
normalizer_manifest_path
=
'./libri.manifest.train'
,
normalizer_num_samples
=
200
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10
,
window_ms
=
20
)
train_batch_reader_sortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
True
,
shuffle
=
False
)
train_batch_reader_nosortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
True
)
test_batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.test'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
False
)
feeding
=
data_generator
.
data_name_feeding
()
# create network config
dict_size
=
audio_data_utils
.
get_
vocabulary_size
()
dict_size
=
data_generator
.
vocabulary_size
()
audio_data
=
paddle
.
layer
.
data
(
name
=
"audio_spectrogram"
,
height
=
161
,
width
=
1
000
,
type
=
paddle
.
data_type
.
dense_vector
(
161
000
))
width
=
2
000
,
type
=
paddle
.
data_type
.
dense_vector
(
322
000
))
text_data
=
paddle
.
layer
.
data
(
name
=
"transcript_text"
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
...
...
@@ -58,47 +92,26 @@ def train():
# create parameters and optimizer
parameters
=
paddle
.
parameters
.
create
(
cost
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
5e-
4
,
gradient_clipping_threshold
=
400
)
learning_rate
=
5e-
5
,
gradient_clipping_threshold
=
400
)
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
# create data readers
feeding
=
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
,
}
train_batch_reader_with_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
train_batch_reader_without_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
False
,
shuffle
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
test_batch_reader
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.dev"
,
sort_by_duration
=
False
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
# create event handler
def
event_handler
(
event
):
global
start_time
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
10
==
0
:
print
"
/
nPass: %d, Batch: %d, TrainCost: %f"
%
(
print
"
\
n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
if
isinstance
(
event
,
paddle
.
event
.
BeginPass
):
start_time
=
time
.
time
()
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
test_batch_reader
,
feeding
=
feeding
)
print
"Pass: %d, TestCost: %s"
%
(
event
.
pass_id
,
result
.
cost
)
print
"
\n
------- Time: %d, Pass: %d, TestCost: %s"
%
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
)
with
gzip
.
open
(
"params.tar.gz"
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
...
...
@@ -106,14 +119,14 @@ def train():
# first pass with sortagrad
if
args
.
use_sortagrad
:
trainer
.
train
(
reader
=
train_batch_reader_
with_
sortagrad
,
reader
=
train_batch_reader_sortagrad
,
event_handler
=
event_handler
,
num_passes
=
1
,
feeding
=
feeding
)
args
.
num_passes
-=
1
# other passes without sortagrad
trainer
.
train
(
reader
=
train_batch_reader_
without_
sortagrad
,
reader
=
train_batch_reader_
no
sortagrad
,
event_handler
=
event_handler
,
num_passes
=
args
.
num_passes
,
feeding
=
feeding
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录