Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
8939994d
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
206
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8939994d
编写于
8月 17, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor augmentation interface
上级
5ae63919
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
160 addition
and
63 deletion
+160
-63
deepspeech/frontend/augmentor/augmentation.py
deepspeech/frontend/augmentor/augmentation.py
+114
-62
deepspeech/frontend/augmentor/base.py
deepspeech/frontend/augmentor/base.py
+4
-0
deepspeech/frontend/augmentor/impulse_response.py
deepspeech/frontend/augmentor/impulse_response.py
+5
-0
deepspeech/frontend/augmentor/noise_perturb.py
deepspeech/frontend/augmentor/noise_perturb.py
+5
-0
deepspeech/frontend/augmentor/online_bayesian_normalization.py
...peech/frontend/augmentor/online_bayesian_normalization.py
+5
-0
deepspeech/frontend/augmentor/resample.py
deepspeech/frontend/augmentor/resample.py
+5
-0
deepspeech/frontend/augmentor/shift_perturb.py
deepspeech/frontend/augmentor/shift_perturb.py
+5
-0
deepspeech/frontend/augmentor/spec_augment.py
deepspeech/frontend/augmentor/spec_augment.py
+5
-0
deepspeech/frontend/augmentor/speed_perturb.py
deepspeech/frontend/augmentor/speed_perturb.py
+5
-0
deepspeech/frontend/augmentor/volume_perturb.py
deepspeech/frontend/augmentor/volume_perturb.py
+5
-0
deepspeech/io/dataset.py
deepspeech/io/dataset.py
+1
-0
requirements.txt
requirements.txt
+1
-1
未找到文件。
deepspeech/frontend/augmentor/augmentation.py
浏览文件 @
8939994d
...
...
@@ -13,18 +13,27 @@
# limitations under the License.
"""Contains the data augmentation pipeline."""
import
json
from
collections.abc
import
Sequence
from
inspect
import
signature
import
numpy
as
np
from
deepspeech.frontend.augmentor.impulse_response
import
ImpulseResponseAugmentor
from
deepspeech.frontend.augmentor.noise_perturb
import
NoisePerturbAugmentor
from
deepspeech.frontend.augmentor.online_bayesian_normalization
import
\
OnlineBayesianNormalizationAugmentor
from
deepspeech.frontend.augmentor.resample
import
ResampleAugmentor
from
deepspeech.frontend.augmentor.shift_perturb
import
ShiftPerturbAugmentor
from
deepspeech.frontend.augmentor.spec_augment
import
SpecAugmentor
from
deepspeech.frontend.augmentor.speed_perturb
import
SpeedPerturbAugmentor
from
deepspeech.frontend.augmentor.volume_perturb
import
VolumePerturbAugmentor
from
deepspeech.utils.dynamic_import
import
dynamic_import
from
deepspeech.utils.log
import
Log
__all__
=
[
"AugmentationPipeline"
]
logger
=
Log
(
__name__
).
getlog
()
import_alias
=
dict
(
volume
=
"deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor"
,
shift
=
"deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor"
,
speed
=
"deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor"
,
resample
=
"deepspeech.frontend.augmentor.resample:ResampleAugmentor"
,
bayesian_normal
=
"deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor"
,
noise
=
"deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor"
,
impulse
=
"deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor"
,
specaug
=
"deepspeech.frontend.augmentor.spec_augment:SpecAugmentor"
,
)
class
AugmentationPipeline
():
...
...
@@ -78,20 +87,74 @@ class AugmentationPipeline():
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
:param augmentation_config: Augmentation configuration in json string.
:type augmentation_config: str
:param random_seed: Random seed.
:type random_seed: int
:raises ValueError: If the augmentation json config is in incorrect format".
Params:
augmentation_config(str): Augmentation configuration in json string.
random_seed(int): Random seed.
train(bool): whether is train mode.
Raises:
ValueError: If the augmentation json config is in incorrect format".
"""
def
__init__
(
self
,
augmentation_config
:
str
,
random_seed
=
0
):
def
__init__
(
self
,
augmentation_config
:
str
,
random_seed
:
int
=
0
):
self
.
_rng
=
np
.
random
.
RandomState
(
random_seed
)
self
.
_spec_types
=
(
'specaug'
)
self
.
_augmentors
,
self
.
_rates
=
self
.
_parse_pipeline_from
(
augmentation_config
,
'audio'
)
if
augmentation_config
is
None
:
self
.
conf
=
{}
else
:
self
.
conf
=
json
.
loads
(
augmentation_config
)
self
.
_augmentors
,
self
.
_rates
=
self
.
_parse_pipeline_from
(
'all'
)
self
.
_audio_augmentors
,
self
.
_audio_rates
=
self
.
_parse_pipeline_from
(
'audio'
)
self
.
_spec_augmentors
,
self
.
_spec_rates
=
self
.
_parse_pipeline_from
(
augmentation_config
,
'feature'
)
'feature'
)
def
__call__
(
self
,
xs
,
uttid_list
=
None
,
**
kwargs
):
if
not
isinstance
(
xs
,
Sequence
):
is_batch
=
False
xs
=
[
xs
]
else
:
is_batch
=
True
if
isinstance
(
uttid_list
,
str
):
uttid_list
=
[
uttid_list
for
_
in
range
(
len
(
xs
))]
if
self
.
conf
.
get
(
"mode"
,
"sequential"
)
==
"sequential"
:
for
idx
,
(
func
,
rate
)
in
enumerate
(
zip
(
self
.
_augmentors
,
self
.
_rates
),
0
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
>=
rate
:
continue
# Derive only the args which the func has
try
:
param
=
signature
(
func
).
parameters
except
ValueError
:
# Some function, e.g. built-in function, are failed
param
=
{}
_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
param
}
try
:
if
uttid_list
is
not
None
and
"uttid"
in
param
:
xs
=
[
func
(
x
,
u
,
**
_kwargs
)
for
x
,
u
in
zip
(
xs
,
uttid_list
)
]
else
:
xs
=
[
func
(
x
,
**
_kwargs
)
for
x
in
xs
]
except
Exception
:
logger
.
fatal
(
"Catch a exception from {}th func: {}"
.
format
(
idx
,
func
))
raise
else
:
raise
NotImplementedError
(
"Not supporting mode={}"
.
format
(
self
.
conf
[
"mode"
]))
if
is_batch
:
return
xs
else
:
return
xs
[
0
]
def
transform_audio
(
self
,
audio_segment
):
"""Run the pre-processing pipeline for data augmentation.
...
...
@@ -101,7 +164,9 @@ class AugmentationPipeline():
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for
augmentor
,
rate
in
zip
(
self
.
_augmentors
,
self
.
_rates
):
if
not
self
.
_train
:
return
for
augmentor
,
rate
in
zip
(
self
.
_audio_augmentors
,
self
.
_audio_rates
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
augmentor
.
transform_audio
(
audio_segment
)
...
...
@@ -111,19 +176,21 @@ class AugmentationPipeline():
Args:
spec_segment (np.ndarray): audio feature, (D, T).
"""
if
not
self
.
_train
:
return
for
augmentor
,
rate
in
zip
(
self
.
_spec_augmentors
,
self
.
_spec_rates
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
spec_segment
=
augmentor
.
transform_feature
(
spec_segment
)
return
spec_segment
def
_parse_pipeline_from
(
self
,
config_json
,
aug_type
=
'audio
'
):
def
_parse_pipeline_from
(
self
,
aug_type
=
'all
'
):
"""Parse the config json to build a augmentation pipelien."""
assert
aug_type
in
(
'audio'
,
'feature'
),
aug_type
try
:
configs
=
json
.
loads
(
config_json
)
assert
aug_type
in
(
'audio'
,
'feature'
,
'all'
),
aug_type
audio_confs
=
[]
feature_confs
=
[]
for
config
in
configs
:
all_confs
=
[]
for
config
in
self
.
conf
:
all_confs
.
append
(
config
)
if
config
[
"type"
]
in
self
.
_spec_types
:
feature_confs
.
append
(
config
)
else
:
...
...
@@ -133,35 +200,20 @@ class AugmentationPipeline():
aug_confs
=
audio_confs
elif
aug_type
==
'feature'
:
aug_confs
=
feature_confs
else
:
aug_confs
=
all_confs
augmentors
=
[
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
for
config
in
aug_confs
]
rates
=
[
config
[
"prob"
]
for
config
in
aug_confs
]
except
Exception
as
e
:
raise
ValueError
(
"Failed to parse the augmentation config json: "
"%s"
%
str
(
e
))
return
augmentors
,
rates
def
_get_augmentor
(
self
,
augmentor_type
,
params
):
"""Return an augmentation model by the type name, and pass in params."""
if
augmentor_type
==
"volume"
:
return
VolumePerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"shift"
:
return
ShiftPerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"speed"
:
return
SpeedPerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"resample"
:
return
ResampleAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"bayesian_normal"
:
return
OnlineBayesianNormalizationAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"noise"
:
return
NoisePerturbAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"impulse"
:
return
ImpulseResponseAugmentor
(
self
.
_rng
,
**
params
)
elif
augmentor_type
==
"specaug"
:
return
SpecAugmentor
(
self
.
_rng
,
**
params
)
else
:
class_obj
=
dynamic_import
(
augmentor_type
,
import_alias
)
try
:
obj
=
class_obj
(
self
.
_rng
,
**
params
)
except
Exception
:
raise
ValueError
(
"Unknown augmentor type [%s]."
%
augmentor_type
)
deepspeech/frontend/augmentor/base.py
浏览文件 @
8939994d
...
...
@@ -28,6 +28,10 @@ class AugmentorBase():
def
__init__
(
self
):
pass
@
abstractmethod
def
__call__
(
self
,
xs
):
raise
NotImplementedError
@
abstractmethod
def
transform_audio
(
self
,
audio_segment
):
"""Adds various effects to the input audio segment. Such effects
...
...
deepspeech/frontend/augmentor/impulse_response.py
浏览文件 @
8939994d
...
...
@@ -30,6 +30,11 @@ class ImpulseResponseAugmentor(AugmentorBase):
self
.
_rng
=
rng
self
.
_impulse_manifest
=
read_manifest
(
impulse_manifest_path
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
"""Add impulse response effect.
...
...
deepspeech/frontend/augmentor/noise_perturb.py
浏览文件 @
8939994d
...
...
@@ -36,6 +36,11 @@ class NoisePerturbAugmentor(AugmentorBase):
self
.
_rng
=
rng
self
.
_noise_manifest
=
read_manifest
(
manifest_path
=
noise_manifest_path
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
"""Add background noise audio.
...
...
deepspeech/frontend/augmentor/online_bayesian_normalization.py
浏览文件 @
8939994d
...
...
@@ -44,6 +44,11 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
self
.
_rng
=
rng
self
.
_startup_delay
=
startup_delay
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
"""Normalizes the input audio using the online Bayesian approach.
...
...
deepspeech/frontend/augmentor/resample.py
浏览文件 @
8939994d
...
...
@@ -31,6 +31,11 @@ class ResampleAugmentor(AugmentorBase):
self
.
_new_sample_rate
=
new_sample_rate
self
.
_rng
=
rng
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
"""Resamples the input audio to a target sample rate.
...
...
deepspeech/frontend/augmentor/shift_perturb.py
浏览文件 @
8939994d
...
...
@@ -31,6 +31,11 @@ class ShiftPerturbAugmentor(AugmentorBase):
self
.
_max_shift_ms
=
max_shift_ms
self
.
_rng
=
rng
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
"""Shift audio.
...
...
deepspeech/frontend/augmentor/spec_augment.py
浏览文件 @
8939994d
...
...
@@ -157,6 +157,11 @@ class SpecAugmentor(AugmentorBase):
self
.
_time_mask
=
(
t_0
,
t_0
+
t
)
return
xs
def
__call__
(
self
,
x
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_feature
(
self
,
xs
:
np
.
ndarray
):
"""
Args:
...
...
deepspeech/frontend/augmentor/speed_perturb.py
浏览文件 @
8939994d
...
...
@@ -79,6 +79,11 @@ class SpeedPerturbAugmentor(AugmentorBase):
self
.
_rates
=
np
.
linspace
(
self
.
_min_rate
,
self
.
_max_rate
,
self
.
_num_rates
,
endpoint
=
True
)
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
"""Sample a new speed rate from the given range and
changes the speed of the given audio clip.
...
...
deepspeech/frontend/augmentor/volume_perturb.py
浏览文件 @
8939994d
...
...
@@ -37,6 +37,11 @@ class VolumePerturbAugmentor(AugmentorBase):
self
.
_max_gain_dBFS
=
max_gain_dBFS
self
.
_rng
=
rng
def
__call__
(
self
,
x
,
uttid
=
None
,
train
=
True
):
if
not
train
:
return
self
.
transform_audio
(
x
)
def
transform_audio
(
self
,
audio_segment
):
"""Change audio loadness.
...
...
deepspeech/io/dataset.py
浏览文件 @
8939994d
...
...
@@ -16,6 +16,7 @@ from typing import Optional
from
paddle.io
import
Dataset
from
yacs.config
import
CfgNode
from
deepspeech.frontend.utility
import
read_manifest
from
deepspeech.utils.log
import
Log
__all__
=
[
"ManifestDataset"
,
"TripletManifestDataset"
,
"TransformDataset"
]
...
...
requirements.txt
浏览文件 @
8939994d
coverage
gpustat
kaldiio
pre-commit
pybind11
resampy
==0.2.2
...
...
@@ -13,4 +14,3 @@ tensorboardX
textgrid
typeguard
yacs
kaldiio
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录