Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8939994d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8939994d
编写于
8月 17, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor augmentation interface
上级
5ae63919
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
160 addition
and
63 deletion
+160
-63
deepspeech/frontend/augmentor/augmentation.py
deepspeech/frontend/augmentor/augmentation.py
+114
-62
deepspeech/frontend/augmentor/base.py
deepspeech/frontend/augmentor/base.py
+4
-0
deepspeech/frontend/augmentor/impulse_response.py
deepspeech/frontend/augmentor/impulse_response.py
+5
-0
deepspeech/frontend/augmentor/noise_perturb.py
deepspeech/frontend/augmentor/noise_perturb.py
+5
-0
deepspeech/frontend/augmentor/online_bayesian_normalization.py
...peech/frontend/augmentor/online_bayesian_normalization.py
+5
-0
deepspeech/frontend/augmentor/resample.py
deepspeech/frontend/augmentor/resample.py
+5
-0
deepspeech/frontend/augmentor/shift_perturb.py
deepspeech/frontend/augmentor/shift_perturb.py
+5
-0
deepspeech/frontend/augmentor/spec_augment.py
deepspeech/frontend/augmentor/spec_augment.py
+5
-0
deepspeech/frontend/augmentor/speed_perturb.py
deepspeech/frontend/augmentor/speed_perturb.py
+5
-0
deepspeech/frontend/augmentor/volume_perturb.py
deepspeech/frontend/augmentor/volume_perturb.py
+5
-0
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+1
-0
requirements.txt
requirements.txt
+1
-1
未找到文件。
deepspeech/frontend/augmentor/augmentation.py
浏览文件 @
8939994d
...
@@ -13,18 +13,27 @@
...
@@ -13,18 +13,27 @@
# limitations under the License.
# limitations under the License.
"""Contains the data augmentation pipeline."""
"""Contains the data augmentation pipeline."""
import
json
import
json
from
collections.abc
import
Sequence
from
inspect
import
signature
import
numpy
as
np
import
numpy
as
np
from
deepspeech.frontend.augmentor.impulse_response
import
ImpulseResponseAugmentor
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.frontend.augmentor.noise_perturb
import
NoisePerturbAugmentor
from
deepspeech.utils.log
import
Log
from
deepspeech.frontend.augmentor.online_bayesian_normalization
import
\
OnlineBayesianNormalizationAugmentor
__all__
=
[
"AugmentationPipeline"
]
from
deepspeech.frontend.augmentor.resample
import
ResampleAugmentor
from
deepspeech.frontend.augmentor.shift_perturb
import
ShiftPerturbAugmentor
logger
=
Log
(
__name__
).
getlog
()
from
deepspeech.frontend.augmentor.spec_augment
import
SpecAugmentor
from
deepspeech.frontend.augmentor.speed_perturb
import
SpeedPerturbAugmentor
import_alias
=
dict
(
from
deepspeech.frontend.augmentor.volume_perturb
import
VolumePerturbAugmentor
volume
=
"deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor"
,
shift
=
"deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor"
,
speed
=
"deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor"
,
resample
=
"deepspeech.frontend.augmentor.resample:ResampleAugmentor"
,
bayesian_normal
=
"deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor"
,
noise
=
"deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor"
,
impulse
=
"deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor"
,
specaug
=
"deepspeech.frontend.augmentor.spec_augment:SpecAugmentor"
,
)
class
AugmentationPipeline
():
class
AugmentationPipeline
():
...
@@ -78,20 +87,74 @@ class AugmentationPipeline():
...
@@ -78,20 +87,74 @@ class AugmentationPipeline():
augmentor to take effect. If "prob" is zero, the augmentor does not take
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
effect.
:param augmentation_config: Augmentation configuration in json string.
Params:
:type augmentation_config: str
augmentation_config(str): Augmentation configuration in json string.
:param random_seed: Random seed.
random_seed(int): Random seed.
:type random_seed: int
train(bool): whether is train mode.
:raises ValueError: If the augmentation json config is in incorrect format".
Raises:
ValueError: If the augmentation json config is in incorrect format".
"""
"""
def
__init__
(
self
,
augmentation_config
:
str
,
random_seed
=
0
):
def
__init__
(
self
,
augmentation_config
:
str
,
random_seed
:
int
=
0
):
self
.
_rng
=
np
.
random
.
RandomState
(
random_seed
)
self
.
_rng
=
np
.
random
.
RandomState
(
random_seed
)
self
.
_spec_types
=
(
'specaug'
)
self
.
_spec_types
=
(
'specaug'
)
self
.
_augmentors
,
self
.
_rates
=
self
.
_parse_pipeline_from
(
augmentation_config
,
'audio'
)
if
augmentation_config
is
None
:
self
.
conf
=
{}
else
:
self
.
conf
=
json
.
loads
(
augmentation_config
)
self
.
_augmentors
,
self
.
_rates
=
self
.
_parse_pipeline_from
(
'all'
)
self
.
_audio_augmentors
,
self
.
_audio_rates
=
self
.
_parse_pipeline_from
(
'audio'
)
self
.
_spec_augmentors
,
self
.
_spec_rates
=
self
.
_parse_pipeline_from
(
self
.
_spec_augmentors
,
self
.
_spec_rates
=
self
.
_parse_pipeline_from
(
augmentation_config
,
'feature'
)
'feature'
)
def
__call__
(
self
,
xs
,
uttid_list
=
None
,
**
kwargs
):
if
not
isinstance
(
xs
,
Sequence
):
is_batch
=
False
xs
=
[
xs
]
else
:
is_batch
=
True
if
isinstance
(
uttid_list
,
str
):
uttid_list
=
[
uttid_list
for
_
in
range
(
len
(
xs
))]
if
self
.
conf
.
get
(
"mode"
,
"sequential"
)
==
"sequential"
:
for
idx
,
(
func
,
rate
)
in
enumerate
(
zip
(
self
.
_augmentors
,
self
.
_rates
),
0
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
>=
rate
:
continue
# Derive only the args which the func has
try
:
param
=
signature
(
func
).
parameters
except
ValueError
:
# Some function, e.g. built-in function, are failed
param
=
{}
_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
param
}
try
:
if
uttid_list
is
not
None
and
"uttid"
in
param
:
xs
=
[
func
(
x
,
u
,
**
_kwargs
)
for
x
,
u
in
zip
(
xs
,
uttid_list
)
]
else
:
xs
=
[
func
(
x
,
**
_kwargs
)
for
x
in
xs
]
except
Exception
:
logger
.
fatal
(
"Catch a exception from {}th func: {}"
.
format
(
idx
,
func
))
raise
else
:
raise
NotImplementedError
(
"Not supporting mode={}"
.
format
(
self
.
conf
[
"mode"
]))
if
is_batch
:
return
xs
else
:
return
xs
[
0
]
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Run the pre-processing pipeline for data augmentation.
"""Run the pre-processing pipeline for data augmentation.
...
@@ -101,7 +164,9 @@ class AugmentationPipeline():
...
@@ -101,7 +164,9 @@ class AugmentationPipeline():
:param audio_segment: Audio segment to process.
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
:type audio_segment: AudioSegmenet|SpeechSegment
"""
"""
for
augmentor
,
rate
in
zip
(
self
.
_augmentors
,
self
.
_rates
):
if
not
self
.
_train
:
return
for
augmentor
,
rate
in
zip
(
self
.
_audio_augmentors
,
self
.
_audio_rates
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
augmentor
.
transform_audio
(
audio_segment
)
augmentor
.
transform_audio
(
audio_segment
)
...
@@ -111,57 +176,44 @@ class AugmentationPipeline():
...
@@ -111,57 +176,44 @@ class AugmentationPipeline():
Args:
Args:
spec_segment (np.ndarray): audio feature, (D, T).
spec_segment (np.ndarray): audio feature, (D, T).
"""
"""
if
not
self
.
_train
:
return
for
augmentor
,
rate
in
zip
(
self
.
_spec_augmentors
,
self
.
_spec_rates
):
for
augmentor
,
rate
in
zip
(
self
.
_spec_augmentors
,
self
.
_spec_rates
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
spec_segment
=
augmentor
.
transform_feature
(
spec_segment
)
spec_segment
=
augmentor
.
transform_feature
(
spec_segment
)
return
spec_segment
return
spec_segment
def
_parse_pipeline_from
(
self
,
config_json
,
aug_type
=
'audio
'
):
def
_parse_pipeline_from
(
self
,
aug_type
=
'all
'
):
"""Parse the config json to build a augmentation pipelien."""
"""Parse the config json to build a augmentation pipelien."""
assert
aug_type
in
(
'audio'
,
'feature'
),
aug_type
assert
aug_type
in
(
'audio'
,
'feature'
,
'all'
),
aug_type
try
:
audio_confs
=
[]
configs
=
json
.
loads
(
config_json
)
feature_confs
=
[]
audio_confs
=
[]
all_confs
=
[]
feature_confs
=
[]
for
config
in
self
.
conf
:
for
config
in
configs
:
all_confs
.
append
(
config
)
if
config
[
"type"
]
in
self
.
_spec_types
:
if
config
[
"type"
]
in
self
.
_spec_types
:
feature_confs
.
append
(
config
)
feature_confs
.
append
(
config
)
else
:
else
:
audio_confs
.
append
(
config
)
audio_confs
.
append
(
config
)
if
aug_type
==
'audio'
:
if
aug_type
==
'audio'
:
aug_confs
=
audio_confs
aug_confs
=
audio_confs
elif
aug_type
==
'feature'
:
elif
aug_type
==
'feature'
:
aug_confs
=
feature_confs
aug_confs
=
feature_confs
else
:
augmentors
=
[
aug_confs
=
all_confs
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
for
config
in
aug_confs
augmentors
=
[
]
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
rates
=
[
config
[
"prob"
]
for
config
in
aug_confs
]
for
config
in
aug_confs
]
except
Exception
as
e
:
rates
=
[
config
[
"prob"
]
for
config
in
aug_confs
]
raise
ValueError
(
"Failed to parse the augmentation config json: "
"%s"
%
str
(
e
))
return
augmentors
,
rates
return
augmentors
,
rates
def
_get_augmentor
(
self
,
augmentor_type
,
params
):
def
_get_augmentor
(
self
,
augmentor_type
,
params
):
"""Return an augmentation model by the type name, and pass in params."""
"""Return an augmentation model by the type name, and pass in params."""
if
augmentor_type
==
"volume"
:
class_obj
=
dynamic_import
(
augmentor_type
,
import_alias
)
return
VolumePerturbAugmentor
(
self
.
_rng
,
**
params
)
try
:
elif
augmentor_type
==
"shift"
:
obj
=
class_obj
(
self
.
_rng
,
**
params
)
return
ShiftPerturbAugmentor
(
self
.
_rng
,
**
params
)
except
Exception
:
elif
augmentor_type
==
"speed"
:
return
SpeedPerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"resample"
:
return
ResampleAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"bayesian_normal"
:
return
OnlineBayesianNormalizationAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"noise"
:
return
NoisePerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"impulse"
:
return
ImpulseResponseAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"specaug"
:
return
SpecAugmentor
(
self
.
_rng
,
**
params
)
else
:
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
deepspeech/frontend/augmentor/base.py
浏览文件 @
8939994d
...
@@ -28,6 +28,10 @@ class AugmentorBase():
...
@@ -28,6 +28,10 @@ class AugmentorBase():
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
@
abstractmethod
def
__call__
(
self
,
xs
):
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Adds various effects to the input audio segment. Such effects
"""Adds various effects to the input audio segment. Such effects
...
...
deepspeech/frontend/augmentor/impulse_response.py
浏览文件 @
8939994d
...
@@ -30,6 +30,11 @@ class ImpulseResponseAugmentor(AugmentorBase):
...
@@ -30,6 +30,11 @@ class ImpulseResponseAugmentor(AugmentorBase):
self
.
_rng
=
rng
self
.
_rng
=
rng
self
.
_impulse_manifest
=
read_manifest
(
impulse_manifest_path
)
self
.
_impulse_manifest
=
read_manifest
(
impulse_manifest_path
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Add impulse response effect.
"""Add impulse response effect.
...
...
deepspeech/frontend/augmentor/noise_perturb.py
浏览文件 @
8939994d
...
@@ -36,6 +36,11 @@ class NoisePerturbAugmentor(AugmentorBase):
...
@@ -36,6 +36,11 @@ class NoisePerturbAugmentor(AugmentorBase):
self
.
_rng
=
rng
self
.
_rng
=
rng
self
.
_noise_manifest
=
read_manifest
(
manifest_path
=
noise_manifest_path
)
self
.
_noise_manifest
=
read_manifest
(
manifest_path
=
noise_manifest_path
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Add background noise audio.
"""Add background noise audio.
...
...
deepspeech/frontend/augmentor/online_bayesian_normalization.py
浏览文件 @
8939994d
...
@@ -44,6 +44,11 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
...
@@ -44,6 +44,11 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
self
.
_rng
=
rng
self
.
_rng
=
rng
self
.
_startup_delay
=
startup_delay
self
.
_startup_delay
=
startup_delay
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Normalizes the input audio using the online Bayesian approach.
"""Normalizes the input audio using the online Bayesian approach.
...
...
deepspeech/frontend/augmentor/resample.py
浏览文件 @
8939994d
...
@@ -31,6 +31,11 @@ class ResampleAugmentor(AugmentorBase):
...
@@ -31,6 +31,11 @@ class ResampleAugmentor(AugmentorBase):
self
.
_new_sample_rate
=
new_sample_rate
self
.
_new_sample_rate
=
new_sample_rate
self
.
_rng
=
rng
self
.
_rng
=
rng
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Resamples the input audio to a target sample rate.
"""Resamples the input audio to a target sample rate.
...
...
deepspeech/frontend/augmentor/shift_perturb.py
浏览文件 @
8939994d
...
@@ -31,6 +31,11 @@ class ShiftPerturbAugmentor(AugmentorBase):
...
@@ -31,6 +31,11 @@ class ShiftPerturbAugmentor(AugmentorBase):
self
.
_max_shift_ms
=
max_shift_ms
self
.
_max_shift_ms
=
max_shift_ms
self
.
_rng
=
rng
self
.
_rng
=
rng
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Shift audio.
"""Shift audio.
...
...
deepspeech/frontend/augmentor/spec_augment.py
浏览文件 @
8939994d
...
@@ -157,6 +157,11 @@ class SpecAugmentor(AugmentorBase):
...
@@ -157,6 +157,11 @@ class SpecAugmentor(AugmentorBase):
self
.
_time_mask
=
(
t_0
,
t_0
+
t
)
self
.
_time_mask
=
(
t_0
,
t_0
+
t
)
return
xs
return
xs
def
__call__
(
self
,
x
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_feature
(
self
,
xs
:
np
.
ndarray
):
def
transform_feature
(
self
,
xs
:
np
.
ndarray
):
"""
"""
Args:
Args:
...
...
deepspeech/frontend/augmentor/speed_perturb.py
浏览文件 @
8939994d
...
@@ -79,6 +79,11 @@ class SpeedPerturbAugmentor(AugmentorBase):
...
@@ -79,6 +79,11 @@ class SpeedPerturbAugmentor(AugmentorBase):
self
.
_rates
=
np
.
linspace
(
self
.
_rates
=
np
.
linspace
(
self
.
_min_rate
,
self
.
_max_rate
,
self
.
_num_rates
,
endpoint
=
True
)
self
.
_min_rate
,
self
.
_max_rate
,
self
.
_num_rates
,
endpoint
=
True
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Sample a new speed rate from the given range and
"""Sample a new speed rate from the given range and
changes the speed of the given audio clip.
changes the speed of the given audio clip.
...
...
deepspeech/frontend/augmentor/volume_perturb.py
浏览文件 @
8939994d
...
@@ -37,6 +37,11 @@ class VolumePerturbAugmentor(AugmentorBase):
...
@@ -37,6 +37,11 @@ class VolumePerturbAugmentor(AugmentorBase):
self
.
_max_gain_dBFS
=
max_gain_dBFS
self
.
_max_gain_dBFS
=
max_gain_dBFS
self
.
_rng
=
rng
self
.
_rng
=
rng
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
def
transform_audio
(
self
,
audio_segment
):
"""Change audio loadness.
"""Change audio loadness.
...
...
deepspeech/io/dataset.py
浏览文件 @
8939994d
...
@@ -16,6 +16,7 @@ from typing import Optional
...
@@ -16,6 +16,7 @@ from typing import Optional
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
from
yacs.config
import
CfgNode
from
yacs.config
import
CfgNode
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.utils.log
import
Log
from
deepspeech.utils.log
import
Log
__all__
=
[
"ManifestDataset"
,
"TripletManifestDataset"
,
"TransformDataset"
]
__all__
=
[
"ManifestDataset"
,
"TripletManifestDataset"
,
"TransformDataset"
]
...
...
requirements.txt
浏览文件 @
8939994d
coverage
coverage
gpustat
gpustat
kaldiio
pre-commit
pre-commit
pybind11
pybind11
resampy
==0.2.2
resampy
==0.2.2
...
@@ -13,4 +14,3 @@ tensorboardX
...
@@ -13,4 +14,3 @@ tensorboardX
textgrid
textgrid
typeguard
typeguard
yacs
yacs
kaldiio
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录