Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
2d89c80e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2d89c80e
编写于
3月 07, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add waveform augment pipeline, test=doc
上级
ac4967e2
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
1543 addition
and
43 deletion
+1543
-43
examples/voxceleb/sv0/local/speaker_verification_cosine.py
examples/voxceleb/sv0/local/speaker_verification_cosine.py
+50
-26
examples/voxceleb/sv0/local/train.py
examples/voxceleb/sv0/local/train.py
+35
-11
paddleaudio/datasets/rirs_noises.py
paddleaudio/datasets/rirs_noises.py
+207
-0
paddleaudio/datasets/voxceleb.py
paddleaudio/datasets/voxceleb.py
+12
-6
paddlespeech/vector/io/augment.py
paddlespeech/vector/io/augment.py
+899
-0
paddlespeech/vector/io/signal_processing.py
paddlespeech/vector/io/signal_processing.py
+219
-0
paddlespeech/vector/models/ecapa_tdnn.py
paddlespeech/vector/models/ecapa_tdnn.py
+93
-0
paddlespeech/vector/training/seeding.py
paddlespeech/vector/training/seeding.py
+28
-0
未找到文件。
examples/voxceleb/sv0/local/speaker_verification_cosine.py
浏览文件 @
2d89c80e
...
...
@@ -23,9 +23,13 @@ from paddle.io import DataLoader
from
tqdm
import
tqdm
from
paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.metrics
import
compute_eer
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
def
pad_right_2d
(
x
,
target_length
,
axis
=-
1
,
mode
=
'constant'
,
**
kwargs
):
...
...
@@ -67,9 +71,19 @@ def feature_normalize(batch, mean_norm: bool=True, std_norm: bool=True):
return
{
'ids'
:
ids
,
'feats'
:
feats
,
'lengths'
:
lengths
}
# feat configuration
cpu_feat_conf
=
{
'n_mels'
:
80
,
'window_size'
:
400
,
#ms
'hop_length'
:
160
,
#ms
}
def
main
(
args
):
# stage0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
args
.
seed
)
# stage1: build the dnn backbone model network
##"channels": [1024, 1024, 1024, 1024, 3072],
...
...
@@ -95,19 +109,18 @@ def main(args):
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
print
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# stage4: construct the enroll and test dataloader
enrol_ds
=
VoxCeleb1
(
subset
=
'enrol'
,
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
n_mels
=
80
,
window_size
=
400
,
hop_length
=
160
)
**
cpu_feat_conf
)
enrol_sampler
=
BatchSampler
(
enrol_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
Tru
e
)
# Shuffle to make embedding normalization more robust.
shuffle
=
Fals
e
)
# Shuffle to make embedding normalization more robust.
enrol_loader
=
DataLoader
(
enrol_ds
,
batch_sampler
=
enrol_sampler
,
collate_fn
=
lambda
x
:
feature_normalize
(
...
...
@@ -117,14 +130,13 @@ def main(args):
test_ds
=
VoxCeleb1
(
subset
=
'test'
,
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
n_mels
=
80
,
window_size
=
400
,
hop_length
=
160
)
**
cpu_feat_conf
)
test_sampler
=
BatchSampler
(
test_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
Tru
e
)
test_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
Fals
e
)
test_loader
=
DataLoader
(
test_ds
,
batch_sampler
=
test_sampler
,
collate_fn
=
lambda
x
:
feature_normalize
(
...
...
@@ -136,10 +148,10 @@ def main(args):
# stage7: global embedding norm to imporve the performance
if
args
.
global_embedding_norm
:
embedding_mean
=
None
embedding_std
=
None
mean_norm
=
args
.
embedding_mean_norm
std_norm
=
args
.
embedding_std_norm
global_
embedding_mean
=
None
global_
embedding_std
=
None
mean_norm
_flag
=
args
.
embedding_mean_norm
std_norm
_flag
=
args
.
embedding_std_norm
batch_count
=
0
# stage8: Compute embeddings of audios in enrol and test dataset from model.
...
...
@@ -147,7 +159,7 @@ def main(args):
# Run multi times to make embedding normalization more stable.
for
i
in
range
(
2
):
for
dl
in
[
enrol_loader
,
test_loader
]:
print
(
logger
.
info
(
f
'Loop
{
[
i
+
1
]
}
: Computing embeddings on
{
dl
.
dataset
.
subset
}
dataset'
)
with
paddle
.
no_grad
():
...
...
@@ -162,20 +174,24 @@ def main(args):
# Global embedding normalization.
if
args
.
global_embedding_norm
:
batch_count
+=
1
mean
=
embeddings
.
mean
(
axis
=
0
)
if
mean_norm
else
0
std
=
embeddings
.
std
(
axis
=
0
)
if
std_norm
else
1
current_mean
=
embeddings
.
mean
(
axis
=
0
)
if
mean_norm_flag
else
0
current_std
=
embeddings
.
std
(
axis
=
0
)
if
std_norm_flag
else
1
# Update global mean and std.
if
embedding_mean
is
None
and
embedding_std
is
None
:
embedding_mean
,
embedding_std
=
mean
,
std
if
global_embedding_mean
is
None
and
global_
embedding_std
is
None
:
global_embedding_mean
,
global_embedding_std
=
current_mean
,
current_
std
else
:
weight
=
1
/
batch_count
# Weight decay by batches.
embedding_mean
=
(
1
-
weight
)
*
embedding_mean
+
weight
*
mean
embedding_std
=
(
1
-
weight
)
*
embedding_std
+
weight
*
std
global_embedding_mean
=
(
1
-
weight
)
*
global_embedding_mean
+
weight
*
current_mean
global_embedding_std
=
(
1
-
weight
)
*
global_embedding_std
+
weight
*
current_std
# Apply global embedding normalization.
embeddings
=
(
embeddings
-
embedding_mean
)
/
embedding_std
embeddings
=
(
embeddings
-
global_embedding_mean
)
/
global_
embedding_std
# Update embedding dict.
id2embedding
.
update
(
dict
(
zip
(
ids
,
embeddings
)))
...
...
@@ -198,7 +214,7 @@ def main(args):
])
# (N, emb_size)
scores
=
cos_sim_func
(
enrol_embeddings
,
test_embeddings
)
EER
,
threshold
=
compute_eer
(
np
.
asarray
(
labels
),
scores
.
numpy
())
print
(
logger
.
info
(
f
'EER of verification test:
{
EER
*
100
:.
4
f
}
%, score threshold:
{
threshold
:.
5
f
}
'
)
...
...
@@ -210,10 +226,18 @@ if __name__ == "__main__":
choices
=
[
'cpu'
,
'gpu'
],
default
=
"gpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--seed"
,
default
=
0
,
type
=
int
,
help
=
"random seed for paddle, numpy and python random package"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"./data/"
,
type
=
str
,
help
=
"data directory"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for
train
ing."
)
help
=
"Total examples' number in batch for
extract the embedd
ing."
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
0
,
...
...
examples/voxceleb/sv0/local/train.py
浏览文件 @
2d89c80e
...
...
@@ -22,6 +22,9 @@ from paddle.io import DistributedBatchSampler
from
paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddleaudio.features.core
import
melspectrogram
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.augment
import
build_augment_pipeline
from
paddlespeech.vector.io.augment
import
waveform_augment
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.io.batch
import
waveform_collate_fn
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
...
...
@@ -29,8 +32,11 @@ from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from
paddlespeech.vector.modules.loss
import
LogSoftmaxWrapper
from
paddlespeech.vector.modules.lr
import
CyclicLRScheduler
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.utils.time
import
Timer
logger
=
Log
(
__name__
).
getlog
()
# feat configuration
cpu_feat_conf
=
{
'n_mels'
:
80
,
...
...
@@ -47,12 +53,19 @@ def main(args):
paddle
.
distributed
.
init_parallel_env
()
nranks
=
paddle
.
distributed
.
get_world_size
()
local_rank
=
paddle
.
distributed
.
get_rank
()
# set the random seed, it is a must for multiprocess training
seed_everything
(
args
.
seed
)
# stage2: data prepare
# note: some cmd must do in rank==0
# stage2: data prepare
, such vox1 and vox2 data, and augment data and pipline
# note: some cmd must do in rank==0
, so wo will refactor the data prepare code
train_ds
=
VoxCeleb1
(
'train'
,
target_dir
=
args
.
data_dir
)
dev_ds
=
VoxCeleb1
(
'dev'
,
target_dir
=
args
.
data_dir
)
if
args
.
augment
:
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
else
:
augment_pipeline
=
[]
# stage3: build the dnn backbone model network
#"channels": [1024, 1024, 1024, 1024, 3072],
model_conf
=
{
...
...
@@ -83,7 +96,7 @@ def main(args):
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch
=
0
if
args
.
load_checkpoint
:
print
(
"load the check point"
)
logger
.
info
(
"load the check point"
)
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
try
:
...
...
@@ -97,14 +110,14 @@ def main(args):
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdopt'
))
optimizer
.
set_state_dict
(
state_dict
)
if
local_rank
==
0
:
print
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
except
FileExistsError
:
if
local_rank
==
0
:
print
(
'Train from scratch.'
)
logger
.
info
(
'Train from scratch.'
)
try
:
start_epoch
=
int
(
args
.
load_checkpoint
[
-
1
])
print
(
f
'Restore training from epoch
{
start_epoch
}
.'
)
logger
.
info
(
f
'Restore training from epoch
{
start_epoch
}
.'
)
except
ValueError
:
pass
...
...
@@ -137,7 +150,10 @@ def main(args):
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
# stage 9-2: audio sample augment method, which is done on the audio sample point
# todo
if
len
(
augment_pipeline
)
!=
0
:
waveforms
=
waveform_augment
(
waveforms
,
augment_pipeline
)
labels
=
paddle
.
concat
(
[
labels
for
i
in
range
(
len
(
augment_pipeline
)
+
1
)])
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats
=
[]
...
...
@@ -185,7 +201,7 @@ def main(args):
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
print
(
print_msg
)
logger
.
info
(
print_msg
)
avg_loss
=
0
num_corrects
=
0
...
...
@@ -217,7 +233,7 @@ def main(args):
num_samples
=
0
# stage 9-13: evaluation the valid dataset batch data
print
(
'Evaluate on validation dataset'
)
logger
.
info
(
'Evaluate on validation dataset'
)
with
paddle
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
dev_loader
):
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
...
...
@@ -238,12 +254,12 @@ def main(args):
print_msg
=
'[Evaluation result]'
print_msg
+=
' dev_acc={:.4f}'
.
format
(
num_corrects
/
num_samples
)
print
(
print_msg
)
logger
.
info
(
print_msg
)
# stage 9-14: Save model parameters
save_dir
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'epoch_{}'
.
format
(
epoch
))
print
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
paddle
.
save
(
optimizer
.
state_dict
(),
...
...
@@ -260,6 +276,10 @@ if __name__ == "__main__":
choices
=
[
'cpu'
,
'gpu'
],
default
=
"cpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--seed"
,
default
=
0
,
type
=
int
,
help
=
"random seed for paddle, numpy and python random package"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"./data/"
,
type
=
str
,
...
...
@@ -295,6 +315,10 @@ if __name__ == "__main__":
type
=
str
,
default
=
'./checkpoint'
,
help
=
"Directory to save model checkpoints."
)
parser
.
add_argument
(
"--augment"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Apply audio augments."
)
args
=
parser
.
parse_args
()
# yapf: enable
...
...
paddleaudio/datasets/rirs_noises.py
0 → 100644
浏览文件 @
2d89c80e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
csv
import
glob
import
os
import
random
from
typing
import
Dict
from
typing
import
List
from
typing
import
Tuple
from
paddle.io
import
Dataset
from
tqdm
import
tqdm
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.backends
import
save_wav
from
paddleaudio.datasets.dataset
import
feat_funcs
from
paddleaudio.utils
import
DATA_HOME
from
paddleaudio.utils
import
decompress
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.utils.download
import
download_and_decompress
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
'OpenRIRNoise'
]
class
OpenRIRNoise
(
Dataset
):
archieves
=
[
{
'url'
:
'http://www.openslr.org/resources/28/rirs_noises.zip'
,
'md5'
:
'e6f48e257286e05de56413b4779d8ffb'
,
},
]
sample_rate
=
16000
meta_info
=
collections
.
namedtuple
(
'META_INFO'
,
(
'id'
,
'duration'
,
'wav'
))
base_path
=
os
.
path
.
join
(
DATA_HOME
,
'open_rir_noise'
)
wav_path
=
os
.
path
.
join
(
base_path
,
'RIRS_NOISES'
)
csv_path
=
os
.
path
.
join
(
base_path
,
'csv'
)
subsets
=
[
'rir'
,
'noise'
]
def
__init__
(
self
,
subset
:
str
=
'rir'
,
feat_type
:
str
=
'raw'
,
target_dir
=
None
,
random_chunk
:
bool
=
True
,
chunk_duration
:
float
=
3.0
,
seed
:
int
=
0
,
**
kwargs
):
assert
subset
in
self
.
subsets
,
\
'Dataset subset must be one in {}, but got {}'
.
format
(
self
.
subsets
,
subset
)
self
.
subset
=
subset
self
.
feat_type
=
feat_type
self
.
feat_config
=
kwargs
self
.
random_chunk
=
random_chunk
self
.
chunk_duration
=
chunk_duration
self
.
csv_path
=
os
.
path
.
join
(
target_dir
,
"open_rir_noise"
,
"csv"
)
if
target_dir
else
self
.
csv_path
self
.
_data
=
self
.
_get_data
()
super
(
OpenRIRNoise
,
self
).
__init__
()
# Set up a seed to reproduce training or predicting result.
# random.seed(seed)
def
_get_data
(
self
):
# Download audio files.
logger
.
info
(
f
"rirs noises base path:
{
self
.
base_path
}
"
)
if
not
os
.
path
.
isdir
(
self
.
base_path
):
download_and_decompress
(
self
.
archieves
,
self
.
base_path
,
decompress
=
True
)
else
:
logger
.
info
(
f
"
{
self
.
base_path
}
already exists, we will not download and decompress again"
)
# Data preparation.
logger
.
info
(
f
"prepare the csv to
{
self
.
csv_path
}
"
)
if
not
os
.
path
.
isdir
(
self
.
csv_path
):
os
.
makedirs
(
self
.
csv_path
)
self
.
prepare_data
()
data
=
[]
with
open
(
os
.
path
.
join
(
self
.
csv_path
,
f
'
{
self
.
subset
}
.csv'
),
'r'
)
as
rf
:
for
line
in
rf
.
readlines
()[
1
:]:
audio_id
,
duration
,
wav
=
line
.
strip
().
split
(
','
)
data
.
append
(
self
.
meta_info
(
audio_id
,
float
(
duration
),
wav
))
random
.
shuffle
(
data
)
return
data
def
_convert_to_record
(
self
,
idx
:
int
):
sample
=
self
.
_data
[
idx
]
record
=
{}
# To show all fields in a namedtuple: `type(sample)._fields`
for
field
in
type
(
sample
).
_fields
:
record
[
field
]
=
getattr
(
sample
,
field
)
waveform
,
sr
=
load_audio
(
record
[
'wav'
])
assert
self
.
feat_type
in
feat_funcs
.
keys
(),
\
f
"Unknown feat_type:
{
self
.
feat_type
}
, it must be one in
{
list
(
feat_funcs
.
keys
())
}
"
feat_func
=
feat_funcs
[
self
.
feat_type
]
feat
=
feat_func
(
waveform
,
sr
=
sr
,
**
self
.
feat_config
)
if
feat_func
else
waveform
record
.
update
({
'feat'
:
feat
})
return
record
@
staticmethod
def
_get_chunks
(
seg_dur
,
audio_id
,
audio_duration
):
num_chunks
=
int
(
audio_duration
/
seg_dur
)
# all in milliseconds
chunk_lst
=
[
audio_id
+
"_"
+
str
(
i
*
seg_dur
)
+
"_"
+
str
(
i
*
seg_dur
+
seg_dur
)
for
i
in
range
(
num_chunks
)
]
return
chunk_lst
def
_get_audio_info
(
self
,
wav_file
:
str
,
split_chunks
:
bool
)
->
List
[
List
[
str
]]:
waveform
,
sr
=
load_audio
(
wav_file
)
audio_id
=
wav_file
.
split
(
"/open_rir_noise/"
)[
-
1
].
split
(
"."
)[
0
]
audio_duration
=
waveform
.
shape
[
0
]
/
sr
ret
=
[]
if
split_chunks
and
audio_duration
>
self
.
chunk_duration
:
# Split into pieces of self.chunk_duration seconds.
uniq_chunks_list
=
self
.
_get_chunks
(
self
.
chunk_duration
,
audio_id
,
audio_duration
)
for
idx
,
chunk
in
enumerate
(
uniq_chunks_list
):
s
,
e
=
chunk
.
split
(
"_"
)[
-
2
:]
# Timestamps of start and end
start_sample
=
int
(
float
(
s
)
*
sr
)
end_sample
=
int
(
float
(
e
)
*
sr
)
new_wav_file
=
os
.
path
.
join
(
self
.
base_path
,
audio_id
+
f
'_chunk_
{
idx
+
1
:
02
}
.wav'
)
save_wav
(
waveform
[
start_sample
:
end_sample
],
sr
,
new_wav_file
)
# id, duration, new_wav
ret
.
append
([
chunk
,
self
.
chunk_duration
,
new_wav_file
])
else
:
# Keep whole audio.
ret
.
append
([
audio_id
,
audio_duration
,
wav_file
])
return
ret
def
generate_csv
(
self
,
wav_files
:
List
[
str
],
output_file
:
str
,
split_chunks
:
bool
=
True
):
logger
.
info
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"id"
,
"duration"
,
"wav"
]
infos
=
list
(
tqdm
(
map
(
self
.
_get_audio_info
,
wav_files
,
[
split_chunks
]
*
len
(
wav_files
)),
total
=
len
(
wav_files
)))
csv_lines
=
[]
for
info
in
infos
:
csv_lines
.
extend
(
info
)
with
open
(
output_file
,
mode
=
"w"
)
as
csv_f
:
csv_writer
=
csv
.
writer
(
csv_f
,
delimiter
=
","
,
quotechar
=
'"'
,
quoting
=
csv
.
QUOTE_MINIMAL
)
csv_writer
.
writerow
(
header
)
for
line
in
csv_lines
:
csv_writer
.
writerow
(
line
)
def
prepare_data
(
self
):
rir_list
=
os
.
path
.
join
(
self
.
wav_path
,
"real_rirs_isotropic_noises"
,
"rir_list"
)
rir_files
=
[]
with
open
(
rir_list
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
rir_file
=
line
.
strip
().
split
(
' '
)[
-
1
]
rir_files
.
append
(
os
.
path
.
join
(
self
.
base_path
,
rir_file
))
noise_list
=
os
.
path
.
join
(
self
.
wav_path
,
"pointsource_noises"
,
"noise_list"
)
noise_files
=
[]
with
open
(
noise_list
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
noise_file
=
line
.
strip
().
split
(
' '
)[
-
1
]
noise_files
.
append
(
os
.
path
.
join
(
self
.
base_path
,
noise_file
))
self
.
generate_csv
(
rir_files
,
os
.
path
.
join
(
self
.
csv_path
,
'rir.csv'
))
self
.
generate_csv
(
noise_files
,
os
.
path
.
join
(
self
.
csv_path
,
'noise.csv'
))
def
__getitem__
(
self
,
idx
):
return
self
.
_convert_to_record
(
idx
)
def
__len__
(
self
):
return
len
(
self
.
_data
)
paddleaudio/datasets/voxceleb.py
浏览文件 @
2d89c80e
...
...
@@ -29,9 +29,12 @@ from paddleaudio.datasets.dataset import feat_funcs
from
paddleaudio.utils
import
DATA_HOME
from
paddleaudio.utils
import
decompress
from
paddleaudio.utils
import
download_and_decompress
from
paddlespeech.s2t.utils.log
import
Log
from
utils.utility
import
download
from
utils.utility
import
unpack
logger
=
Log
(
__name__
).
getlog
()
__all__
=
[
'VoxCeleb1'
]
...
...
@@ -121,9 +124,9 @@ class VoxCeleb1(Dataset):
# Download audio files.
# We need the users to decompress all vox1/dev/wav and vox1/test/wav/ to vox1/wav/ dir
# so, we check the vox1/wav dir status
print
(
"wav base path: {}"
.
format
(
self
.
wav_path
)
)
logger
.
info
(
f
"wav base path:
{
self
.
wav_path
}
"
)
if
not
os
.
path
.
isdir
(
self
.
wav_path
):
print
(
"start to download the voxceleb1 dataset"
)
logger
.
info
(
f
"start to download the voxceleb1 dataset"
)
download_and_decompress
(
# multi-zip parts concatenate to vox1_dev_wav.zip
self
.
archieves_audio_dev
,
self
.
base_path
,
...
...
@@ -135,7 +138,7 @@ class VoxCeleb1(Dataset):
# Download all parts and concatenate the files into one zip file.
dev_zipfile
=
os
.
path
.
join
(
self
.
base_path
,
'vox1_dev_wav.zip'
)
print
(
f
'Concatenating all parts to:
{
dev_zipfile
}
'
)
logger
.
info
(
f
'Concatenating all parts to:
{
dev_zipfile
}
'
)
os
.
system
(
f
'cat
{
os
.
path
.
join
(
self
.
base_path
,
"vox1_dev_wav_parta*"
)
}
>
{
dev_zipfile
}
'
)
...
...
@@ -154,6 +157,9 @@ class VoxCeleb1(Dataset):
self
.
prepare_data
()
data
=
[]
logger
.
info
(
f
"read the
{
self
.
subset
}
from
{
os
.
path
.
join
(
self
.
csv_path
,
f
'
{
self
.
subset
}
.
csv
')
}
"
)
with
open
(
os
.
path
.
join
(
self
.
csv_path
,
f
'
{
self
.
subset
}
.csv'
),
'r'
)
as
rf
:
for
line
in
rf
.
readlines
()[
1
:]:
audio_id
,
duration
,
wav
,
start
,
stop
,
spk_id
=
line
.
strip
(
...
...
@@ -246,7 +252,7 @@ class VoxCeleb1(Dataset):
wav_files
:
List
[
str
],
output_file
:
str
,
split_chunks
:
bool
=
True
):
print
(
f
'Generating csv:
{
output_file
}
'
)
logger
.
info
(
f
'Generating csv:
{
output_file
}
'
)
header
=
[
"id"
,
"duration"
,
"wav"
,
"start"
,
"stop"
,
"spk_id"
]
with
Pool
(
64
)
as
p
:
...
...
@@ -269,7 +275,7 @@ class VoxCeleb1(Dataset):
def
prepare_data
(
self
):
# Audio of speakers in veri_test_file should not be included in training set.
print
(
"start to prepare the data csv file"
)
logger
.
info
(
"start to prepare the data csv file"
)
enrol_files
=
set
()
test_files
=
set
()
# get the enroll and test audio file path
...
...
@@ -299,7 +305,7 @@ class VoxCeleb1(Dataset):
speakers
.
add
(
spk
)
audio_files
.
append
(
file
)
print
(
"start to generate the {}"
.
format
(
logger
.
info
(
"start to generate the {}"
.
format
(
os
.
path
.
join
(
self
.
meta_path
,
'spk_id2label.txt'
)))
# encode the train and dev speakers label to spk_id2label.txt
with
open
(
os
.
path
.
join
(
self
.
meta_path
,
'spk_id2label.txt'
),
'w'
)
as
f
:
...
...
paddlespeech/vector/io/augment.py
0 → 100644
浏览文件 @
2d89c80e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
from
typing
import
List
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.datasets.rirs_noises
import
OpenRIRNoise
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.signal_processing
import
compute_amplitude
from
paddlespeech.vector.io.signal_processing
import
convolve1d
from
paddlespeech.vector.io.signal_processing
import
dB_to_amplitude
from
paddlespeech.vector.io.signal_processing
import
notch_filter
from
paddlespeech.vector.io.signal_processing
import
reverberate
logger
=
Log
(
__name__
).
getlog
()
# TODO: Complete type-hint and doc string.
class
DropFreq
(
nn
.
Layer
):
def
__init__
(
self
,
drop_freq_low
=
1e-14
,
drop_freq_high
=
1
,
drop_count_low
=
1
,
drop_count_high
=
2
,
drop_width
=
0.05
,
drop_prob
=
1
,
):
super
(
DropFreq
,
self
).
__init__
()
self
.
drop_freq_low
=
drop_freq_low
self
.
drop_freq_high
=
drop_freq_high
self
.
drop_count_low
=
drop_count_low
self
.
drop_count_high
=
drop_count_high
self
.
drop_width
=
drop_width
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
waveforms
):
# Don't drop (return early) 1-`drop_prob` portion of the batches
dropped_waveform
=
waveforms
.
clone
()
if
paddle
.
rand
([
1
])
>
self
.
drop_prob
:
return
dropped_waveform
# Add channels dimension
if
len
(
waveforms
.
shape
)
==
2
:
dropped_waveform
=
dropped_waveform
.
unsqueeze
(
-
1
)
# Pick number of frequencies to drop
drop_count
=
paddle
.
randint
(
low
=
self
.
drop_count_low
,
high
=
self
.
drop_count_high
+
1
,
shape
=
[
1
])
# Pick a frequency to drop
drop_range
=
self
.
drop_freq_high
-
self
.
drop_freq_low
drop_frequency
=
(
paddle
.
rand
([
drop_count
])
*
drop_range
+
self
.
drop_freq_low
)
# Filter parameters
filter_length
=
101
pad
=
filter_length
//
2
# Start with delta function
drop_filter
=
paddle
.
zeros
([
1
,
filter_length
,
1
])
drop_filter
[
0
,
pad
,
0
]
=
1
# Subtract each frequency
for
frequency
in
drop_frequency
:
notch_kernel
=
notch_filter
(
frequency
,
filter_length
,
self
.
drop_width
)
drop_filter
=
convolve1d
(
drop_filter
,
notch_kernel
,
pad
)
# Apply filter
dropped_waveform
=
convolve1d
(
dropped_waveform
,
drop_filter
,
pad
)
# Remove channels dimension if added
return
dropped_waveform
.
squeeze
(
-
1
)
class
DropChunk
(
nn
.
Layer
):
def
__init__
(
self
,
drop_length_low
=
100
,
drop_length_high
=
1000
,
drop_count_low
=
1
,
drop_count_high
=
10
,
drop_start
=
0
,
drop_end
=
None
,
drop_prob
=
1
,
noise_factor
=
0.0
,
):
super
(
DropChunk
,
self
).
__init__
()
self
.
drop_length_low
=
drop_length_low
self
.
drop_length_high
=
drop_length_high
self
.
drop_count_low
=
drop_count_low
self
.
drop_count_high
=
drop_count_high
self
.
drop_start
=
drop_start
self
.
drop_end
=
drop_end
self
.
drop_prob
=
drop_prob
self
.
noise_factor
=
noise_factor
# Validate low < high
if
drop_length_low
>
drop_length_high
:
raise
ValueError
(
"Low limit must not be more than high limit"
)
if
drop_count_low
>
drop_count_high
:
raise
ValueError
(
"Low limit must not be more than high limit"
)
# Make sure the length doesn't exceed end - start
if
drop_end
is
not
None
and
drop_end
>=
0
:
if
drop_start
>
drop_end
:
raise
ValueError
(
"Low limit must not be more than high limit"
)
drop_range
=
drop_end
-
drop_start
self
.
drop_length_low
=
min
(
drop_length_low
,
drop_range
)
self
.
drop_length_high
=
min
(
drop_length_high
,
drop_range
)
def
forward
(
self
,
waveforms
,
lengths
):
# Reading input list
lengths
=
(
lengths
*
waveforms
.
shape
[
1
]).
astype
(
'int64'
)
batch_size
=
waveforms
.
shape
[
0
]
dropped_waveform
=
waveforms
.
clone
()
# Don't drop (return early) 1-`drop_prob` portion of the batches
if
paddle
.
rand
([
1
])
>
self
.
drop_prob
:
return
dropped_waveform
# Store original amplitude for computing white noise amplitude
clean_amplitude
=
compute_amplitude
(
waveforms
,
lengths
.
unsqueeze
(
1
))
# Pick a number of times to drop
drop_times
=
paddle
.
randint
(
low
=
self
.
drop_count_low
,
high
=
self
.
drop_count_high
+
1
,
shape
=
[
batch_size
],
)
# Iterate batch to set mask
for
i
in
range
(
batch_size
):
if
drop_times
[
i
]
==
0
:
continue
# Pick lengths
length
=
paddle
.
randint
(
low
=
self
.
drop_length_low
,
high
=
self
.
drop_length_high
+
1
,
shape
=
[
drop_times
[
i
]],
)
# Compute range of starting locations
start_min
=
self
.
drop_start
if
start_min
<
0
:
start_min
+=
lengths
[
i
]
start_max
=
self
.
drop_end
if
start_max
is
None
:
start_max
=
lengths
[
i
]
if
start_max
<
0
:
start_max
+=
lengths
[
i
]
start_max
=
max
(
0
,
start_max
-
length
.
max
())
# Pick starting locations
start
=
paddle
.
randint
(
low
=
start_min
,
high
=
start_max
+
1
,
shape
=
[
drop_times
[
i
]],
)
end
=
start
+
length
# Update waveform
if
not
self
.
noise_factor
:
for
j
in
range
(
drop_times
[
i
]):
dropped_waveform
[
i
,
start
[
j
]:
end
[
j
]]
=
0.0
else
:
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
noise_max
=
2
*
clean_amplitude
[
i
]
*
self
.
noise_factor
for
j
in
range
(
drop_times
[
i
]):
# zero-center the noise distribution
noise_vec
=
paddle
.
rand
([
length
[
j
]],
dtype
=
'float32'
)
noise_vec
=
2
*
noise_max
*
noise_vec
-
noise_max
dropped_waveform
[
i
,
int
(
start
[
j
]):
int
(
end
[
j
])]
=
noise_vec
return
dropped_waveform
class
Resample
(
nn
.
Layer
):
def
__init__
(
self
,
orig_freq
=
16000
,
new_freq
=
16000
,
lowpass_filter_width
=
6
,
):
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
new_freq
=
new_freq
self
.
lowpass_filter_width
=
lowpass_filter_width
# Compute rate for striding
self
.
_compute_strides
()
assert
self
.
orig_freq
%
self
.
conv_stride
==
0
assert
self
.
new_freq
%
self
.
conv_transpose_stride
==
0
def
_compute_strides
(
self
):
# Compute new unit based on ratio of in/out frequencies
base_freq
=
math
.
gcd
(
self
.
orig_freq
,
self
.
new_freq
)
input_samples_in_unit
=
self
.
orig_freq
//
base_freq
self
.
output_samples
=
self
.
new_freq
//
base_freq
# Store the appropriate stride based on the new units
self
.
conv_stride
=
input_samples_in_unit
self
.
conv_transpose_stride
=
self
.
output_samples
def
forward
(
self
,
waveforms
):
if
not
hasattr
(
self
,
"first_indices"
):
self
.
_indices_and_weights
(
waveforms
)
# Don't do anything if the frequencies are the same
if
self
.
orig_freq
==
self
.
new_freq
:
return
waveforms
unsqueezed
=
False
if
len
(
waveforms
.
shape
)
==
2
:
waveforms
=
waveforms
.
unsqueeze
(
1
)
unsqueezed
=
True
elif
len
(
waveforms
.
shape
)
==
3
:
waveforms
=
waveforms
.
transpose
([
0
,
2
,
1
])
else
:
raise
ValueError
(
"Input must be 2 or 3 dimensions"
)
# Do resampling
resampled_waveform
=
self
.
_perform_resample
(
waveforms
)
if
unsqueezed
:
resampled_waveform
=
resampled_waveform
.
squeeze
(
1
)
else
:
resampled_waveform
=
resampled_waveform
.
transpose
([
0
,
2
,
1
])
return
resampled_waveform
def
_perform_resample
(
self
,
waveforms
):
# Compute output size and initialize
batch_size
,
num_channels
,
wave_len
=
waveforms
.
shape
window_size
=
self
.
weights
.
shape
[
1
]
tot_output_samp
=
self
.
_output_samples
(
wave_len
)
resampled_waveform
=
paddle
.
zeros
((
batch_size
,
num_channels
,
tot_output_samp
))
# eye size: (num_channels, num_channels, 1)
eye
=
paddle
.
eye
(
num_channels
).
unsqueeze
(
2
)
# Iterate over the phases in the polyphase filter
for
i
in
range
(
self
.
first_indices
.
shape
[
0
]):
wave_to_conv
=
waveforms
first_index
=
int
(
self
.
first_indices
[
i
].
item
())
if
first_index
>=
0
:
# trim the signal as the filter will not be applied
# before the first_index
wave_to_conv
=
wave_to_conv
[:,
:,
first_index
:]
# pad the right of the signal to allow partial convolutions
# meaning compute values for partial windows (e.g. end of the
# window is outside the signal length)
max_index
=
(
tot_output_samp
-
1
)
//
self
.
output_samples
end_index
=
max_index
*
self
.
conv_stride
+
window_size
current_wave_len
=
wave_len
-
first_index
right_padding
=
max
(
0
,
end_index
+
1
-
current_wave_len
)
left_padding
=
max
(
0
,
-
first_index
)
wave_to_conv
=
paddle
.
nn
.
functional
.
pad
(
wave_to_conv
,
[
left_padding
,
right_padding
],
data_format
=
'NCL'
)
conv_wave
=
paddle
.
nn
.
functional
.
conv1d
(
x
=
wave_to_conv
,
# weight=self.weights[i].repeat(num_channels, 1, 1),
weight
=
self
.
weights
[
i
].
expand
((
num_channels
,
1
,
-
1
)),
stride
=
self
.
conv_stride
,
groups
=
num_channels
,
)
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave
=
paddle
.
nn
.
functional
.
conv1d_transpose
(
conv_wave
,
eye
,
stride
=
self
.
conv_transpose_stride
)
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding
=
i
previous_padding
=
left_padding
+
dilated_conv_wave
.
shape
[
-
1
]
right_padding
=
max
(
0
,
tot_output_samp
-
previous_padding
)
dilated_conv_wave
=
paddle
.
nn
.
functional
.
pad
(
dilated_conv_wave
,
[
left_padding
,
right_padding
],
data_format
=
'NCL'
)
dilated_conv_wave
=
dilated_conv_wave
[:,
:,
:
tot_output_samp
]
resampled_waveform
+=
dilated_conv_wave
return
resampled_waveform
def
_output_samples
(
self
,
input_num_samp
):
samp_in
=
int
(
self
.
orig_freq
)
samp_out
=
int
(
self
.
new_freq
)
tick_freq
=
abs
(
samp_in
*
samp_out
)
//
math
.
gcd
(
samp_in
,
samp_out
)
ticks_per_input_period
=
tick_freq
//
samp_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_in ).
interval_length
=
input_num_samp
*
ticks_per_input_period
if
interval_length
<=
0
:
return
0
ticks_per_output_period
=
tick_freq
//
samp_out
# Get the last output-sample in the closed interval,
# i.e. replacing [ ) with [ ]. Note: integer division rounds down.
# See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
# explanation of the notation.
last_output_samp
=
interval_length
//
ticks_per_output_period
# We need the last output-sample in the open interval, so if it
# takes us to the end of the interval exactly, subtract one.
if
last_output_samp
*
ticks_per_output_period
==
interval_length
:
last_output_samp
-=
1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp
=
last_output_samp
+
1
return
num_output_samp
def
_indices_and_weights
(
self
,
waveforms
):
# Lowpass filter frequency depends on smaller of two frequencies
min_freq
=
min
(
self
.
orig_freq
,
self
.
new_freq
)
lowpass_cutoff
=
0.99
*
0.5
*
min_freq
assert
lowpass_cutoff
*
2
<=
min_freq
window_width
=
self
.
lowpass_filter_width
/
(
2.0
*
lowpass_cutoff
)
assert
lowpass_cutoff
<
min
(
self
.
orig_freq
,
self
.
new_freq
)
/
2
output_t
=
paddle
.
arange
(
start
=
0.0
,
end
=
self
.
output_samples
)
output_t
/=
self
.
new_freq
min_t
=
output_t
-
window_width
max_t
=
output_t
+
window_width
min_input_index
=
paddle
.
ceil
(
min_t
*
self
.
orig_freq
)
max_input_index
=
paddle
.
floor
(
max_t
*
self
.
orig_freq
)
num_indices
=
max_input_index
-
min_input_index
+
1
max_weight_width
=
num_indices
.
max
()
j
=
paddle
.
arange
(
max_weight_width
,
dtype
=
'float32'
)
input_index
=
min_input_index
.
unsqueeze
(
1
)
+
j
.
unsqueeze
(
0
)
delta_t
=
(
input_index
/
self
.
orig_freq
)
-
output_t
.
unsqueeze
(
1
)
weights
=
paddle
.
zeros_like
(
delta_t
)
inside_window_indices
=
delta_t
.
abs
().
less_than
(
paddle
.
to_tensor
(
window_width
))
# raised-cosine (Hanning) window with width `window_width`
weights
[
inside_window_indices
]
=
0.5
*
(
1
+
paddle
.
cos
(
2
*
math
.
pi
*
lowpass_cutoff
/
self
.
lowpass_filter_width
*
delta_t
.
masked_select
(
inside_window_indices
)))
t_eq_zero_indices
=
delta_t
.
equal
(
paddle
.
zeros_like
(
delta_t
))
t_not_eq_zero_indices
=
delta_t
.
not_equal
(
paddle
.
zeros_like
(
delta_t
))
# sinc filter function
weights
=
paddle
.
where
(
t_not_eq_zero_indices
,
weights
*
paddle
.
sin
(
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
)
/
(
math
.
pi
*
delta_t
),
weights
)
# limit of the function at t = 0
weights
=
paddle
.
where
(
t_eq_zero_indices
,
weights
*
2
*
lowpass_cutoff
,
weights
)
# size (output_samples, max_weight_width)
weights
/=
self
.
orig_freq
self
.
first_indices
=
min_input_index
self
.
weights
=
weights
class
SpeedPerturb
(
nn
.
Layer
):
def
__init__
(
self
,
orig_freq
,
speeds
=
[
90
,
100
,
110
],
perturb_prob
=
1.0
,
):
super
(
SpeedPerturb
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
speeds
=
speeds
self
.
perturb_prob
=
perturb_prob
# Initialize index of perturbation
self
.
samp_index
=
0
# Initialize resamplers
self
.
resamplers
=
[]
for
speed
in
self
.
speeds
:
config
=
{
"orig_freq"
:
self
.
orig_freq
,
"new_freq"
:
self
.
orig_freq
*
speed
//
100
,
}
self
.
resamplers
.
append
(
Resample
(
**
config
))
def
forward
(
self
,
waveform
):
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
if
paddle
.
rand
([
1
])
>
self
.
perturb_prob
:
return
waveform
.
clone
()
# Perform a random perturbation
self
.
samp_index
=
paddle
.
randint
(
len
(
self
.
speeds
),
shape
=
[
1
]).
item
()
perturbed_waveform
=
self
.
resamplers
[
self
.
samp_index
](
waveform
)
return
perturbed_waveform
class
AddNoise
(
nn
.
Layer
):
def
__init__
(
self
,
noise_dataset
=
None
,
# None for white noise
num_workers
=
0
,
snr_low
=
0
,
snr_high
=
0
,
mix_prob
=
1.0
,
start_index
=
None
,
normalize
=
False
,
):
super
(
AddNoise
,
self
).
__init__
()
self
.
num_workers
=
num_workers
self
.
snr_low
=
snr_low
self
.
snr_high
=
snr_high
self
.
mix_prob
=
mix_prob
self
.
start_index
=
start_index
self
.
normalize
=
normalize
self
.
noise_dataset
=
noise_dataset
self
.
noise_dataloader
=
None
def
forward
(
self
,
waveforms
,
lengths
=
None
):
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
len
(
waveforms
)])
# Copy clean waveform to initialize noisy waveform
noisy_waveform
=
waveforms
.
clone
()
lengths
=
(
lengths
*
waveforms
.
shape
[
1
]).
astype
(
'int64'
).
unsqueeze
(
1
)
# Don't add noise (return early) 1-`mix_prob` portion of the batches
if
paddle
.
rand
([
1
])
>
self
.
mix_prob
:
return
noisy_waveform
# Compute the average amplitude of the clean waveforms
clean_amplitude
=
compute_amplitude
(
waveforms
,
lengths
)
# Pick an SNR and use it to compute the mixture amplitude factors
SNR
=
paddle
.
rand
((
len
(
waveforms
),
1
))
SNR
=
SNR
*
(
self
.
snr_high
-
self
.
snr_low
)
+
self
.
snr_low
noise_amplitude_factor
=
1
/
(
dB_to_amplitude
(
SNR
)
+
1
)
new_noise_amplitude
=
noise_amplitude_factor
*
clean_amplitude
# Scale clean signal appropriately
noisy_waveform
*=
1
-
noise_amplitude_factor
# Loop through clean samples and create mixture
if
self
.
noise_dataset
is
None
:
white_noise
=
paddle
.
normal
(
shape
=
waveforms
.
shape
)
noisy_waveform
+=
new_noise_amplitude
*
white_noise
else
:
tensor_length
=
waveforms
.
shape
[
1
]
noise_waveform
,
noise_length
=
self
.
_load_noise
(
lengths
,
tensor_length
,
)
# Rescale and add
noise_amplitude
=
compute_amplitude
(
noise_waveform
,
noise_length
)
noise_waveform
*=
new_noise_amplitude
/
(
noise_amplitude
+
1e-14
)
noisy_waveform
+=
noise_waveform
# Normalizing to prevent clipping
if
self
.
normalize
:
abs_max
,
_
=
paddle
.
max
(
paddle
.
abs
(
noisy_waveform
),
axis
=
1
,
keepdim
=
True
)
noisy_waveform
=
noisy_waveform
/
abs_max
.
clip
(
min
=
1.0
)
return
noisy_waveform
def
_load_noise
(
self
,
lengths
,
max_length
):
"""
Load a batch of noises
args
lengths(Paddle.Tensor): Num samples of waveforms with shape (N, 1).
max_length(int): Width of a batch.
"""
lengths
=
lengths
.
squeeze
(
1
)
batch_size
=
len
(
lengths
)
# Load a noise batch
if
self
.
noise_dataloader
is
None
:
def
noise_collate_fn
(
batch
):
def
pad
(
x
,
target_length
,
mode
=
'constant'
,
**
kwargs
):
x
=
np
.
asarray
(
x
)
w
=
target_length
-
x
.
shape
[
0
]
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
ids
=
[
item
[
'id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
waveforms
=
list
(
map
(
lambda
x
:
pad
(
x
,
max
(
max_length
,
lengths
.
max
().
item
())),
[
item
[
'feat'
]
for
item
in
batch
]))
waveforms
=
np
.
stack
(
waveforms
)
return
{
'ids'
:
ids
,
'feats'
:
waveforms
,
'lengths'
:
lengths
}
# Create noise data loader.
self
.
noise_dataloader
=
paddle
.
io
.
DataLoader
(
self
.
noise_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
self
.
num_workers
,
collate_fn
=
noise_collate_fn
,
return_list
=
True
,
)
self
.
noise_data
=
iter
(
self
.
noise_dataloader
)
noise_batch
,
noise_len
=
self
.
_load_noise_batch_of_size
(
batch_size
)
# Select a random starting location in the waveform
start_index
=
self
.
start_index
if
self
.
start_index
is
None
:
start_index
=
0
max_chop
=
(
noise_len
-
lengths
).
min
().
clip
(
min
=
1
)
start_index
=
paddle
.
randint
(
high
=
max_chop
,
shape
=
[
1
])
# Truncate noise_batch to max_length
noise_batch
=
noise_batch
[:,
start_index
:
start_index
+
max_length
]
noise_len
=
(
noise_len
-
start_index
).
clip
(
max
=
max_length
).
unsqueeze
(
1
)
return
noise_batch
,
noise_len
def
_load_noise_batch_of_size
(
self
,
batch_size
):
"""Concatenate noise batches, then chop to correct size"""
noise_batch
,
noise_lens
=
self
.
_load_noise_batch
()
# Expand
while
len
(
noise_batch
)
<
batch_size
:
noise_batch
=
paddle
.
concat
((
noise_batch
,
noise_batch
))
noise_lens
=
paddle
.
concat
((
noise_lens
,
noise_lens
))
# Contract
if
len
(
noise_batch
)
>
batch_size
:
noise_batch
=
noise_batch
[:
batch_size
]
noise_lens
=
noise_lens
[:
batch_size
]
return
noise_batch
,
noise_lens
def
_load_noise_batch
(
self
):
"""Load a batch of noises, restarting iteration if necessary."""
try
:
batch
=
next
(
self
.
noise_data
)
except
StopIteration
:
self
.
noise_data
=
iter
(
self
.
noise_dataloader
)
batch
=
next
(
self
.
noise_data
)
noises
,
lens
=
batch
[
'feats'
],
batch
[
'lengths'
]
return
noises
,
lens
class
AddReverb
(
nn
.
Layer
):
def
__init__
(
self
,
rir_dataset
,
reverb_prob
=
1.0
,
rir_scale_factor
=
1.0
,
num_workers
=
0
,
):
super
(
AddReverb
,
self
).
__init__
()
self
.
rir_dataset
=
rir_dataset
self
.
reverb_prob
=
reverb_prob
self
.
rir_scale_factor
=
rir_scale_factor
# Create rir data loader.
def
rir_collate_fn
(
batch
):
def
pad
(
x
,
target_length
,
mode
=
'constant'
,
**
kwargs
):
x
=
np
.
asarray
(
x
)
w
=
target_length
-
x
.
shape
[
0
]
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
0
]
}
'
return
np
.
pad
(
x
,
[
0
,
w
],
mode
=
mode
,
**
kwargs
)
ids
=
[
item
[
'id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
0
]
for
item
in
batch
])
waveforms
=
list
(
map
(
lambda
x
:
pad
(
x
,
lengths
.
max
().
item
()),
[
item
[
'feat'
]
for
item
in
batch
]))
waveforms
=
np
.
stack
(
waveforms
)
return
{
'ids'
:
ids
,
'feats'
:
waveforms
,
'lengths'
:
lengths
}
self
.
rir_dataloader
=
paddle
.
io
.
DataLoader
(
self
.
rir_dataset
,
collate_fn
=
rir_collate_fn
,
num_workers
=
num_workers
,
shuffle
=
True
,
return_list
=
True
,
)
self
.
rir_data
=
iter
(
self
.
rir_dataloader
)
def
forward
(
self
,
waveforms
,
lengths
=
None
):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
len
(
waveforms
)])
# Don't add reverb (return early) 1-`reverb_prob` portion of the time
if
paddle
.
rand
([
1
])
>
self
.
reverb_prob
:
return
waveforms
.
clone
()
# Add channels dimension if necessary
channel_added
=
False
if
len
(
waveforms
.
shape
)
==
2
:
waveforms
=
waveforms
.
unsqueeze
(
-
1
)
channel_added
=
True
# Load and prepare RIR
rir_waveform
=
self
.
_load_rir
()
# Compress or dilate RIR
if
self
.
rir_scale_factor
!=
1
:
rir_waveform
=
F
.
interpolate
(
rir_waveform
.
transpose
([
0
,
2
,
1
]),
scale_factor
=
self
.
rir_scale_factor
,
mode
=
"linear"
,
align_corners
=
False
,
data_format
=
'NCW'
,
)
# (N, C, L) -> (N, L, C)
rir_waveform
=
rir_waveform
.
transpose
([
0
,
2
,
1
])
rev_waveform
=
reverberate
(
waveforms
,
rir_waveform
,
self
.
rir_dataset
.
sample_rate
,
rescale_amp
=
"avg"
)
# Remove channels dimension if added
if
channel_added
:
return
rev_waveform
.
squeeze
(
-
1
)
return
rev_waveform
def
_load_rir
(
self
):
try
:
batch
=
next
(
self
.
rir_data
)
except
StopIteration
:
self
.
rir_data
=
iter
(
self
.
rir_dataloader
)
batch
=
next
(
self
.
rir_data
)
rir_waveform
=
batch
[
'feats'
]
# Make sure RIR has correct channels
if
len
(
rir_waveform
.
shape
)
==
2
:
rir_waveform
=
rir_waveform
.
unsqueeze
(
-
1
)
return
rir_waveform
class
AddBabble
(
nn
.
Layer
):
def
__init__
(
self
,
speaker_count
=
3
,
snr_low
=
0
,
snr_high
=
0
,
mix_prob
=
1
,
):
super
(
AddBabble
,
self
).
__init__
()
self
.
speaker_count
=
speaker_count
self
.
snr_low
=
snr_low
self
.
snr_high
=
snr_high
self
.
mix_prob
=
mix_prob
def
forward
(
self
,
waveforms
,
lengths
=
None
):
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
len
(
waveforms
)])
babbled_waveform
=
waveforms
.
clone
()
lengths
=
(
lengths
*
waveforms
.
shape
[
1
]).
unsqueeze
(
1
)
batch_size
=
len
(
waveforms
)
# Don't mix (return early) 1-`mix_prob` portion of the batches
if
paddle
.
rand
([
1
])
>
self
.
mix_prob
:
return
babbled_waveform
# Pick an SNR and use it to compute the mixture amplitude factors
clean_amplitude
=
compute_amplitude
(
waveforms
,
lengths
)
SNR
=
paddle
.
rand
((
batch_size
,
1
))
SNR
=
SNR
*
(
self
.
snr_high
-
self
.
snr_low
)
+
self
.
snr_low
noise_amplitude_factor
=
1
/
(
dB_to_amplitude
(
SNR
)
+
1
)
new_noise_amplitude
=
noise_amplitude_factor
*
clean_amplitude
# Scale clean signal appropriately
babbled_waveform
*=
1
-
noise_amplitude_factor
# For each speaker in the mixture, roll and add
babble_waveform
=
waveforms
.
roll
((
1
,
),
axis
=
0
)
babble_len
=
lengths
.
roll
((
1
,
),
axis
=
0
)
for
i
in
range
(
1
,
self
.
speaker_count
):
babble_waveform
+=
waveforms
.
roll
((
1
+
i
,
),
axis
=
0
)
babble_len
=
paddle
.
concat
(
[
babble_len
,
babble_len
.
roll
((
1
,
),
axis
=
0
)],
axis
=-
1
).
max
(
axis
=-
1
,
keepdim
=
True
)
# Rescale and add to mixture
babble_amplitude
=
compute_amplitude
(
babble_waveform
,
babble_len
)
babble_waveform
*=
new_noise_amplitude
/
(
babble_amplitude
+
1e-14
)
babbled_waveform
+=
babble_waveform
return
babbled_waveform
class
TimeDomainSpecAugment
(
nn
.
Layer
):
def
__init__
(
self
,
perturb_prob
=
1.0
,
drop_freq_prob
=
1.0
,
drop_chunk_prob
=
1.0
,
speeds
=
[
95
,
100
,
105
],
sample_rate
=
16000
,
drop_freq_count_low
=
0
,
drop_freq_count_high
=
3
,
drop_chunk_count_low
=
0
,
drop_chunk_count_high
=
5
,
drop_chunk_length_low
=
1000
,
drop_chunk_length_high
=
2000
,
drop_chunk_noise_factor
=
0
,
):
super
(
TimeDomainSpecAugment
,
self
).
__init__
()
self
.
speed_perturb
=
SpeedPerturb
(
perturb_prob
=
perturb_prob
,
orig_freq
=
sample_rate
,
speeds
=
speeds
,
)
self
.
drop_freq
=
DropFreq
(
drop_prob
=
drop_freq_prob
,
drop_count_low
=
drop_freq_count_low
,
drop_count_high
=
drop_freq_count_high
,
)
self
.
drop_chunk
=
DropChunk
(
drop_prob
=
drop_chunk_prob
,
drop_count_low
=
drop_chunk_count_low
,
drop_count_high
=
drop_chunk_count_high
,
drop_length_low
=
drop_chunk_length_low
,
drop_length_high
=
drop_chunk_length_high
,
noise_factor
=
drop_chunk_noise_factor
,
)
def
forward
(
self
,
waveforms
,
lengths
=
None
):
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
len
(
waveforms
)])
with
paddle
.
no_grad
():
# Augmentation
waveforms
=
self
.
speed_perturb
(
waveforms
)
waveforms
=
self
.
drop_freq
(
waveforms
)
waveforms
=
self
.
drop_chunk
(
waveforms
,
lengths
)
return
waveforms
class
EnvCorrupt
(
nn
.
Layer
):
def
__init__
(
self
,
reverb_prob
=
1.0
,
babble_prob
=
1.0
,
noise_prob
=
1.0
,
rir_dataset
=
None
,
noise_dataset
=
None
,
num_workers
=
0
,
babble_speaker_count
=
0
,
babble_snr_low
=
0
,
babble_snr_high
=
0
,
noise_snr_low
=
0
,
noise_snr_high
=
0
,
rir_scale_factor
=
1.0
,
):
super
(
EnvCorrupt
,
self
).
__init__
()
# Initialize corrupters
if
rir_dataset
is
not
None
and
reverb_prob
>
0.0
:
self
.
add_reverb
=
AddReverb
(
rir_dataset
=
rir_dataset
,
num_workers
=
num_workers
,
reverb_prob
=
reverb_prob
,
rir_scale_factor
=
rir_scale_factor
,
)
if
babble_speaker_count
>
0
and
babble_prob
>
0.0
:
self
.
add_babble
=
AddBabble
(
speaker_count
=
babble_speaker_count
,
snr_low
=
babble_snr_low
,
snr_high
=
babble_snr_high
,
mix_prob
=
babble_prob
,
)
if
noise_dataset
is
not
None
and
noise_prob
>
0.0
:
self
.
add_noise
=
AddNoise
(
noise_dataset
=
noise_dataset
,
num_workers
=
num_workers
,
snr_low
=
noise_snr_low
,
snr_high
=
noise_snr_high
,
mix_prob
=
noise_prob
,
)
def
forward
(
self
,
waveforms
,
lengths
=
None
):
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
len
(
waveforms
)])
# Augmentation
with
paddle
.
no_grad
():
if
hasattr
(
self
,
"add_reverb"
):
try
:
waveforms
=
self
.
add_reverb
(
waveforms
,
lengths
)
except
Exception
:
pass
if
hasattr
(
self
,
"add_babble"
):
waveforms
=
self
.
add_babble
(
waveforms
,
lengths
)
if
hasattr
(
self
,
"add_noise"
):
waveforms
=
self
.
add_noise
(
waveforms
,
lengths
)
return
waveforms
def
build_augment_pipeline
(
target_dir
=
None
)
->
List
[
paddle
.
nn
.
Layer
]:
"""build augment pipeline
Note: this pipeline cannot be used in the paddle.DataLoader
Returns:
List[paddle.nn.Layer]: all augment process
"""
logger
.
info
(
"start to build the augment pipeline"
)
noise_dataset
=
OpenRIRNoise
(
'noise'
,
target_dir
=
target_dir
)
rir_dataset
=
OpenRIRNoise
(
'rir'
)
wavedrop
=
TimeDomainSpecAugment
(
sample_rate
=
16000
,
speeds
=
[
100
],
)
speed_perturb
=
TimeDomainSpecAugment
(
sample_rate
=
16000
,
speeds
=
[
95
,
100
,
105
],
)
add_noise
=
EnvCorrupt
(
noise_dataset
=
noise_dataset
,
reverb_prob
=
0.0
,
noise_prob
=
1.0
,
noise_snr_low
=
0
,
noise_snr_high
=
15
,
rir_scale_factor
=
1.0
,
)
add_rev
=
EnvCorrupt
(
rir_dataset
=
rir_dataset
,
reverb_prob
=
1.0
,
noise_prob
=
0.0
,
rir_scale_factor
=
1.0
,
)
add_rev_noise
=
EnvCorrupt
(
noise_dataset
=
noise_dataset
,
rir_dataset
=
rir_dataset
,
reverb_prob
=
1.0
,
noise_prob
=
1.0
,
noise_snr_low
=
0
,
noise_snr_high
=
15
,
rir_scale_factor
=
1.0
,
)
return
[
wavedrop
,
speed_perturb
,
add_noise
,
add_rev
,
add_rev_noise
]
def
waveform_augment
(
waveforms
:
paddle
.
Tensor
,
augment_pipeline
:
List
[
paddle
.
nn
.
Layer
])
->
paddle
.
Tensor
:
"""process the augment pipeline and return all the waveforms
Args:
waveforms (paddle.Tensor): _description_
augment_pipeline (List[paddle.nn.Layer]): _description_
Returns:
paddle.Tensor: _description_
"""
waveforms_aug_list
=
[
waveforms
]
for
aug
in
augment_pipeline
:
waveforms_aug
=
aug
(
waveforms
)
# (N, L)
if
waveforms_aug
.
shape
[
1
]
>=
waveforms
.
shape
[
1
]:
# Trunc
waveforms_aug
=
waveforms_aug
[:,
:
waveforms
.
shape
[
1
]]
else
:
# Pad
lengths_to_pad
=
waveforms
.
shape
[
1
]
-
waveforms_aug
.
shape
[
1
]
waveforms_aug
=
F
.
pad
(
waveforms_aug
.
unsqueeze
(
-
1
),
[
0
,
lengths_to_pad
],
data_format
=
'NLC'
).
squeeze
(
-
1
)
waveforms_aug_list
.
append
(
waveforms_aug
)
return
paddle
.
concat
(
waveforms_aug_list
,
axis
=
0
)
paddlespeech/vector/io/signal_processing.py
0 → 100644
浏览文件 @
2d89c80e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
paddle
# TODO: Complete type-hint and doc string.
def
blackman_window
(
win_len
,
dtype
=
np
.
float32
):
arcs
=
np
.
pi
*
np
.
arange
(
win_len
)
/
float
(
win_len
)
win
=
np
.
asarray
(
[
0.42
-
0.5
*
np
.
cos
(
2
*
arc
)
+
0.08
*
np
.
cos
(
4
*
arc
)
for
arc
in
arcs
],
dtype
=
dtype
)
return
paddle
.
to_tensor
(
win
)
def
compute_amplitude
(
waveforms
,
lengths
=
None
,
amp_type
=
"avg"
,
scale
=
"linear"
):
if
len
(
waveforms
.
shape
)
==
1
:
waveforms
=
waveforms
.
unsqueeze
(
0
)
assert
amp_type
in
[
"avg"
,
"peak"
]
assert
scale
in
[
"linear"
,
"dB"
]
if
amp_type
==
"avg"
:
if
lengths
is
None
:
out
=
paddle
.
mean
(
paddle
.
abs
(
waveforms
),
axis
=
1
,
keepdim
=
True
)
else
:
wav_sum
=
paddle
.
sum
(
paddle
.
abs
(
waveforms
),
axis
=
1
,
keepdim
=
True
)
out
=
wav_sum
/
lengths
elif
amp_type
==
"peak"
:
out
=
paddle
.
max
(
paddle
.
abs
(
waveforms
),
axis
=
1
,
keepdim
=
True
)
else
:
raise
NotImplementedError
if
scale
==
"linear"
:
return
out
elif
scale
==
"dB"
:
return
paddle
.
clip
(
20
*
paddle
.
log10
(
out
),
min
=-
80
)
else
:
raise
NotImplementedError
def
dB_to_amplitude
(
SNR
):
return
10
**
(
SNR
/
20
)
def
convolve1d
(
waveform
,
kernel
,
padding
=
0
,
pad_type
=
"constant"
,
stride
=
1
,
groups
=
1
,
):
if
len
(
waveform
.
shape
)
!=
3
:
raise
ValueError
(
"Convolve1D expects a 3-dimensional tensor"
)
# Padding can be a tuple (left_pad, right_pad) or an int
if
isinstance
(
padding
,
list
):
waveform
=
paddle
.
nn
.
functional
.
pad
(
x
=
waveform
,
pad
=
padding
,
mode
=
pad_type
,
data_format
=
'NLC'
,
)
# Move time dimension last, which pad and fft and conv expect.
# (N, L, C) -> (N, C, L)
waveform
=
waveform
.
transpose
([
0
,
2
,
1
])
kernel
=
kernel
.
transpose
([
0
,
2
,
1
])
convolved
=
paddle
.
nn
.
functional
.
conv1d
(
x
=
waveform
,
weight
=
kernel
,
stride
=
stride
,
groups
=
groups
,
padding
=
padding
if
not
isinstance
(
padding
,
list
)
else
0
,
)
# Return time dimension to the second dimension.
return
convolved
.
transpose
([
0
,
2
,
1
])
def
notch_filter
(
notch_freq
,
filter_width
=
101
,
notch_width
=
0.05
):
# Check inputs
assert
0
<
notch_freq
<=
1
assert
filter_width
%
2
!=
0
pad
=
filter_width
//
2
inputs
=
paddle
.
arange
(
filter_width
,
dtype
=
'float32'
)
-
pad
# Avoid frequencies that are too low
notch_freq
+=
notch_width
# Define sinc function, avoiding division by zero
def
sinc
(
x
):
def
_sinc
(
x
):
return
paddle
.
sin
(
x
)
/
x
# The zero is at the middle index
res
=
paddle
.
concat
(
[
_sinc
(
x
[:
pad
]),
paddle
.
ones
([
1
]),
_sinc
(
x
[
pad
+
1
:])])
return
res
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf
=
sinc
(
3
*
(
notch_freq
-
notch_width
)
*
inputs
)
# import torch
# hlpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hlpf
*=
blackman_window
(
filter_width
)
hlpf
/=
paddle
.
sum
(
hlpf
)
# Compute a high-pass filter with cutoff frequency notch_freq.
hhpf
=
sinc
(
3
*
(
notch_freq
+
notch_width
)
*
inputs
)
# hhpf *= paddle.to_tensor(torch.blackman_window(filter_width).detach().numpy())
hhpf
*=
blackman_window
(
filter_width
)
hhpf
/=
-
paddle
.
sum
(
hhpf
)
hhpf
[
pad
]
+=
1
# Adding filters creates notch filter
return
(
hlpf
+
hhpf
).
reshape
([
1
,
-
1
,
1
])
def
reverberate
(
waveforms
,
rir_waveform
,
sample_rate
,
impulse_duration
=
0.3
,
rescale_amp
=
"avg"
):
orig_shape
=
waveforms
.
shape
if
len
(
waveforms
.
shape
)
>
3
or
len
(
rir_waveform
.
shape
)
>
3
:
raise
NotImplementedError
# if inputs are mono tensors we reshape to 1, samples
if
len
(
waveforms
.
shape
)
==
1
:
waveforms
=
waveforms
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)
elif
len
(
waveforms
.
shape
)
==
2
:
waveforms
=
waveforms
.
unsqueeze
(
-
1
)
if
len
(
rir_waveform
.
shape
)
==
1
:
# convolve1d expects a 3d tensor !
rir_waveform
=
rir_waveform
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)
elif
len
(
rir_waveform
.
shape
)
==
2
:
rir_waveform
=
rir_waveform
.
unsqueeze
(
-
1
)
# Compute the average amplitude of the clean
orig_amplitude
=
compute_amplitude
(
waveforms
,
waveforms
.
shape
[
1
],
rescale_amp
)
# Compute index of the direct signal, so we can preserve alignment
impulse_index_start
=
rir_waveform
.
abs
().
argmax
(
axis
=
1
).
item
()
impulse_index_end
=
min
(
impulse_index_start
+
int
(
sample_rate
*
impulse_duration
),
rir_waveform
.
shape
[
1
])
rir_waveform
=
rir_waveform
[:,
impulse_index_start
:
impulse_index_end
,
:]
rir_waveform
=
rir_waveform
/
paddle
.
norm
(
rir_waveform
,
p
=
2
)
rir_waveform
=
paddle
.
flip
(
rir_waveform
,
[
1
])
waveforms
=
convolve1d
(
waveform
=
waveforms
,
kernel
=
rir_waveform
,
padding
=
[
rir_waveform
.
shape
[
1
]
-
1
,
0
],
)
# Rescale to the peak amplitude of the clean waveform
waveforms
=
rescale
(
waveforms
,
waveforms
.
shape
[
1
],
orig_amplitude
,
rescale_amp
)
if
len
(
orig_shape
)
==
1
:
waveforms
=
waveforms
.
squeeze
(
0
).
squeeze
(
-
1
)
if
len
(
orig_shape
)
==
2
:
waveforms
=
waveforms
.
squeeze
(
-
1
)
return
waveforms
def
rescale
(
waveforms
,
lengths
,
target_lvl
,
amp_type
=
"avg"
,
scale
=
"linear"
):
assert
amp_type
in
[
"peak"
,
"avg"
]
assert
scale
in
[
"linear"
,
"dB"
]
batch_added
=
False
if
len
(
waveforms
.
shape
)
==
1
:
batch_added
=
True
waveforms
=
waveforms
.
unsqueeze
(
0
)
waveforms
=
normalize
(
waveforms
,
lengths
,
amp_type
)
if
scale
==
"linear"
:
out
=
target_lvl
*
waveforms
elif
scale
==
"dB"
:
out
=
dB_to_amplitude
(
target_lvl
)
*
waveforms
else
:
raise
NotImplementedError
(
"Invalid scale, choose between dB and linear"
)
if
batch_added
:
out
=
out
.
squeeze
(
0
)
return
out
def
normalize
(
waveforms
,
lengths
=
None
,
amp_type
=
"avg"
,
eps
=
1e-14
):
assert
amp_type
in
[
"avg"
,
"peak"
]
batch_added
=
False
if
len
(
waveforms
.
shape
)
==
1
:
batch_added
=
True
waveforms
=
waveforms
.
unsqueeze
(
0
)
den
=
compute_amplitude
(
waveforms
,
lengths
,
amp_type
)
+
eps
if
batch_added
:
waveforms
=
waveforms
.
squeeze
(
0
)
return
waveforms
/
den
paddlespeech/vector/models/ecapa_tdnn.py
浏览文件 @
2d89c80e
...
...
@@ -19,6 +19,16 @@ import paddle.nn.functional as F
def
length_to_mask
(
length
,
max_len
=
None
,
dtype
=
None
):
"""_summary_
Args:
length (_type_): _description_
max_len (_type_, optional): _description_. Defaults to None.
dtype (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
assert
len
(
length
.
shape
)
==
1
if
max_len
is
None
:
...
...
@@ -47,6 +57,19 @@ class Conv1d(nn.Layer):
groups
=
1
,
bias
=
True
,
padding_mode
=
"reflect"
,
):
"""_summary_
Args:
in_channels (_type_): _description_
out_channels (_type_): _description_
kernel_size (_type_): _description_
stride (int, optional): _description_. Defaults to 1.
padding (str, optional): _description_. Defaults to "same".
dilation (int, optional): _description_. Defaults to 1.
groups (int, optional): _description_. Defaults to 1.
bias (bool, optional): _description_. Defaults to True.
padding_mode (str, optional): _description_. Defaults to "reflect".
"""
super
().
__init__
()
self
.
kernel_size
=
kernel_size
...
...
@@ -66,6 +89,17 @@ class Conv1d(nn.Layer):
bias_attr
=
bias
,
)
def
forward
(
self
,
x
):
"""_summary_
Args:
x (_type_): _description_
Raises:
ValueError: _description_
Returns:
_type_: _description_
"""
if
self
.
padding
==
"same"
:
x
=
self
.
_manage_padding
(
x
,
self
.
kernel_size
,
self
.
dilation
,
self
.
stride
)
...
...
@@ -75,6 +109,17 @@ class Conv1d(nn.Layer):
return
self
.
conv
(
x
)
def
_manage_padding
(
self
,
x
,
kernel_size
:
int
,
dilation
:
int
,
stride
:
int
):
"""_summary_
Args:
x (_type_): _description_
kernel_size (int): _description_
dilation (int): _description_
stride (int): _description_
Returns:
_type_: _description_
"""
L_in
=
x
.
shape
[
-
1
]
# Detecting input shape
padding
=
self
.
_get_padding_elem
(
L_in
,
stride
,
kernel_size
,
dilation
)
# Time padding
...
...
@@ -88,6 +133,17 @@ class Conv1d(nn.Layer):
stride
:
int
,
kernel_size
:
int
,
dilation
:
int
):
"""_summary_
Args:
L_in (int): _description_
stride (int): _description_
kernel_size (int): _description_
dilation (int): _description_
Returns:
_type_: _description_
"""
if
stride
>
1
:
n_steps
=
math
.
ceil
(((
L_in
-
kernel_size
*
dilation
)
/
stride
)
+
1
)
L_out
=
stride
*
(
n_steps
-
1
)
+
kernel_size
*
dilation
...
...
@@ -134,6 +190,15 @@ class TDNNBlock(nn.Layer):
kernel_size
,
dilation
,
activation
=
nn
.
ReLU
,
):
"""Implementation of TDNN network
Args:
in_channels (int): input channels or input embedding dimensions
out_channels (int): output channels or output embedding dimensions
kernel_size (int): the kernel size of the TDNN network block
dilation (int): the dilation of the TDNN network block
activation (paddle class, optional): the activation layers. Defaults to nn.ReLU.
"""
super
().
__init__
()
self
.
conv
=
Conv1d
(
in_channels
=
in_channels
,
...
...
@@ -149,6 +214,15 @@ class TDNNBlock(nn.Layer):
class
Res2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
scale
=
8
,
dilation
=
1
):
"""Implementation of Res2Net Block with dilation
The paper is refered as "Res2Net: A New Multi-scale Backbone Architecture",
whose url is https://arxiv.org/abs/1904.01169
Args:
in_channels (int): input channels or input dimensions
out_channels (int): output channels or output dimensions
scale (int, optional): _description_. Defaults to 8.
dilation (int, optional): _description_. Defaults to 1.
"""
super
().
__init__
()
assert
in_channels
%
scale
==
0
assert
out_channels
%
scale
==
0
...
...
@@ -179,6 +253,14 @@ class Res2NetBlock(nn.Layer):
class
SEBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
se_channels
,
out_channels
):
"""Implementation of SEBlock
The paper is refered as "Squeeze-and-Excitation Networks"
whose url is https://arxiv.org/abs/1709.01507
Args:
in_channels (int): input channels or input data dimensions
se_channels (_type_): _description_
out_channels (int): output channels or output data dimensions
"""
super
().
__init__
()
self
.
conv1
=
Conv1d
(
...
...
@@ -275,6 +357,17 @@ class SERes2NetBlock(nn.Layer):
kernel_size
=
1
,
dilation
=
1
,
activation
=
nn
.
ReLU
,
):
"""Implementation of Squeeze-Extraction Res2Blocks in ECAPA-TDNN network model
Args:
in_channels (int): input channels or input data dimensions
out_channels (_type_): _description_
res2net_scale (int, optional): _description_. Defaults to 8.
se_channels (int, optional): _description_. Defaults to 128.
kernel_size (int, optional): _description_. Defaults to 1.
dilation (int, optional): _description_. Defaults to 1.
activation (_type_, optional): _description_. Defaults to nn.ReLU.
"""
super
().
__init__
()
self
.
out_channels
=
out_channels
self
.
tdnn1
=
TDNNBlock
(
...
...
paddlespeech/vector/training/seeding.py
0 → 100644
浏览文件 @
2d89c80e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddlespeech.s2t.utils.log
import
Log
logger
=
Log
(
__name__
).
getlog
()
import
random
import
numpy
as
np
import
paddle
def
seed_everything
(
seed
:
int
):
"""Seed paddle, random and np.random to help reproductivity."""
paddle
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
logger
.
info
(
f
"Set the seed of paddle, random, np.random to
{
seed
}
."
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录