Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
e6a34999
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e6a34999
编写于
5月 30, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor data utils into a class and add feature normalization.
上级
9c3cd3c7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
389 addition
and
208 deletion
+389
-208
audio_data_utils.py
audio_data_utils.py
+340
-172
train.py
train.py
+49
-36
未找到文件。
audio_data_utils.py
浏览文件 @
e6a34999
"""
"""
Audio data preprocessing tools and reader creators.
Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
"""
"""
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
import
logging
import
logging
...
@@ -9,143 +10,201 @@ import soundfile
...
@@ -9,143 +10,201 @@ import soundfile
import
numpy
as
np
import
numpy
as
np
import
os
import
os
# TODO: add z-score normalization.
RANDOM_SEED
=
0
ENGLISH_CHAR_VOCAB_FILEPATH
=
"eng_vocab.txt"
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
spectrogram_from_file
(
filename
,
class
DataGenerator
(
object
):
stride_ms
=
10
,
window_ms
=
20
,
max_freq
=
None
,
eps
=
1e-14
):
"""
Calculate the log of linear spectrogram from FFT energy
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
audio
,
sample_rate
=
soundfile
.
read
(
filename
)
if
audio
.
ndim
>=
2
:
audio
=
np
.
mean
(
audio
,
1
)
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
spectrogram
,
freqs
=
extract_spectrogram
(
audio
,
window_size
=
window_size
,
stride_size
=
stride_size
,
sample_rate
=
sample_rate
)
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
return
np
.
log
(
spectrogram
[:
ind
,
:]
+
eps
)
def
extract_spectrogram
(
samples
,
window_size
,
stride_size
,
sample_rate
):
"""
"""
Compute the spectrogram for a real discrete signal.
DataGenerator provides basic audio data preprocessing pipeline, and offer
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
both instance-level and batch-level data reader interfaces.
"""
Normalized FFT are used as audio features here.
# extract strided windows
truncate_size
=
(
len
(
samples
)
-
window_size
)
%
stride_size
:param vocab_filepath: Vocabulary file path for indexing tokenized
samples
=
samples
[:
len
(
samples
)
-
truncate_size
]
transcriptions.
nshape
=
(
window_size
,
(
len
(
samples
)
-
window_size
)
//
stride_size
+
1
)
:type vocab_filepath: basestring
nstrides
=
(
samples
.
strides
[
0
],
samples
.
strides
[
0
]
*
stride_size
)
:param normalizer_manifest_path: Manifest filepath for collecting feature
windows
=
np
.
lib
.
stride_tricks
.
as_strided
(
normalization statistics, e.g. mean, std.
samples
,
shape
=
nshape
,
strides
=
nstrides
)
:type normalizer_manifest_path: basestring
assert
np
.
all
(
:param normalizer_num_samples: Number of instances sampled for collecting
windows
[:,
1
]
==
samples
[
stride_size
:(
stride_size
+
window_size
)])
feature normalization statistics.
# window weighting, compute squared Fast Fourier Transform (fft), scaling
Default is 100.
weighting
=
np
.
hanning
(
window_size
)[:,
None
]
:type normalizer_num_samples: int
fft
=
np
.
fft
.
rfft
(
windows
*
weighting
,
axis
=
0
)
:param max_duration: Audio clips with duration (in seconds) greater than
fft
=
np
.
absolute
(
fft
)
**
2
this will be discarded. Default is 20.0.
scale
=
np
.
sum
(
weighting
**
2
)
*
sample_rate
:type max_duration: float
fft
[
1
:
-
1
,
:]
*=
(
2.0
/
scale
)
:param min_duration: Audio clips with duration (in seconds) smaller than
fft
[(
0
,
-
1
),
:]
/=
scale
this will be discarded. Default is 0.0.
# prepare fft frequency list
:type min_duration: float
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
:param stride_ms: Striding size (in milliseconds) for generating frames.
return
fft
,
freqs
Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for frames. Default is 20.0.
def
vocabulary_from_file
(
vocabulary_path
):
:type window_ms: float
"""
:param max_frequency: Maximun frequency for FFT features. FFT features of
Load vocabulary from file.
frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
"""
"""
if
os
.
path
.
exists
(
vocabulary_path
):
vocab_lines
=
[]
with
open
(
vocabulary_path
,
'r'
)
as
file
:
vocab_lines
.
extend
(
file
.
readlines
())
vocab_list
=
[
line
[:
-
1
]
for
line
in
vocab_lines
]
vocab_dict
=
dict
(
[(
token
,
id
)
for
(
id
,
token
)
in
enumerate
(
vocab_list
)])
return
vocab_dict
,
vocab_list
else
:
raise
ValueError
(
"Vocabulary file %s not found."
,
vocabulary_path
)
def
__init__
(
self
,
vocab_filepath
,
normalizer_manifest_path
,
normalizer_num_samples
=
100
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_frequency
=
None
):
self
.
__max_duration__
=
max_duration
self
.
__min_duration__
=
min_duration
self
.
__stride_ms__
=
stride_ms
self
.
__window_ms__
=
window_ms
self
.
__max_frequency__
=
max_frequency
self
.
__random__
=
random
.
Random
(
RANDOM_SEED
)
# load vocabulary (dictionary)
self
.
__vocab_dict__
,
self
.
__vocab_list__
=
\
self
.
__load_vocabulary_from_file__
(
vocab_filepath
)
# collect normalizer statistics
self
.
__mean__
,
self
.
__std__
=
self
.
__collect_normalizer_statistics__
(
manifest_path
=
normalizer_manifest_path
,
num_samples
=
normalizer_num_samples
)
def
get_vocabulary_size
(
):
def
__audio_featurize__
(
self
,
audio_filename
):
"""
"""
Get vocabulary size
.
Preprocess audio data, including feature extraction, normalization etc.
.
"""
"""
vocab_dict
,
_
=
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
features
=
self
.
__audio_basic_featurize__
(
audio_filename
)
return
len
(
vocab_dict
)
return
self
.
__normalize__
(
features
)
def
__text_featurize__
(
self
,
text
):
"""
Preprocess text data, including tokenizing and token indexing etc..
"""
return
self
.
__convert_text_to_char_index__
(
text
=
text
,
vocabulary
=
self
.
__vocab_dict__
)
def
get_vocabulary
():
def
__audio_basic_featurize__
(
self
,
audio_filename
):
"""
"""
Get vocabulary.
Compute basic (without normalization etc.) features for audio data.
"""
"""
return
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
return
self
.
__spectrogram_from_file__
(
filename
=
audio_filename
,
stride_ms
=
self
.
__stride_ms__
,
window_ms
=
self
.
__window_ms__
,
max_freq
=
self
.
__max_frequency__
)
def
__collect_normalizer_statistics__
(
self
,
manifest_path
,
num_samples
=
100
):
"""
Compute feature normalization statistics, i.e. mean and stddev.
"""
# read manifest
manifest
=
self
.
__read_manifest__
(
manifest_path
=
manifest_path
,
max_duration
=
self
.
__max_duration__
,
min_duration
=
self
.
__min_duration__
)
# sample for statistics
sampled_manifest
=
self
.
__random__
.
sample
(
manifest
,
num_samples
)
# extract spectrogram feature
features
=
[]
for
instance
in
sampled_manifest
:
spectrogram
=
self
.
__audio_basic_featurize__
(
instance
[
"audio_filepath"
])
features
.
append
(
spectrogram
)
features
=
np
.
hstack
(
features
)
mean
=
np
.
mean
(
features
,
axis
=
1
).
reshape
([
-
1
,
1
])
std
=
np
.
std
(
features
,
axis
=
1
).
reshape
([
-
1
,
1
])
return
mean
,
std
def
parse_transcript
(
text
,
vocabulary
):
def
__normalize__
(
self
,
features
,
eps
=
1e-14
):
"""
"""
Convert the transcript text string to list of token index integers
.
Normalize features to be of zero mean and unit stddev
.
"""
"""
return
[
vocabulary
[
w
]
for
w
in
text
]
return
(
features
-
self
.
__mean__
)
/
(
self
.
__std__
+
eps
)
def
__spectrogram_from_file__
(
self
,
filename
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""
Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
audio
,
sample_rate
=
soundfile
.
read
(
filename
)
if
audio
.
ndim
>=
2
:
audio
=
np
.
mean
(
audio
,
1
)
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
spectrogram
,
freqs
=
self
.
__extract_spectrogram__
(
audio
,
window_size
=
window_size
,
stride_size
=
stride_size
,
sample_rate
=
sample_rate
)
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
return
np
.
log
(
spectrogram
[:
ind
,
:]
+
eps
)
def
reader_creator
(
manifest_path
,
def
__extract_spectrogram__
(
self
,
samples
,
window_size
,
stride_size
,
sort_by_duration
=
True
,
sample_rate
):
shuffle
=
False
,
"""
max_duration
=
10.0
,
Compute the spectrogram by FFT for a discrete real signal.
min_duration
=
0.0
):
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
"""
Audio data reader creator.
# extract strided windows
truncate_size
=
(
len
(
samples
)
-
window_size
)
%
stride_size
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
samples
=
samples
[:
len
(
samples
)
-
truncate_size
]
tokenized transcription text.
nshape
=
(
window_size
,
(
len
(
samples
)
-
window_size
)
//
stride_size
+
1
)
nstrides
=
(
samples
.
strides
[
0
],
samples
.
strides
[
0
]
*
stride_size
)
:param manifest_path: Filepath for Manifest of audio clip files.
windows
=
np
.
lib
.
stride_tricks
.
as_strided
(
:type manifest_path: basestring
samples
,
shape
=
nshape
,
strides
=
nstrides
)
:param sort_by_duration: Sort the audio clips by duration if set True.
assert
np
.
all
(
For SortaGrad.
windows
[:,
1
]
==
samples
[
stride_size
:(
stride_size
+
window_size
)])
:type sort_by_duration: bool
# window weighting, squared Fast Fourier Transform (fft), scaling
:param shuffle: Shuffle the audio clips if set True.
weighting
=
np
.
hanning
(
window_size
)[:,
None
]
:type shuffle: bool
fft
=
np
.
fft
.
rfft
(
windows
*
weighting
,
axis
=
0
)
:param max_duration: Audio clips with duration (in seconds) greater than
fft
=
np
.
absolute
(
fft
)
**
2
this will be discarded.
scale
=
np
.
sum
(
weighting
**
2
)
*
sample_rate
:type max_duration: float
fft
[
1
:
-
1
,
:]
*=
(
2.0
/
scale
)
:param min_duration: Audio clips with duration (in seconds) smaller than
fft
[(
0
,
-
1
),
:]
/=
scale
this will be discarded.
# prepare fft frequency list
:type min_duration: float
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
:return: Data reader function.
return
fft
,
freqs
:rtype: callable
"""
if
sort_by_duration
and
shuffle
:
sort_by_duration
=
False
logger
.
warn
(
"When shuffle set to true, "
"sort_by_duration is forced to set False."
)
vocab_dict
,
_
=
vocabulary_from_file
(
ENGLISH_CHAR_VOCAB_FILEPATH
)
def
reader
():
def
__load_vocabulary_from_file__
(
self
,
vocabulary_path
):
# read manifest
"""
manifest_data
=
[]
Load vocabulary from file.
"""
if
not
os
.
path
.
exists
(
vocabulary_path
):
raise
ValueError
(
"Vocabulary file %s not found."
,
vocabulary_path
)
vocab_lines
=
[]
with
open
(
vocabulary_path
,
'r'
)
as
file
:
vocab_lines
.
extend
(
file
.
readlines
())
vocab_list
=
[
line
[:
-
1
]
for
line
in
vocab_lines
]
vocab_dict
=
dict
(
[(
token
,
id
)
for
(
id
,
token
)
in
enumerate
(
vocab_list
)])
return
vocab_dict
,
vocab_list
def
__convert_text_to_char_index__
(
self
,
text
,
vocabulary
):
"""
Convert text string to a list of character index integers.
"""
return
[
vocabulary
[
w
]
for
w
in
text
]
def
__read_manifest__
(
self
,
manifest_path
,
max_duration
,
min_duration
):
"""
Load and parse manifest file.
"""
manifest
=
[]
for
json_line
in
open
(
manifest_path
):
for
json_line
in
open
(
manifest_path
):
try
:
try
:
json_data
=
json
.
loads
(
json_line
)
json_data
=
json
.
loads
(
json_line
)
...
@@ -153,63 +212,172 @@ def reader_creator(manifest_path,
...
@@ -153,63 +212,172 @@ def reader_creator(manifest_path,
raise
ValueError
(
"Error reading manifest: %s"
%
str
(
e
))
raise
ValueError
(
"Error reading manifest: %s"
%
str
(
e
))
if
(
json_data
[
"duration"
]
<=
max_duration
and
if
(
json_data
[
"duration"
]
<=
max_duration
and
json_data
[
"duration"
]
>=
min_duration
):
json_data
[
"duration"
]
>=
min_duration
):
manifest_data
.
append
(
json_data
)
manifest
.
append
(
json_data
)
# sort (by duration) or shuffle manifest
return
manifest
if
sort_by_duration
:
manifest_data
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
if
shuffle
:
random
.
shuffle
(
manifest_data
)
# extract spectrogram feature
for
instance
in
manifest_data
:
spectrogram
=
spectrogram_from_file
(
instance
[
"audio_filepath"
])
text
=
parse_transcript
(
instance
[
"text"
],
vocab_dict
)
yield
(
spectrogram
,
text
)
return
reader
def
__padding_batch__
(
self
,
batch
,
padding_to
=-
1
,
flatten
=
False
):
"""
Padding audio part of features (only in the time axis -- column axis)
with zeros, to make each instance in the batch share the same
audio feature shape.
If `padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1.
def
padding_batch_reader
(
batch_reader
,
padding
=
[
-
1
,
-
1
],
flatten
=
True
):
If `flatten` is set True, audio data will be flatten to be a 1-dim
"""
ndarray. Default is False.
Padding for batches. Return a batch reader.
"""
Each instance in a batch will be padded to be of a same target shape.
The target shape is the largest shape among all the batch instances and
'padding' argument. Therefore, if padding is set [-1, -1], instance will be
padded to have the same shape just within each batch and the shape will
be different across batches; if padding is set
[VERY_LARGE_NUM, VERY_LARGE_NUM], instances in all batches will be padded to
have the same shape of [VERY_LARGE_NUM, VERY_LARGE_NUM].
:param batch_reader: Input batch reader.
:type batch_reader: callable
:param padding: Padding pattern. Details please refer to the above.
:type padding: list
:param flatten: Flatten the tensor to be one dimension.
:type flatten: bool
:return: Batch reader function.
:rtype: callable
"""
def
padding_batch
(
batch
):
new_batch
=
[]
new_batch
=
[]
# get target shape within batch
# get target shape
nshape_list
=
[
padding
]
max_length
=
max
([
audio
.
shape
[
1
]
for
audio
,
text
in
batch
])
for
audio
,
text
in
batch
:
if
padding_to
!=
-
1
:
nshape_list
.
append
(
audio
.
shape
)
if
padding_to
<
max_length
:
target_shape
=
np
.
array
(
nshape_list
).
max
(
axis
=
0
)
raise
ValueError
(
"If padding_to is not -1, it should be greater"
" or equal to the original instance length."
)
max_length
=
padding_to
# padding
# padding
for
audio
,
text
in
batch
:
for
audio
,
text
in
batch
:
pad_shape
=
target_shape
-
audio
.
shape
padded_audio
=
np
.
zeros
([
audio
.
shape
[
0
],
max_length
])
assert
np
.
all
(
pad_shape
>=
0
)
padded_audio
[:,
:
audio
.
shape
[
1
]]
=
audio
padded_audio
=
np
.
pad
(
audio
,
[(
0
,
pad_shape
[
0
]),
(
0
,
pad_shape
[
1
])],
mode
=
"constant"
)
if
flatten
:
if
flatten
:
padded_audio
=
padded_audio
.
flatten
()
padded_audio
=
padded_audio
.
flatten
()
new_batch
.
append
((
padded_audio
,
text
))
new_batch
.
append
((
padded_audio
,
text
))
return
new_batch
return
new_batch
def
new_batch_reader
():
def
instance_reader_creator
(
self
,
for
batch
in
batch_reader
():
manifest_path
,
yield
padding_batch
(
batch
)
sort_by_duration
=
True
,
shuffle
=
False
):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Data reader function.
:rtype: callable
"""
if
sort_by_duration
and
shuffle
:
sort_by_duration
=
False
logger
.
warn
(
"When shuffle set to true, "
"sort_by_duration is forced to set False."
)
def
reader
():
# read manifest
manifest
=
self
.
__read_manifest__
(
manifest_path
=
manifest_path
,
max_duration
=
self
.
__max_duration__
,
min_duration
=
self
.
__min_duration__
)
# sort (by duration) or shuffle manifest
if
sort_by_duration
:
manifest
.
sort
(
key
=
lambda
x
:
x
[
"duration"
])
if
shuffle
:
self
.
__random__
.
shuffle
(
manifest
)
# extract spectrogram feature
for
instance
in
manifest
:
spectrogram
=
self
.
__audio_featurize__
(
instance
[
"audio_filepath"
])
transcript
=
self
.
__text_featurize__
(
instance
[
"text"
])
yield
(
spectrogram
,
transcript
)
return
reader
def
batch_reader_creator
(
self
,
manifest_path
,
batch_size
,
padding_to
=-
1
,
flatten
=
False
,
sort_by_duration
=
True
,
shuffle
=
False
):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Audio features will be padded with zeros to make each instance in the
batch to share the same audio feature shape.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param batch_size: Instance number in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def
batch_reader
():
instance_reader
=
self
.
instance_reader_creator
(
manifest_path
=
manifest_path
,
sort_by_duration
=
sort_by_duration
,
shuffle
=
shuffle
)
batch
=
[]
for
instance
in
instance_reader
():
batch
.
append
(
instance
)
if
len
(
batch
)
==
batch_size
:
yield
self
.
__padding_batch__
(
batch
,
padding_to
,
flatten
)
batch
=
[]
if
len
(
batch
)
>
0
:
yield
self
.
__padding_batch__
(
batch
,
padding_to
,
flatten
)
return
batch_reader
def
vocabulary_size
(
self
):
"""
Get vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return
len
(
self
.
__vocab_list__
)
def
vocabulary_dict
(
self
):
"""
Get vocabulary in dict.
:return: Vocabulary in dict.
:rtype: dict
"""
return
self
.
__vocab_dict__
def
vocabulary_list
(
self
):
"""
Get vocabulary in list.
:return: Vocabulary in list
:rtype: list
"""
return
self
.
__vocab_list__
def
data_name_feeding
(
self
):
"""
Get feeddings (data field name and corresponding field id).
return
new_batch_reader
:return: Feeding dict.
:rtype: dict
"""
feeding
=
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
,
}
return
feeding
train.py
浏览文件 @
e6a34999
...
@@ -5,16 +5,18 @@
...
@@ -5,16 +5,18 @@
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
import
argparse
import
argparse
import
gzip
import
gzip
import
time
import
sys
import
sys
from
model
import
deep_speech2
from
model
import
deep_speech2
import
audio_data_utils
from
audio_data_utils
import
DataGenerator
import
numpy
as
np
#TODO: add WER metric
#TODO: add WER metric
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'Simplified version of DeepSpeech2 trainer.'
)
description
=
'Simplified version of DeepSpeech2 trainer.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch_size"
,
default
=
51
2
,
type
=
int
,
help
=
"Minibatch size."
)
"--batch_size"
,
default
=
3
2
,
type
=
int
,
help
=
"Minibatch size."
)
parser
.
add_argument
(
"--trainer"
,
default
=
1
,
type
=
int
,
help
=
"Trainer number."
)
parser
.
add_argument
(
"--trainer"
,
default
=
1
,
type
=
int
,
help
=
"Trainer number."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_passes"
,
default
=
20
,
type
=
int
,
help
=
"Training pass number."
)
"--num_passes"
,
default
=
20
,
type
=
int
,
help
=
"Training pass number."
)
...
@@ -23,7 +25,7 @@ parser.add_argument(
...
@@ -23,7 +25,7 @@ parser.add_argument(
parser
.
add_argument
(
parser
.
add_argument
(
"--num_rnn_layers"
,
default
=
5
,
type
=
int
,
help
=
"RNN layer number."
)
"--num_rnn_layers"
,
default
=
5
,
type
=
int
,
help
=
"RNN layer number."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--rnn_layer_size"
,
default
=
256
,
type
=
int
,
help
=
"RNN layer cell number."
)
"--rnn_layer_size"
,
default
=
512
,
type
=
int
,
help
=
"RNN layer cell number."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_gpu"
,
default
=
True
,
type
=
bool
,
help
=
"Use gpu or not."
)
"--use_gpu"
,
default
=
True
,
type
=
bool
,
help
=
"Use gpu or not."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -37,13 +39,45 @@ def train():
...
@@ -37,13 +39,45 @@ def train():
"""
"""
DeepSpeech2 training.
DeepSpeech2 training.
"""
"""
# create data readers
data_generator
=
DataGenerator
(
vocab_filepath
=
'eng_vocab.txt'
,
normalizer_manifest_path
=
'./libri.manifest.train'
,
normalizer_num_samples
=
200
,
max_duration
=
20.0
,
min_duration
=
0.0
,
stride_ms
=
10
,
window_ms
=
20
)
train_batch_reader_sortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
True
,
shuffle
=
False
)
train_batch_reader_nosortagrad
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.dev.small'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
True
)
test_batch_reader
=
data_generator
.
batch_reader_creator
(
manifest_path
=
'./libri.manifest.test'
,
batch_size
=
args
.
batch_size
//
args
.
trainer
,
padding_to
=
2000
,
flatten
=
True
,
sort_by_duration
=
False
,
shuffle
=
False
)
feeding
=
data_generator
.
data_name_feeding
()
# create network config
# create network config
dict_size
=
audio_data_utils
.
get_
vocabulary_size
()
dict_size
=
data_generator
.
vocabulary_size
()
audio_data
=
paddle
.
layer
.
data
(
audio_data
=
paddle
.
layer
.
data
(
name
=
"audio_spectrogram"
,
name
=
"audio_spectrogram"
,
height
=
161
,
height
=
161
,
width
=
1
000
,
width
=
2
000
,
type
=
paddle
.
data_type
.
dense_vector
(
161
000
))
type
=
paddle
.
data_type
.
dense_vector
(
322
000
))
text_data
=
paddle
.
layer
.
data
(
text_data
=
paddle
.
layer
.
data
(
name
=
"transcript_text"
,
name
=
"transcript_text"
,
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
type
=
paddle
.
data_type
.
integer_value_sequence
(
dict_size
))
...
@@ -58,47 +92,26 @@ def train():
...
@@ -58,47 +92,26 @@ def train():
# create parameters and optimizer
# create parameters and optimizer
parameters
=
paddle
.
parameters
.
create
(
cost
)
parameters
=
paddle
.
parameters
.
create
(
cost
)
optimizer
=
paddle
.
optimizer
.
Adam
(
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
5e-
4
,
gradient_clipping_threshold
=
400
)
learning_rate
=
5e-
5
,
gradient_clipping_threshold
=
400
)
trainer
=
paddle
.
trainer
.
SGD
(
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
# create data readers
feeding
=
{
"audio_spectrogram"
:
0
,
"transcript_text"
:
1
,
}
train_batch_reader_with_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
train_batch_reader_without_sortagrad
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.train"
,
sort_by_duration
=
False
,
shuffle
=
True
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
test_batch_reader
=
audio_data_utils
.
padding_batch_reader
(
paddle
.
batch
(
audio_data_utils
.
reader_creator
(
manifest_path
=
"./libri.manifest.dev"
,
sort_by_duration
=
False
),
batch_size
=
args
.
batch_size
//
args
.
trainer
),
padding
=
[
-
1
,
1000
])
# create event handler
# create event handler
def
event_handler
(
event
):
def
event_handler
(
event
):
global
start_time
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
10
==
0
:
if
event
.
batch_id
%
10
==
0
:
print
"
/
nPass: %d, Batch: %d, TrainCost: %f"
%
(
print
"
\
n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
else
:
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
if
isinstance
(
event
,
paddle
.
event
.
BeginPass
):
start_time
=
time
.
time
()
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
test_batch_reader
,
feeding
=
feeding
)
result
=
trainer
.
test
(
reader
=
test_batch_reader
,
feeding
=
feeding
)
print
"Pass: %d, TestCost: %s"
%
(
event
.
pass_id
,
result
.
cost
)
print
"
\n
------- Time: %d, Pass: %d, TestCost: %s"
%
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
)
with
gzip
.
open
(
"params.tar.gz"
,
'w'
)
as
f
:
with
gzip
.
open
(
"params.tar.gz"
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
parameters
.
to_tar
(
f
)
...
@@ -106,14 +119,14 @@ def train():
...
@@ -106,14 +119,14 @@ def train():
# first pass with sortagrad
# first pass with sortagrad
if
args
.
use_sortagrad
:
if
args
.
use_sortagrad
:
trainer
.
train
(
trainer
.
train
(
reader
=
train_batch_reader_
with_
sortagrad
,
reader
=
train_batch_reader_sortagrad
,
event_handler
=
event_handler
,
event_handler
=
event_handler
,
num_passes
=
1
,
num_passes
=
1
,
feeding
=
feeding
)
feeding
=
feeding
)
args
.
num_passes
-=
1
args
.
num_passes
-=
1
# other passes without sortagrad
# other passes without sortagrad
trainer
.
train
(
trainer
.
train
(
reader
=
train_batch_reader_
without_
sortagrad
,
reader
=
train_batch_reader_
no
sortagrad
,
event_handler
=
event_handler
,
event_handler
=
event_handler
,
num_passes
=
args
.
num_passes
,
num_passes
=
args
.
num_passes
,
feeding
=
feeding
)
feeding
=
feeding
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录