Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7324d41e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7324d41e
编写于
6月 26, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/models
into ctc_decoder_dev
上级
80338456
cdd52ac2
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
180 addition
and
83 deletion
+180
-83
README.md
README.md
+2
-2
data_utils/audio.py
data_utils/audio.py
+91
-66
data_utils/augmentor/augmentation.py
data_utils/augmentor/augmentation.py
+3
-0
data_utils/augmentor/volume_perturb.py
data_utils/augmentor/volume_perturb.py
+1
-1
data_utils/data.py
data_utils/data.py
+6
-1
data_utils/featurizer/audio_featurizer.py
data_utils/featurizer/audio_featurizer.py
+40
-2
data_utils/featurizer/speech_featurizer.py
data_utils/featurizer/speech_featurizer.py
+21
-3
infer.py
infer.py
+1
-1
setup.sh
setup.sh
+3
-0
train.py
train.py
+12
-7
未找到文件。
README.md
浏览文件 @
7324d41e
...
@@ -51,13 +51,13 @@ python compute_mean_std.py --help
...
@@ -51,13 +51,13 @@ python compute_mean_std.py --help
For GPU Training:
For GPU Training:
```
```
CUDA_VISIBLE_DEVICES=0,1,2,3
python train.py --trainer_count 4
CUDA_VISIBLE_DEVICES=0,1,2,3
,4,5,6,7 python train.py
```
```
For CPU Training:
For CPU Training:
```
```
python train.py --
trainer_count 8 --
use_gpu False
python train.py --use_gpu False
```
```
More help for arguments:
More help for arguments:
...
...
data_utils/audio.py
浏览文件 @
7324d41e
...
@@ -66,6 +66,54 @@ class AudioSegment(object):
...
@@ -66,6 +66,54 @@ class AudioSegment(object):
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
slice_from_file
(
cls
,
file
,
start
=
None
,
end
=
None
):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile
=
soundfile
.
SoundFile
(
file
)
sample_rate
=
sndfile
.
samplerate
duration
=
float
(
len
(
sndfile
))
/
sample_rate
start
=
0.
if
start
is
None
else
start
end
=
0.
if
end
is
None
else
end
if
start
<
0.0
:
start
+=
duration
if
end
<
0.0
:
end
+=
duration
if
start
<
0.0
:
raise
ValueError
(
"The slice start position (%f s) is out of "
"bounds."
%
start
)
if
end
<
0.0
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds."
%
end
)
if
start
>
end
:
raise
ValueError
(
"The slice start position (%f s) is later than "
"the slice end position (%f s)."
%
(
start
,
end
))
if
end
>
duration
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds "
"(> %f s)"
%
(
end
,
duration
))
start_frame
=
int
(
start
*
sample_rate
)
end_frame
=
int
(
end
*
sample_rate
)
sndfile
.
seek
(
start_frame
)
data
=
sndfile
.
read
(
frames
=
end_frame
-
start_frame
,
dtype
=
'float32'
)
return
cls
(
data
,
sample_rate
)
@
classmethod
@
classmethod
def
from_bytes
(
cls
,
bytes
):
def
from_bytes
(
cls
,
bytes
):
"""Create audio segment from a byte string containing audio samples.
"""Create audio segment from a byte string containing audio samples.
...
@@ -105,6 +153,20 @@ class AudioSegment(object):
...
@@ -105,6 +153,20 @@ class AudioSegment(object):
samples
=
np
.
concatenate
([
seg
.
samples
for
seg
in
segments
])
samples
=
np
.
concatenate
([
seg
.
samples
for
seg
in
segments
])
return
cls
(
samples
,
sample_rate
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
make_silence
(
cls
,
duration
,
sample_rate
):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples
=
np
.
zeros
(
int
(
duration
*
sample_rate
))
return
cls
(
samples
,
sample_rate
)
def
to_wav_file
(
self
,
filepath
,
dtype
=
'float32'
):
def
to_wav_file
(
self
,
filepath
,
dtype
=
'float32'
):
"""Save audio segment to disk as wav file.
"""Save audio segment to disk as wav file.
...
@@ -130,68 +192,6 @@ class AudioSegment(object):
...
@@ -130,68 +192,6 @@ class AudioSegment(object):
format
=
'WAV'
,
format
=
'WAV'
,
subtype
=
subtype_map
[
dtype
])
subtype
=
subtype_map
[
dtype
])
@
classmethod
def
slice_from_file
(
cls
,
file
,
start
=
None
,
end
=
None
):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile
=
soundfile
.
SoundFile
(
file
)
sample_rate
=
sndfile
.
samplerate
duration
=
float
(
len
(
sndfile
))
/
sample_rate
start
=
0.
if
start
is
None
else
start
end
=
0.
if
end
is
None
else
end
if
start
<
0.0
:
start
+=
duration
if
end
<
0.0
:
end
+=
duration
if
start
<
0.0
:
raise
ValueError
(
"The slice start position (%f s) is out of "
"bounds."
%
start
)
if
end
<
0.0
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds."
%
end
)
if
start
>
end
:
raise
ValueError
(
"The slice start position (%f s) is later than "
"the slice end position (%f s)."
%
(
start
,
end
))
if
end
>
duration
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds "
"(> %f s)"
%
(
end
,
duration
))
start_frame
=
int
(
start
*
sample_rate
)
end_frame
=
int
(
end
*
sample_rate
)
sndfile
.
seek
(
start_frame
)
data
=
sndfile
.
read
(
frames
=
end_frame
-
start_frame
,
dtype
=
'float32'
)
return
cls
(
data
,
sample_rate
)
@
classmethod
def
make_silence
(
cls
,
duration
,
sample_rate
):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples
=
np
.
zeros
(
int
(
duration
*
sample_rate
))
return
cls
(
samples
,
sample_rate
)
def
superimpose
(
self
,
other
):
def
superimpose
(
self
,
other
):
"""Add samples from another segment to those of this segment
"""Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation).
(sample-wise addition, not segment concatenation).
...
@@ -225,7 +225,7 @@ class AudioSegment(object):
...
@@ -225,7 +225,7 @@ class AudioSegment(object):
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
return
samples
.
tostring
()
return
samples
.
tostring
()
def
apply_gain
(
self
,
gain
):
def
gain_db
(
self
,
gain
):
"""Apply gain in decibels to samples.
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
Note that this is an in-place transformation.
...
@@ -278,7 +278,7 @@ class AudioSegment(object):
...
@@ -278,7 +278,7 @@ class AudioSegment(object):
"Unable to normalize segment to %f dB because the "
"Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)"
%
"the probable gain have exceeds max_gain_db (%f dB)"
%
(
target_db
,
max_gain_db
))
(
target_db
,
max_gain_db
))
self
.
apply_gain
(
min
(
max_gain_db
,
target_db
-
self
.
rms_db
))
self
.
gain_db
(
min
(
max_gain_db
,
target_db
-
self
.
rms_db
))
def
normalize_online_bayesian
(
self
,
def
normalize_online_bayesian
(
self
,
target_db
,
target_db
,
...
@@ -319,7 +319,7 @@ class AudioSegment(object):
...
@@ -319,7 +319,7 @@ class AudioSegment(object):
rms_estimate_db
=
10
*
np
.
log10
(
mean_squared_estimate
)
rms_estimate_db
=
10
*
np
.
log10
(
mean_squared_estimate
)
# Compute required time-varying gain.
# Compute required time-varying gain.
gain_db
=
target_db
-
rms_estimate_db
gain_db
=
target_db
-
rms_estimate_db
self
.
apply_gain
(
gain_db
)
self
.
gain_db
(
gain_db
)
def
resample
(
self
,
target_sample_rate
,
quality
=
'sinc_medium'
):
def
resample
(
self
,
target_sample_rate
,
quality
=
'sinc_medium'
):
"""Resample the audio to a target sample rate.
"""Resample the audio to a target sample rate.
...
@@ -366,6 +366,31 @@ class AudioSegment(object):
...
@@ -366,6 +366,31 @@ class AudioSegment(object):
raise
ValueError
(
"Unknown value for the sides %s"
%
sides
)
raise
ValueError
(
"Unknown value for the sides %s"
%
sides
)
self
.
_samples
=
padded
.
_samples
self
.
_samples
=
padded
.
_samples
def
shift
(
self
,
shift_ms
):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if
abs
(
shift_ms
)
/
1000.0
>
self
.
duration
:
raise
ValueError
(
"Absolute value of shift_ms should be smaller "
"than audio duration."
)
shift_samples
=
int
(
shift_ms
*
self
.
_sample_rate
/
1000
)
if
shift_samples
>
0
:
# time advance
self
.
_samples
[:
-
shift_samples
]
=
self
.
_samples
[
shift_samples
:]
self
.
_samples
[
-
shift_samples
:]
=
0
elif
shift_samples
<
0
:
# time delay
self
.
_samples
[
-
shift_samples
:]
=
self
.
_samples
[:
shift_samples
]
self
.
_samples
[:
-
shift_samples
]
=
0
def
subsegment
(
self
,
start_sec
=
None
,
end_sec
=
None
):
def
subsegment
(
self
,
start_sec
=
None
,
end_sec
=
None
):
"""Cut the AudioSegment between given boundaries.
"""Cut the AudioSegment between given boundaries.
...
@@ -505,7 +530,7 @@ class AudioSegment(object):
...
@@ -505,7 +530,7 @@ class AudioSegment(object):
noise_gain_db
=
min
(
self
.
rms_db
-
noise
.
rms_db
-
snr_dB
,
max_gain_db
)
noise_gain_db
=
min
(
self
.
rms_db
-
noise
.
rms_db
-
snr_dB
,
max_gain_db
)
noise_new
=
copy
.
deepcopy
(
noise
)
noise_new
=
copy
.
deepcopy
(
noise
)
noise_new
.
random_subsegment
(
self
.
duration
,
rng
=
rng
)
noise_new
.
random_subsegment
(
self
.
duration
,
rng
=
rng
)
noise_new
.
apply_gain
(
noise_gain_db
)
noise_new
.
gain_db
(
noise_gain_db
)
self
.
superimpose
(
noise_new
)
self
.
superimpose
(
noise_new
)
@
property
@
property
...
...
data_utils/augmentor/augmentation.py
浏览文件 @
7324d41e
...
@@ -6,6 +6,7 @@ from __future__ import print_function
...
@@ -6,6 +6,7 @@ from __future__ import print_function
import
json
import
json
import
random
import
random
from
data_utils.augmentor.volume_perturb
import
VolumePerturbAugmentor
from
data_utils.augmentor.volume_perturb
import
VolumePerturbAugmentor
from
data_utils.augmentor.shift_perturb
import
ShiftPerturbAugmentor
class
AugmentationPipeline
(
object
):
class
AugmentationPipeline
(
object
):
...
@@ -76,5 +77,7 @@ class AugmentationPipeline(object):
...
@@ -76,5 +77,7 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params."""
"""Return an augmentation model by the type name, and pass in params."""
if
augmentor_type
==
"volume"
:
if
augmentor_type
==
"volume"
:
return
VolumePerturbAugmentor
(
self
.
_rng
,
**
params
)
return
VolumePerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"shift"
:
return
ShiftPerturbAugmentor
(
self
.
_rng
,
**
params
)
else
:
else
:
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
data_utils/augmentor/volume_perturb.py
浏览文件 @
7324d41e
...
@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
...
@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
:type audio_segment: AudioSegmenet|SpeechSegment
"""
"""
gain
=
self
.
_rng
.
uniform
(
min_gain_dBFS
,
max_gain_dBFS
)
gain
=
self
.
_rng
.
uniform
(
self
.
_min_gain_dBFS
,
self
.
_
max_gain_dBFS
)
audio_segment
.
apply_gain
(
gain
)
audio_segment
.
apply_gain
(
gain
)
data_utils/data.py
浏览文件 @
7324d41e
...
@@ -45,6 +45,9 @@ class DataGenerator(object):
...
@@ -45,6 +45,9 @@ class DataGenerator(object):
:types max_freq: None|float
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data.
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:type num_threads: int
:param random_seed: Random seed.
:param random_seed: Random seed.
...
@@ -61,6 +64,7 @@ class DataGenerator(object):
...
@@ -61,6 +64,7 @@ class DataGenerator(object):
window_ms
=
20.0
,
window_ms
=
20.0
,
max_freq
=
None
,
max_freq
=
None
,
specgram_type
=
'linear'
,
specgram_type
=
'linear'
,
use_dB_normalization
=
True
,
num_threads
=
multiprocessing
.
cpu_count
(),
num_threads
=
multiprocessing
.
cpu_count
(),
random_seed
=
0
):
random_seed
=
0
):
self
.
_max_duration
=
max_duration
self
.
_max_duration
=
max_duration
...
@@ -73,7 +77,8 @@ class DataGenerator(object):
...
@@ -73,7 +77,8 @@ class DataGenerator(object):
specgram_type
=
specgram_type
,
specgram_type
=
specgram_type
,
stride_ms
=
stride_ms
,
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
window_ms
=
window_ms
,
max_freq
=
max_freq
)
max_freq
=
max_freq
,
use_dB_normalization
=
use_dB_normalization
)
self
.
_num_threads
=
num_threads
self
.
_num_threads
=
num_threads
self
.
_rng
=
random
.
Random
(
random_seed
)
self
.
_rng
=
random
.
Random
(
random_seed
)
self
.
_epoch
=
0
self
.
_epoch
=
0
...
...
data_utils/featurizer/audio_featurizer.py
浏览文件 @
7324d41e
...
@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
...
@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
corresponding to frequencies between [0, max_freq] are
returned.
returned.
:types max_freq: None|float
:types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
specgram_type
=
'linear'
,
specgram_type
=
'linear'
,
stride_ms
=
10.0
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
window_ms
=
20.0
,
max_freq
=
None
):
max_freq
=
None
,
target_sample_rate
=
16000
,
use_dB_normalization
=
True
,
target_dB
=-
20
):
self
.
_specgram_type
=
specgram_type
self
.
_specgram_type
=
specgram_type
self
.
_stride_ms
=
stride_ms
self
.
_stride_ms
=
stride_ms
self
.
_window_ms
=
window_ms
self
.
_window_ms
=
window_ms
self
.
_max_freq
=
max_freq
self
.
_max_freq
=
max_freq
self
.
_target_sample_rate
=
target_sample_rate
self
.
_use_dB_normalization
=
use_dB_normalization
self
.
_target_dB
=
target_dB
def
featurize
(
self
,
audio_segment
):
def
featurize
(
self
,
audio_segment
,
allow_downsampling
=
True
,
allow_upsamplling
=
True
):
"""Extract audio features from AudioSegment or SpeechSegment.
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
"""
# upsampling or downsampling
if
((
audio_segment
.
sample_rate
>
self
.
_target_sample_rate
and
allow_downsampling
)
or
(
audio_segment
.
sample_rate
<
self
.
_target_sample_rate
and
allow_upsampling
)):
audio_segment
.
resample
(
self
.
_target_sample_rate
)
if
audio_segment
.
sample_rate
!=
self
.
_target_sample_rate
:
raise
ValueError
(
"Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on."
)
# decibel normalization
if
self
.
_use_dB_normalization
:
audio_segment
.
normalize
(
target_db
=
self
.
_target_dB
)
# extract spectrogram
return
self
.
_compute_specgram
(
audio_segment
.
samples
,
return
self
.
_compute_specgram
(
audio_segment
.
samples
,
audio_segment
.
sample_rate
)
audio_segment
.
sample_rate
)
...
...
data_utils/featurizer/speech_featurizer.py
浏览文件 @
7324d41e
...
@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
...
@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
corresponding to frequencies between [0, max_freq] are
returned.
returned.
:types max_freq: None|float
:types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
...
@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
specgram_type
=
'linear'
,
specgram_type
=
'linear'
,
stride_ms
=
10.0
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
window_ms
=
20.0
,
max_freq
=
None
):
max_freq
=
None
,
self
.
_audio_featurizer
=
AudioFeaturizer
(
specgram_type
,
stride_ms
,
target_sample_rate
=
16000
,
window_ms
,
max_freq
)
use_dB_normalization
=
True
,
target_dB
=-
20
):
self
.
_audio_featurizer
=
AudioFeaturizer
(
specgram_type
=
specgram_type
,
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
max_freq
=
max_freq
,
target_sample_rate
=
target_sample_rate
,
use_dB_normalization
=
use_dB_normalization
,
target_dB
=
target_dB
)
self
.
_text_featurizer
=
TextFeaturizer
(
vocab_filepath
)
self
.
_text_featurizer
=
TextFeaturizer
(
vocab_filepath
)
def
featurize
(
self
,
speech_segment
):
def
featurize
(
self
,
speech_segment
):
...
...
infer.py
浏览文件 @
7324d41e
...
@@ -58,7 +58,7 @@ parser.add_argument(
...
@@ -58,7 +58,7 @@ parser.add_argument(
help
=
"Manifest path for decoding. (default: %(default)s)"
)
help
=
"Manifest path for decoding. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_filepath"
,
"--model_filepath"
,
default
=
'
./params
.tar.gz'
,
default
=
'
checkpoints/params.latest
.tar.gz'
,
type
=
str
,
type
=
str
,
help
=
"Model filepath. (default: %(default)s)"
)
help
=
"Model filepath. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
setup.sh
浏览文件 @
7324d41e
...
@@ -27,4 +27,7 @@ if [ $? != 0 ]; then
...
@@ -27,4 +27,7 @@ if [ $? != 0 ]; then
exit
1
exit
1
fi
fi
# prepare ./checkpoints
mkdir
checkpoints
echo
"Install all dependencies successfully."
echo
"Install all dependencies successfully."
train.py
浏览文件 @
7324d41e
...
@@ -17,10 +17,10 @@ import utils
...
@@ -17,10 +17,10 @@ import utils
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch_size"
,
default
=
32
,
type
=
int
,
help
=
"Minibatch size."
)
"--batch_size"
,
default
=
256
,
type
=
int
,
help
=
"Minibatch size."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_passes"
,
"--num_passes"
,
default
=
20
,
default
=
20
0
,
type
=
int
,
type
=
int
,
help
=
"Training pass number. (default: %(default)s)"
)
help
=
"Training pass number. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -55,7 +55,7 @@ parser.add_argument(
...
@@ -55,7 +55,7 @@ parser.add_argument(
help
=
"Use sortagrad or not. (default: %(default)s)"
)
help
=
"Use sortagrad or not. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_duration"
,
"--max_duration"
,
default
=
100
.0
,
default
=
27
.0
,
type
=
float
,
type
=
float
,
help
=
"Audios with duration larger than this will be discarded. "
help
=
"Audios with duration larger than this will be discarded. "
"(default: %(default)s)"
)
"(default: %(default)s)"
)
...
@@ -67,13 +67,13 @@ parser.add_argument(
...
@@ -67,13 +67,13 @@ parser.add_argument(
"(default: %(default)s)"
)
"(default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--shuffle_method"
,
"--shuffle_method"
,
default
=
'
instance_shuffle
'
,
default
=
'
batch_shuffle_clipped
'
,
type
=
str
,
type
=
str
,
help
=
"Shuffle method: 'instance_shuffle', 'batch_shuffle', "
help
=
"Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)"
)
"'batch_shuffle_batch'. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--trainer_count"
,
"--trainer_count"
,
default
=
4
,
default
=
8
,
type
=
int
,
type
=
int
,
help
=
"Trainer number. (default: %(default)s)"
)
help
=
"Trainer number. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -110,7 +110,9 @@ parser.add_argument(
...
@@ -110,7 +110,9 @@ parser.add_argument(
"the existing model of this path. (default: %(default)s)"
)
"the existing model of this path. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--augmentation_config"
,
"--augmentation_config"
,
default
=
'{}'
,
default
=
'[{"type": "shift", '
'"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
'"prob": 1.0}]'
,
type
=
str
,
type
=
str
,
help
=
"Augmentation configuration in json-format. "
help
=
"Augmentation configuration in json-format. "
"(default: %(default)s)"
)
"(default: %(default)s)"
)
...
@@ -189,7 +191,7 @@ def train():
...
@@ -189,7 +191,7 @@ def train():
print
(
"
\n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
print
(
"
\n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
event
.
pass_id
,
event
.
batch_id
+
1
,
cost_sum
/
cost_counter
))
event
.
pass_id
,
event
.
batch_id
+
1
,
cost_sum
/
cost_counter
))
cost_sum
,
cost_counter
=
0.0
,
0
cost_sum
,
cost_counter
=
0.0
,
0
with
gzip
.
open
(
"
params
.tar.gz"
,
'w'
)
as
f
:
with
gzip
.
open
(
"
checkpoints/params.latest
.tar.gz"
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
parameters
.
to_tar
(
f
)
else
:
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
write
(
'.'
)
...
@@ -202,6 +204,9 @@ def train():
...
@@ -202,6 +204,9 @@ def train():
reader
=
test_batch_reader
,
feeding
=
test_generator
.
feeding
)
reader
=
test_batch_reader
,
feeding
=
test_generator
.
feeding
)
print
(
"
\n
------- Time: %d sec, Pass: %d, ValidationCost: %s"
%
print
(
"
\n
------- Time: %d sec, Pass: %d, ValidationCost: %s"
%
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
))
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
))
with
gzip
.
open
(
"checkpoints/params.pass-%d.tar.gz"
%
event
.
pass_id
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
# run train
# run train
trainer
.
train
(
trainer
.
train
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录