提交 584a2c0e 编写于 作者: X xiongxinlei

add ecapa-tdnn config yaml file

上级 993d6783
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# currently, we only support fbank
feature:
n_mels: 80
window_size: 400 #25ms, sample rate 16000, 25 * 16000 / 1000 = 400
hop_length: 160 #10ms, sample rate 16000, 10 * 16000 / 1000 = 160
###########################################################
# MODEL SETTING #
###########################################################
# currently, we only support ecapa-tdnn in the ecapa_tdnn.yaml
# if we want use another model, please choose another configuration yaml file
model:
input_size: 80
##"channels": [1024, 1024, 1024, 1024, 3072],
# "channels": [512, 512, 512, 512, 1536],
channels: [512, 512, 512, 512, 1536]
kernel_sizes: [5, 3, 3, 3, 1]
dilations: [1, 2, 3, 4, 1]
attention_channels: 128
lin_neurons: 192
###########################################
# Training #
###########################################
seed: 0
epochs: 10
batch_size: 32
num_workers: 2
save_freq: 10
log_freq: 10
learning_rate: 1e-8
......@@ -31,20 +31,22 @@ if [ $stage -le 1 ]; then
python3 \
-m paddle.distributed.launch --gpus=0,1,2,3 \
${BIN_DIR}/train.py --device "gpu" --checkpoint-dir ${exp_dir} --augment \
--save-freq 10 --data-dir ${dir} --batch-size 64 --epochs 100
--data-dir ${dir} --config conf/ecapa_tdnn.yaml
fi
if [ $stage -le 2 ]; then
# stage 1: get the speaker verification scores with cosine function
python3 \
${BIN_DIR}/speaker_verification_cosine.py\
--batch-size 4 --data-dir ${dir} --load-checkpoint ${exp_dir}/epoch_10/
--config conf/ecapa_tdnn.yaml \
--data-dir ${dir} --load-checkpoint ${exp_dir}/epoch_10/
fi
if [ $stage -le 3 ]; then
# stage 3: extract the audio embedding
python3 \
${BIN_DIR}/extract_speaker_embedding.py\
--config conf/ecapa_tdnn.yaml \
--audio-path "demo/csv/00001.wav" --load-checkpoint ${exp_dir}/epoch_60/
fi
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import numpy as np
import paddle
from yacs.config import CfgNode
from paddleaudio.paddleaudio.backends import load as load_audio
from paddleaudio.paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def extract_audio_embedding(args, config):
# stage 0: set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage 1: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=1211)
# stage 2: load the pre-trained model
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
# load model checkpoint to sid model
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage 3: we must set the model to eval mode
model.eval()
# stage 4: read the audio data and extract the embedding
# wavform is one dimension numpy array
waveform, sr = load_audio(args.audio_path)
# feat type is numpy array, whose shape is [dim, time]
# we need convert the audio feat to one-batch shape [batch, dim, time], where the batch is one
# so the final shape is [1, dim, time]
feat = melspectrogram(x=waveform, **config.feature)
feat = paddle.to_tensor(feat).unsqueeze(0)
# in inference period, the lengths is all one without padding
lengths = paddle.ones([1])
feat = feature_normalize(
feat, mean_norm=True, std_norm=False, convert_to_numpy=True)
# model backbone network forward the feats and get the embedding
embedding = model.backbone(
feat, lengths).squeeze().numpy() # (1, emb_size, 1) -> (emb_size)
# stage 5: do global norm with external mean and std
# todo
return embedding
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
parser.add_argument("--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm",
type=str,
default=None,
help="Apply global normalization on speaker embeddings.")
parser.add_argument("--audio-path",
default="./data/demo.wav",
type=str,
help="Single audio file path")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
extract_audio_embedding(args, config)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import numpy as np
import paddle
from yacs.config import CfgNode
import paddle.nn.functional as F
from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm import tqdm
from paddleaudio.paddleaudio.datasets import VoxCeleb1
from paddlespeech.s2t.utils.log import Log
from paddleaudio.paddleaudio.metric import compute_eer
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
logger = Log(__name__).getlog()
def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage1: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage2: build the speaker verification eval instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers)
# stage3: load the pre-trained model
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
# load model checkpoint to sid model
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
# stage4: construct the enroll and test dataloader
enroll_dataset = VoxCeleb1(
subset='enroll',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
**config.feature)
enroll_sampler = BatchSampler(
enroll_dataset, batch_size=config.batch_size,
shuffle=True) # Shuffle to make embedding normalization more robust.
enrol_loader = DataLoader(enroll_dataset,
batch_sampler=enroll_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
test_dataset = VoxCeleb1(
subset='test',
target_dir=args.data_dir,
feat_type='melspectrogram',
random_chunk=False,
**config.feature)
test_sampler = BatchSampler(
test_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset,
batch_sampler=test_sampler,
collate_fn=lambda x: batch_feature_normalize(
x, mean_norm=True, std_norm=False),
num_workers=config.num_workers,
return_list=True,)
# stage6: we must set the model to eval mode
model.eval()
# stage7: global embedding norm to imporve the performance
if args.global_embedding_norm:
global_embedding_mean = None
global_embedding_std = None
mean_norm_flag = args.embedding_mean_norm
std_norm_flag = args.embedding_std_norm
batch_count = 0
# stage8: Compute embeddings of audios in enrol and test dataset from model.
id2embedding = {}
# Run multi times to make embedding normalization more stable.
for i in range(2):
for dl in [enrol_loader, test_loader]:
logger.info(
f'Loop {[i+1]}: Computing embeddings on {dl.dataset.subset} dataset'
)
with paddle.no_grad():
for batch_idx, batch in enumerate(tqdm(dl)):
# stage 8-1: extrac the audio embedding
ids, feats, lengths = batch['ids'], batch['feats'], batch[
'lengths']
embeddings = model.backbone(feats, lengths).squeeze(
-1).numpy() # (N, emb_size, 1) -> (N, emb_size)
# Global embedding normalization.
if args.global_embedding_norm:
batch_count += 1
current_mean = embeddings.mean(
axis=0) if mean_norm_flag else 0
current_std = embeddings.std(
axis=0) if std_norm_flag else 1
# Update global mean and std.
if global_embedding_mean is None and global_embedding_std is None:
global_embedding_mean, global_embedding_std = current_mean, current_std
else:
weight = 1 / batch_count # Weight decay by batches.
global_embedding_mean = (
1 - weight
) * global_embedding_mean + weight * current_mean
global_embedding_std = (
1 - weight
) * global_embedding_std + weight * current_std
# Apply global embedding normalization.
embeddings = (embeddings - global_embedding_mean
) / global_embedding_std
# Update embedding dict.
id2embedding.update(dict(zip(ids, embeddings)))
# stage 9: Compute cosine scores.
labels = []
enrol_ids = []
test_ids = []
with open(VoxCeleb1.veri_test_file, 'r') as f:
for line in f.readlines():
label, enrol_id, test_id = line.strip().split(' ')
labels.append(int(label))
enrol_ids.append(enrol_id.split('.')[0].replace('/', '-'))
test_ids.append(test_id.split('.')[0].replace('/', '-'))
cos_sim_func = paddle.nn.CosineSimilarity(axis=1)
enrol_embeddings, test_embeddings = map(lambda ids: paddle.to_tensor(
np.asarray([id2embedding[id] for id in ids], dtype='float32')),
[enrol_ids, test_ids
]) # (N, emb_size)
scores = cos_sim_func(enrol_embeddings, test_embeddings)
EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
logger.info(
f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
)
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
parser.add_argument("--data-dir",
default="./data/",
type=str,
help="data directory")
parser.add_argument("--load-checkpoint",
type=str,
default='',
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--global-embedding-norm",
type=bool,
default=True,
help="Apply global normalization on speaker embeddings.")
parser.add_argument("--embedding-mean-norm",
type=bool,
default=True,
help="Apply mean normalization on speaker embeddings.")
parser.add_argument("--embedding-std-norm",
type=bool,
default=False,
help="Apply std normalization on speaker embeddings.")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
main(args, config)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
from paddleaudio.paddleaudio.compliance.librosa import melspectrogram
from paddleaudio.paddleaudio.datasets.voxceleb import VoxCeleb1
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.augment import build_augment_pipeline
from paddlespeech.vector.io.augment import waveform_augment
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.io.batch import waveform_collate_fn
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.loss import AdditiveAngularMargin
from paddlespeech.vector.modules.loss import LogSoftmaxWrapper
from paddlespeech.vector.training.scheduler import CyclicLRScheduler
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything
from paddlespeech.vector.utils.time import Timer
logger = Log(__name__).getlog()
def main(args, config):
# stage0: set the training device, cpu or gpu
paddle.set_device(args.device)
# stage1: we must call the paddle.distributed.init_parallel_env() api at the begining
paddle.distributed.init_parallel_env()
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
# stage2: data prepare, such vox1 and vox2 data, and augment noise data and pipline
# note: some cmd must do in rank==0, so wo will refactor the data prepare code
train_dataset = VoxCeleb1('train', target_dir=args.data_dir)
dev_dataset = VoxCeleb1('dev', target_dir=args.data_dir)
if args.augment:
augment_pipeline = build_augment_pipeline(target_dir=args.data_dir)
else:
augment_pipeline = []
# stage3: build the dnn backbone model network
ecapa_tdnn = EcapaTdnn(**config.model)
# stage4: build the speaker verification train instance with backbone model
model = SpeakerIdetification(
backbone=ecapa_tdnn, num_class=VoxCeleb1.num_speakers)
# stage5: build the optimizer, we now only construct the AdamW optimizer
lr_schedule = CyclicLRScheduler(
base_lr=config.learning_rate, max_lr=1e-3, step_size=140000 // nranks)
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_schedule, parameters=model.parameters())
# stage6: build the loss function, we now only support LogSoftmaxWrapper
criterion = LogSoftmaxWrapper(
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))
# stage7: confirm training start epoch
# if pre-trained model exists, start epoch confirmed by the pre-trained model
start_epoch = 0
if args.load_checkpoint:
logger.info("load the check point")
args.load_checkpoint = os.path.abspath(
os.path.expanduser(args.load_checkpoint))
try:
# load model checkpoint
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdparams'))
model.set_state_dict(state_dict)
# load optimizer checkpoint
state_dict = paddle.load(
os.path.join(args.load_checkpoint, 'model.pdopt'))
optimizer.set_state_dict(state_dict)
if local_rank == 0:
logger.info(f'Checkpoint loaded from {args.load_checkpoint}')
except FileExistsError:
if local_rank == 0:
logger.info('Train from scratch.')
try:
start_epoch = int(args.load_checkpoint[-1])
logger.info(f'Restore training from epoch {start_epoch}.')
except ValueError:
pass
# stage8: we build the batch sampler for paddle.DataLoader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=False)
train_loader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=config.num_workers,
collate_fn=waveform_collate_fn,
return_list=True,
use_buffer_reader=True, )
# stage9: start to train
# we will comment the training process
steps_per_epoch = len(train_sampler)
timer = Timer(steps_per_epoch * config.epochs)
timer.start()
for epoch in range(start_epoch + 1, config.epochs + 1):
# at the begining, model must set to train mode
model.train()
avg_loss = 0
num_corrects = 0
num_samples = 0
for batch_idx, batch in enumerate(train_loader):
# stage 9-1: batch data is audio sample points and speaker id label
waveforms, labels = batch['waveforms'], batch['labels']
# stage 9-2: audio sample augment method, which is done on the audio sample point
if len(augment_pipeline) != 0:
waveforms = waveform_augment(waveforms, augment_pipeline)
labels = paddle.concat(
[labels for i in range(len(augment_pipeline) + 1)])
# stage 9-3: extract the audio feats,such fbank, mfcc, spectrogram
feats = []
for waveform in waveforms.numpy():
feat = melspectrogram(x=waveform, **config.feature)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
# stage 9-4: feature normalize, which help converge and imporve the performance
feats = feature_normalize(
feats, mean_norm=True, std_norm=False) # Features normalization
# stage 9-5: model forward, such ecapa-tdnn, x-vector
logits = model(feats)
# stage 9-6: loss function criterion, such AngularMargin, AdditiveAngularMargin
loss = criterion(logits, labels)
# stage 9-7: update the gradient and clear the gradient cache
loss.backward()
optimizer.step()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
optimizer.clear_grad()
# stage 9-8: Calculate average loss per batch
avg_loss += loss.numpy()[0]
# stage 9-9: Calculate metrics, which is one-best accuracy
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
timer.count() # step plus one in timer
# stage 9-10: print the log information only on 0-rank per log-freq batchs
if (batch_idx + 1) % config.log_freq == 0 and local_rank == 0:
lr = optimizer.get_lr()
avg_loss /= config.log_freq
avg_acc = num_corrects / num_samples
print_msg = 'Train Epoch={}/{}, Step={}/{}'.format(
epoch, config.epochs, batch_idx + 1, steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' acc={:.4f}'.format(avg_acc)
print_msg += ' lr={:.4E} step/sec={:.2f} | ETA {}'.format(
lr, timer.timing, timer.eta)
logger.info(print_msg)
avg_loss = 0
num_corrects = 0
num_samples = 0
# stage 9-11: save the model parameters only on 0-rank per save-freq batchs
if epoch % config.save_freq == 0 and batch_idx + 1 == steps_per_epoch:
if local_rank != 0:
paddle.distributed.barrier(
) # Wait for valid step in main process
continue # Resume trainning on other process
# stage 9-12: construct the valid dataset dataloader
dev_sampler = BatchSampler(
dev_dataset,
batch_size=config.batch_size // 4,
shuffle=False,
drop_last=False)
dev_loader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=waveform_collate_fn,
num_workers=config.num_workers,
return_list=True, )
# set the model to eval mode
model.eval()
num_corrects = 0
num_samples = 0
# stage 9-13: evaluation the valid dataset batch data
logger.info('Evaluate on validation dataset')
with paddle.no_grad():
for batch_idx, batch in enumerate(dev_loader):
waveforms, labels = batch['waveforms'], batch['labels']
feats = []
for waveform in waveforms.numpy():
# feat = melspectrogram(x=waveform, **cpu_feat_conf)
feat = melspectrogram(x=waveform, **config.feature)
feats.append(feat)
feats = paddle.to_tensor(np.asarray(feats))
feats = feature_normalize(
feats, mean_norm=True, std_norm=False)
logits = model(feats)
preds = paddle.argmax(logits, axis=1)
num_corrects += (preds == labels).numpy().sum()
num_samples += feats.shape[0]
print_msg = '[Evaluation result]'
print_msg += ' dev_acc={:.4f}'.format(num_corrects / num_samples)
logger.info(print_msg)
# stage 9-14: Save model parameters
save_dir = os.path.join(args.checkpoint_dir,
'epoch_{}'.format(epoch))
logger.info('Saving model checkpoint to {}'.format(save_dir))
paddle.save(model.state_dict(),
os.path.join(save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(save_dir, 'model.pdopt'))
if nranks > 1:
paddle.distributed.barrier() # Main process
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument('--device',
choices=['cpu', 'gpu'],
default="cpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument("--config",
default=None,
type=str,
help="configuration file")
parser.add_argument("--data-dir",
default="./data/",
type=str,
help="data directory")
parser.add_argument("--load-checkpoint",
type=str,
default=None,
help="Directory to load model checkpoint to contiune trainning.")
parser.add_argument("--checkpoint-dir",
type=str,
default='./checkpoint',
help="Directory to save model checkpoints.")
parser.add_argument("--augment",
action="store_true",
default=False,
help="Apply audio augments.")
args = parser.parse_args()
# yapf: enable
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
config.freeze()
print(config)
main(args, config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册