Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
7988d0ff
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7988d0ff
编写于
4月 12, 2022
作者:
Honei_X
提交者:
GitHub
4月 12, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1690 from Honei/v0.3
[vec]add vector necessary note, test=doc
上级
fdc189a3
d1935d85
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
78 addition
and
25 deletion
+78
-25
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
+9
-1
examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml
examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml
+8
-1
paddlespeech/vector/exps/ecapa_tdnn/test.py
paddlespeech/vector/exps/ecapa_tdnn/test.py
+41
-16
paddlespeech/vector/exps/ecapa_tdnn/train.py
paddlespeech/vector/exps/ecapa_tdnn/train.py
+15
-5
paddlespeech/vector/io/embedding_norm.py
paddlespeech/vector/io/embedding_norm.py
+5
-2
未找到文件。
examples/voxceleb/sv0/conf/ecapa_tdnn.yaml
浏览文件 @
7988d0ff
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
augment
:
True
augment
:
True
batch_size
:
32
batch_size
:
32
num_workers
:
2
num_workers
:
2
num_speakers
:
1211
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
num_speakers
:
7205
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle
:
True
shuffle
:
True
skip_prep
:
False
skip_prep
:
False
split_ratio
:
0.9
split_ratio
:
0.9
...
@@ -42,8 +42,16 @@ epochs: 10
...
@@ -42,8 +42,16 @@ epochs: 10
save_interval
:
10
save_interval
:
10
log_interval
:
10
log_interval
:
10
learning_rate
:
1e-8
learning_rate
:
1e-8
max_lr
:
1e-3
step_size
:
140000
###########################################
# loss #
###########################################
margin
:
0.2
scale
:
30
###########################################
###########################################
# Testing #
# Testing #
###########################################
###########################################
...
...
examples/voxceleb/sv0/conf/ecapa_tdnn_small.yaml
浏览文件 @
7988d0ff
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Data #
# Data #
###########################################
###########################################
augment
:
True
augment
:
True
batch_size
:
16
batch_size
:
32
num_workers
:
2
num_workers
:
2
num_speakers
:
1211
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
num_speakers
:
1211
# 1211 vox1, 5994 vox2, 7205 vox1+2, test speakers: 41
shuffle
:
True
shuffle
:
True
...
@@ -42,7 +42,14 @@ epochs: 100
...
@@ -42,7 +42,14 @@ epochs: 100
save_interval
:
10
save_interval
:
10
log_interval
:
10
log_interval
:
10
learning_rate
:
1e-8
learning_rate
:
1e-8
max_lr
:
1e-3
step_size
:
140000
###########################################
# loss #
###########################################
margin
:
0.2
scale
:
30
###########################################
###########################################
# Testing #
# Testing #
...
...
paddlespeech/vector/exps/ecapa_tdnn/test.py
浏览文件 @
7988d0ff
...
@@ -38,10 +38,10 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
...
@@ -38,10 +38,10 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
"""compute the dataset embeddings
"""compute the dataset embeddings
Args:
Args:
data_loader (
_type_): _description_
data_loader (
paddle.io.Dataloader): the dataset loader to be compute the embedding
model (
_type_): _description_
model (
paddle.nn.Layer): the speaker verification model
mean_var_norm_emb
(_type_): _description_
mean_var_norm_emb
: compute the embedding mean and std norm
config (
_type_): _description_
config (
yacs.config.CfgNode): the yaml config
"""
"""
logger
.
info
(
logger
.
info
(
f
'Computing embeddings on
{
data_loader
.
dataset
.
csv_path
}
dataset'
)
f
'Computing embeddings on
{
data_loader
.
dataset
.
csv_path
}
dataset'
)
...
@@ -65,6 +65,17 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
...
@@ -65,6 +65,17 @@ def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
def
compute_verification_scores
(
id2embedding
,
train_cohort
,
config
):
def
compute_verification_scores
(
id2embedding
,
train_cohort
,
config
):
"""Compute the verification trial scores
Args:
id2embedding (dict): the utterance embedding
train_cohort (paddle.tensor): the cohort dataset embedding
config (yacs.config.CfgNode): the yaml config
Returns:
the scores and the trial labels,
1 refers the target and 0 refers the nontarget in labels
"""
labels
=
[]
labels
=
[]
enroll_ids
=
[]
enroll_ids
=
[]
test_ids
=
[]
test_ids
=
[]
...
@@ -119,20 +130,32 @@ def compute_verification_scores(id2embedding, train_cohort, config):
...
@@ -119,20 +130,32 @@ def compute_verification_scores(id2embedding, train_cohort, config):
def
main
(
args
,
config
):
def
main
(
args
,
config
):
"""The main process for test the speaker verification model
Args:
args (argparse.Namespace): the command line args namespace
config (yacs.config.CfgNode): the yaml config
"""
# stage0: set the training device, cpu or gpu
# stage0: set the training device, cpu or gpu
# if set the gpu, paddlespeech will select a gpu according the env CUDA_VISIBLE_DEVICES
paddle
.
set_device
(
args
.
device
)
paddle
.
set_device
(
args
.
device
)
# set the random seed, it is
a must
for multiprocess training
# set the random seed, it is
the necessary measures
for multiprocess training
seed_everything
(
config
.
seed
)
seed_everything
(
config
.
seed
)
# stage1: build the dnn backbone model network
# stage1: build the dnn backbone model network
# we will extract the audio embedding from the backbone model
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
# stage2: build the speaker verification eval instance with backbone model
# stage2: build the speaker verification eval instance with backbone model
# because the checkpoint dict name has the SpeakerIdetification prefix
# so we need to create the SpeakerIdetification instance
# but we acutally use the backbone model to extact the audio embedding
model
=
SpeakerIdetification
(
model
=
SpeakerIdetification
(
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
backbone
=
ecapa_tdnn
,
num_class
=
config
.
num_speakers
)
# stage3: load the pre-trained model
# stage3: load the pre-trained model
#
we get the last model from the epoch and save_interval
#
generally, we get the last model from the epoch
args
.
load_checkpoint
=
os
.
path
.
abspath
(
args
.
load_checkpoint
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
os
.
path
.
expanduser
(
args
.
load_checkpoint
))
...
@@ -143,7 +166,8 @@ def main(args, config):
...
@@ -143,7 +166,8 @@ def main(args, config):
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
logger
.
info
(
f
'Checkpoint loaded from
{
args
.
load_checkpoint
}
'
)
# stage4: construct the enroll and test dataloader
# stage4: construct the enroll and test dataloader
# Now, wo think the enroll dataset is in the {args.data_dir}/vox/csv/enroll.csv,
# and the test dataset is in the {args.data_dir}/vox/csv/test.csv
enroll_dataset
=
CSVDataset
(
enroll_dataset
=
CSVDataset
(
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/enroll.csv"
),
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/enroll.csv"
),
feat_type
=
'melspectrogram'
,
feat_type
=
'melspectrogram'
,
...
@@ -152,14 +176,14 @@ def main(args, config):
...
@@ -152,14 +176,14 @@ def main(args, config):
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_size
)
hop_length
=
config
.
hop_size
)
enroll_sampler
=
BatchSampler
(
enroll_sampler
=
BatchSampler
(
enroll_dataset
,
batch_size
=
config
.
batch_size
,
enroll_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
False
)
shuffle
=
False
)
# Shuffle to make embedding normalization more robust.
enroll_loader
=
DataLoader
(
enroll_dataset
,
enroll_loader
=
DataLoader
(
enroll_dataset
,
batch_sampler
=
enroll_sampler
,
batch_sampler
=
enroll_sampler
,
collate_fn
=
lambda
x
:
batch_feature_normalize
(
collate_fn
=
lambda
x
:
batch_feature_normalize
(
x
,
mean_norm
=
True
,
std_norm
=
False
),
x
,
mean_norm
=
True
,
std_norm
=
False
),
num_workers
=
config
.
num_workers
,
num_workers
=
config
.
num_workers
,
return_list
=
True
,)
return_list
=
True
,)
test_dataset
=
CSVDataset
(
test_dataset
=
CSVDataset
(
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/test.csv"
),
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/test.csv"
),
feat_type
=
'melspectrogram'
,
feat_type
=
'melspectrogram'
,
...
@@ -167,7 +191,6 @@ def main(args, config):
...
@@ -167,7 +191,6 @@ def main(args, config):
n_mels
=
config
.
n_mels
,
n_mels
=
config
.
n_mels
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
hop_length
=
config
.
hop_size
)
hop_length
=
config
.
hop_size
)
test_sampler
=
BatchSampler
(
test_sampler
=
BatchSampler
(
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
False
)
test_dataset
,
batch_size
=
config
.
batch_size
,
shuffle
=
False
)
test_loader
=
DataLoader
(
test_dataset
,
test_loader
=
DataLoader
(
test_dataset
,
...
@@ -180,16 +203,17 @@ def main(args, config):
...
@@ -180,16 +203,17 @@ def main(args, config):
model
.
eval
()
model
.
eval
()
# stage6: global embedding norm to imporve the performance
# stage6: global embedding norm to imporve the performance
# and we create the InputNormalization instance to process the embedding mean and std norm
logger
.
info
(
f
"global embedding norm:
{
config
.
global_embedding_norm
}
"
)
logger
.
info
(
f
"global embedding norm:
{
config
.
global_embedding_norm
}
"
)
# stage7: Compute embeddings of audios in enrol and test dataset from model.
if
config
.
global_embedding_norm
:
if
config
.
global_embedding_norm
:
mean_var_norm_emb
=
InputNormalization
(
mean_var_norm_emb
=
InputNormalization
(
norm_type
=
"global"
,
norm_type
=
"global"
,
mean_norm
=
config
.
embedding_mean_norm
,
mean_norm
=
config
.
embedding_mean_norm
,
std_norm
=
config
.
embedding_std_norm
)
std_norm
=
config
.
embedding_std_norm
)
# stage 7: score norm need the imposters dataset
# we select the train dataset as the idea imposters dataset
# and we select the config.n_train_snts utterance to as the final imposters dataset
if
"score_norm"
in
config
:
if
"score_norm"
in
config
:
logger
.
info
(
f
"we will do score norm:
{
config
.
score_norm
}
"
)
logger
.
info
(
f
"we will do score norm:
{
config
.
score_norm
}
"
)
train_dataset
=
CSVDataset
(
train_dataset
=
CSVDataset
(
...
@@ -209,6 +233,7 @@ def main(args, config):
...
@@ -209,6 +233,7 @@ def main(args, config):
num_workers
=
config
.
num_workers
,
num_workers
=
config
.
num_workers
,
return_list
=
True
,)
return_list
=
True
,)
# stage 8: Compute embeddings of audios in enrol and test dataset from model.
id2embedding
=
{}
id2embedding
=
{}
# Run multi times to make embedding normalization more stable.
# Run multi times to make embedding normalization more stable.
logger
.
info
(
"First loop for enroll and test dataset"
)
logger
.
info
(
"First loop for enroll and test dataset"
)
...
@@ -225,7 +250,7 @@ def main(args, config):
...
@@ -225,7 +250,7 @@ def main(args, config):
mean_var_norm_emb
.
save
(
mean_var_norm_emb
.
save
(
os
.
path
.
join
(
args
.
load_checkpoint
,
"mean_var_norm_emb"
))
os
.
path
.
join
(
args
.
load_checkpoint
,
"mean_var_norm_emb"
))
# stage
8
: Compute cosine scores.
# stage
9
: Compute cosine scores.
train_cohort
=
None
train_cohort
=
None
if
"score_norm"
in
config
:
if
"score_norm"
in
config
:
train_embeddings
=
{}
train_embeddings
=
{}
...
@@ -234,11 +259,11 @@ def main(args, config):
...
@@ -234,11 +259,11 @@ def main(args, config):
train_embeddings
)
train_embeddings
)
train_cohort
=
paddle
.
stack
(
list
(
train_embeddings
.
values
()))
train_cohort
=
paddle
.
stack
(
list
(
train_embeddings
.
values
()))
# compute the scores
#
stage 10:
compute the scores
scores
,
labels
=
compute_verification_scores
(
id2embedding
,
train_cohort
,
scores
,
labels
=
compute_verification_scores
(
id2embedding
,
train_cohort
,
config
)
config
)
# compute the EER and threshold
#
stage 11:
compute the EER and threshold
scores
=
paddle
.
to_tensor
(
scores
)
scores
=
paddle
.
to_tensor
(
scores
)
EER
,
threshold
=
compute_eer
(
np
.
asarray
(
labels
),
scores
.
numpy
())
EER
,
threshold
=
compute_eer
(
np
.
asarray
(
labels
),
scores
.
numpy
())
logger
.
info
(
logger
.
info
(
...
...
paddlespeech/vector/exps/ecapa_tdnn/train.py
浏览文件 @
7988d0ff
...
@@ -42,6 +42,12 @@ logger = Log(__name__).getlog()
...
@@ -42,6 +42,12 @@ logger = Log(__name__).getlog()
def
main
(
args
,
config
):
def
main
(
args
,
config
):
"""The main process for test the speaker verification model
Args:
args (argparse.Namespace): the command line args namespace
config (yacs.config.CfgNode): the yaml config
"""
# stage0: set the training device, cpu or gpu
# stage0: set the training device, cpu or gpu
paddle
.
set_device
(
args
.
device
)
paddle
.
set_device
(
args
.
device
)
...
@@ -49,11 +55,11 @@ def main(args, config):
...
@@ -49,11 +55,11 @@ def main(args, config):
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
nranks
=
paddle
.
distributed
.
get_world_size
()
nranks
=
paddle
.
distributed
.
get_world_size
()
local_rank
=
paddle
.
distributed
.
get_rank
()
local_rank
=
paddle
.
distributed
.
get_rank
()
# set the random seed, it is
a must
for multiprocess training
# set the random seed, it is
the necessary measures
for multiprocess training
seed_everything
(
config
.
seed
)
seed_everything
(
config
.
seed
)
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some
cmd must do in rank==0, so wo will refactor the data prepare code
# note: some
operations must be done in rank==0
train_dataset
=
CSVDataset
(
train_dataset
=
CSVDataset
(
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/train.csv"
),
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/train.csv"
),
label2id_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/label2id.txt"
))
label2id_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/label2id.txt"
))
...
@@ -61,12 +67,14 @@ def main(args, config):
...
@@ -61,12 +67,14 @@ def main(args, config):
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/dev.csv"
),
csv_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/csv/dev.csv"
),
label2id_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/label2id.txt"
))
label2id_path
=
os
.
path
.
join
(
args
.
data_dir
,
"vox/meta/label2id.txt"
))
# we will build the augment pipeline process list
if
config
.
augment
:
if
config
.
augment
:
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
augment_pipeline
=
build_augment_pipeline
(
target_dir
=
args
.
data_dir
)
else
:
else
:
augment_pipeline
=
[]
augment_pipeline
=
[]
# stage3: build the dnn backbone model network
# stage3: build the dnn backbone model network
# in speaker verification period, we use the backbone mode to extract the audio embedding
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
ecapa_tdnn
=
EcapaTdnn
(
**
config
.
model
)
# stage4: build the speaker verification train instance with backbone model
# stage4: build the speaker verification train instance with backbone model
...
@@ -77,13 +85,15 @@ def main(args, config):
...
@@ -77,13 +85,15 @@ def main(args, config):
# 140000 is single gpu steps
# 140000 is single gpu steps
# so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler
# so, in multi-gpu mode, wo reduce the step_size to 140000//nranks to enable CyclicLRScheduler
lr_schedule
=
CyclicLRScheduler
(
lr_schedule
=
CyclicLRScheduler
(
base_lr
=
config
.
learning_rate
,
max_lr
=
1e-3
,
step_size
=
140000
//
nranks
)
base_lr
=
config
.
learning_rate
,
max_lr
=
config
.
max_lr
,
step_size
=
config
.
step_size
//
nranks
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_schedule
,
parameters
=
model
.
parameters
())
learning_rate
=
lr_schedule
,
parameters
=
model
.
parameters
())
# stage6: build the loss function, we now only support LogSoftmaxWrapper
# stage6: build the loss function, we now only support LogSoftmaxWrapper
criterion
=
LogSoftmaxWrapper
(
criterion
=
LogSoftmaxWrapper
(
loss_fn
=
AdditiveAngularMargin
(
margin
=
0.2
,
scale
=
30
))
loss_fn
=
AdditiveAngularMargin
(
margin
=
config
.
margin
,
scale
=
config
.
scale
))
# stage7: confirm training start epoch
# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
# if pre-trained model exists, start epoch confirmed by the pre-trained model
...
@@ -225,7 +235,7 @@ def main(args, config):
...
@@ -225,7 +235,7 @@ def main(args, config):
print_msg
+=
' avg_train_cost: {:.5f} sec,'
.
format
(
print_msg
+=
' avg_train_cost: {:.5f} sec,'
.
format
(
train_run_cost
/
config
.
log_interval
)
train_run_cost
/
config
.
log_interval
)
print_msg
+=
' lr={:.4E} step/sec={:.2f} ips
:
{:.5f}| ETA {}'
.
format
(
print_msg
+=
' lr={:.4E} step/sec={:.2f} ips
=
{:.5f}| ETA {}'
.
format
(
lr
,
timer
.
timing
,
timer
.
ips
,
timer
.
eta
)
lr
,
timer
.
timing
,
timer
.
ips
,
timer
.
eta
)
logger
.
info
(
print_msg
)
logger
.
info
(
print_msg
)
...
...
paddlespeech/vector/io/embedding_norm.py
浏览文件 @
7988d0ff
...
@@ -57,14 +57,14 @@ class InputNormalization:
...
@@ -57,14 +57,14 @@ class InputNormalization:
lengths (paddle.Tensor): A batch of tensors containing the relative length of each
lengths (paddle.Tensor): A batch of tensors containing the relative length of each
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
computing stats on zero-padded steps.
computing stats on zero-padded steps.
spk_ids (
_type_
, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
spk_ids (
paddle.Tensor
, optional): tensor containing the ids of each speaker (e.g, [0 10 6]).
It is used to perform per-speaker normalization when
It is used to perform per-speaker normalization when
norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32").
Returns:
Returns:
paddle.Tensor: The normalized feature or embedding
paddle.Tensor: The normalized feature or embedding
"""
"""
N_batches
=
x
.
shape
[
0
]
N_batches
=
x
.
shape
[
0
]
# print(f"x shape: {x.shape[1]}")
current_means
=
[]
current_means
=
[]
current_stds
=
[]
current_stds
=
[]
...
@@ -75,6 +75,9 @@ class InputNormalization:
...
@@ -75,6 +75,9 @@ class InputNormalization:
actual_size
=
paddle
.
round
(
lengths
[
snt_id
]
*
actual_size
=
paddle
.
round
(
lengths
[
snt_id
]
*
x
.
shape
[
1
]).
astype
(
"int32"
)
x
.
shape
[
1
]).
astype
(
"int32"
)
# computing actual time data statistics
# computing actual time data statistics
# we extract the snt_id embedding from the x
# and the target paddle.Tensor will reduce an 0-axis
# so we need unsqueeze operation to recover the all axis
current_mean
,
current_std
=
self
.
_compute_current_stats
(
current_mean
,
current_std
=
self
.
_compute_current_stats
(
x
[
snt_id
,
0
:
actual_size
,
...].
unsqueeze
(
0
))
x
[
snt_id
,
0
:
actual_size
,
...].
unsqueeze
(
0
))
current_means
.
append
(
current_mean
)
current_means
.
append
(
current_mean
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录