Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
584a2c0e
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
584a2c0e
编写于
3月 09, 2022
作者:
X
xiongxinlei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ecapa-tdnn config yaml file
上级
993d6783
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
656 addition
and
2 deletion
+656
-2
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+35
-0
examples/voxceleb/sv0/run.sh
examples/voxceleb/sv0/run.sh
+4
-2
paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py
...peech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py
+112
-0
paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py
...ech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py
+207
-0
paddlespeech/vector/exps/ecapa-tdnn/train.py
paddlespeech/vector/exps/ecapa-tdnn/train.py
+298
-0
未找到文件。
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
0 → 100644
浏览文件 @
584a2c0e
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
feature
:
n_mels
:
80
window_size
:
400
#25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length
:
160
#10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
# MODEL SETTING #
###########################################################
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
# if we want use another model, please choose another configuration yaml file
model
:
input_size
:
80
##"channels": [1024, 1024, 1024, 1024, 3072],
# "channels": [512, 512, 512, 512, 1536],
channels
:
[
512
,
512
,
512
,
512
,
1536
]
kernel_sizes
:
[
5
,
3
,
3
,
3
,
1
]
dilations
:
[
1
,
2
,
3
,
4
,
1
]
attention_channels
:
128
lin_neurons
:
192
###########################################
# Training #
###########################################
seed
:
0
epochs
:
10
batch_size
:
32
num_workers
:
2
save_freq
:
10
log_freq
:
10
learning_rate
:
1e-8
examples/voxceleb/sv0/run.sh
浏览文件 @
584a2c0e
...
...
@@ -31,20 +31,22 @@ if [ $stage -le 1 ]; then
python3
\
-m
paddle.distributed.launch
--gpus
=
0,1,2,3
\
${
BIN_DIR
}
/train.py
--device
"gpu"
--checkpoint-dir
${
exp_dir
}
--augment
\
--
save-freq
10
--data-dir
${
dir
}
--batch-size
64
--epochs
100
--
data-dir
${
dir
}
--config
conf/ecapa_tdnn.yaml
fi
if
[
$stage
-le
2
]
;
then
# stage 1: get the speaker verification scores with cosine function
python3
\
${
BIN_DIR
}
/speaker_verification_cosine.py
\
--batch-size
4
--data-dir
${
dir
}
--load-checkpoint
${
exp_dir
}
/epoch_10/
--config
conf/ecapa_tdnn.yaml
\
--data-dir
${
dir
}
--load-checkpoint
${
exp_dir
}
/epoch_10/
fi
if
[
$stage
-le
3
]
;
then
# stage 3: extract the audio embedding
python3
\
${
BIN_DIR
}
/extract_speaker_embedding.py
\
--config
conf/ecapa_tdnn.yaml
\
--audio-path
"demo/csv/00001.wav"
--load-checkpoint
${
exp_dir
}
/epoch_60/
fi
...
...
paddlespeech/vector/exps/ecapa-tdnn/extract_speaker_embedding.py
0 → 100644
浏览文件 @
584a2c0e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
numpy
as
np
import
paddle
from
yacs.config
import
CfgNode
from
paddleaudio.paddleaudio.backends
import
load
as
load_audio
from
paddleaudio.paddleaudio.compliance.librosa
import
melspectrogram
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
def
extract_audio_embedding
(
args
,
config
):
# stage 0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
config
.
seed
)
# stage 1: build the dnn backbone model network
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
# stage4: build the speaker verification train instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
1211
)
# stage 2: load the pre-trained model
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
# load model checkpoint to sid model
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# stage 3: we must set the model to eval mode
model
.
eval
()
# stage 4: read the audio data and extract the embedding
# wavform is one dimension numpy array
waveform
,
sr
=
load_audio
(
args
.
audio_path
)
# feat type is numpy array, whose shape is [dim, time]
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time]
feat
=
melspectrogram
(
x
=
waveform
,
**
config
.
feature
)
feat
=
paddle
.
to_tensor
(
feat
).
unsqueeze
(
0
)
# in inference period, the lengths is all one without padding
lengths
=
paddle
.
ones
([
1
])
feat
=
feature_normalize
(
feat
,
mean_norm
=
True
,
std_norm
=
False
,
convert_to_numpy
=
True
)
# model backbone network forward the feats and get the embedding
embedding
=
model
.
backbone
(
feat
,
lengths
).
squeeze
().
numpy
()
# (1, emb_size, 1) -> (emb_size)
# stage 5: do global norm with external mean and std
# todo
return
embedding
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"gpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--config"
,
default
=
None
,
type
=
str
,
help
=
"configuration file"
)
parser
.
add_argument
(
"--load-checkpoint"
,
type
=
str
,
default
=
''
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--global-embedding-norm"
,
type
=
str
,
default
=
None
,
help
=
"Apply global normalization on speaker embeddings."
)
parser
.
add_argument
(
"--audio-path"
,
default
=
"./data/demo.wav"
,
type
=
str
,
help
=
"Single audio file path"
)
args
=
parser
.
parse_args
()
# yapf: enable
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
freeze
()
print
(
config
)
extract_audio_embedding
(
args
,
config
)
paddlespeech/vector/exps/ecapa-tdnn/speaker_verification_cosine.py
0 → 100644
浏览文件 @
584a2c0e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
ast
import
os
import
numpy
as
np
import
paddle
from
yacs.config
import
CfgNode
import
paddle.nn.functional
as
F
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
tqdm
import
tqdm
from
paddleaudio.paddleaudio.datasets
import
VoxCeleb1
from
paddlespeech.s2t.utils.log
import
Log
from
paddleaudio.paddleaudio.metric
import
compute_eer
from
paddlespeech.vector.io.batch
import
batch_feature_normalize
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
logger
=
Log
(
__name__
).
getlog
()
def
main
(
args
,
config
):
# stage0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# set the random seed, it is a must for multiprocess training
seed_everything
(
config
.
seed
)
# stage1: build the dnn backbone model network
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
# stage2: build the speaker verification eval instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb1
.
num_speakers
)
# stage3: load the pre-trained model
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
# load model checkpoint to sid model
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# stage4: construct the enroll and test dataloader
enroll_dataset
=
VoxCeleb1
(
subset
=
'enroll'
,
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
**
config
.
feature
)
enroll_sampler
=
BatchSampler
(
enroll_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
)
# Shuffle to make embedding normalization more robust.
enrol_loader
=
DataLoader
(
enroll_dataset
,
batch_sampler
=
enroll_sampler
,
collate_fn
=
lambda
x
:
batch_feature_normalize
(
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
config
.
num_workers
,
return_list
=
True
,)
test_dataset
=
VoxCeleb1
(
subset
=
'test'
,
target_dir
=
args
.
data_dir
,
feat_type
=
'melspectrogram'
,
random_chunk
=
False
,
**
config
.
feature
)
test_sampler
=
BatchSampler
(
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
)
test_loader
=
DataLoader
(
test_dataset
,
batch_sampler
=
test_sampler
,
collate_fn
=
lambda
x
:
batch_feature_normalize
(
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
config
.
num_workers
,
return_list
=
True
,)
# stage6: we must set the model to eval mode
model
.
eval
()
# stage7: global embedding norm to imporve the performance
if
args
.
global_embedding_norm
:
global_embedding_mean
=
None
global_embedding_std
=
None
mean_norm_flag
=
args
.
embedding_mean_norm
std_norm_flag
=
args
.
embedding_std_norm
batch_count
=
0
# stage8: Compute embeddings of audios in enrol and test dataset from model.
id2embedding
=
{}
# Run multi times to make embedding normalization more stable.
for
i
in
range
(
2
):
for
dl
in
[
enrol_loader
,
test_loader
]:
logger
.
info
(
f
'Loop
{
[
i
+
1
]
}
: Computing embeddings on
{
dl
.
dataset
.
subset
}
dataset'
)
with
paddle
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
tqdm
(
dl
)):
# stage 8-1: extrac the audio embedding
ids
,
feats
,
lengths
=
batch
[
'ids'
],
batch
[
'feats'
],
batch
[
'lengths'
]
embeddings
=
model
.
backbone
(
feats
,
lengths
).
squeeze
(
-
1
).
numpy
()
# (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
if
args
.
global_embedding_norm
:
batch_count
+=
1
current_mean
=
embeddings
.
mean
(
axis
=
0
)
if
mean_norm_flag
else
0
current_std
=
embeddings
.
std
(
axis
=
0
)
if
std_norm_flag
else
1
# Update global mean and std.
if
global_embedding_mean
is
None
and
global_embedding_std
is
None
:
global_embedding_mean
,
global_embedding_std
=
current_mean
,
current_std
else
:
weight
=
1
/
batch_count
# Weight decay by batches.
global_embedding_mean
=
(
1
-
weight
)
*
global_embedding_mean
+
weight
*
current_mean
global_embedding_std
=
(
1
-
weight
)
*
global_embedding_std
+
weight
*
current_std
# Apply global embedding normalization.
embeddings
=
(
embeddings
-
global_embedding_mean
)
/
global_embedding_std
# Update embedding dict.
id2embedding
.
update
(
dict
(
zip
(
ids
,
embeddings
)))
# stage 9: Compute cosine scores.
labels
=
[]
enrol_ids
=
[]
test_ids
=
[]
with
open
(
VoxCeleb1
.
veri_test_file
,
'r'
)
as
f
:
for
line
in
f
.
readlines
():
label
,
enrol_id
,
test_id
=
line
.
strip
().
split
(
' '
)
labels
.
append
(
int
(
label
))
enrol_ids
.
append
(
enrol_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
test_ids
.
append
(
test_id
.
split
(
'.'
)[
0
].
replace
(
'/'
,
'-'
))
cos_sim_func
=
paddle
.
nn
.
CosineSimilarity
(
axis
=
1
)
enrol_embeddings
,
test_embeddings
=
map
(
lambda
ids
:
paddle
.
to_tensor
(
np
.
asarray
([
id2embedding
[
id
]
for
id
in
ids
],
dtype
=
'float32'
)),
[
enrol_ids
,
test_ids
])
# (N, emb_size)
scores
=
cos_sim_func
(
enrol_embeddings
,
test_embeddings
)
EER
,
threshold
=
compute_eer
(
np
.
asarray
(
labels
),
scores
.
numpy
())
logger
.
info
(
f
'EER of verification test:
{
EER
*
100
:.
4
f
}
%, score threshold:
{
threshold
:.
5
f
}
'
)
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"gpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--config"
,
default
=
None
,
type
=
str
,
help
=
"configuration file"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"./data/"
,
type
=
str
,
help
=
"data directory"
)
parser
.
add_argument
(
"--load-checkpoint"
,
type
=
str
,
default
=
''
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--global-embedding-norm"
,
type
=
bool
,
default
=
True
,
help
=
"Apply global normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-mean-norm"
,
type
=
bool
,
default
=
True
,
help
=
"Apply mean normalization on speaker embeddings."
)
parser
.
add_argument
(
"--embedding-std-norm"
,
type
=
bool
,
default
=
False
,
help
=
"Apply std normalization on speaker embeddings."
)
args
=
parser
.
parse_args
()
# yapf: enable
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
freeze
()
print
(
config
)
main
(
args
,
config
)
paddlespeech/vector/exps/ecapa-tdnn/train.py
0 → 100644
浏览文件 @
584a2c0e
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
numpy
as
np
import
paddle
from
paddle.io
import
BatchSampler
from
paddle.io
import
DataLoader
from
paddle.io
import
DistributedBatchSampler
from
yacs.config
import
CfgNode
from
paddleaudio.paddleaudio.compliance.librosa
import
melspectrogram
from
paddleaudio.paddleaudio.datasets.voxceleb
import
VoxCeleb1
from
paddlespeech.s2t.utils.log
import
Log
from
paddlespeech.vector.io.augment
import
build_augment_pipeline
from
paddlespeech.vector.io.augment
import
waveform_augment
from
paddlespeech.vector.io.batch
import
feature_normalize
from
paddlespeech.vector.io.batch
import
waveform_collate_fn
from
paddlespeech.vector.models.ecapa_tdnn
import
EcapaTdnn
from
paddlespeech.vector.modules.loss
import
AdditiveAngularMargin
from
paddlespeech.vector.modules.loss
import
LogSoftmaxWrapper
from
paddlespeech.vector.training.scheduler
import
CyclicLRScheduler
from
paddlespeech.vector.modules.sid_model
import
SpeakerIdetification
from
paddlespeech.vector.training.seeding
import
seed_everything
from
paddlespeech.vector.utils.time
import
Timer
logger
=
Log
(
__name__
).
getlog
()
def
main
(
args
,
config
):
# stage0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
paddle
.
distributed
.
init_parallel_env
()
nranks
=
paddle
.
distributed
.
get_world_size
()
local_rank
=
paddle
.
distributed
.
get_rank
()
# set the random seed, it is a must for multiprocess training
seed_everything
(
config
.
seed
)
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset
=
VoxCeleb1
(
'train'
,
target_dir
=
args
.
data_dir
)
dev_dataset
=
VoxCeleb1
(
'dev'
,
target_dir
=
args
.
data_dir
)
if
args
.
augment
:
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
else
:
augment_pipeline
=
[]
# stage3: build the dnn backbone model network
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
# stage4: build the speaker verification train instance with backbone model
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
VoxCeleb1
.
num_speakers
)
# stage5: build the optimizer, we now only construct the AdamW optimizer
lr_schedule
=
CyclicLRScheduler
(
base_lr
=
config
.
learning_rate
,
max_lr
=
1e-3
,
step_size
=
140000
//
nranks
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_schedule
,
parameters
=
model
.
parameters
())
# stage6: build the loss function, we now only support LogSoftmaxWrapper
criterion
=
LogSoftmaxWrapper
(
loss_fn
=
AdditiveAngularMargin
(
margin
=
0.2
,
scale
=
30
))
# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch
=
0
if
args
.
load_checkpoint
:
logger
.
info
(
"load the check point"
)
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
try
:
# load model checkpoint
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdparams'
))
model
.
set_state_dict
(
state_dict
)
# load optimizer checkpoint
state_dict
=
paddle
.
load
(
os
.
path
.
join
(
args
.
load_checkpoint
,
'model.pdopt'
))
optimizer
.
set_state_dict
(
state_dict
)
if
local_rank
==
0
:
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
except
FileExistsError
:
if
local_rank
==
0
:
logger
.
info
(
'Train from scratch.'
)
try
:
start_epoch
=
int
(
args
.
load_checkpoint
[
-
1
])
logger
.
info
(
f
'Restore training from epoch
{
start_epoch
}
.'
)
except
ValueError
:
pass
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler
=
DistributedBatchSampler
(
train_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
True
,
drop_last
=
False
)
train_loader
=
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
config
.
num_workers
,
collate_fn
=
waveform_collate_fn
,
return_list
=
True
,
use_buffer_reader
=
True
,
)
# stage9: start to train
# we will comment the training process
steps_per_epoch
=
len
(
train_sampler
)
timer
=
Timer
(
steps_per_epoch
*
config
.
epochs
)
timer
.
start
()
for
epoch
in
range
(
start_epoch
+
1
,
config
.
epochs
+
1
):
# at the begining, model must set to train mode
model
.
train
()
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
for
batch_idx
,
batch
in
enumerate
(
train_loader
):
# stage 9-1: batch data is audio sample points and speaker id label
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
# stage 9-2: audio sample augment method, which is done on the audio sample point
if
len
(
augment_pipeline
)
!=
0
:
waveforms
=
waveform_augment
(
waveforms
,
augment_pipeline
)
labels
=
paddle
.
concat
(
[
labels
for
i
in
range
(
len
(
augment_pipeline
)
+
1
)])
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
feat
=
melspectrogram
(
x
=
waveform
,
**
config
.
feature
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
# stage 9-4: feature normalize, which help converge and imporve the performance
feats
=
feature_normalize
(
feats
,
mean_norm
=
True
,
std_norm
=
False
)
# Features normalization
# stage 9-5: model forward, such ecapa-tdnn, x-vector
logits
=
model
(
feats
)
# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
loss
=
criterion
(
logits
,
labels
)
# stage 9-7: update the gradient and clear the gradient cache
loss
.
backward
()
optimizer
.
step
()
if
isinstance
(
optimizer
.
_learning_rate
,
paddle
.
optimizer
.
lr
.
LRScheduler
):
optimizer
.
_learning_rate
.
step
()
optimizer
.
clear_grad
()
# stage 9-8: Calculate average loss per batch
avg_loss
+=
loss
.
numpy
()[
0
]
# stage 9-9: Calculate metrics, which is one-best accuracy
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
timer
.
count
()
# step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if
(
batch_idx
+
1
)
%
config
.
log_freq
==
0
and
local_rank
==
0
:
lr
=
optimizer
.
get_lr
()
avg_loss
/=
config
.
log_freq
avg_acc
=
num_corrects
/
num_samples
print_msg
=
'Train Epoch={}/{}, Step={}/{}'
.
format
(
epoch
,
config
.
epochs
,
batch_idx
+
1
,
steps_per_epoch
)
print_msg
+=
' loss={:.4f}'
.
format
(
avg_loss
)
print_msg
+=
' acc={:.4f}'
.
format
(
avg_acc
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} | ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
eta
)
logger
.
info
(
print_msg
)
avg_loss
=
0
num_corrects
=
0
num_samples
=
0
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if
epoch
%
config
.
save_freq
==
0
and
batch_idx
+
1
==
steps_per_epoch
:
if
local_rank
!=
0
:
paddle
.
distributed
.
barrier
(
)
# Wait for valid step in main process
continue
# Resume trainning on other process
# stage 9-12: construct the valid dataset dataloader
dev_sampler
=
BatchSampler
(
dev_dataset
,
batch_size
=
config
.
batch_size
//
4
,
shuffle
=
False
,
drop_last
=
False
)
dev_loader
=
DataLoader
(
dev_dataset
,
batch_sampler
=
dev_sampler
,
collate_fn
=
waveform_collate_fn
,
num_workers
=
config
.
num_workers
,
return_list
=
True
,
)
# set the model to eval mode
model
.
eval
()
num_corrects
=
0
num_samples
=
0
# stage 9-13: evaluation the valid dataset batch data
logger
.
info
(
'Evaluate on validation dataset'
)
with
paddle
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
dev_loader
):
waveforms
,
labels
=
batch
[
'waveforms'
],
batch
[
'labels'
]
feats
=
[]
for
waveform
in
waveforms
.
numpy
():
# feat = melspectrogram(x=waveform, **cpu_feat_conf)
feat
=
melspectrogram
(
x
=
waveform
,
**
config
.
feature
)
feats
.
append
(
feat
)
feats
=
paddle
.
to_tensor
(
np
.
asarray
(
feats
))
feats
=
feature_normalize
(
feats
,
mean_norm
=
True
,
std_norm
=
False
)
logits
=
model
(
feats
)
preds
=
paddle
.
argmax
(
logits
,
axis
=
1
)
num_corrects
+=
(
preds
==
labels
).
numpy
().
sum
()
num_samples
+=
feats
.
shape
[
0
]
print_msg
=
'[Evaluation result]'
print_msg
+=
' dev_acc={:.4f}'
.
format
(
num_corrects
/
num_samples
)
logger
.
info
(
print_msg
)
# stage 9-14: Save model parameters
save_dir
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'epoch_{}'
.
format
(
epoch
))
logger
.
info
(
'Saving model checkpoint to {}'
.
format
(
save_dir
))
paddle
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdparams'
))
paddle
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
save_dir
,
'model.pdopt'
))
if
nranks
>
1
:
paddle
.
distributed
.
barrier
()
# Main process
if
__name__
==
"__main__"
:
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
parser
.
add_argument
(
'--device'
,
choices
=
[
'cpu'
,
'gpu'
],
default
=
"cpu"
,
help
=
"Select which device to train model, defaults to gpu."
)
parser
.
add_argument
(
"--config"
,
default
=
None
,
type
=
str
,
help
=
"configuration file"
)
parser
.
add_argument
(
"--data-dir"
,
default
=
"./data/"
,
type
=
str
,
help
=
"data directory"
)
parser
.
add_argument
(
"--load-checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Directory to load model checkpoint to contiune trainning."
)
parser
.
add_argument
(
"--checkpoint-dir"
,
type
=
str
,
default
=
'./checkpoint'
,
help
=
"Directory to save model checkpoints."
)
parser
.
add_argument
(
"--augment"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Apply audio augments."
)
args
=
parser
.
parse_args
()
# yapf: enable
# https://yaml.org/type/float.html
config
=
CfgNode
(
new_allowed
=
True
)
if
args
.
config
:
config
.
merge_from_file
(
args
.
config
)
config
.
freeze
()
print
(
config
)
main
(
args
,
config
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录