Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
68caa8ca
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68caa8ca
编写于
6月 26, 2017
作者:
X
Xinghai Sun
提交者:
GitHub
6月 26, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #114 from xinghai-sun/ds2_feature
Improve audio featurizer and add shift augmentor for DS2.
上级
c47f940f
d8348e24
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
180 addition
and
83 deletion
+180
-83
deep_speech_2/README.md
deep_speech_2/README.md
+2
-2
deep_speech_2/data_utils/audio.py
deep_speech_2/data_utils/audio.py
+91
-66
deep_speech_2/data_utils/augmentor/augmentation.py
deep_speech_2/data_utils/augmentor/augmentation.py
+3
-0
deep_speech_2/data_utils/augmentor/volume_perturb.py
deep_speech_2/data_utils/augmentor/volume_perturb.py
+1
-1
deep_speech_2/data_utils/data.py
deep_speech_2/data_utils/data.py
+6
-1
deep_speech_2/data_utils/featurizer/audio_featurizer.py
deep_speech_2/data_utils/featurizer/audio_featurizer.py
+40
-2
deep_speech_2/data_utils/featurizer/speech_featurizer.py
deep_speech_2/data_utils/featurizer/speech_featurizer.py
+21
-3
deep_speech_2/infer.py
deep_speech_2/infer.py
+1
-1
deep_speech_2/setup.sh
deep_speech_2/setup.sh
+3
-0
deep_speech_2/train.py
deep_speech_2/train.py
+12
-7
未找到文件。
deep_speech_2/README.md
浏览文件 @
68caa8ca
...
@@ -51,13 +51,13 @@ python compute_mean_std.py --help
...
@@ -51,13 +51,13 @@ python compute_mean_std.py --help
For GPU Training:
For GPU Training:
```
```
CUDA_VISIBLE_DEVICES=0,1,2,3
python train.py --trainer_count 4
CUDA_VISIBLE_DEVICES=0,1,2,3
,4,5,6,7 python train.py
```
```
For CPU Training:
For CPU Training:
```
```
python train.py --
trainer_count 8 --
use_gpu False
python train.py --use_gpu False
```
```
More help for arguments:
More help for arguments:
...
...
deep_speech_2/data_utils/audio.py
浏览文件 @
68caa8ca
...
@@ -66,6 +66,54 @@ class AudioSegment(object):
...
@@ -66,6 +66,54 @@ class AudioSegment(object):
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
slice_from_file
(
cls
,
file
,
start
=
None
,
end
=
None
):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile
=
soundfile
.
SoundFile
(
file
)
sample_rate
=
sndfile
.
samplerate
duration
=
float
(
len
(
sndfile
))
/
sample_rate
start
=
0.
if
start
is
None
else
start
end
=
0.
if
end
is
None
else
end
if
start
<
0.0
:
start
+=
duration
if
end
<
0.0
:
end
+=
duration
if
start
<
0.0
:
raise
ValueError
(
"The slice start position (%f s) is out of "
"bounds."
%
start
)
if
end
<
0.0
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds."
%
end
)
if
start
>
end
:
raise
ValueError
(
"The slice start position (%f s) is later than "
"the slice end position (%f s)."
%
(
start
,
end
))
if
end
>
duration
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds "
"(> %f s)"
%
(
end
,
duration
))
start_frame
=
int
(
start
*
sample_rate
)
end_frame
=
int
(
end
*
sample_rate
)
sndfile
.
seek
(
start_frame
)
data
=
sndfile
.
read
(
frames
=
end_frame
-
start_frame
,
dtype
=
'float32'
)
return
cls
(
data
,
sample_rate
)
@
classmethod
@
classmethod
def
from_bytes
(
cls
,
bytes
):
def
from_bytes
(
cls
,
bytes
):
"""Create audio segment from a byte string containing audio samples.
"""Create audio segment from a byte string containing audio samples.
...
@@ -105,6 +153,20 @@ class AudioSegment(object):
...
@@ -105,6 +153,20 @@ class AudioSegment(object):
samples
=
np
.
concatenate
([
seg
.
samples
for
seg
in
segments
])
samples
=
np
.
concatenate
([
seg
.
samples
for
seg
in
segments
])
return
cls
(
samples
,
sample_rate
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
make_silence
(
cls
,
duration
,
sample_rate
):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples
=
np
.
zeros
(
int
(
duration
*
sample_rate
))
return
cls
(
samples
,
sample_rate
)
def
to_wav_file
(
self
,
filepath
,
dtype
=
'float32'
):
def
to_wav_file
(
self
,
filepath
,
dtype
=
'float32'
):
"""Save audio segment to disk as wav file.
"""Save audio segment to disk as wav file.
...
@@ -130,68 +192,6 @@ class AudioSegment(object):
...
@@ -130,68 +192,6 @@ class AudioSegment(object):
format
=
'WAV'
,
format
=
'WAV'
,
subtype
=
subtype_map
[
dtype
])
subtype
=
subtype_map
[
dtype
])
@
classmethod
def
slice_from_file
(
cls
,
file
,
start
=
None
,
end
=
None
):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: basestring|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile
=
soundfile
.
SoundFile
(
file
)
sample_rate
=
sndfile
.
samplerate
duration
=
float
(
len
(
sndfile
))
/
sample_rate
start
=
0.
if
start
is
None
else
start
end
=
0.
if
end
is
None
else
end
if
start
<
0.0
:
start
+=
duration
if
end
<
0.0
:
end
+=
duration
if
start
<
0.0
:
raise
ValueError
(
"The slice start position (%f s) is out of "
"bounds."
%
start
)
if
end
<
0.0
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds."
%
end
)
if
start
>
end
:
raise
ValueError
(
"The slice start position (%f s) is later than "
"the slice end position (%f s)."
%
(
start
,
end
))
if
end
>
duration
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds "
"(> %f s)"
%
(
end
,
duration
))
start_frame
=
int
(
start
*
sample_rate
)
end_frame
=
int
(
end
*
sample_rate
)
sndfile
.
seek
(
start_frame
)
data
=
sndfile
.
read
(
frames
=
end_frame
-
start_frame
,
dtype
=
'float32'
)
return
cls
(
data
,
sample_rate
)
@
classmethod
def
make_silence
(
cls
,
duration
,
sample_rate
):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples
=
np
.
zeros
(
int
(
duration
*
sample_rate
))
return
cls
(
samples
,
sample_rate
)
def
superimpose
(
self
,
other
):
def
superimpose
(
self
,
other
):
"""Add samples from another segment to those of this segment
"""Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation).
(sample-wise addition, not segment concatenation).
...
@@ -225,7 +225,7 @@ class AudioSegment(object):
...
@@ -225,7 +225,7 @@ class AudioSegment(object):
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
return
samples
.
tostring
()
return
samples
.
tostring
()
def
apply_gain
(
self
,
gain
):
def
gain_db
(
self
,
gain
):
"""Apply gain in decibels to samples.
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
Note that this is an in-place transformation.
...
@@ -278,7 +278,7 @@ class AudioSegment(object):
...
@@ -278,7 +278,7 @@ class AudioSegment(object):
"Unable to normalize segment to %f dB because the "
"Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)"
%
"the probable gain have exceeds max_gain_db (%f dB)"
%
(
target_db
,
max_gain_db
))
(
target_db
,
max_gain_db
))
self
.
apply_gain
(
min
(
max_gain_db
,
target_db
-
self
.
rms_db
))
self
.
gain_db
(
min
(
max_gain_db
,
target_db
-
self
.
rms_db
))
def
normalize_online_bayesian
(
self
,
def
normalize_online_bayesian
(
self
,
target_db
,
target_db
,
...
@@ -319,7 +319,7 @@ class AudioSegment(object):
...
@@ -319,7 +319,7 @@ class AudioSegment(object):
rms_estimate_db
=
10
*
np
.
log10
(
mean_squared_estimate
)
rms_estimate_db
=
10
*
np
.
log10
(
mean_squared_estimate
)
# Compute required time-varying gain.
# Compute required time-varying gain.
gain_db
=
target_db
-
rms_estimate_db
gain_db
=
target_db
-
rms_estimate_db
self
.
apply_gain
(
gain_db
)
self
.
gain_db
(
gain_db
)
def
resample
(
self
,
target_sample_rate
,
quality
=
'sinc_medium'
):
def
resample
(
self
,
target_sample_rate
,
quality
=
'sinc_medium'
):
"""Resample the audio to a target sample rate.
"""Resample the audio to a target sample rate.
...
@@ -366,6 +366,31 @@ class AudioSegment(object):
...
@@ -366,6 +366,31 @@ class AudioSegment(object):
raise
ValueError
(
"Unknown value for the sides %s"
%
sides
)
raise
ValueError
(
"Unknown value for the sides %s"
%
sides
)
self
.
_samples
=
padded
.
_samples
self
.
_samples
=
padded
.
_samples
def
shift
(
self
,
shift_ms
):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if
abs
(
shift_ms
)
/
1000.0
>
self
.
duration
:
raise
ValueError
(
"Absolute value of shift_ms should be smaller "
"than audio duration."
)
shift_samples
=
int
(
shift_ms
*
self
.
_sample_rate
/
1000
)
if
shift_samples
>
0
:
# time advance
self
.
_samples
[:
-
shift_samples
]
=
self
.
_samples
[
shift_samples
:]
self
.
_samples
[
-
shift_samples
:]
=
0
elif
shift_samples
<
0
:
# time delay
self
.
_samples
[
-
shift_samples
:]
=
self
.
_samples
[:
shift_samples
]
self
.
_samples
[:
-
shift_samples
]
=
0
def
subsegment
(
self
,
start_sec
=
None
,
end_sec
=
None
):
def
subsegment
(
self
,
start_sec
=
None
,
end_sec
=
None
):
"""Cut the AudioSegment between given boundaries.
"""Cut the AudioSegment between given boundaries.
...
@@ -505,7 +530,7 @@ class AudioSegment(object):
...
@@ -505,7 +530,7 @@ class AudioSegment(object):
noise_gain_db
=
min
(
self
.
rms_db
-
noise
.
rms_db
-
snr_dB
,
max_gain_db
)
noise_gain_db
=
min
(
self
.
rms_db
-
noise
.
rms_db
-
snr_dB
,
max_gain_db
)
noise_new
=
copy
.
deepcopy
(
noise
)
noise_new
=
copy
.
deepcopy
(
noise
)
noise_new
.
random_subsegment
(
self
.
duration
,
rng
=
rng
)
noise_new
.
random_subsegment
(
self
.
duration
,
rng
=
rng
)
noise_new
.
apply_gain
(
noise_gain_db
)
noise_new
.
gain_db
(
noise_gain_db
)
self
.
superimpose
(
noise_new
)
self
.
superimpose
(
noise_new
)
@
property
@
property
...
...
deep_speech_2/data_utils/augmentor/augmentation.py
浏览文件 @
68caa8ca
...
@@ -6,6 +6,7 @@ from __future__ import print_function
...
@@ -6,6 +6,7 @@ from __future__ import print_function
import
json
import
json
import
random
import
random
from
data_utils.augmentor.volume_perturb
import
VolumePerturbAugmentor
from
data_utils.augmentor.volume_perturb
import
VolumePerturbAugmentor
from
data_utils.augmentor.shift_perturb
import
ShiftPerturbAugmentor
class
AugmentationPipeline
(
object
):
class
AugmentationPipeline
(
object
):
...
@@ -76,5 +77,7 @@ class AugmentationPipeline(object):
...
@@ -76,5 +77,7 @@ class AugmentationPipeline(object):
"""Return an augmentation model by the type name, and pass in params."""
"""Return an augmentation model by the type name, and pass in params."""
if
augmentor_type
==
"volume"
:
if
augmentor_type
==
"volume"
:
return
VolumePerturbAugmentor
(
self
.
_rng
,
**
params
)
return
VolumePerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"shift"
:
return
ShiftPerturbAugmentor
(
self
.
_rng
,
**
params
)
else
:
else
:
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
deep_speech_2/data_utils/augmentor/volume_perturb.py
浏览文件 @
68caa8ca
...
@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
...
@@ -36,5 +36,5 @@ class VolumePerturbAugmentor(AugmentorBase):
:param audio_segment: Audio segment to add effects to.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
:type audio_segment: AudioSegmenet|SpeechSegment
"""
"""
gain
=
self
.
_rng
.
uniform
(
min_gain_dBFS
,
max_gain_dBFS
)
gain
=
self
.
_rng
.
uniform
(
self
.
_min_gain_dBFS
,
self
.
_
max_gain_dBFS
)
audio_segment
.
apply_gain
(
gain
)
audio_segment
.
apply_gain
(
gain
)
deep_speech_2/data_utils/data.py
浏览文件 @
68caa8ca
...
@@ -45,6 +45,9 @@ class DataGenerator(object):
...
@@ -45,6 +45,9 @@ class DataGenerator(object):
:types max_freq: None|float
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param num_threads: Number of CPU threads for processing data.
:param num_threads: Number of CPU threads for processing data.
:type num_threads: int
:type num_threads: int
:param random_seed: Random seed.
:param random_seed: Random seed.
...
@@ -61,6 +64,7 @@ class DataGenerator(object):
...
@@ -61,6 +64,7 @@ class DataGenerator(object):
window_ms
=
20.0
,
window_ms
=
20.0
,
max_freq
=
None
,
max_freq
=
None
,
specgram_type
=
'linear'
,
specgram_type
=
'linear'
,
use_dB_normalization
=
True
,
num_threads
=
multiprocessing
.
cpu_count
(),
num_threads
=
multiprocessing
.
cpu_count
(),
random_seed
=
0
):
random_seed
=
0
):
self
.
_max_duration
=
max_duration
self
.
_max_duration
=
max_duration
...
@@ -73,7 +77,8 @@ class DataGenerator(object):
...
@@ -73,7 +77,8 @@ class DataGenerator(object):
specgram_type
=
specgram_type
,
specgram_type
=
specgram_type
,
stride_ms
=
stride_ms
,
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
window_ms
=
window_ms
,
max_freq
=
max_freq
)
max_freq
=
max_freq
,
use_dB_normalization
=
use_dB_normalization
)
self
.
_num_threads
=
num_threads
self
.
_num_threads
=
num_threads
self
.
_rng
=
random
.
Random
(
random_seed
)
self
.
_rng
=
random
.
Random
(
random_seed
)
self
.
_epoch
=
0
self
.
_epoch
=
0
...
...
deep_speech_2/data_utils/featurizer/audio_featurizer.py
浏览文件 @
68caa8ca
...
@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
...
@@ -24,26 +24,64 @@ class AudioFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
corresponding to frequencies between [0, max_freq] are
returned.
returned.
:types max_freq: None|float
:types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
specgram_type
=
'linear'
,
specgram_type
=
'linear'
,
stride_ms
=
10.0
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
window_ms
=
20.0
,
max_freq
=
None
):
max_freq
=
None
,
target_sample_rate
=
16000
,
use_dB_normalization
=
True
,
target_dB
=-
20
):
self
.
_specgram_type
=
specgram_type
self
.
_specgram_type
=
specgram_type
self
.
_stride_ms
=
stride_ms
self
.
_stride_ms
=
stride_ms
self
.
_window_ms
=
window_ms
self
.
_window_ms
=
window_ms
self
.
_max_freq
=
max_freq
self
.
_max_freq
=
max_freq
self
.
_target_sample_rate
=
target_sample_rate
self
.
_use_dB_normalization
=
use_dB_normalization
self
.
_target_dB
=
target_dB
def
featurize
(
self
,
audio_segment
):
def
featurize
(
self
,
audio_segment
,
allow_downsampling
=
True
,
allow_upsamplling
=
True
):
"""Extract audio features from AudioSegment or SpeechSegment.
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
"""
# upsampling or downsampling
if
((
audio_segment
.
sample_rate
>
self
.
_target_sample_rate
and
allow_downsampling
)
or
(
audio_segment
.
sample_rate
<
self
.
_target_sample_rate
and
allow_upsampling
)):
audio_segment
.
resample
(
self
.
_target_sample_rate
)
if
audio_segment
.
sample_rate
!=
self
.
_target_sample_rate
:
raise
ValueError
(
"Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on."
)
# decibel normalization
if
self
.
_use_dB_normalization
:
audio_segment
.
normalize
(
target_db
=
self
.
_target_dB
)
# extract spectrogram
return
self
.
_compute_specgram
(
audio_segment
.
samples
,
return
self
.
_compute_specgram
(
audio_segment
.
samples
,
audio_segment
.
sample_rate
)
audio_segment
.
sample_rate
)
...
...
deep_speech_2/data_utils/featurizer/speech_featurizer.py
浏览文件 @
68caa8ca
...
@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
...
@@ -29,6 +29,15 @@ class SpeechFeaturizer(object):
corresponding to frequencies between [0, max_freq] are
corresponding to frequencies between [0, max_freq] are
returned.
returned.
:types max_freq: None|float
:types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
...
@@ -36,9 +45,18 @@ class SpeechFeaturizer(object):
specgram_type
=
'linear'
,
specgram_type
=
'linear'
,
stride_ms
=
10.0
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
window_ms
=
20.0
,
max_freq
=
None
):
max_freq
=
None
,
self
.
_audio_featurizer
=
AudioFeaturizer
(
specgram_type
,
stride_ms
,
target_sample_rate
=
16000
,
window_ms
,
max_freq
)
use_dB_normalization
=
True
,
target_dB
=-
20
):
self
.
_audio_featurizer
=
AudioFeaturizer
(
specgram_type
=
specgram_type
,
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
max_freq
=
max_freq
,
target_sample_rate
=
target_sample_rate
,
use_dB_normalization
=
use_dB_normalization
,
target_dB
=
target_dB
)
self
.
_text_featurizer
=
TextFeaturizer
(
vocab_filepath
)
self
.
_text_featurizer
=
TextFeaturizer
(
vocab_filepath
)
def
featurize
(
self
,
speech_segment
):
def
featurize
(
self
,
speech_segment
):
...
...
deep_speech_2/infer.py
浏览文件 @
68caa8ca
...
@@ -56,7 +56,7 @@ parser.add_argument(
...
@@ -56,7 +56,7 @@ parser.add_argument(
help
=
"Manifest path for decoding. (default: %(default)s)"
)
help
=
"Manifest path for decoding. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_filepath"
,
"--model_filepath"
,
default
=
'
./params
.tar.gz'
,
default
=
'
checkpoints/params.latest
.tar.gz'
,
type
=
str
,
type
=
str
,
help
=
"Model filepath. (default: %(default)s)"
)
help
=
"Model filepath. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
deep_speech_2/setup.sh
浏览文件 @
68caa8ca
...
@@ -27,4 +27,7 @@ if [ $? != 0 ]; then
...
@@ -27,4 +27,7 @@ if [ $? != 0 ]; then
exit
1
exit
1
fi
fi
# prepare ./checkpoints
mkdir
checkpoints
echo
"Install all dependencies successfully."
echo
"Install all dependencies successfully."
deep_speech_2/train.py
浏览文件 @
68caa8ca
...
@@ -17,10 +17,10 @@ import utils
...
@@ -17,10 +17,10 @@ import utils
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch_size"
,
default
=
32
,
type
=
int
,
help
=
"Minibatch size."
)
"--batch_size"
,
default
=
256
,
type
=
int
,
help
=
"Minibatch size."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_passes"
,
"--num_passes"
,
default
=
20
,
default
=
20
0
,
type
=
int
,
type
=
int
,
help
=
"Training pass number. (default: %(default)s)"
)
help
=
"Training pass number. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -55,7 +55,7 @@ parser.add_argument(
...
@@ -55,7 +55,7 @@ parser.add_argument(
help
=
"Use sortagrad or not. (default: %(default)s)"
)
help
=
"Use sortagrad or not. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_duration"
,
"--max_duration"
,
default
=
100
.0
,
default
=
27
.0
,
type
=
float
,
type
=
float
,
help
=
"Audios with duration larger than this will be discarded. "
help
=
"Audios with duration larger than this will be discarded. "
"(default: %(default)s)"
)
"(default: %(default)s)"
)
...
@@ -67,13 +67,13 @@ parser.add_argument(
...
@@ -67,13 +67,13 @@ parser.add_argument(
"(default: %(default)s)"
)
"(default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--shuffle_method"
,
"--shuffle_method"
,
default
=
'
instance_shuffle
'
,
default
=
'
batch_shuffle_clipped
'
,
type
=
str
,
type
=
str
,
help
=
"Shuffle method: 'instance_shuffle', 'batch_shuffle', "
help
=
"Shuffle method: 'instance_shuffle', 'batch_shuffle', "
"'batch_shuffle_batch'. (default: %(default)s)"
)
"'batch_shuffle_batch'. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--trainer_count"
,
"--trainer_count"
,
default
=
4
,
default
=
8
,
type
=
int
,
type
=
int
,
help
=
"Trainer number. (default: %(default)s)"
)
help
=
"Trainer number. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -110,7 +110,9 @@ parser.add_argument(
...
@@ -110,7 +110,9 @@ parser.add_argument(
"the existing model of this path. (default: %(default)s)"
)
"the existing model of this path. (default: %(default)s)"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--augmentation_config"
,
"--augmentation_config"
,
default
=
'{}'
,
default
=
'[{"type": "shift", '
'"params": {"min_shift_ms": -5, "max_shift_ms": 5},'
'"prob": 1.0}]'
,
type
=
str
,
type
=
str
,
help
=
"Augmentation configuration in json-format. "
help
=
"Augmentation configuration in json-format. "
"(default: %(default)s)"
)
"(default: %(default)s)"
)
...
@@ -189,7 +191,7 @@ def train():
...
@@ -189,7 +191,7 @@ def train():
print
(
"
\n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
print
(
"
\n
Pass: %d, Batch: %d, TrainCost: %f"
%
(
event
.
pass_id
,
event
.
batch_id
+
1
,
cost_sum
/
cost_counter
))
event
.
pass_id
,
event
.
batch_id
+
1
,
cost_sum
/
cost_counter
))
cost_sum
,
cost_counter
=
0.0
,
0
cost_sum
,
cost_counter
=
0.0
,
0
with
gzip
.
open
(
"
params
.tar.gz"
,
'w'
)
as
f
:
with
gzip
.
open
(
"
checkpoints/params.latest
.tar.gz"
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
parameters
.
to_tar
(
f
)
else
:
else
:
sys
.
stdout
.
write
(
'.'
)
sys
.
stdout
.
write
(
'.'
)
...
@@ -202,6 +204,9 @@ def train():
...
@@ -202,6 +204,9 @@ def train():
reader
=
test_batch_reader
,
feeding
=
test_generator
.
feeding
)
reader
=
test_batch_reader
,
feeding
=
test_generator
.
feeding
)
print
(
"
\n
------- Time: %d sec, Pass: %d, ValidationCost: %s"
%
print
(
"
\n
------- Time: %d sec, Pass: %d, ValidationCost: %s"
%
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
))
(
time
.
time
()
-
start_time
,
event
.
pass_id
,
result
.
cost
))
with
gzip
.
open
(
"checkpoints/params.pass-%d.tar.gz"
%
event
.
pass_id
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
# run train
# run train
trainer
.
train
(
trainer
.
train
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录