Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
e6a34999
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e6a34999
编写于
5月 30, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor data utils into a class and add feature normalization.
上级
9c3cd3c7
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
389 addition
and
208 deletion
+389
-208
audio_data_utils.py
audio_data_utils.py
+340
-172
train.py
train.py
+49
-36
未找到文件。
audio_data_utils.py
浏览文件 @
e6a34999
"""
"""
Audio data preprocessing tools and reader creators.
Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
"""
"""
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
import
logging
import
logging
...
@@ -9,20 +10,127 @@ import soundfile
...
@@ -9,20 +10,127 @@ import soundfile
import
numpy
as
np
import
numpy
as
np
import
os
import
os
# TODO: add z-score normalization.
RANDOM_SEED
=
0
logger
=
logging
.
getLogger
(
__name__
)
ENGLISH_CHAR_VOCAB_FILEPATH
=
"eng_vocab.txt"
logger
=
logging
.
getLogger
(
__name__
)
class
DataGenerator
(
object
):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offer
both instance-level and batch-level data reader interfaces.
Normalized FFT are used as audio features here.
:param vocab_filepath: Vocabulary file path for indexing tokenized
transcriptions.
:type vocab_filepath: basestring
:param normalizer_manifest_path: Manifest filepath for collecting feature
normalization statistics, e.g. mean, std.
:type normalizer_manifest_path: basestring
:param normalizer_num_samples: Number of instances sampled for collecting
feature normalization statistics.
Default is 100.
:type normalizer_num_samples: int
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded. Default is 20.0.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded. Default is 0.0.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for frames. Default is 20.0.
:type window_ms: float
:param max_frequency: Maximun frequency for FFT features. FFT features of
frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
"""
def
__init__
(
self
,
vocab_filepath
,
normalizer_manifest_path
,
normalizer_num_samples
=
100
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_frequency
=
None
):
self
.
__max_duration__
=
max_duration
self
.
__min_duration__
=
min_duration
self
.
__stride_ms__
=
stride_ms
self
.
__window_ms__
=
window_ms
self
.
__max_frequency__
=
max_frequency
self
.
__random__
=
random
.
Random
(
RANDOM_SEED
)
# load vocabulary (dictionary)
self
.
__vocab_dict__
,
self
.
__vocab_list__
=
\
self
.
__load_vocabulary_from_file__
(
vocab_filepath
)
# collect normalizer statistics
self
.
__mean__
,
self
.
__std__
=
self
.
__collect_normalizer_statistics__
(
manifest_path
=
normalizer_manifest_path
,
num_samples
=
normalizer_num_samples
)
def
__audio_featurize__
(
self
,
audio_filename
):
"""
Preprocess audio data, including feature extraction, normalization etc..
"""
features
=
self
.
__audio_basic_featurize__
(
audio_filename
)
return
self
.
__normalize__
(
features
)
def
__text_featurize__
(
self
,
text
):
"""
Preprocess text data, including tokenizing and token indexing etc..
"""
return
self
.
__convert_text_to_char_index__
(
text
=
text
,
vocabulary
=
self
.
__vocab_dict__
)
def
__audio_basic_featurize__
(
self
,
audio_filename
):
"""
Compute basic (without normalization etc.) features for audio data.
"""
return
self
.
__spectrogram_from_file__
(
filename
=
audio_filename
,
stride_ms
=
self
.
__stride_ms__
,
window_ms
=
self
.
__window_ms__
,
max_freq
=
self
.
__max_frequency__
)
def
__collect_normalizer_statistics__
(
self
,
manifest_path
,
num_samples
=
100
):
"""
Compute feature normalization statistics, i.e. mean and stddev.
"""
# read manifest
manifest
=
self
.
__read_manifest__
(
manifest_path
=
manifest_path
,
max_duration
=
self
.
__max_duration__
,
min_duration
=
self
.
__min_duration__
)
# sample for statistics
sampled_manifest
=
self
.
__random__
.
sample
(
manifest
,
num_samples
)
# extract spectrogram feature
features
=
[]
for
instance
in
sampled_manifest
:
spectrogram
=
self
.
__audio_basic_featurize__
(
instance
[
"audio_filepath"
])
features
.
append
(
spectrogram
)
features
=
np
.
hstack
(
features
)
mean
=
np
.
mean
(
features
,
axis
=
1
).
reshape
([
-
1
,
1
])
std
=
np
.
std
(
features
,
axis
=
1
).
reshape
([
-
1
,
1
])
return
mean
,
std
def
spectrogram_from_file
(
filename
,
def
__normalize__
(
self
,
features
,
eps
=
1e-14
):
stride_ms
=
10
,
"""
window_ms
=
20
,
Normalize features to be of zero mean and unit stddev.
"""
return
(
features
-
self
.
__mean__
)
/
(
self
.
__std__
+
eps
)
def
__spectrogram_from_file__
(
self
,
filename
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
max_freq
=
None
,
eps
=
1e-14
):
eps
=
1e-14
):
"""
"""
Calculate the log of linear spectrogram from FFT energy
Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
"""
audio
,
sample_rate
=
soundfile
.
read
(
filename
)
audio
,
sample_rate
=
soundfile
.
read
(
filename
)
...
@@ -34,10 +142,11 @@ def spectrogram_from_file(filename,
...
@@ -34,10 +142,11 @@ def spectrogram_from_file(filename,
raise
ValueError
(
"max_freq must be greater than half of "
raise
ValueError
(
"max_freq must be greater than half of "
"sample rate."
)
"sample rate."
)
if
stride_ms
>
window_ms
:
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than window size."
)
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
spectrogram
,
freqs
=
extract_spectrogram
(
spectrogram
,
freqs
=
self
.
__extract_spectrogram__
(
audio
,
audio
,
window_size
=
window_size
,
window_size
=
window_size
,
stride_size
=
stride_size
,
stride_size
=
stride_size
,
...
@@ -45,10 +154,10 @@ def spectrogram_from_file(filename,
...
@@ -45,10 +154,10 @@ def spectrogram_from_file(filename,
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
return
np
.
log
(
spectrogram
[:
ind
,
:]
+
eps
)
return
np
.
log
(
spectrogram
[:
ind
,
:]
+
eps
)
def
__extract_spectrogram__
(
self
,
samples
,
window_size
,
stride_size
,
def
extract_spectrogram
(
samples
,
window_size
,
stride_size
,
sample_rate
):
sample_rate
):
"""
"""
Compute the spectrogram for a real discrete
signal.
Compute the spectrogram by FFT for a discrete real
signal.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
"""
# extract strided windows
# extract strided windows
...
@@ -60,7 +169,7 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
...
@@ -60,7 +169,7 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
samples
,
shape
=
nshape
,
strides
=
nstrides
)
samples
,
shape
=
nshape
,
strides
=
nstrides
)
assert
np
.
all
(
assert
np
.
all
(
windows
[:,
1
]
==
samples
[
stride_size
:(
stride_size
+
window_size
)])
windows
[:,
1
]
==
samples
[
stride_size
:(
stride_size
+
window_size
)])
# window weighting, compute
squared Fast Fourier Transform (fft), scaling
# window weighting,
squared Fast Fourier Transform (fft), scaling
weighting
=
np
.
hanning
(
window_size
)[:,
None
]
weighting
=
np
.
hanning
(
window_size
)[:,
None
]
fft
=
np
.
fft
.
rfft
(
windows
*
weighting
,
axis
=
0
)
fft
=
np
.
fft
.
rfft
(
windows
*
weighting
,
axis
=
0
)
fft
=
np
.
absolute
(
fft
)
**
2
fft
=
np
.
absolute
(
fft
)
**
2
...
@@ -71,12 +180,12 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
...
@@ -71,12 +180,12 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
return
fft
,
freqs
return
fft
,
freqs
def
__load_vocabulary_from_file__
(
self
,
vocabulary_path
):
def
vocabulary_from_file
(
vocabulary_path
):
"""
"""
Load vocabulary from file.
Load vocabulary from file.
"""
"""
if
os
.
path
.
exists
(
vocabulary_path
):
if
not
os
.
path
.
exists
(
vocabulary_path
):
raise
ValueError
(
"Vocabulary file %s not found."
,
vocabulary_path
)
vocab_lines
=
[]
vocab_lines
=
[]
with
open
(
vocabulary_path
,
'r'
)
as
file
:
with
open
(
vocabulary_path
,
'r'
)
as
file
:
vocab_lines
.
extend
(
file
.
readlines
())
vocab_lines
.
extend
(
file
.
readlines
())
...
@@ -84,56 +193,76 @@ def vocabulary_from_file(vocabulary_path):
...
@@ -84,56 +193,76 @@ def vocabulary_from_file(vocabulary_path):
vocab_dict
=
dict
(
vocab_dict
=
dict
(
[(
token
,
id
)
for
(
id
,
token
)
in
enumerate
(
vocab_list
)])
[(
token
,
id
)
for
(
id
,
token
)
in
enumerate
(
vocab_list
)])
return
vocab_dict
,
vocab_list
return
vocab_dict
,
vocab_list
else
:
raise
ValueError
(
"Vocabulary file %s not found."
,
vocabulary_path
)
def
__convert_text_to_char_index__
(
self
,
text
,
vocabulary
):
def
get_vocabulary_size
():
"""
"""
Get vocabulary size
.
Convert text string to a list of character index integers
.
"""
"""
vocab_dict
,
_
=
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
return
[
vocabulary
[
w
]
for
w
in
text
]
return
len
(
vocab_dict
)
def
get_vocabulary
(
):
def
__read_manifest__
(
self
,
manifest_path
,
max_duration
,
min_duration
):
"""
"""
Get vocabulary
.
Load and parse manifest file
.
"""
"""
return
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
manifest
=
[]
for
json_line
in
open
(
manifest_path
):
try
:
json_data
=
json
.
loads
(
json_line
)
except
Exception
as
e
:
raise
ValueError
(
"Error reading manifest: %s"
%
str
(
e
))
if
(
json_data
[
"duration"
]
<=
max_duration
and
json_data
[
"duration"
]
>=
min_duration
):
manifest
.
append
(
json_data
)
return
manifest
def
parse_transcript
(
text
,
vocabulary
):
def
__padding_batch__
(
self
,
batch
,
padding_to
=-
1
,
flatten
=
False
):
"""
"""
Convert the transcript text string to list of token index integers.
Padding audio part of features (only in the time axis -- column axis)
"""
with zeros, to make each instance in the batch share the same
return
[
vocabulary
[
w
]
for
w
in
text
]
audio feature shape.
If `padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1.
If `flatten` is set True, audio data will be flatten to be a 1-dim
ndarray. Default is False.
"""
new_batch
=
[]
# get target shape
max_length
=
max
([
audio
.
shape
[
1
]
for
audio
,
text
in
batch
])
if
padding_to
!=
-
1
:
if
padding_to
<
max_length
:
raise
ValueError
(
"If padding_to is not -1, it should be greater"
" or equal to the original instance length."
)
max_length
=
padding_to
# padding
for
audio
,
text
in
batch
:
padded_audio
=
np
.
zeros
([
audio
.
shape
[
0
],
max_length
])
padded_audio
[:,
:
audio
.
shape
[
1
]]
=
audio
if
flatten
:
padded_audio
=
padded_audio
.
flatten
()
new_batch
.
append
((
padded_audio
,
text
))
return
new_batch
def
reader_creator
(
manifest_path
,
def
instance_reader_creator
(
self
,
manifest_path
,
sort_by_duration
=
True
,
sort_by_duration
=
True
,
shuffle
=
False
,
shuffle
=
False
):
max_duration
=
10.0
,
min_duration
=
0.0
):
"""
"""
Audio data reader creator.
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokeniz
ed transcription text.
tokenized and index
ed transcription text.
:param manifest_path: Filepath for Manifest of
audio clip files.
:param manifest_path: Filepath of manifest for
audio clip files.
:type manifest_path: basestring
:type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True.
:param sort_by_duration: Sort the audio clips by duration if set True
For SortaGrad
.
(for SortaGrad)
.
:type sort_by_duration: bool
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:type shuffle: bool
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded.
:type min_duration: float
:return: Data reader function.
:return: Data reader function.
:rtype: callable
:rtype: callable
"""
"""
...
@@ -141,75 +270,114 @@ def reader_creator(manifest_path,
...
@@ -141,75 +270,114 @@ def reader_creator(manifest_path,
sort_by_duration
=
False
sort_by_duration
=
False
logger
.
warn
(
"When shuffle set to true, "
logger
.
warn
(
"When shuffle set to true, "
"sort_by_duration is forced to set False."
)
"sort_by_duration is forced to set False."
)
vocab_dict
,
_
=
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
def
reader
():
def
reader
():
# read manifest
# read manifest
manifest_data
=
[]
manifest
=
self
.
__read_manifest__
(
for
json_line
in
open
(
manifest_path
):
manifest_path
=
manifest_path
,
try
:
max_duration
=
self
.
__max_duration__
,
json_data
=
json
.
loads
(
json_line
)
min_duration
=
self
.
__min_duration__
)
except
Exception
as
e
:
raise
ValueError
(
"Error reading manifest: %s"
%
str
(
e
))
if
(
json_data
[
"duration"
]
<=
max_duration
and
json_data
[
"duration"
]
>=
min_duration
):
manifest_data
.
append
(
json_data
)
# sort (by duration) or shuffle manifest
# sort (by duration) or shuffle manifest
if
sort_by_duration
:
if
sort_by_duration
:
manifest_data
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
if
shuffle
:
if
shuffle
:
random
.
shuffle
(
manifest_data
)
self
.
__random__
.
shuffle
(
manifest
)
# extract spectrogram feature
# extract spectrogram feature
for
instance
in
manifest_data
:
for
instance
in
manifest
:
spectrogram
=
spectrogram_from_file
(
instance
[
"audio_filepath"
])
spectrogram
=
self
.
__audio_featurize__
(
text
=
parse_transcript
(
instance
[
"text"
],
vocab_dict
)
instance
[
"audio_filepath"
])
yield
(
spectrogram
,
text
)
transcript
=
self
.
__text_featurize__
(
instance
[
"text"
])
yield
(
spectrogram
,
transcript
)
return
reader
return
reader
def
batch_reader_creator
(
self
,
def
padding_batch_reader
(
batch_reader
,
padding
=
[
-
1
,
-
1
],
flatten
=
True
):
manifest_path
,
batch_size
,
padding_to
=-
1
,
flatten
=
False
,
sort_by_duration
=
True
,
shuffle
=
False
):
"""
"""
Padding for batches. Return a batch reader.
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Each instance in a batch will be padded to be of a same target shape.
Audio features will be padded with zeros to make each instance in the
The target shape is the largest shape among all the batch instances and
batch to share the same audio feature shape.
'padding' argument. Therefore, if padding is set [-1, -1], instance will be
padded to have the same shape just within each batch and the shape will
be different across batches; if padding is set
[VERY_LARGE_NUM, VERY_LARGE_NUM], instances in all batches will be padded to
have the same shape of [VERY_LARGE_NUM, VERY_LARGE_NUM].
:param batch_reader: Input batch reader.
:param manifest_path: Filepath of manifest for audio clip files.
:type batch_reader: callable
:type manifest_path: basestring
:param padding: Padding pattern. Details please refer to the above.
:param batch_size: Instance number in a batch.
:type padding: list
:type batch_size: int
:param flatten: Flatten the tensor to be one dimension.
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:type flatten: bool
:return: Batch reader function.
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
:rtype: callable
"""
"""
def
padding_batch
(
batch
):
def
batch_reader
():
new_batch
=
[]
instance_reader
=
self
.
instance_reader_creator
(
# get target shape within batch
manifest_path
=
manifest_path
,
nshape_list
=
[
padding
]
sort_by_duration
=
sort_by_duration
,
for
audio
,
text
in
batch
:
shuffle
=
shuffle
)
nshape_list
.
append
(
audio
.
shape
)
batch
=
[]
target_shape
=
np
.
array
(
nshape_list
).
max
(
axis
=
0
)
for
instance
in
instance_reader
():
# padding
batch
.
append
(
instance
)
for
audio
,
text
in
batch
:
if
len
(
batch
)
==
batch_size
:
pad_shape
=
target_shape
-
audio
.
shape
yield
self
.
__padding_batch__
(
batch
,
padding_to
,
flatten
)
assert
np
.
all
(
pad_shape
>=
0
)
batch
=
[]
padded_audio
=
np
.
pad
(
if
len
(
batch
)
>
0
:
audio
,
[(
0
,
pad_shape
[
0
]),
(
0
,
pad_shape
[
1
])],
mode
=
"constant"
)
yield
self
.
__padding_batch__
(
batch
,
padding_to
,
flatten
)
if
flatten
:
padded_audio
=
padded_audio
.
flatten
()
return
batch_reader
new_batch
.
append
((
padded_audio
,
text
))
return
new_batch
def
vocabulary_size
(
self
):
"""
Get vocabulary size.
def
new_batch_reader
():
:return: Vocabulary size.
for
batch
in
batch_reader
():
:rtype: int
yield
padding_batch
(
batch
)
"""
return
len
(
self
.
__vocab_list__
)
return
new_batch_reader
def
vocabulary_dict
(
self
):
"""
Get vocabulary in dict.
:return: Vocabulary in dict.
:rtype: dict
"""
return
self
.
__vocab_dict__
def
vocabulary_list
(
self
):
"""
Get vocabulary in list.
:return: Vocabulary in list
:rtype: list
"""
return
self
.
__vocab_list__
def
data_name_feeding
(
self
):
"""
Get feeddings (data field name and corresponding field id).
:return: Feeding dict.
:rtype: dict
"""
feeding
=
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
,
}
return
feeding
train.py
浏览文件 @
e6a34999
...
@@ -5,16 +5,18 @@
...
@@ -5,16 +5,18 @@
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
import
argparse
import
argparse
import
gzip
import
gzip
import
time
import
sys
import
sys
from
model
import
deep_speech2
from
model
import
deep_speech2
import
audio_data_utils
from
audio_data_utils
import
DataGenerator
import
numpy
as
np
#TODO: add WER metric
#TODO: add WER metric
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'Simplified version of DeepSpeech2 trainer.'
)
description
=
'Simplified version of DeepSpeech2 trainer.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch_size"
,
default
=
51
2
,
type
=
int
,
help
=
"Minibatch size."
)
"--batch_size"
,
default
=
3
2
,
type
=
int
,
help
=
"Minibatch size."
)
parser
.
add_argument
(
"--trainer"
,
default
=
1
,
type
=
int
,
help
=
"Trainer number."
)
parser
.
add_argument
(
"--trainer"
,
default
=
1
,
type
=
int
,
help
=
"Trainer number."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_passes"
,
default
=
20
,
type
=
int
,
help
=
"Training pass number."
)
"--num_passes"
,
default
=
20
,
type
=
int
,
help
=
"Training pass number."
)
...
@@ -23,7 +25,7 @@ parser.add_argument(
...
@@ -23,7 +25,7 @@ parser.add_argument(
parser
.
add_argument
(
parser
.
add_argument
(
"--num_rnn_layers"
,
default
=
5
,
type
=
int
,
help
=
"RNN layer number."
)
"--num_rnn_layers"
,
default
=
5
,
type
=
int
,
help
=
"RNN layer number."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--rnn_layer_size"
,
default
=
256
,
type
=
int
,
help
=
"RNN layer cell number."
)
"--rnn_layer_size"
,
default
=
512
,
type
=
int
,
help
=
"RNN layer cell number."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_gpu"
,
default
=
True
,
type
=
bool
,
help
=
"Use gpu or not."
)
"--use_gpu"
,
default
=
True
,
type
=
bool
,
help
=
"Use gpu or not."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -37,13 +39,45 @@ def train():
...
@@ -37,13 +39,45 @@ def train():
"""
"""
DeepSpeech2 training.
DeepSpeech2 training.
"""
"""
# create data readers
data_generator
=
DataGenerator
(
vocab_filepath
=
'eng_vocab.txt'
,
normalizer_manifest_path
=
'./libri.manifest.train'
,
normalizer_num_samples
=
200
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10
,
window_ms
=
20
)
train_batch_reader_sortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
True
,
shuffle
=
False
)
train_batch_reader_nosortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
True
)
test_batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.test'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
False
)
feeding
=
data_generator
.
data_name_feeding
()
# create network config
# create network config
dict_size
=
audio_data_utils
.
get_
vocabulary_size
()
dict_size
=
data_generator
.
vocabulary_size
()
audio_data
=
paddle
.
layer
.
data
(
audio_data
=
paddle
.
layer
.
data
(
name
=
"audio_spectrogram"
,
name
=
"audio_spectrogram"
,
height
=
161
,
height
=
161
,
width
=
1
000
,
width
=
2
000
,
type
=
paddle
.
data_type
.
dense_vector
(
161
000
))
type
=
paddle
.
data_type
.
dense_vector
(
322
000
))
text_data
=
paddle
.
layer
.
data
(
text_data
=
paddle
.
layer
.
data
(
name
=
"transcript_text"
,
name
=
"transcript_text"
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
...
@@ -58,47 +92,26 @@ def train():
...
@@ -58,47 +92,26 @@ def train():
# create parameters and optimizer
# create parameters and optimizer
parameters
=
paddle
.
parameters
.
create
(
cost
)
parameters
=
paddle
.
parameters
.
create
(
cost
)
optimizer
=
paddle
.
optimizer
.
Adam
(
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
5e-
4
,
gradient_clipping_threshold
=
400
)
learning_rate
=
5e-
5
,
gradient_clipping_threshold
=
400
)
trainer
=
paddle
.
trainer
.
SGD
(
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
# create data readers
feeding
=
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
,
}
train_batch_reader_with_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
train_batch_reader_without_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
False
,
shuffle
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
test_batch_reader
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.dev"
,
sort_by_duration
=
False
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
# create event handler
# create event handler
def
event_handler
(
event
):
def
event_handler
(
event
):
global
start_time
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
10
==
0
:
if
event
.
batch_id
%
10
==
0
:
print
"
/
nPass: %d, Batch: %d, TrainCost: %f"
%
(
print
"
\
n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
else
:
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
if
isinstance
(
event
,
paddle
.
event
.
BeginPass
):
start_time
=
time
.
time
()
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
test_batch_reader
,
feeding
=
feeding
)
result
=
trainer
.
test
(
reader
=
test_batch_reader
,
feeding
=
feeding
)
print
"Pass: %d, TestCost: %s"
%
(
event
.
pass_id
,
result
.
cost
)
print
"
\n
------- Time: %d, Pass: %d, TestCost: %s"
%
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
)
with
gzip
.
open
(
"params.tar.gz"
,
'w'
)
as
f
:
with
gzip
.
open
(
"params.tar.gz"
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
parameters
.
to_tar
(
f
)
...
@@ -106,14 +119,14 @@ def train():
...
@@ -106,14 +119,14 @@ def train():
# first pass with sortagrad
# first pass with sortagrad
if
args
.
use_sortagrad
:
if
args
.
use_sortagrad
:
trainer
.
train
(
trainer
.
train
(
reader
=
train_batch_reader_
with_
sortagrad
,
reader
=
train_batch_reader_sortagrad
,
event_handler
=
event_handler
,
event_handler
=
event_handler
,
num_passes
=
1
,
num_passes
=
1
,
feeding
=
feeding
)
feeding
=
feeding
)
args
.
num_passes
-=
1
args
.
num_passes
-=
1
# other passes without sortagrad
# other passes without sortagrad
trainer
.
train
(
trainer
.
train
(
reader
=
train_batch_reader_
without_
sortagrad
,
reader
=
train_batch_reader_
no
sortagrad
,
event_handler
=
event_handler
,
event_handler
=
event_handler
,
num_passes
=
args
.
num_passes
,
num_passes
=
args
.
num_passes
,
feeding
=
feeding
)
feeding
=
feeding
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录