未验证 提交 888b3caf 编写于 作者: R ranchlai 提交者: GitHub

Add Speaker model examples (#5337)

* added speaker

* update readme
上级 9cab6c61
repos:
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.md$
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
- id: remove-crlf
files: \.md$
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$
- repo: local
hooks:
- id: clang-format
name: clang-format
description: Format files with ClangFormat
entry: bash .clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
[style]
based_on_style = pep8
column_limit = 80
# Speaker verification using and ResnetSE ECAPA-TDNN
## Introduction
In this example, we demonstrate how to use PaddleAudio to train two types of networks for speaker verification.
The networks supported here are
- Resnet34 with Squeeze-and-excite block \[1\] to adaptively re-weight the feature maps.
- ECAPA-TDNN \[2\]
## Requirements
Install the requirements via
```
# install paddleaudio
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleAudio
pip install -e .
```
Then install additional requirements by
```
cd examples/speaker
pip install -r requirements.txt
```
## Training
### Training datasets
Following from this example and this example, we use the dev split [VoxCeleb 1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) which consists aof `1,211` speakers and the dev split of [VoxCeleb 2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html) consisting of `5,994` speakers for training. Thus there are `7,502` speakers totally in our training set.
Please download the two datasets from the [official website](https://www.robots.ox.ac.uk/~vgg/data/voxceleb) and unzip all audio into a folder, e.g., `./data/voxceleb/`. Make sure there are `7502` subfolders with prefix `id1****` under the folder. You don't need to further process the data because all data processing such as adding noise / reverberation / speed perturbation will be done on-the-fly. However, to speed up audio decoding, you can manually convert the m4a file in VoxCeleb 2 to wav file format, at the expanse of using more storage.
Finally, create a txt file that contains the list of audios for training by
```
cd ./data/voxceleb/
find `pwd`/ --type f > vox_files.txt
```
### Augmentation datasets
The following datasets are required for dataset augmentation
- [Room Impulse Response and Noise Database](https://openslr.org/28/)
- [MUSAN](https://openslr.org/17/)
For the RIR dataset, you must list all audio files under the folder `RIRS_NOISES/simulated_rirs/` into a text file, e.g., data/rir.list and config it as rir_path in the `config.yaml` file.
Likewise, you have to config the the following fields in the config file for noise augmentation
``` yaml
muse_speech: <musan_split/speech.list> #replace with your actual path
muse_speech_srn_high: 15.0
muse_speech_srn_low: 12.0
muse_music: <musan_split/music.list> #replace with your actual path
muse_music_srn_high: 15.0
muse_music_srn_low: 5.0
muse_noise: <musan_split/noise.list> #replace with your actual path
muse_noise_srn_high: 15
muse_noise_srn_low: 5.0
```
To train your model from scratch, first create a folder(workspace) by
``` bash
cd egs
mkdir <your_example>
cd <your_example>
cp ../resnet/config.yaml . #Copy an example config to your workspace
```
Then change the config file accordingly to make sure all audio files can be correctly located(including the files used for data augmentation). Also you can change the training and model hyper-parameters to suit your need.
Finally start your training by
``` bash
python ../../train.py -c config.yaml -d gpu:0
```
## Testing
## <a name="test_dataset"></a>Testing datasets
The testing split of VoxCeleb 1 is used for measuring the performance of speaker verification duration training and after the training completes. You will need to download the data and unzip into a folder, e.g, `./data/voxceleb/test/`.
Then download the text files which list utterance pairs to compare and the true labels indicating whether the utterances come from the same speaker. There are multiple trials and we will use [veri_test2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt).
To start testing, first download the checkpoints for resnet or ecapa-tdnn,
| checkpoint |size| eer |
| --------------- | --------------- | --------------- |
| [ResnetSE34 + SAP + CMSoftmax](https://bj.bcebos.com/paddleaudio/models/speaker/resnetse34_epoch92_eer0.00931.pdparams) |26MB | 0.93%|
| [ecapa-tdnn + AAMSoftmax ](https://bj.bcebos.com/paddleaudio/models/speaker/tdnn_amsoftmax_epoch51_eer0.011.pdparams)| 80MB |1.10%|
Then prepare the test dataset as described in [Testing datasets](#test_dataset), and set the following path in the config file,
``` yaml
mean_std_file: ../../data/stat.pd
test_list: ../../data/veri_test2.txt
test_folder: ../../data/voxceleb1/
```
To compute the eer using resnet, run:
``` bash
cd egs/resnet/
python ../../test.py -w <checkpoint path> -c config.yaml -d gpu:0
```
which will result in eer 0.00931.
for ecapa-tdnn, run:
``` bash
cd egs/ecapa-tdnn/
python ../../test.py -w <checkpoint path> -c config.yaml -d gpu:0
```
which gives you eer 0.0105.
## Results
We compare our results with [voxceleb_trainer](https://github.com/clovaai/voxceleb_trainer).
### Pretrained model of voxceleb_trainer
The test list is veri_test2.txt, which can be download from here [VoxCeleb1 (cleaned)](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt)
| model |config|checkpoint |eval frames| eer |
| --------------- | --------------- | --------------- |--------------- |--------------- |
| ResnetSE34 + ASP + softmaxproto| - | [baseline_v2_ap](http://www.robots.ox.ac.uk/~joon/data/baseline_v2_ap.model)|400|1.06%|
| ResnetSE34 + ASP + softmaxproto| - | [baseline_v2_ap](http://www.robots.ox.ac.uk/~joon/data/baseline_v2_ap.model)|all|1.18%|
### This example
| model |config|checkpoint |eval frames| eer |
| --------------- | --------------- | --------------- |--------------- |--------------- |
| ResnetSE34 + SAP + CMSoftmax| [config.yaml](./egs/resent/config.yaml) |[checkpoint](https://bj.bcebos.com/paddleaudio/models/speaker/resnetse34_epoch92_eer0.00931.pdparams) | all|0.93%|
| ECAPA-TDNN + AAMSoftmax | [config.yaml](./egs/ecapa-tdnn/config.yaml) | [checkpoint](https://bj.bcebos.com/paddleaudio/models/speaker/tdnn_amsoftmax_epoch51_eer0.011.pdparams) | all|1.10%|
## References
[1] Hu J, Shen L, Sun G. Squeeze-and-excitation networks[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 7132-7141
[2] Desplanques B, Thienpondt J, Demuynck K. Ecapa-tdnn: Emphasized channel attention, propagation and aggregation in tdnn based speaker verification[J]. arXiv preprint arXiv:2005.07143, 2020.
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import os
import pickle
import random
import subprocess
import time
import warnings
import numpy as np
import paddle
import paddleaudio
import yaml
#from paddle.io import DataLoader, Dataset, IterableDataset
from paddle.utils import download
from paddleaudio.utils import augments, get_logger
logger = get_logger(__file__)
def random_choice(a):
i = np.random.randint(0, high=len(a))
return a[int(i)]
def read_scp(file):
lines = open(file).read().split('\n')
keys = [l.split()[0] for l in lines if l.startswith('id')]
speakers = [l.split()[0].split('-')[0] for l in lines if l.startswith('id')]
files = [l.split()[1] for l in lines if l.startswith('id')]
return keys, speakers, files
def read_list(file):
lines = open(file).read().split('\n')
keys = [
'-'.join(l.split('/')[-3:]).split('.')[0] for l in lines
if l.startswith('/')
]
speakers = [k.split('-')[0] for k in keys]
files = [l for l in lines if l.startswith('/')]
return keys, speakers, files
class Dataset(paddle.io.Dataset):
"""
Dataset class for Audioset, with mel features stored in multiple hdf5 files.
The h5 files store mel-spectrogram features pre-extracted from wav files.
Use wav2mel.py to do feature extraction.
"""
def __init__(self,
scp,
keys=None,
sample_rate=16000,
duration=None,
augment=True,
speaker_set=None,
augment_prob=0.5,
training=True,
balanced_sampling=False):
super(Dataset, self).__init__()
self.keys, self.speakers, self.files = read_list(scp)
self.key2file = {k: f for k, f in zip(self.keys, self.files)}
self.n_files = len(self.files)
if speaker_set:
if isinstance(speaker_set, str):
with open(speaker_set) as f:
self.speaker_set = f.read().split('\n')
print(self.speaker_set[:10])
else:
self.speaker_set = speaker_set
else:
self.speaker_set = list(set(self.speakers))
self.speaker_set.sort()
self.spk2cls = {s: i for i, s in enumerate(self.speaker_set)}
self.n_class = len(self.speaker_set)
logger.info(f'speaker size: {self.n_class}')
logger.info(f'file size: {self.n_files}')
self.augment = augment
self.augment_prob = augment_prob
self.training = training
self.sample_rate = sample_rate
self.balanced_sampling = balanced_sampling
self.duration = duration
if augment:
assert duration, 'if augment is True, duration must not be None'
if self.duration:
self.duration = int(self.sample_rate * self.duration)
if keys is not None:
if isinstance(keys, list):
self.keys = keys
elif isinstance(keys, str):
with open(keys) as f:
self.keys = f.read().split('\n')
self.keys = [k for k in self.keys if k.startswith('id')]
logger.info(f'using {len(self.keys)} keys')
def __getitem__(self, idx):
idx = idx % len(self.keys)
key = self.keys[idx]
spk = key.split('-')[0]
cls_idx = self.spk2cls[spk]
file = self.key2file[key]
file_duration = None
if not self.augment and self.duration:
file_duration = self.duration
while True:
try:
wav, sr = paddleaudio.load(file,
sr=self.sample_rate,
duration=file_duration)
break
except:
key = self.keys[idx]
spk = key.split('-')[0]
#spk = self.speakers[idx]
cls_idx = self.spk2cls[spk]
file = self.key2file[key]
print(f'error loading file {file}')
speed = random.choice([0, 1, 2])
if speed == 1:
wav = paddleaudio.resample(wav, 16000, 16000 * 0.9)
cls_idx = cls_idx * 3 + 1
elif speed == 2:
wav = paddleaudio.resample(wav, 16000, 16000 * 1.1)
cls_idx = cls_idx * 3 + 2
else:
cls_idx = cls_idx * 3
if self.augment:
wav = augments.random_crop_or_pad1d(wav, self.duration)
elif self.duration:
wav = augments.center_crop_or_pad1d(wav, self.duration)
return wav, cls_idx
def __len__(self):
return len(self.keys)
def worker_init(worker_id):
time.sleep(worker_id / 32)
seed = int(time.time()) % 10000 + worker_id
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def get_train_loader(config):
dataset = Dataset(config['spk_scp'],
keys=config['train_keys'],
speaker_set=config['speaker_set'],
augment=True,
duration=config['duration'])
train_loader = paddle.io.DataLoader(dataset,
shuffle=True,
batch_size=config['batch_size'],
drop_last=True,
num_workers=config['num_workers'],
use_buffer_reader=True,
use_shared_memory=True,
worker_init_fn=worker_init)
return train_loader
def get_val_loader(config):
dataset = Dataset(config['spk_scp'],
keys=config['val_keys'],
speaker_set=config['speaker_set'],
augment=False,
duration=config['duration'])
val_loader = paddle.io.DataLoader(dataset,
shuffle=False,
batch_size=config['val_batch_size'],
drop_last=False,
num_workers=config['num_workers'])
return val_loader
if __name__ == '__main__':
# do some testing here
with open('config.yaml') as f:
config = yaml.safe_load(f)
train_loader = get_train_loader(config)
# val_loader = get_val_loader(config)
for i, (x, y) in enumerate(train_loader()):
print(x, y)
break
# for i, (x, y) in enumerate(val_loader()):
# print(x, y)
# break
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIaONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
fbank:
sr: 16000 # sample rate
n_fft: 512
win_length: 400 #25ms
hop_length: 160 #10ms
n_mels: 80
f_min: 20
f_max: 7600
window: hann
amin: !!float 1e-5
top_db: 75
augment_with_sox: False
augment_mel: False
augment_wav: True
rir_path: ../../data/RIRS_NOISES/rir.list
muse_speech: ../../data/musan_split/speech.list
muse_speech_srn_high: 15.0
muse_speech_srn_low: 12.0
muse_music: ../../data/musan_split/music.list
muse_music_srn_high: 10.0
muse_music_srn_low: 5.0
muse_noise: ../../data/musan_split/noise.list
muse_noise_srn_high: 10
muse_noise_srn_low: 5.0
batch_size: 64
val_batch_size: 16
num_workers: 16
num_classes: 7205
duration: 5
balanced_sampling: False
epoch_num: 500
max_lr: !!float 1e-05
base_lr: !!float 1e-04
reverse_lr: True
half_cycle: 10 # epoch
# set the data path accordingly
spk_scp: ../../data/voxceleb1and2_list.txt
mean_std_file: ../../data/stat.pd
#for testing
test_list: ../../data/veri_test2.txt
test_folder: ../../data/voxceleb1/
speaker_set: ../../data/speaker_set_vox12.txt
train_keys: ~
model_dir : ./checkpoints/
model_prefix: 'tdnn'
log_dir : ./log/
log_file : ./log.txt
log_step: 10
checkpoint_step : 5000
eval_step: 10000
max_time_mask: 2
max_freq_mask: 1
max_time_mask_width: 20
max_freq_mask_width: 10
model:
name: EcapaTDNN
params:
input_size: 80 # should be the same as in the fbank config
normalize: True
loss:
name: AdditiveAngularMargin
params:
margin: 0.3
scale: 30.0
easy_margin: False
feature_dim: 192
n_classes: 22506
# loss:
# name: CMSoftmax
# params:
# margin: 0.10
# margin2: 0.10
# scale: 30.0
# feature_dim: 256
# n_classes: 22506
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIaONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
fbank:
sr: 16000 # sample rate
n_fft: 512
win_length: 400 #25ms
hop_length: 160 #10ms
n_mels: 80
f_min: 20
f_max: 7600
window: hann
amin: !!float 1e-5
top_db: 75
augment_with_sox: False
augment_mel: False
augment_wav: True
rir_path: ../../data/RIRS_NOISES/rir.list
muse_speech: ../../data/musan_split/speech.list
muse_speech_srn_high: 15.0
muse_speech_srn_low: 12.0
muse_music: ../../data/musan_split/music.list
muse_music_srn_high: 15.0
muse_music_srn_low: 5.0
muse_noise: ../../data/musan_split/noise.list
muse_noise_srn_high: 15
muse_noise_srn_low: 5.0
freeze_param: False
freezed_layers: -8
batch_size: 64
val_batch_size: 16
num_workers: 32
num_classes: 7205
duration: 5
balanced_sampling: False
epoch_num: 500
max_lr: !!float 1e-06
base_lr: !!float 1e-04
reverse_lr: False
half_cycle: 10 #epoch
# set the data path accordingly
spk_scp: ../../data/voxceleb1and2_list.txt
speaker_set: ../../data/speaker_set.txt
train_keys: ~
# for testing
mean_std_file: ../../data/stat.pd
test_list: ../../data/veri_test2.txt
test_folder: ../../data/voxceleb1/
model_dir : ./checkpoints/
model_prefix: 'resnet'
log_dir : ./log/
log_file : ./log.txt
log_step: 10
checkpoint_step : 5000
eval_step: 600
max_time_mask: 2
max_freq_mask: 1
max_time_mask_width: 20
max_freq_mask_width: 10
model:
name: ResNetSE34V2 # or ResNetSE34
params:
feature_dim: 256
scale_factor: 1
encoder_type: SAP #ASP #ASP with attention
n_mels: 80 # should be the same as in the fbank config
normalize: True
# loss:
# name: AdditiveAngularMargin
# params:
# margin: 0.35
# scale: 30.0
# easy_margin: False
# feature_dim: 192
# n_classes: 22506
loss:
name: CMSoftmax
params:
margin: 0.10
margin2: 0.10
scale: 30.0
feature_dim: 256
n_classes: 22506
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ['ProtoTypical', 'AMSoftmaxLoss', 'CMSoftmax']
class AMSoftmaxLoss(nn.Layer):
"""Additive margin softmax loss.
Additive margin softmax loss is usefully for training neural networks for speaker recognition/verification.
Notes:
The loss itself contains parameters that need to pass to optimizer for gradient descends.
References:
Wang, Feng, et al. “Additive Margin Softmax for Face Verification.”
IEEE Signal Processing Letters, vol. 25, no. 7, 2018, pp. 926–930.
"""
def __init__(self,
feature_dim: int,
n_classes: int,
eps: float = 1e-5,
margin: float = 0.3,
scale: float = 30.0):
super(AMSoftmaxLoss, self).__init__()
self.w = paddle.create_parameter((feature_dim, n_classes), 'float32')
self.eps = eps
self.scale = scale
self.margin = margin
self.nll_loss = nn.NLLLoss()
self.n_classes = n_classes
def forward(self, logits, label):
logits = F.normalize(logits, p=2, axis=1, epsilon=self.eps)
wn = F.normalize(self.w, p=2, axis=0, epsilon=self.eps)
cosine = paddle.matmul(logits, wn)
y = paddle.zeros((logits.shape[0], self.n_classes))
for i in range(logits.shape[0]):
y[i, label[i]] = self.margin
pred = F.log_softmax((cosine - y) * self.scale, -1)
return self.nll_loss(pred, label), pred
class ProtoTypical(nn.Layer):
"""Proto-typical loss as described in [1].
Reference:
[1] Chung, Joon Son, et al. “In Defence of Metric Learning for Speaker Recognition.”
Interspeech 2020, 2020, pp. 2977–2981.
"""
def __init__(self, s=20.0, eps=1e-8):
super(ProtoTypical, self).__init__()
self.nll_loss = nn.NLLLoss()
self.eps = eps
self.s = s
def forward(self, logits):
assert logits.ndim == 3, (
f'the input logits must be a ' +
f'3d tensor of shape [n_spk,n_uttns,emb_dim],' +
f'but received logits.ndim = {logits.ndim}')
import pdb
pdb.set_trace()
logits = F.normalize(logits, p=2, axis=-1, epsilon=self.eps)
proto = paddle.mean(logits[:, 1:, :], axis=1, keepdim=False).transpose(
(1, 0)) # [emb_dim, n_spk]
query = logits[:, 0, :] # [n_spk, emb_dim]
similarity = paddle.matmul(query, proto) * self.s #[n_spk,n_spk]
label = paddle.arange(0, similarity.shape[0])
log_sim = F.log_softmax(similarity, -1)
return self.nll_loss(log_sim, label), log_sim
class AngularMargin(nn.Layer):
def __init__(self, margin=0.0, scale=1.0):
super(AngularMargin, self).__init__()
self.margin = margin
self.scale = scale
def forward(self, outputs, targets):
outputs = outputs - self.margin * targets
return self.scale * outputs
class LogSoftmaxWrapper(nn.Layer):
def __init__(self, loss_fn):
super(LogSoftmaxWrapper, self).__init__()
self.loss_fn = loss_fn
self.criterion = paddle.nn.KLDivLoss(reduction="sum")
def forward(self, outputs, targets, length=None):
targets = F.one_hot(targets, outputs.shape[1])
try:
predictions = self.loss_fn(outputs, targets)
except TypeError:
predictions = self.loss_fn(outputs)
predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss
class AdditiveAngularMargin(AngularMargin):
def __init__(self,
margin=0.0,
scale=1.0,
feature_dim=256,
n_classes=1000,
easy_margin=False):
super(AdditiveAngularMargin, self).__init__(margin, scale)
self.easy_margin = easy_margin
self.w = paddle.create_parameter((feature_dim, n_classes), 'float32')
self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin
self.nll_loss = nn.NLLLoss()
self.n_classes = n_classes
def forward(self, logits, targets):
# logits = self.drop(logits)
logits = F.normalize(logits, p=2, axis=1, epsilon=1e-8)
wn = F.normalize(self.w, p=2, axis=0, epsilon=1e-8)
cosine = logits @ wn
#cosine = outputs.astype('float32')
sine = paddle.sqrt(1.0 - paddle.square(cosine))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = paddle.where(cosine > 0, phi, cosine)
else:
phi = paddle.where(cosine > self.th, phi, cosine - self.mm)
target_one_hot = F.one_hot(targets, self.n_classes)
outputs = (target_one_hot * phi) + ((1.0 - target_one_hot) * cosine)
outputs = self.scale * outputs
pred = F.log_softmax(outputs, axis=-1)
return self.nll_loss(pred, targets), pred
class CMSoftmax(AngularMargin):
def __init__(self,
margin=0.0,
margin2=0.0,
scale=1.0,
feature_dim=256,
n_classes=1000,
easy_margin=False):
super(CMSoftmax, self).__init__(margin, scale)
self.easy_margin = easy_margin
self.w = paddle.create_parameter((feature_dim, n_classes), 'float32')
self.cos_m = math.cos(self.margin)
self.sin_m = math.sin(self.margin)
self.th = math.cos(math.pi - self.margin)
self.mm = math.sin(math.pi - self.margin) * self.margin
self.nll_loss = nn.NLLLoss()
self.n_classes = n_classes
self.margin2 = margin2
def forward(self, logits, targets):
logits = F.normalize(logits, p=2, axis=1, epsilon=1e-8)
wn = F.normalize(self.w, p=2, axis=0, epsilon=1e-8)
cosine = logits @ wn
sine = paddle.sqrt(1.0 - paddle.square(cosine))
phi = cosine * self.cos_m - sine * self.sin_m # cos(theta + m)
if self.easy_margin:
phi = paddle.where(cosine > 0, phi, cosine)
else:
phi = paddle.where(cosine > self.th, phi, cosine - self.mm)
target_one_hot = F.one_hot(targets, self.n_classes)
outputs = (target_one_hot * phi) + (
(1.0 - target_one_hot) * cosine) - target_one_hot * self.margin2
outputs = self.scale * outputs
pred = F.log_softmax(outputs, axis=-1)
return self.nll_loss(pred, targets), pred
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import List, Tuple, Union
import numpy as np
def compute_eer(
scores: Union[np.ndarray, List[float]], labels: Union[np.ndarray, List[int]]
) -> Tuple[float, float, np.ndarray, np.ndarray]:
"""Compute equal error rate(EER) given matching scores and corresponding labels
Parameters:
scores(np.ndarray,list): the cosine similarity between two speaker embeddings.
labels(np.ndarray,list): the labels of the speaker pairs, with value 1 indicates same speaker and 0 otherwise.
Returns:
eer(float): the equal error rate.
thresh_for_eer(float): the thresh value at which false acceptance rate equals to false rejection rate.
fr_rate(np.ndarray): the false rejection rate as a function of increasing thresholds.
fa_rate(np.ndarray): the false acceptance rate as a function of increasing thresholds.
"""
if isinstance(labels, list):
labels = np.array(labels)
if isinstance(scores, list):
scores = np.array(scores)
label_set = list(np.unique(labels))
assert len(
label_set
) == 2, f'the input labels must contains both two labels, but recieved set(labels) = {label_set}'
label_set.sort()
assert label_set == [
0, 1
], 'the input labels must contain 0 and 1 for distinct and identical id. '
eps = 1e-8
#assert np.min(scores) >= -1.0 - eps and np.max(
# scores
# ) < 1.0 + eps, 'the score must be in the range between -1.0 and 1.0'
same_id_scores = scores[labels == 1]
diff_id_scores = scores[labels == 0]
thresh = np.linspace(np.min(diff_id_scores), np.max(same_id_scores), 1000)
thresh = np.expand_dims(thresh, 1)
fr_matrix = same_id_scores < thresh
fa_matrix = diff_id_scores >= thresh
fr_rate = np.mean(fr_matrix, 1)
fa_rate = np.mean(fa_matrix, 1)
thresh_idx = np.argmin(np.abs(fa_rate - fr_rate))
result = namedtuple('speaker', ('eer', 'thresh', 'fa', 'fr'))
result.eer = (fr_rate[thresh_idx] + fa_rate[thresh_idx]) / 2
result.thresh = thresh[thresh_idx, 0]
result.fr = fr_rate
result.fa = fa_rate
return result
def compute_min_dcf(fr_rate, fa_rate, p_target=0.05, c_miss=1.0, c_fa=1.0):
""" Compute normalized minimum detection cost function (minDCF) given
the costs for false accepts and false rejects as well as a priori
probability for target speakers
Parameters:
fr_rate(np.ndarray): the false rejection rate as a function of increasing thresholds.
fa_rate(np.ndarray): the false acceptance rate as a function of increasing thresholds.
p_target(float): the prior probability of being a target.
c_miss(float): cost of miss detection(false rejects).
c_fa(float): cost of miss detection(false accepts).
Returns:
min_cdf(float): the normalized minimum detection cost function (minDCF)
"""
dcf = c_miss * fr_rate * p_target + c_fa * fa_rate * (1 - p_target)
c_det = np.min(dcf)
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
min_cdf = c_det / c_def
return min_cdf
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ecapa_tdnn import EcapaTDNN
from .resnet_se34 import ResNetSE34
from .resnet_se34v2 import ResNetSE34V2
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def length_to_mask(length, max_len=None, dtype=None):
assert len(length.shape) == 1
if max_len is None:
max_len = length.max().astype(
'int').item() # using arange to generate mask
mask = paddle.arange(max_len, dtype=length.dtype).expand(
(len(length), max_len)) < length.unsqueeze(1)
if dtype is None:
dtype = length.dtype
mask = paddle.to_tensor(mask, dtype=dtype)
return mask
class Conv1d(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding="same",
dilation=1,
groups=1,
bias=True,
padding_mode="reflect",
):
super(Conv1d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.conv = nn.Conv1D(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
padding=0,
dilation=self.dilation,
groups=groups,
bias_attr=bias,
)
def forward(self, x):
if self.padding == "same":
x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride)
else:
raise ValueError("Padding must be 'same'. Got {self.padding}")
return self.conv(x)
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
L_in = x.shape[-1] # Detecting input shape
padding = self._get_padding_elem(L_in, stride, kernel_size,
dilation) # Time padding
x = F.pad(x, padding, mode=self.padding_mode,
data_format="NCL") # Applying padding
return x
def _get_padding_elem(self, L_in: int, stride: int, kernel_size: int,
dilation: int):
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
padding = [kernel_size // 2, kernel_size // 2]
else:
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
return padding
class BatchNorm1d(nn.Layer):
def __init__(
self,
input_size,
eps=1e-05,
momentum=0.9,
weight_attr=None,
bias_attr=None,
data_format='NCL',
use_global_stats=None,
):
super(BatchNorm1d, self).__init__()
self.norm = nn.BatchNorm1D(
input_size,
epsilon=eps,
momentum=momentum,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format,
use_global_stats=use_global_stats,
)
def forward(self, x):
x_n = self.norm(x)
return x_n
class TDNNBlock(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation,
activation=nn.ReLU,
):
super(TDNNBlock, self).__init__()
self.conv = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dilation=dilation,
)
self.activation = activation()
self.norm = BatchNorm1d(input_size=out_channels)
def forward(self, x):
return self.norm(self.activation(self.conv(x)))
class Res2NetBlock(nn.Layer):
def __init__(self, in_channels, out_channels, scale=8, dilation=1):
super(Res2NetBlock, self).__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
in_channel = in_channels // scale
hidden_channel = out_channels // scale
self.blocks = nn.LayerList([
TDNNBlock(in_channel,
hidden_channel,
kernel_size=3,
dilation=dilation) for i in range(scale - 1)
])
self.scale = scale
def forward(self, x):
y = []
for i, x_i in enumerate(paddle.chunk(x, self.scale, axis=1)):
if i == 0:
y_i = x_i
elif i == 1:
y_i = self.blocks[i - 1](x_i)
else:
y_i = self.blocks[i - 1](x_i + y_i)
y.append(y_i)
y = paddle.concat(y, axis=1)
return y
class SEBlock(nn.Layer):
def __init__(self, in_channels, se_channels, out_channels):
super(SEBlock, self).__init__()
self.conv1 = Conv1d(in_channels=in_channels,
out_channels=se_channels,
kernel_size=1)
self.relu = paddle.nn.ReLU()
self.conv2 = Conv1d(in_channels=se_channels,
out_channels=out_channels,
kernel_size=1)
self.sigmoid = paddle.nn.Sigmoid()
def forward(self, x, lengths=None):
L = x.shape[-1]
if lengths is not None:
mask = length_to_mask(lengths * L, max_len=L)
mask = mask.unsqueeze(1)
total = mask.sum(axis=2, keepdim=True)
s = (x * mask).sum(axis=2, keepdim=True) / total
else:
s = x.mean(axis=2, keepdim=True)
s = self.relu(self.conv1(s))
s = self.sigmoid(self.conv2(s))
return s * x
class AttentiveStatisticsPooling(nn.Layer):
def __init__(self, channels, attention_channels=128, global_context=True):
super().__init__()
self.eps = 1e-12
self.global_context = global_context
if global_context:
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
else:
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
self.tanh = nn.Tanh()
self.conv = Conv1d(in_channels=attention_channels,
out_channels=channels,
kernel_size=1)
def forward(self, x, lengths=None):
C, L = x.shape[1], x.shape[2] # KP: (N, C, L)
def _compute_statistics(x, m, axis=2, eps=self.eps):
mean = (m * x).sum(axis)
std = paddle.sqrt(
(m * (x - mean.unsqueeze(axis)).pow(2)).sum(axis).clip(eps))
return mean, std
if lengths is None:
lengths = paddle.ones([x.shape[0]])
# Make binary mask of shape [N, 1, L]
mask = length_to_mask(lengths * L, max_len=L)
mask = mask.unsqueeze(1)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if self.global_context:
total = mask.sum(axis=2, keepdim=True).astype('float32')
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).tile((1, 1, L))
std = std.unsqueeze(2).tile((1, 1, L))
attn = paddle.concat([x, mean, std], axis=1)
else:
attn = x
# Apply layers
attn = self.conv(self.tanh(self.tdnn(attn)))
# Filter out zero-paddings
attn = paddle.where(
mask.tile((1, C, 1)) == 0,
paddle.ones_like(attn) * float("-inf"), attn)
attn = F.softmax(attn, axis=2)
mean, std = _compute_statistics(x, attn)
# Append mean and std of the batch
pooled_stats = paddle.concat((mean, std), axis=1)
pooled_stats = pooled_stats.unsqueeze(2)
return pooled_stats
class SERes2NetBlock(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
res2net_scale=8,
se_channels=128,
kernel_size=1,
dilation=1,
activation=nn.ReLU,
):
super(SERes2NetBlock, self).__init__()
self.out_channels = out_channels
self.tdnn1 = TDNNBlock(
in_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
)
self.res2net_block = Res2NetBlock(out_channels, out_channels,
res2net_scale, dilation)
self.tdnn2 = TDNNBlock(
out_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
)
self.se_block = SEBlock(out_channels, se_channels, out_channels)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x, lengths=None):
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.tdnn1(x)
x = self.res2net_block(x)
x = self.tdnn2(x)
x = self.se_block(x, lengths)
return x + residual
class EcapaTDNN(nn.Layer):
def __init__(self,
input_size,
lin_neurons=192,
activation=nn.ReLU,
channels=[1024, 1024, 1024, 1024, 3072],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=8,
se_channels=128,
global_context=True):
super(EcapaTDNN, self).__init__()
assert len(channels) == len(kernel_sizes)
assert len(channels) == len(dilations)
self.channels = channels
self.blocks = nn.LayerList()
self.emb_size = lin_neurons
# The initial TDNN layer
self.blocks.append(
TDNNBlock(
input_size,
channels[0],
kernel_sizes[0],
dilations[0],
activation,
))
# SE-Res2Net layers
for i in range(1, len(channels) - 1):
self.blocks.append(
SERes2NetBlock(
channels[i - 1],
channels[i],
res2net_scale=res2net_scale,
se_channels=se_channels,
kernel_size=kernel_sizes[i],
dilation=dilations[i],
activation=activation,
))
# Multi-layer feature aggregation
self.mfa = TDNNBlock(
channels[-1],
channels[-1],
kernel_sizes[-1],
dilations[-1],
activation,
)
# Attentive Statistical Pooling
self.asp = AttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
)
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
# Final linear transformation
self.fc = Conv1d(
in_channels=channels[-1] * 2,
out_channels=self.emb_size,
kernel_size=1,
)
self.drop = nn.Dropout(0.25)
def forward(self, x, lengths=None):
xl = []
for layer in self.blocks:
try:
x = layer(x, lengths=lengths)
except TypeError:
x = layer(x)
xl.append(x)
# Multi-layer feature aggregation
x = paddle.concat(xl[1:], axis=1)
x = self.mfa(x)
# Attentive Statistical Pooling
x = self.asp(x, lengths=lengths)
x = self.asp_bn(x)
# Final linear transformation
x = self.drop(x)
x = self.fc(x)
x = x[:, :, 0]
return x
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
class SEBasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
reduction=8):
super(SEBasicBlock, self).__init__()
self.conv1 = nn.Conv2D(inplanes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes)
self.conv2 = nn.Conv2D(planes,
planes,
kernel_size=3,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.se = SELayer(planes, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.bn1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SEBottleneck(nn.Layer):
expansion = 4
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
reduction=8):
super(SEBottleneck, self).__init__()
self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes)
self.conv2 = nn.Conv2D(planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes)
self.conv3 = nn.Conv2D(planes,
planes * 4,
kernel_size=1,
bias_attr=False)
self.bn3 = nn.BatchNorm2D(planes * 4)
self.relu = nn.ReLU()
self.se = SELayer(planes * 4, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SELayer(nn.Layer):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.fc = nn.Sequential(nn.Linear(channel, channel // reduction),
nn.ReLU(),
nn.Linear(channel // reduction, channel),
nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.shape
y = self.avg_pool(x).reshape((b, c))
y = self.fc(y).reshape((b, c, 1, 1))
return x * y
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleaudio.transforms import LogMelSpectrogram, MelSpectrogram
from .resnet_blocks import SEBasicBlock, SEBottleneck
class ResNetSE(nn.Layer):
def __init__(self,
block,
layers,
num_filters,
feature_dim,
feature_config,
encoder_type='SAP',
n_mels=80,
log_input=True,
**kwargs):
super(ResNetSE, self).__init__()
print('Embedding size is %d, encoder %s.' % (feature_dim, encoder_type))
self.inplanes = num_filters[0]
self.encoder_type = encoder_type
self.n_mels = n_mels
self.log_input = log_input
self.conv1 = nn.Conv2D(1,
num_filters[0],
kernel_size=3,
stride=1,
padding=1)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2D(num_filters[0])
self.layer1 = self._make_layer(block, num_filters[0], layers[0])
self.layer2 = self._make_layer(block,
num_filters[1],
layers[1],
stride=(2, 2))
self.layer3 = self._make_layer(block,
num_filters[2],
layers[2],
stride=(2, 2))
self.layer4 = self._make_layer(block,
num_filters[3],
layers[3],
stride=(2, 2))
# self.instancenorm = nn.InstanceNorm1D(n_mels)
outmap_size = int(self.n_mels / 8)
self.attention = nn.Sequential(
nn.Conv1D(num_filters[3] * outmap_size, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1D(128),
nn.Conv1D(128, num_filters[3] * outmap_size, kernel_size=1),
nn.Softmax(axis=2),
)
if self.encoder_type == "SAP":
out_dim = num_filters[3] * outmap_size
elif self.encoder_type == "ASP":
out_dim = num_filters[3] * outmap_size * 2
else:
raise ValueError('Undefined encoder')
self.fc = nn.Linear(out_dim, feature_dim)
self.melspectrogram = LogMelSpectrogram(**feature_config)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def new_parameter(self, size):
out = paddle.create_parameter(size, 'float32')
nn.initializer.XavierNormal(out)
return out
def forward(self, x):
x = x.unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape((x.shape[0], -1, x.shape[-1]))
w = self.attention(x)
if self.encoder_type == "SAP":
x = paddle.sum(x * w, axis=2)
elif self.encoder_type == "ASP":
mu = paddle.sum(x * w, axis=2)
sg = paddle.sum((x**2) * w, axis=2) - mu**2
sg = paddle.clip(sg, min=1e-5)
sg = paddle.sqrt(sg)
x = paddle.concat((mu, sg), 1)
x = x.reshape((x.shape[0], -1))
x = self.fc(x)
return x
def ResNetSE34(feature_dim=256, scale_factor=1, **kwargs):
# Number of filters
num_filters = [
32 * scale_factor, 64 * scale_factor, 128 * scale_factor,
256 * scale_factor
]
model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, feature_dim,
**kwargs)
return model
if __name__ == '__main__':
print(ResNetSE34())
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddleaudio
from paddleaudio.transforms import LogMelSpectrogram, MelSpectrogram
from .resnet_blocks import SEBasicBlock, SEBottleneck
class ResNetSE(nn.Layer):
def __init__(self,
block,
layers,
num_filters,
feature_dim,
encoder_type='SAP',
n_mels=40,
log_input=True,
**kwargs):
super(ResNetSE, self).__init__()
print('Embedding size is %d, encoder %s.' % (feature_dim, encoder_type))
self.inplanes = num_filters[0]
self.encoder_type = encoder_type
self.n_mels = n_mels
self.log_input = log_input
self.conv1 = nn.Conv2D(1,
num_filters[0],
kernel_size=3,
stride=1,
padding=1)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2D(num_filters[0])
self.layer1 = self._make_layer(block, num_filters[0], layers[0])
self.layer2 = self._make_layer(block,
num_filters[1],
layers[1],
stride=(2, 2))
self.layer3 = self._make_layer(block,
num_filters[2],
layers[2],
stride=(2, 2))
self.layer4 = self._make_layer(block,
num_filters[3],
layers[3],
stride=(2, 2))
outmap_size = int(self.n_mels / 8)
self.attention = nn.Sequential(
nn.Conv1D(num_filters[3] * outmap_size, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1D(128),
nn.Conv1D(128, num_filters[3] * outmap_size, kernel_size=1),
nn.Softmax(axis=2),
)
if self.encoder_type == "SAP":
out_dim = num_filters[3] * outmap_size
elif self.encoder_type == "ASP":
out_dim = num_filters[3] * outmap_size * 2
else:
raise ValueError('Undefined encoder')
self.fc = nn.Linear(out_dim, feature_dim)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def new_parameter(self, *size):
out = paddle.create_parameter(size, 'float32')
nn.initializer.XavierNormal(out)
return out
def forward(self, x):
x = x.unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape((x.shape[0], -1, x.shape[-1]))
w = self.attention(x)
if self.encoder_type == "SAP":
x = paddle.sum(x * w, axis=2)
elif self.encoder_type == "ASP":
mu = paddle.sum(x * w, axis=2)
sg = paddle.sum((x**2) * w, axis=2) - mu**2
sg = paddle.clip(sg, min=1e-5)
sg = paddle.sqrt(sg)
x = paddle.concat((mu, sg), 1)
x = x.reshape((x.shape[0], -1))
x = self.fc(x)
return x
def ResNetSE34V2(feature_dim=256, scale_factor=1, **kwargs):
# Number of filters
num_filters = [
32 * scale_factor, 64 * scale_factor, 128 * scale_factor,
256 * scale_factor
]
model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, feature_dim,
**kwargs)
return model
if __name__ == '__main__':
print(ResNetSE34V2())
paddlepaddle-gpu==2.1.1
PyYAML==5.4.1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import metrics
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddleaudio
import yaml
from dataset import get_val_loader
from paddleaudio.transforms import *
from paddleaudio.utils import get_logger
from models import EcapaTDNN, ResNetSE34, ResNetSE34V2
logger = get_logger()
file2feature = {}
def get_feature(file, model, melspectrogram, random_sampling=False):
global file2feature
if file in file2feature:
return file2feature[file]
s0, _ = paddleaudio.load(file, sr=16000) #, norm_type='gaussian')
s = paddle.to_tensor(s0[None, :])
s = melspectrogram(s).astype('float32')
with paddle.no_grad():
feature = model(s) #.squeeze()
feature = feature / paddle.sqrt(paddle.sum(feature**2))
file2feature.update({file: feature})
return feature
class Normalize:
def __init__(self, mean_file, eps=1e-5):
self.eps = eps
mean = paddle.load(mean_file)['mean']
self.mean = mean.unsqueeze((0, 2))
def __call__(self, x):
assert x.ndim == 3
return x - self.mean
def get_score(features1, features2): # feature mean
score = float(paddle.dot(features1.squeeze(), features2.squeeze()))
return score
def compute_eer(config, model):
transforms = []
melspectrogram = LogMelSpectrogram(**config['fbank'])
transforms += [melspectrogram]
if config['normalize']:
transforms += [Normalize(config['mean_std_file'])]
transforms = Compose(transforms)
global file2feature # to avoid repeated computation
file2feature = {}
test_list = config['test_list']
test_folder = config['test_folder']
model.eval()
with open(test_list) as f:
lines = f.read().split('\n')
label_wav_pairs = [l.split() for l in lines if len(l) > 0]
logger.info(f'{len(label_wav_pairs)} test pairs listed')
labels = []
scores = []
for i, (label, f1, f2) in enumerate(label_wav_pairs):
full_path1 = os.path.join(test_folder, f1)
full_path2 = os.path.join(test_folder, f2)
feature1 = get_feature(full_path1, model, transforms)
feature2 = get_feature(full_path2, model, transforms)
score = get_score(feature1, feature2)
labels.append(label)
scores.append(score)
if i % (len(label_wav_pairs) // 10) == 0:
logger.info(f'processed {i}|{len(label_wav_pairs)}')
scores = np.array(scores)
labels = np.array([int(l) for l in labels])
result = metrics.compute_eer(scores, labels)
min_dcf = metrics.compute_min_dcf(result.fr, result.fa)
logger.info(f'eer={result.eer}, thresh={result.thresh}, minDCF={min_dcf}')
return result, min_dcf
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c',
'--config',
type=str,
required=False,
default='config.yaml')
parser.add_argument(
'-d',
'--device',
default="gpu",
help="Select which device to train model, defaults to gpu.")
parser.add_argument('-w', '--weight', type=str, required=True)
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
paddle.set_device(args.device)
logger.info('model:' + config['model']['name'])
logger.info('device: ' + args.device)
logger.info(f'using ' + config['model']['name'])
ModelClass = eval(config['model']['name'])
model = ModelClass(**config['model']['params'])
state_dict = paddle.load(args.weight)
if 'model' in state_dict.keys():
state_dict = state_dict['model']
model.load_dict(state_dict)
result, min_dcf = compute_eer(config, model)
logger.info(f'eer={result.eer}, thresh={result.thresh}, minDCF={min_dcf}')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import os
import random
import time
from test import compute_eer
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.nn as nn
import paddle.nn.functional as F
import yaml
from dataset import get_train_loader
from losses import AdditiveAngularMargin, AMSoftmaxLoss, CMSoftmax
from paddle.optimizer import SGD, Adam
from paddle.utils import download
from paddleaudio.transforms import *
from paddleaudio.utils import get_logger
from utils import NoiseSource, Normalize, RIRSource
from models import *
def get_lr(step, base_lr, max_lr, half_cycle=5000, reverse=False):
if int(step / half_cycle) % 2 == 0:
lr = (step % half_cycle) / half_cycle * (max_lr - base_lr)
lr = base_lr + lr
else:
lr = (step % half_cycle / half_cycle) * (max_lr - base_lr)
lr = max_lr - lr
lr = max_lr - lr
return lr
def freeze_bn(layer):
if isinstance(layer, paddle.nn.BatchNorm1D):
layer._momentum = 0.8
print(layer._momentum)
if isinstance(layer, paddle.nn.BatchNorm2D):
layer._momentum = 0.8
print(layer._momentum)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True)
parser.add_argument(
'-d',
'--device',
default="gpu",
help='Select which device to train model, defaults to gpu.')
parser.add_argument(
'-r',
'--restore',
type=int,
required=False,
default=-1,
help=
'the epoch number to restore from(the checkpoint contains weights for model/loss/optimizer)'
)
parser.add_argument('-w',
'--weight',
type=str,
required=False,
default='',
help='the model wieght to restore form')
parser.add_argument('-e',
'--eval_at_begin',
type=bool,
choices=[True, False],
required=False,
default=False)
parser.add_argument('--distributed',
type=bool,
choices=[True, False],
required=False,
default=False)
args = parser.parse_args()
with open(args.config) as f:
config = yaml.safe_load(f)
os.makedirs(config['log_dir'], exist_ok=True)
logger = get_logger(__file__,
log_dir=config['log_dir'],
log_file_name=config['log_file'])
prefix = config['model_prefix']
if args.distributed:
dist.init_parallel_env()
local_rank = dist.get_rank()
print(local_rank)
else:
paddle.set_device(args.device)
local_rank = 0
logger.info(f'using ' + config['model']['name'])
ModelClass = eval(config['model']['name'])
model = ModelClass(**config['model']['params'])
#define loss and lr
LossClass = eval(config['loss']['name'])
loss_fn = LossClass(**config['loss']['params'])
loss_fn.train()
params = model.parameters() + loss_fn.parameters()
transforms = []
if config['augment_wav']:
noise_source1 = NoiseSource(open(
config['muse_speech']).read().split('\n')[:-1],
sample_rate=16000,
duration=config['duration'],
batch_size=config['batch_size'])
noisify1 = Noisify(noise_source1,
snr_high=config['muse_speech_srn_high'],
snr_low=config['muse_speech_srn_low'],
random=True)
noise_source2 = NoiseSource(open(
config['muse_music']).read().split('\n')[:-1],
sample_rate=16000,
duration=config['duration'],
batch_size=config['batch_size'])
noisify2 = Noisify(noise_source2,
snr_high=config['muse_music_srn_high'],
snr_low=config['muse_music_srn_low'],
random=True)
noise_source3 = NoiseSource(open(
config['muse_noise']).read().split('\n')[:-1],
sample_rate=16000,
duration=config['duration'],
batch_size=config['batch_size'])
noisify3 = Noisify(noise_source3,
snr_high=config['muse_noise_srn_high'],
snr_low=config['muse_noise_srn_low'],
random=True)
rir_files = open(config['rir_path']).read().split('\n')[:-1]
random_rir_reader = RIRSource(rir_files, random=True, sample_rate=16000)
reverb = Reverberate(rir_source=random_rir_reader)
muse_augment = RandomChoice([noisify1, noisify2, noisify3])
wav_augments = RandomApply([muse_augment, reverb], 0.25)
transforms += [wav_augments]
melspectrogram = LogMelSpectrogram(**config['fbank'])
transforms += [melspectrogram]
if config['normalize']:
transforms += [Normalize(config['mean_std_file'])]
if config['augment_mel']:
#define spectrogram masking
time_masking = RandomMasking(
max_mask_count=config['max_time_mask'],
max_mask_width=config['max_time_mask_width'],
axis=-1)
freq_masking = RandomMasking(
max_mask_count=config['max_freq_mask'],
max_mask_width=config['max_freq_mask_width'],
axis=-2)
mel_augments = RandomApply([freq_masking, time_masking], p=0.25)
transforms += [mel_augments]
transforms = Compose(transforms)
if args.restore != -1:
logger.info(f'restoring from checkpoint {args.restore}')
fn = os.path.join(config['model_dir'],
f'{prefix}_checkpoint_epoch{args.restore}.tar')
ckpt = paddle.load(fn)
model.load_dict(ckpt['model'])
optimizer = Adam(learning_rate=config['max_lr'], parameters=params)
opti_state_dict = ckpt['opti']
try:
optimizer.set_state_dict(opti_state_dict)
except:
logger.error('failed to load state dict for optimizers')
try:
loss_fn.load_dict(ckpt['loss'])
except:
logger.error('failed to load state dict for loss')
start_epoch = args.restore + 1
else:
start_epoch = 0
optimizer = Adam(learning_rate=config['max_lr'], parameters=params)
if args.weight != '':
logger.info(f'loading weight from {args.weight}')
sd = paddle.load(args.weight)
model.load_dict(sd)
os.makedirs(config['model_dir'], exist_ok=True)
if args.distributed:
model = paddle.DataParallel(model)
train_loader = get_train_loader(config)
epoch_num = config['epoch_num']
if args.restore != -1 and local_rank == 0 and args.eval_at_begin:
result, min_dcf = compute_eer(config, model)
best_eer = result.eer #0.022#result.eer
logger.info(f'eer: {best_eer}')
else:
best_eer = 1.0
step = start_epoch * len(train_loader)
if config.get('freeze_param', None):
for p in list(model.parameters())[:config['freezed_layers']]:
if not isinstance(p, nn.BatchNorm1D):
p.stop_gradient = True
if not isinstance(p, nn.BatchNorm1D):
p.stop_gradient = True
for epoch in range(start_epoch, epoch_num):
avg_loss = 0.0
avg_acc = 0.0
model.train()
model.clear_gradients()
t0 = time.time()
if config['max_lr'] > config['base_lr']:
lr = get_lr(epoch - start_epoch, config['base_lr'],
config['max_lr'], config['half_cycle'],
config['reverse_lr'])
optimizer.set_lr(lr)
logger.info(f'Setting lr to {lr}')
for batch_id, (x, y) in enumerate(train_loader()):
x_mel = transforms(x)
logits = model(x_mel)
loss, pred = loss_fn(logits, y)
loss.backward()
optimizer.step()
model.clear_gradients()
acc = np.mean(np.argmax(pred.numpy(), axis=1) == y.numpy())
if batch_id < 100:
avg_acc = acc
avg_loss = loss.numpy()[0]
else:
factor = 0.999
avg_acc = avg_acc * factor + acc * (1 - factor)
avg_loss = avg_loss * factor + loss.numpy()[0] * (1 - factor)
elapsed = (time.time() - t0) / 3600
remain = elapsed / (1 + batch_id) * (len(train_loader) - batch_id)
msg = f'epoch:{epoch}, batch:{batch_id}'
msg += f'|{len(train_loader)}'
msg += f', loss:{avg_loss:.3}'
msg += f', acc:{avg_acc:.3}'
msg += f', lr:{optimizer.get_lr():.2}'
msg += f', elapsed:{elapsed:.3}h'
msg += f', remained:{remain:.3}h'
if batch_id % config['log_step'] == 0 and local_rank == 0:
logger.info(msg)
if step % config['checkpoint_step'] == 0 and local_rank == 0:
fn = os.path.join(config['model_dir'],
f'{prefix}_checkpoint_epoch{epoch}.tar')
obj = {
'model': model.state_dict(),
'loss': loss_fn.state_dict(),
'opti': optimizer.state_dict(),
'lr': optimizer.get_lr()
}
paddle.save(obj, fn)
if step != 0 and step % config['eval_step'] == 0 and local_rank == 0:
result, min_dcf = compute_eer(config, model)
eer = result.eer
model.train()
model.clear_gradients()
if eer < best_eer:
logger.info('eer improved from {} to {}'.format(
best_eer, eer))
best_eer = eer
fn = os.path.join(config['model_dir'],
f'{prefix}_epoch{epoch}_eer{eer:.3}')
paddle.save(model.state_dict(), fn + '.pdparams')
else:
logger.info(f'eer {eer} did not improve from {best_eer}')
step += 1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import os
import pickle
import random
from typing import Any, List, Optional, Union
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddleaudio
import paddleaudio.functional as F
from paddle import Tensor
from paddle.utils import download
__all__ = [
'NoiseSource',
'RIRSource',
'Normalize',
]
class NoiseSource:
"""Read audio files randomly or sequentially from disk and pack them as a tensor.
Parameters:
audio_path_or_files(os.PathLike|List[os.PathLike]]): the audio folder or the audio file list.
sample_rate(int): the target audio sample rate. If it is different from the native sample rate,
resampling method will be invoked.
duration(float): the duration after the audio is loaded. Padding or random cropping will take place
depending on whether actual audio length is shorter or longer than int(sample_rate*duration).
The audio tensor will have shape [batch_size, int(sample_rate*duration)]
batch_size(int): the number of audio files contained in the returned tensor.
random(bool): whether to read audio file randomly. If False, will read them sequentially.
Default: True.
Notes:
In sequential mode, once the end of audio list is reached, the reader will start over again.
The AudioSource object can be called endlessly.
Shapes:
- output: 2-D tensor with shape [batch_size, int(sample_rate*duration)]
Examples:
.. code-block:: python
import paddle
import paddleaudio.transforms as T
reader = AudioSource(<audio_folder>, sample_rate=16000, duration=3.0, batch_size=2)
audio = reader(x)
print(audio.shape)
>> [2,48000]
"""
def __init__(self,
audio_path_or_files: Union[os.PathLike, List[os.PathLike]],
sample_rate: int,
duration: float,
batch_size: int,
random: bool = True):
if isinstance(audio_path_or_files, list):
self.audio_files = audio_path_or_files
elif os.path.isdir(audio_path_or_files):
self.audio_files = glob.glob(audio_path_or_files + '/*.wav',
recursive=True)
if len(self.audio_files) == 0:
raise FileNotFoundError(
f'no files were found in {audio_path_or_files}')
elif os.path.isfile(audio_path_or_files):
self.audio_files = [audio_path_or_files]
else:
raise ValueError(
f'rir_path_or_files={audio_path_or_files} is invalid')
self.n_files = len(self.audio_files)
self.idx = 0
self.random = random
self.batch_size = batch_size
self.sample_rate = sample_rate
self.duration = int(duration * sample_rate)
self._data = paddle.zeros((self.batch_size, self.duration),
dtype='float32')
def load_wav(self, file: os.PathLike):
s, _ = paddleaudio.load(file, sr=self.sample_rate)
s = paddle.to_tensor(s)
s = F.random_cropping(s, target_size=self.duration)
s = F.center_padding(s, target_size=self.duration)
return s
def __call__(self) -> Tensor:
if self.random:
files = [
random.choice(self.audio_files) for _ in range(self.batch_size)
]
else:
files = []
for _ in range(self.batch_size):
file = self.audio_files[self.idx]
self.idx += 1
if self.idx >= self.n_files:
self.idx = 0
files += [file]
for i, f in enumerate(files):
self._data[i, :] = self.load_wav(f)
return self._data
def __repr__(self):
return (
self.__class__.__name__ +
f'(n_files={self.n_files}, random={self.random}, sample_rate={self.sample_rate})'
)
class RIRSource(nn.Layer):
"""Gererate RIR filter coefficients from local file sources.
Parameters:
rir_path_or_files(os.PathLike|List[os.PathLike]): the directory that contains rir files directly
(without subfolders) or the list of rir files.
Examples:
.. code-block:: python
import paddle
import paddleaudio.transforms as T
reader = T.RIRSource(<rir_folder>, sample_rate=16000, random=True)
weight = reader()
"""
def __init__(self,
rir_path_or_files: Union[os.PathLike, List[os.PathLike]],
sample_rate: int,
random: bool = True):
super(RIRSource, self).__init__()
if isinstance(rir_path_or_files, list):
self.rir_files = rir_path_or_files
elif os.path.isdir(rir_path_or_files):
self.rir_files = glob.glob(rir_path_or_files + '/*.wav',
recursive=True)
if len(self.rir_files) == 0:
raise FileNotFoundError(
f'no files were found in {rir_path_or_files}')
elif os.path.isfile(rir_path_or_files):
self.rir_files = [rir_path_or_files]
else:
raise ValueError(
f'rir_path_or_files={rir_path_or_files} is invalid')
self.n_files = len(self.rir_files)
self.idx = 0
self.random = random
self.sample_rate = sample_rate
def forward(self) -> Tensor:
if self.random:
file = random.choice(self.rir_files)
else:
i = self.idx % self.n_files
file = self.rir_files[i]
self.idx += 1
if self.idx >= self.n_files:
self.idx = 0
rir, _ = paddleaudio.load(file, sr=self.sample_rate, mono=True)
rir_weight = paddle.to_tensor(rir[None, None, ::-1])
rir_weight = paddle.nn.functional.normalize(rir_weight, p=2, axis=-1)
return rir_weight
def __repr__(self):
return (
self.__class__.__name__ +
f'(n_files={self.n_files}, random={self.random}, sample_rate={self.sample_rate})'
)
class Normalize:
def __init__(self, mean_file, eps=1e-5):
self.eps = eps
mean = paddle.load(mean_file)['mean']
std = paddle.load(mean_file)['std']
self.mean = mean.unsqueeze((0, 2))
def __call__(self, x):
assert x.ndim == 3
return x - self.mean
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册