Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0e87037f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0e87037f
编写于
3月 09, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor to compilance paddleaudio
上级
4473405f
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
63 addition
and
753 deletion
+63
-753
examples/voxceleb/sv0/local/data_prepare.py
examples/voxceleb/sv0/local/data_prepare.py
+1
-14
examples/voxceleb/sv0/local/extract_speaker_embedding.py
examples/voxceleb/sv0/local/extract_speaker_embedding.py
+0
-129
examples/voxceleb/sv0/local/speaker_verification_cosine.py
examples/voxceleb/sv0/local/speaker_verification_cosine.py
+0
-264
examples/voxceleb/sv0/local/train.py
examples/voxceleb/sv0/local/train.py
+0
-326
examples/voxceleb/sv0/path.sh
examples/voxceleb/sv0/path.sh
+3
-0
examples/voxceleb/sv0/run.sh
examples/voxceleb/sv0/run.sh
+5
-7
paddleaudio/paddleaudio/datasets/__init__.py
paddleaudio/paddleaudio/datasets/__init__.py
+2
-0
paddleaudio/paddleaudio/datasets/rirs_noises.py
paddleaudio/paddleaudio/datasets/rirs_noises.py
+5
-5
paddleaudio/paddleaudio/datasets/voxceleb.py
paddleaudio/paddleaudio/datasets/voxceleb.py
+6
-6
paddleaudio/paddleaudio/metric/__init__.py
paddleaudio/paddleaudio/metric/__init__.py
+1
-0
paddleaudio/paddleaudio/metric/eer.py
paddleaudio/paddleaudio/metric/eer.py
+0
-0
paddlespeech/vector/io/augment.py
paddlespeech/vector/io/augment.py
+2
-2
paddlespeech/vector/io/batch.py
paddlespeech/vector/io/batch.py
+38
-0
paddlespeech/vector/training/scheduler.py
paddlespeech/vector/training/scheduler.py
+0
-0
未找到文件。
examples/voxceleb/sv0/local/data_prepare.py
浏览文件 @
0e87037f
...
@@ -3,24 +3,11 @@ import os
...
@@ -3,24 +3,11 @@ import os
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddleaudio.paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddleaudio.features.core
import
melspectrogram
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.augment
import
build_augment_pipeline
from
paddlespeech.vector.io.augment
import
build_augment_pipeline
from
paddlespeech.vector.io.augment
import
waveform_augment
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.io.batch
import
waveform_collate_fn
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.loss
import
AdditiveAngularMargin
from
paddlespeech.vector.modules.loss
import
LogSoftmaxWrapper
from
paddlespeech.vector.modules.lr
import
CyclicLRScheduler
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.utils.time
import
Timer
logger
=
Log
(
__name__
).
getlog
()
logger
=
Log
(
__name__
).
getlog
()
...
...
examples/voxceleb/sv0/local/extract_speaker_embedding.py
已删除
100644 → 0
浏览文件 @
4473405f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
tqdm
import
tqdm
from
paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddleaudio.features.core
import
melspectrogram
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.metrics
import
compute_eer
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
# feat configuration
cpu_feat_conf
=
{
'n_mels'
:
80
,
'window_size'
:
400
,
#ms
'hop_length'
:
160
,
#ms
}
def
extract_audio_embedding
(
args
):
# stage 0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
args
.
seed
)
# stage 1: build the dnn backbone model network
##"channels": [1024, 1024, 1024, 1024, 3072],
model_conf
=
{
"input_size"
:
80
,
"channels"
:
[
512
,
512
,
512
,
512
,
1536
],
"kernel_sizes"
:
[
5
,
3
,
3
,
3
,
1
],
"dilations"
:
[
1
,
2
,
3
,
4
,
1
],
"attention_channels"
:
128
,
"lin_neurons"
:
192
,
}
ecapa_tdnn
=
EcapaTdnn
(
**
model_conf
)
# stage4: build the speaker verification train instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
1211
)
# stage 2: load the pre-trained model
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
# load model checkpoint to sid model
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# stage 3: we must set the model to eval mode
model
.
eval
()
# stage 4: read the audio data and extract the embedding
# wavform is one dimension numpy array
waveform
,
sr
=
load_audio
(
args
.
audio_path
)
# feat type is numpy array, whose shape is [dim, time]
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time]
feat
=
melspectrogram
(
x
=
waveform
,
**
cpu_feat_conf
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
# in inference period, the lengths is all one without padding
lengths
=
paddle
.
ones
([
1
])
feat
=
feature_normalize
(
feat
,
mean_norm
=
True
,
std_norm
=
False
,
convert_to_numpy
=
True
)
# model backbone network forward the feats and get the embedding
embedding
=
model
.
backbone
(
feat
,
lengths
).
squeeze
().
numpy
()
# (1, emb_size, 1) -> (emb_size)
# stage 5: do global norm with external mean and std
# todo
# np.save("audio-embedding", embedding)
return
embedding
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"gpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--seed"
,
default
=
0
,
type
=
int
,
help
=
"random seed for paddle, numpy and python random package"
)
parser
.
add_argument
(
"--load-checkpoint"
,
type
=
str
,
default
=
''
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--global-embedding-norm"
,
type
=
str
,
default
=
None
,
help
=
"Apply global normalization on speaker embeddings."
)
parser
.
add_argument
(
"--audio-path"
,
default
=
"./data/demo.wav"
,
type
=
str
,
help
=
"Single audio file path"
)
args
=
parser
.
parse_args
()
# yapf: enable
extract_audio_embedding
(
args
)
examples/voxceleb/sv0/local/speaker_verification_cosine.py
已删除
100644 → 0
浏览文件 @
4473405f
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
numpy
as
np
import
paddle
import
paddle.nn.functional
as
F
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
tqdm
import
tqdm
from
paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.metrics
import
compute_eer
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
def
pad_right_2d
(
x
,
target_length
,
axis
=-
1
,
mode
=
'constant'
,
**
kwargs
):
x
=
np
.
asarray
(
x
)
assert
len
(
x
.
shape
)
==
2
,
f
'Only 2D arrays supported, but got shape:
{
x
.
shape
}
'
w
=
target_length
-
x
.
shape
[
axis
]
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
axis
]
}
'
if
axis
==
0
:
pad_width
=
[[
0
,
w
],
[
0
,
0
]]
else
:
pad_width
=
[[
0
,
0
],
[
0
,
w
]]
return
np
.
pad
(
x
,
pad_width
,
mode
=
mode
,
**
kwargs
)
def
feature_normalize
(
batch
,
mean_norm
:
bool
=
True
,
std_norm
:
bool
=
True
):
ids
=
[
item
[
'id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
1
]
for
item
in
batch
])
feats
=
list
(
map
(
lambda
x
:
pad_right_2d
(
x
,
lengths
.
max
()),
[
item
[
'feat'
]
for
item
in
batch
]))
feats
=
np
.
stack
(
feats
)
# Features normalization if needed
for
i
in
range
(
len
(
feats
)):
feat
=
feats
[
i
][:,
:
lengths
[
i
]]
# Excluding pad values.
mean
=
feat
.
mean
(
axis
=-
1
,
keepdims
=
True
)
if
mean_norm
else
0
std
=
feat
.
std
(
axis
=-
1
,
keepdims
=
True
)
if
std_norm
else
1
feats
[
i
][:,
:
lengths
[
i
]]
=
(
feat
-
mean
)
/
std
assert
feats
[
i
][:,
lengths
[
i
]:].
sum
()
==
0
# Padding valus should all be 0.
# Converts into ratios.
lengths
=
(
lengths
/
lengths
.
max
()).
astype
(
np
.
float32
)
return
{
'ids'
:
ids
,
'feats'
:
feats
,
'lengths'
:
lengths
}
# feat configuration
cpu_feat_conf
=
{
'n_mels'
:
80
,
'window_size'
:
400
,
#ms
'hop_length'
:
160
,
#ms
}
def
main
(
args
):
# stage0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
args
.
seed
)
# stage1: build the dnn backbone model network
##"channels": [1024, 1024, 1024, 1024, 3072],
model_conf
=
{
"input_size"
:
80
,
"channels"
:
[
512
,
512
,
512
,
512
,
1536
],
"kernel_sizes"
:
[
5
,
3
,
3
,
3
,
1
],
"dilations"
:
[
1
,
2
,
3
,
4
,
1
],
"attention_channels"
:
128
,
"lin_neurons"
:
192
,
}
ecapa_tdnn
=
EcapaTdnn
(
**
model_conf
)
# stage2: build the speaker verification eval instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb1
.
num_speakers
)
# stage3: load the pre-trained model
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
# load model checkpoint to sid model
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# stage4: construct the enroll and test dataloader
enrol_ds
=
VoxCeleb1
(
subset
=
'enrol'
,
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
**
cpu_feat_conf
)
enrol_sampler
=
BatchSampler
(
enrol_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
# Shuffle to make embedding normalization more robust.
enrol_loader
=
DataLoader
(
enrol_ds
,
batch_sampler
=
enrol_sampler
,
collate_fn
=
lambda
x
:
feature_normalize
(
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
args
.
num_workers
,
return_list
=
True
,)
test_ds
=
VoxCeleb1
(
subset
=
'test'
,
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
**
cpu_feat_conf
)
test_sampler
=
BatchSampler
(
test_ds
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
)
test_loader
=
DataLoader
(
test_ds
,
batch_sampler
=
test_sampler
,
collate_fn
=
lambda
x
:
feature_normalize
(
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
args
.
num_workers
,
return_list
=
True
,)
# stage6: we must set the model to eval mode
model
.
eval
()
# stage7: global embedding norm to imporve the performance
if
args
.
global_embedding_norm
:
global_embedding_mean
=
None
global_embedding_std
=
None
mean_norm_flag
=
args
.
embedding_mean_norm
std_norm_flag
=
args
.
embedding_std_norm
batch_count
=
0
# stage8: Compute embeddings of audios in enrol and test dataset from model.
id2embedding
=
{}
# Run multi times to make embedding normalization more stable.
for
i
in
range
(
2
):
for
dl
in
[
enrol_loader
,
test_loader
]:
logger
.
info
(
f
'Loop
{
[
i
+
1
]
}
: Computing embeddings on
{
dl
.
dataset
.
subset
}
dataset'
)
with
paddle
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
dl
)):
# stage 8-1: extrac the audio embedding
ids
,
feats
,
lengths
=
batch
[
'ids'
],
batch
[
'feats'
],
batch
[
'lengths'
]
embeddings
=
model
.
backbone
(
feats
,
lengths
).
squeeze
(
-
1
).
numpy
()
# (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
if
args
.
global_embedding_norm
:
batch_count
+=
1
current_mean
=
embeddings
.
mean
(
axis
=
0
)
if
mean_norm_flag
else
0
current_std
=
embeddings
.
std
(
axis
=
0
)
if
std_norm_flag
else
1
# Update global mean and std.
if
global_embedding_mean
is
None
and
global_embedding_std
is
None
:
global_embedding_mean
,
global_embedding_std
=
current_mean
,
current_std
else
:
weight
=
1
/
batch_count
# Weight decay by batches.
global_embedding_mean
=
(
1
-
weight
)
*
global_embedding_mean
+
weight
*
current_mean
global_embedding_std
=
(
1
-
weight
)
*
global_embedding_std
+
weight
*
current_std
# Apply global embedding normalization.
embeddings
=
(
embeddings
-
global_embedding_mean
)
/
global_embedding_std
# Update embedding dict.
id2embedding
.
update
(
dict
(
zip
(
ids
,
embeddings
)))
# stage 9: Compute cosine scores.
labels
=
[]
enrol_ids
=
[]
test_ids
=
[]
with
open
(
VoxCeleb1
.
veri_test_file
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
label
,
enrol_id
,
test_id
=
line
.
strip
().
split
(
' '
)
labels
.
append
(
int
(
label
))
enrol_ids
.
append
(
enrol_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
test_ids
.
append
(
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
cos_sim_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
1
)
enrol_embeddings
,
test_embeddings
=
map
(
lambda
ids
:
paddle
.
to_tensor
(
np
.
asarray
([
id2embedding
[
id
]
for
id
in
ids
],
dtype
=
'float32'
)),
[
enrol_ids
,
test_ids
])
# (N, emb_size)
scores
=
cos_sim_func
(
enrol_embeddings
,
test_embeddings
)
EER
,
threshold
=
compute_eer
(
np
.
asarray
(
labels
),
scores
.
numpy
())
logger
.
info
(
f
'EER of verification test:
{
EER
*
100
:.
4
f
}
%, score threshold:
{
threshold
:.
5
f
}
'
)
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"gpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--seed"
,
default
=
0
,
type
=
int
,
help
=
"random seed for paddle, numpy and python random package"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"./data/"
,
type
=
str
,
help
=
"data directory"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
,
help
=
"Total examples' number in batch for extract the embedding."
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
"Number of workers in dataloader."
)
parser
.
add_argument
(
"--load-checkpoint"
,
type
=
str
,
default
=
''
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--global-embedding-norm"
,
type
=
bool
,
default
=
True
,
help
=
"Apply global normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-mean-norm"
,
type
=
bool
,
default
=
True
,
help
=
"Apply mean normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-std-norm"
,
type
=
bool
,
default
=
False
,
help
=
"Apply std normalization on speaker embeddings."
)
args
=
parser
.
parse_args
()
# yapf: enable
main
(
args
)
examples/voxceleb/sv0/local/train.py
已删除
100644 → 0
浏览文件 @
4473405f
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
numpy
as
np
import
paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddleaudio.features.core
import
melspectrogram
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.augment
import
build_augment_pipeline
from
paddlespeech.vector.io.augment
import
waveform_augment
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.io.batch
import
waveform_collate_fn
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.loss
import
AdditiveAngularMargin
from
paddlespeech.vector.modules.loss
import
LogSoftmaxWrapper
from
paddlespeech.vector.modules.lr
import
CyclicLRScheduler
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.utils.time
import
Timer
logger
=
Log
(
__name__
).
getlog
()
# feat configuration
cpu_feat_conf
=
{
'n_mels'
:
80
,
'window_size'
:
400
,
#ms
'hop_length'
:
160
,
#ms
}
def
main
(
args
):
# stage0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
paddle
.
distributed
.
init_parallel_env
()
nranks
=
paddle
.
distributed
.
get_world_size
()
local_rank
=
paddle
.
distributed
.
get_rank
()
# set the random seed, it is a must for multiprocess training
seed_everything
(
args
.
seed
)
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset
=
VoxCeleb1
(
'train'
,
target_dir
=
args
.
data_dir
)
dev_dataset
=
VoxCeleb1
(
'dev'
,
target_dir
=
args
.
data_dir
)
if
args
.
augment
:
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
else
:
augment_pipeline
=
[]
# stage3: build the dnn backbone model network
#"channels": [1024, 1024, 1024, 1024, 3072],
model_conf
=
{
"input_size"
:
80
,
"channels"
:
[
512
,
512
,
512
,
512
,
1536
],
"kernel_sizes"
:
[
5
,
3
,
3
,
3
,
1
],
"dilations"
:
[
1
,
2
,
3
,
4
,
1
],
"attention_channels"
:
128
,
"lin_neurons"
:
192
,
}
ecapa_tdnn
=
EcapaTdnn
(
**
model_conf
)
# stage4: build the speaker verification train instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb1
.
num_speakers
)
# stage5: build the optimizer, we now only construct the AdamW optimizer
lr_schedule
=
CyclicLRScheduler
(
base_lr
=
args
.
learning_rate
,
max_lr
=
1e-3
,
step_size
=
140000
//
nranks
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_schedule
,
parameters
=
model
.
parameters
())
# stage6: build the loss function, we now only support LogSoftmaxWrapper
criterion
=
LogSoftmaxWrapper
(
loss_fn
=
AdditiveAngularMargin
(
margin
=
0.2
,
scale
=
30
))
# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch
=
0
if
args
.
load_checkpoint
:
logger
.
info
(
"load the check point"
)
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
try
:
# load model checkpoint
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
# load optimizer checkpoint
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdopt'
))
optimizer
.
set_state_dict
(
state_dict
)
if
local_rank
==
0
:
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
except
FileExistsError
:
if
local_rank
==
0
:
logger
.
info
(
'Train from scratch.'
)
try
:
start_epoch
=
int
(
args
.
load_checkpoint
[
-
1
])
logger
.
info
(
f
'Restore training from epoch
{
start_epoch
}
.'
)
except
ValueError
:
pass
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
collate_fn
=
waveform_collate_fn
,
return_list
=
True
,
use_buffer_reader
=
True
,
)
# stage9: start to train
# we will comment the training process
steps_per_epoch
=
len
(
train_sampler
)
timer
=
Timer
(
steps_per_epoch
*
args
.
epochs
)
timer
.
start
()
for
epoch
in
range
(
start_epoch
+
1
,
args
.
epochs
+
1
):
# at the begining, model must set to train mode
model
.
train
()
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
# stage 9-1: batch data is audio sample points and speaker id label
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
# stage 9-2: audio sample augment method, which is done on the audio sample point
if
len
(
augment_pipeline
)
!=
0
:
waveforms
=
waveform_augment
(
waveforms
,
augment_pipeline
)
labels
=
paddle
.
concat
(
[
labels
for
i
in
range
(
len
(
augment_pipeline
)
+
1
)])
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
feat
=
melspectrogram
(
x
=
waveform
,
**
cpu_feat_conf
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
# stage 9-4: feature normalize, which help converge and imporve the performance
feats
=
feature_normalize
(
feats
,
mean_norm
=
True
,
std_norm
=
False
)
# Features normalization
# stage 9-5: model forward, such ecapa-tdnn, x-vector
logits
=
model
(
feats
)
# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
loss
=
criterion
(
logits
,
labels
)
# stage 9-7: update the gradient and clear the gradient cache
loss
.
backward
()
optimizer
.
step
()
if
isinstance
(
optimizer
.
_learning_rate
,
paddle
.
optimizer
.
lr
.
LRScheduler
):
optimizer
.
_learning_rate
.
step
()
optimizer
.
clear_grad
()
# stage 9-8: Calculate average loss per batch
avg_loss
+=
loss
.
numpy
()[
0
]
# stage 9-9: Calculate metrics, which is one-best accuracy
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if
(
batch_idx
+
1
)
%
args
.
log_freq
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
avg_loss
/=
args
.
log_freq
avg_acc
=
num_corrects
/
num_samples
print_msg
=
'Train Epoch={}/{}, Step={}/{}'
.
format
(
epoch
,
args
.
epochs
,
batch_idx
+
1
,
steps_per_epoch
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
logger
.
info
(
print_msg
)
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if
epoch
%
args
.
save_freq
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
if
local_rank
!=
0
:
paddle
.
distributed
.
barrier
(
)
# Wait for valid step in main process
continue
# Resume trainning on other process
# stage 9-12: construct the valid dataset dataloader
dev_sampler
=
BatchSampler
(
dev_dataset
,
batch_size
=
args
.
batch_size
//
4
,
shuffle
=
False
,
drop_last
=
False
)
dev_loader
=
DataLoader
(
dev_dataset
,
batch_sampler
=
dev_sampler
,
collate_fn
=
waveform_collate_fn
,
num_workers
=
args
.
num_workers
,
return_list
=
True
,
)
# set the model to eval mode
model
.
eval
()
num_corrects
=
0
num_samples
=
0
# stage 9-13: evaluation the valid dataset batch data
logger
.
info
(
'Evaluate on validation dataset'
)
with
paddle
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
dev_loader
):
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
feat
=
melspectrogram
(
x
=
waveform
,
**
cpu_feat_conf
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
feats
=
feature_normalize
(
feats
,
mean_norm
=
True
,
std_norm
=
False
)
logits
=
model
(
feats
)
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
print_msg
=
'[Evaluation result]'
print_msg
+=
' dev_acc={:.4f}'
.
format
(
num_corrects
/
num_samples
)
logger
.
info
(
print_msg
)
# stage 9-14: Save model parameters
save_dir
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'epoch_{}'
.
format
(
epoch
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdopt'
))
if
nranks
>
1
:
paddle
.
distributed
.
barrier
()
# Main process
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"cpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--seed"
,
default
=
0
,
type
=
int
,
help
=
"random seed for paddle, numpy and python random package"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"./data/"
,
type
=
str
,
help
=
"data directory"
)
parser
.
add_argument
(
"--learning-rate"
,
type
=
float
,
default
=
1e-8
,
help
=
"Learning rate used to train with warmup."
)
parser
.
add_argument
(
"--load-checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
,
help
=
"Total examples' number in batch for training."
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
"Number of workers in dataloader."
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
50
,
help
=
"Number of epoches for fine-tuning."
)
parser
.
add_argument
(
"--log-freq"
,
type
=
int
,
default
=
10
,
help
=
"Log the training infomation every n steps."
)
parser
.
add_argument
(
"--save-freq"
,
type
=
int
,
default
=
1
,
help
=
"Save checkpoint every n epoch."
)
parser
.
add_argument
(
"--checkpoint-dir"
,
type
=
str
,
default
=
'./checkpoint'
,
help
=
"Directory to save model checkpoints."
)
parser
.
add_argument
(
"--augment"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Apply audio augments."
)
args
=
parser
.
parse_args
()
# yapf: enable
main
(
args
)
examples/voxceleb/sv0/path.sh
浏览文件 @
0e87037f
...
@@ -9,3 +9,6 @@ export PYTHONIOENCODING=UTF-8
...
@@ -9,3 +9,6 @@ export PYTHONIOENCODING=UTF-8
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
export
PYTHONPATH
=
${
MAIN_ROOT
}
:
${
PYTHONPATH
}
export
LD_LIBRARY_PATH
=
${
LD_LIBRARY_PATH
}
:/usr/local/lib/
export
LD_LIBRARY_PATH
=
${
LD_LIBRARY_PATH
}
:/usr/local/lib/
MODEL
=
ecapa-tdnn
export
BIN_DIR
=
${
MAIN_ROOT
}
/paddlespeech/vector/exps/
${
MODEL
}
\ No newline at end of file
examples/voxceleb/sv0/run.sh
浏览文件 @
0e87037f
...
@@ -30,23 +30,21 @@ if [ $stage -le 1 ]; then
...
@@ -30,23 +30,21 @@ if [ $stage -le 1 ]; then
# stage 1: train the speaker identification model
# stage 1: train the speaker identification model
python3
\
python3
\
-m
paddle.distributed.launch
--gpus
=
0,1,2,3
\
-m
paddle.distributed.launch
--gpus
=
0,1,2,3
\
local
/train.py
--device
"gpu"
--checkpoint-dir
${
exp_dir
}
--augment
\
${
BIN_DIR
}
/train.py
--device
"gpu"
--checkpoint-dir
${
exp_dir
}
--augment
\
--save-freq
10
--data-dir
${
dir
}
--batch-size
64
--epochs
100
--save-freq
10
--data-dir
${
dir
}
--batch-size
64
--epochs
100
fi
fi
if
[
$stage
-le
2
]
;
then
if
[
$stage
-le
2
]
;
then
# stage 1: train the speaker identification model
# stage 1: get the speaker verification scores with cosine function
# you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset
python3
\
python3
\
local
/speaker_verification_cosine.py
\
${
BIN_DIR
}
/speaker_verification_cosine.py
\
--batch-size
4
--data-dir
${
dir
}
--load-checkpoint
${
exp_dir
}
/epoch_10/
--batch-size
4
--data-dir
${
dir
}
--load-checkpoint
${
exp_dir
}
/epoch_10/
fi
fi
if
[
$stage
-le
3
]
;
then
if
[
$stage
-le
3
]
;
then
# stage 1: train the speaker identification model
# stage 3: extract the audio embedding
# you can set the variable PPAUDIO_HOME to specifiy the downloaded the vox1 and vox2 dataset
python3
\
python3
\
local
/extract_speaker_embedding.py
\
${
BIN_DIR
}
/extract_speaker_embedding.py
\
--audio-path
"demo/csv/00001.wav"
--load-checkpoint
${
exp_dir
}
/epoch_60/
--audio-path
"demo/csv/00001.wav"
--load-checkpoint
${
exp_dir
}
/epoch_60/
fi
fi
...
...
paddleaudio/paddleaudio/datasets/__init__.py
浏览文件 @
0e87037f
...
@@ -15,3 +15,5 @@ from .esc50 import ESC50
...
@@ -15,3 +15,5 @@ from .esc50 import ESC50
from
.gtzan
import
GTZAN
from
.gtzan
import
GTZAN
from
.tess
import
TESS
from
.tess
import
TESS
from
.urban_sound
import
UrbanSound8K
from
.urban_sound
import
UrbanSound8K
from
.voxceleb
import
VoxCeleb1
from
.rirs_noises
import
OpenRIRNoise
paddleaudio/datasets/rirs_noises.py
→
paddleaudio/
paddleaudio/
datasets/rirs_noises.py
浏览文件 @
0e87037f
...
@@ -23,11 +23,11 @@ from typing import Tuple
...
@@ -23,11 +23,11 @@ from typing import Tuple
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
paddleaudio
.backends
import
load
as
load_audio
from
.
.backends
import
load
as
load_audio
from
paddleaudio.backends
import
save_wav
from
..backends
import
save
as
save_wav
from
paddleaudio.datasets
.dataset
import
feat_funcs
from
.dataset
import
feat_funcs
from
paddleaudio
.utils
import
DATA_HOME
from
.
.utils
import
DATA_HOME
from
paddleaudio
.utils
import
decompress
from
.
.utils
import
decompress
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.utils.download
import
download_and_decompress
from
paddlespeech.vector.utils.download
import
download_and_decompress
...
...
paddleaudio/datasets/voxceleb.py
→
paddleaudio/
paddleaudio/
datasets/voxceleb.py
浏览文件 @
0e87037f
...
@@ -25,10 +25,10 @@ from paddle.io import Dataset
...
@@ -25,10 +25,10 @@ from paddle.io import Dataset
from
pathos.multiprocessing
import
Pool
from
pathos.multiprocessing
import
Pool
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
paddleaudio.backends
import
load
as
load_audio
from
.dataset
import
feat_funcs
from
paddleaudio.datasets.dataset
import
feat_funcs
from
..backends
import
load
as
load_audio
from
paddleaudio
.utils
import
DATA_HOME
from
.
.utils
import
DATA_HOME
from
paddleaudio
.utils
import
decompress
from
.
.utils
import
decompress
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.utils.download
import
download_and_decompress
from
paddlespeech.vector.utils.download
import
download_and_decompress
from
utils.utility
import
download
from
utils.utility
import
download
...
@@ -83,7 +83,7 @@ class VoxCeleb1(Dataset):
...
@@ -83,7 +83,7 @@ class VoxCeleb1(Dataset):
meta_path
=
os
.
path
.
join
(
base_path
,
'meta'
)
meta_path
=
os
.
path
.
join
(
base_path
,
'meta'
)
veri_test_file
=
os
.
path
.
join
(
meta_path
,
'veri_test2.txt'
)
veri_test_file
=
os
.
path
.
join
(
meta_path
,
'veri_test2.txt'
)
csv_path
=
os
.
path
.
join
(
base_path
,
'csv'
)
csv_path
=
os
.
path
.
join
(
base_path
,
'csv'
)
subsets
=
[
'train'
,
'dev'
,
'enrol'
,
'test'
]
subsets
=
[
'train'
,
'dev'
,
'enrol
l
'
,
'test'
]
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -330,7 +330,7 @@ class VoxCeleb1(Dataset):
...
@@ -330,7 +330,7 @@ class VoxCeleb1(Dataset):
self
.
generate_csv
(
self
.
generate_csv
(
enroll_files
,
enroll_files
,
os
.
path
.
join
(
self
.
csv_path
,
'enrol.csv'
),
os
.
path
.
join
(
self
.
csv_path
,
'enrol
l
.csv'
),
split_chunks
=
False
)
split_chunks
=
False
)
self
.
generate_csv
(
self
.
generate_csv
(
test_files
,
test_files
,
...
...
paddleaudio/paddleaudio/metric/__init__.py
浏览文件 @
0e87037f
...
@@ -13,3 +13,4 @@
...
@@ -13,3 +13,4 @@
# limitations under the License.
# limitations under the License.
from
.dtw
import
dtw_distance
from
.dtw
import
dtw_distance
from
.mcd
import
mcd_distance
from
.mcd
import
mcd_distance
from
.eer
import
compute_eer
paddle
speech/vector/training/metrics
.py
→
paddle
audio/paddleaudio/metric/eer
.py
浏览文件 @
0e87037f
文件已移动
paddlespeech/vector/io/augment.py
浏览文件 @
0e87037f
...
@@ -20,8 +20,8 @@ import paddle
...
@@ -20,8 +20,8 @@ import paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
from
paddleaudio.
backends
import
load
as
load_audio
from
paddleaudio.
paddleaudio
import
load
as
load_audio
from
paddleaudio.datasets.rirs_noises
import
OpenRIRNoise
from
paddleaudio.
paddleaudio.
datasets.rirs_noises
import
OpenRIRNoise
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.signal_processing
import
compute_amplitude
from
paddlespeech.vector.io.signal_processing
import
compute_amplitude
from
paddlespeech.vector.io.signal_processing
import
convolve1d
from
paddlespeech.vector.io.signal_processing
import
convolve1d
...
...
paddlespeech/vector/io/batch.py
浏览文件 @
0e87037f
...
@@ -40,3 +40,41 @@ def feature_normalize(feats: paddle.Tensor,
...
@@ -40,3 +40,41 @@ def feature_normalize(feats: paddle.Tensor,
feats
=
(
feats
-
mean
)
/
std
feats
=
(
feats
-
mean
)
/
std
return
feats
return
feats
def
pad_right_2d
(
x
,
target_length
,
axis
=-
1
,
mode
=
'constant'
,
**
kwargs
):
x
=
np
.
asarray
(
x
)
assert
len
(
x
.
shape
)
==
2
,
f
'Only 2D arrays supported, but got shape:
{
x
.
shape
}
'
w
=
target_length
-
x
.
shape
[
axis
]
assert
w
>=
0
,
f
'Target length
{
target_length
}
is less than origin length
{
x
.
shape
[
axis
]
}
'
if
axis
==
0
:
pad_width
=
[[
0
,
w
],
[
0
,
0
]]
else
:
pad_width
=
[[
0
,
0
],
[
0
,
w
]]
return
np
.
pad
(
x
,
pad_width
,
mode
=
mode
,
**
kwargs
)
def
batch_feature_normalize
(
batch
,
mean_norm
:
bool
=
True
,
std_norm
:
bool
=
True
):
ids
=
[
item
[
'id'
]
for
item
in
batch
]
lengths
=
np
.
asarray
([
item
[
'feat'
].
shape
[
1
]
for
item
in
batch
])
feats
=
list
(
map
(
lambda
x
:
pad_right_2d
(
x
,
lengths
.
max
()),
[
item
[
'feat'
]
for
item
in
batch
]))
feats
=
np
.
stack
(
feats
)
# Features normalization if needed
for
i
in
range
(
len
(
feats
)):
feat
=
feats
[
i
][:,
:
lengths
[
i
]]
# Excluding pad values.
mean
=
feat
.
mean
(
axis
=-
1
,
keepdims
=
True
)
if
mean_norm
else
0
std
=
feat
.
std
(
axis
=-
1
,
keepdims
=
True
)
if
std_norm
else
1
feats
[
i
][:,
:
lengths
[
i
]]
=
(
feat
-
mean
)
/
std
assert
feats
[
i
][:,
lengths
[
i
]:].
sum
()
==
0
# Padding valus should all be 0.
# Converts into ratios.
lengths
=
(
lengths
/
lengths
.
max
()).
astype
(
np
.
float32
)
return
{
'ids'
:
ids
,
'feats'
:
feats
,
'lengths'
:
lengths
}
\ No newline at end of file
paddlespeech/vector/
modules/l
r.py
→
paddlespeech/vector/
training/schedule
r.py
浏览文件 @
0e87037f
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录