From 9ba49968857c3272fe98e94987704ae4a3d40b3c Mon Sep 17 00:00:00 2001
From: KP <109694228@qq.com>
Date: Thu, 30 Dec 2021 14:50:34 +0800
Subject: [PATCH] Add ecapa_tdnn_voxceleb.
---
.../ecapa_tdnn_voxceleb/README.md | 117 ++++++
.../ecapa_tdnn_voxceleb/__init__.py | 0
.../ecapa_tdnn_voxceleb/ecapa_tdnn.py | 392 ++++++++++++++++++
.../ecapa_tdnn_voxceleb/feature.py | 99 +++++
.../ecapa_tdnn_voxceleb/module.py | 93 +++++
.../ecapa_tdnn_voxceleb/requirements.txt | 1 +
6 files changed, 702 insertions(+)
create mode 100644 modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
create mode 100644 modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
create mode 100644 modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
create mode 100644 modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
create mode 100644 modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
create mode 100644 modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
new file mode 100644
index 00000000..735c6295
--- /dev/null
+++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
@@ -0,0 +1,117 @@
+# ecapa_tdnn_voxceleb
+
+|模型名称|ecapa_tdnn_voxceleb|
+| :--- | :---: |
+|类别|语音-声纹识别|
+|网络|ECAPA-TDNN|
+|数据集|VoxCeleb|
+|是否支持Fine-tuning|否|
+|模型大小|79MB|
+|最新更新日期|2021-12-30|
+|数据指标|EER 0.69%|
+
+## 一、模型基本信息
+
+### 模型介绍
+
+ecapa_tdnn_voxceleb采用了[ECAPA-TDNN](https://arxiv.org/abs/2005.07143)的模型结构,并在[VoxCeleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/)数据集上进行了预训练,在VoxCeleb1的声纹识别测试集([veri_test.txt](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt))上的测试结果为 EER 0.69%,达到了该数据集的SOTA。
+
+
+
+
+
+
+
+更多详情请参考
+- [VoxCeleb: a large-scale speaker identification dataset](https://www.robots.ox.ac.uk/~vgg/publications/2017/Nagrani17/nagrani17.pdf)
+- [ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification](https://arxiv.org/pdf/2005.07143.pdf)
+- [The SpeechBrain Toolkit](https://github.com/speechbrain/speechbrain)
+
+
+## 二、安装
+
+- ### 1、环境依赖
+
+ - paddlepaddle >= 2.2.0
+
+ - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
+
+- ### 2、安装
+
+ - ```shell
+ $ hub install ecapa_tdnn_voxceleb
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+
+## 三、模型API预测
+
+- ### 1、预测代码示例
+
+ ```python
+ import paddlehub as hub
+
+ model = hub.Module(
+ name='ecapa_tdnn_voxceleb',
+ threshold=0.25,
+ version='1.0.0')
+
+ # 通过下列链接可下载示例音频
+ # https://paddlehub.bj.bcebos.com/hub_dev/sv1.wav
+ # https://paddlehub.bj.bcebos.com/hub_dev/sv2.wav
+
+ # Speaker Embedding
+ embedding = model.speaker_embedding('sv1.wav')
+ print(embedding.shape)
+ # (192,)
+
+ # Speaker Verification
+ score, pred = model.speaker_verify('sv1.wav', 'sv2.wav')
+ print(score, pred)
+ # [0.16354457], [False]
+ ```
+
+- ### 2、API
+ - ```python
+ def speaker_embedding(
+ wav: os.PathLike,
+ )
+ ```
+ - 获取输入音频的声纹特征
+
+ - **参数**
+
+ - `wav`:输入的说话人的音频文件,格式为`*.wav`。
+
+ - **返回**
+
+ - 输出纬度为 (192,) 的声纹特征向量。
+
+ - ```python
+ def speaker_verify(
+ wav1: os.PathLike,
+ wav2: os.PathLike,
+ )
+ ```
+ - 对比两段音频,分别计算其声纹特征的相似度得分,并判断是否为同一说话人。
+
+ - **参数**
+
+ - `wav1`:输入的说话人1的音频文件,格式为`*.wav`。
+ - `wav2`:输入的说话人2的音频文件,格式为`*.wav`。
+
+ - **返回**
+
+ - 返回声纹相似度得分[-1, 1]和预测结果。
+
+
+## 四、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+ ```shell
+ $ hub install ecapa_tdnn_voxceleb
+ ```
diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
new file mode 100644
index 00000000..59950860
--- /dev/null
+++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
@@ -0,0 +1,392 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import os
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+def length_to_mask(length, max_len=None, dtype=None):
+ assert len(length.shape) == 1
+
+ if max_len is None:
+ max_len = length.max().astype('int').item() # using arange to generate mask
+ mask = paddle.arange(max_len, dtype=length.dtype).expand((len(length), max_len)) < length.unsqueeze(1)
+
+ if dtype is None:
+ dtype = length.dtype
+
+ mask = paddle.to_tensor(mask, dtype=dtype)
+ return mask
+
+
+class Conv1d(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding="same",
+ dilation=1,
+ groups=1,
+ bias=True,
+ padding_mode="reflect",
+ ):
+ super(Conv1d, self).__init__()
+
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+ self.padding = padding
+ self.padding_mode = padding_mode
+
+ self.conv = nn.Conv1D(
+ in_channels,
+ out_channels,
+ self.kernel_size,
+ stride=self.stride,
+ padding=0,
+ dilation=self.dilation,
+ groups=groups,
+ bias_attr=bias,
+ )
+
+ def forward(self, x):
+ if self.padding == "same":
+ x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
+ else:
+ raise ValueError("Padding must be 'same'. Got {self.padding}")
+
+ return self.conv(x)
+
+ def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
+ L_in = x.shape[-1] # Detecting input shape
+ padding = self._get_padding_elem(L_in, stride, kernel_size, dilation) # Time padding
+ x = F.pad(x, padding, mode=self.padding_mode, data_format="NCL") # Applying padding
+ return x
+
+ def _get_padding_elem(self, L_in: int, stride: int, kernel_size: int, dilation: int):
+ if stride > 1:
+ n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
+ L_out = stride * (n_steps - 1) + kernel_size * dilation
+ padding = [kernel_size // 2, kernel_size // 2]
+ else:
+ L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
+
+ padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
+
+ return padding
+
+
+class BatchNorm1d(nn.Layer):
+ def __init__(
+ self,
+ input_size,
+ eps=1e-05,
+ momentum=0.9,
+ weight_attr=None,
+ bias_attr=None,
+ data_format='NCL',
+ use_global_stats=None,
+ ):
+ super(BatchNorm1d, self).__init__()
+
+ self.norm = nn.BatchNorm1D(
+ input_size,
+ epsilon=eps,
+ momentum=momentum,
+ weight_attr=weight_attr,
+ bias_attr=bias_attr,
+ data_format=data_format,
+ use_global_stats=use_global_stats,
+ )
+
+ def forward(self, x):
+ x_n = self.norm(x)
+ return x_n
+
+
+class TDNNBlock(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ dilation,
+ activation=nn.ReLU,
+ ):
+ super(TDNNBlock, self).__init__()
+ self.conv = Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ )
+ self.activation = activation()
+ self.norm = BatchNorm1d(input_size=out_channels)
+
+ def forward(self, x):
+ return self.norm(self.activation(self.conv(x)))
+
+
+class Res2NetBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels, scale=8, dilation=1):
+ super(Res2NetBlock, self).__init__()
+ assert in_channels % scale == 0
+ assert out_channels % scale == 0
+
+ in_channel = in_channels // scale
+ hidden_channel = out_channels // scale
+
+ self.blocks = nn.LayerList(
+ [TDNNBlock(in_channel, hidden_channel, kernel_size=3, dilation=dilation) for i in range(scale - 1)])
+ self.scale = scale
+
+ def forward(self, x):
+ y = []
+ for i, x_i in enumerate(paddle.chunk(x, self.scale, axis=1)):
+ if i == 0:
+ y_i = x_i
+ elif i == 1:
+ y_i = self.blocks[i - 1](x_i)
+ else:
+ y_i = self.blocks[i - 1](x_i + y_i)
+ y.append(y_i)
+ y = paddle.concat(y, axis=1)
+ return y
+
+
+class SEBlock(nn.Layer):
+ def __init__(self, in_channels, se_channels, out_channels):
+ super(SEBlock, self).__init__()
+
+ self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
+ self.relu = paddle.nn.ReLU()
+ self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
+ self.sigmoid = paddle.nn.Sigmoid()
+
+ def forward(self, x, lengths=None):
+ L = x.shape[-1]
+ if lengths is not None:
+ mask = length_to_mask(lengths * L, max_len=L)
+ mask = mask.unsqueeze(1)
+ total = mask.sum(axis=2, keepdim=True)
+ s = (x * mask).sum(axis=2, keepdim=True) / total
+ else:
+ s = x.mean(axis=2, keepdim=True)
+
+ s = self.relu(self.conv1(s))
+ s = self.sigmoid(self.conv2(s))
+
+ return s * x
+
+
+class AttentiveStatisticsPooling(nn.Layer):
+ def __init__(self, channels, attention_channels=128, global_context=True):
+ super().__init__()
+
+ self.eps = 1e-12
+ self.global_context = global_context
+ if global_context:
+ self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
+ else:
+ self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
+ self.tanh = nn.Tanh()
+ self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1)
+
+ def forward(self, x, lengths=None):
+ C, L = x.shape[1], x.shape[2] # KP: (N, C, L)
+
+ def _compute_statistics(x, m, axis=2, eps=self.eps):
+ mean = (m * x).sum(axis)
+ std = paddle.sqrt((m * (x - mean.unsqueeze(axis)).pow(2)).sum(axis).clip(eps))
+ return mean, std
+
+ if lengths is None:
+ lengths = paddle.ones([x.shape[0]])
+
+ # Make binary mask of shape [N, 1, L]
+ mask = length_to_mask(lengths * L, max_len=L)
+ mask = mask.unsqueeze(1)
+
+ # Expand the temporal context of the pooling layer by allowing the
+ # self-attention to look at global properties of the utterance.
+ if self.global_context:
+ total = mask.sum(axis=2, keepdim=True).astype('float32')
+ mean, std = _compute_statistics(x, mask / total)
+ mean = mean.unsqueeze(2).tile((1, 1, L))
+ std = std.unsqueeze(2).tile((1, 1, L))
+ attn = paddle.concat([x, mean, std], axis=1)
+ else:
+ attn = x
+
+ # Apply layers
+ attn = self.conv(self.tanh(self.tdnn(attn)))
+
+ # Filter out zero-paddings
+ attn = paddle.where(mask.tile((1, C, 1)) == 0, paddle.ones_like(attn) * float("-inf"), attn)
+
+ attn = F.softmax(attn, axis=2)
+ mean, std = _compute_statistics(x, attn)
+
+ # Append mean and std of the batch
+ pooled_stats = paddle.concat((mean, std), axis=1)
+ pooled_stats = pooled_stats.unsqueeze(2)
+
+ return pooled_stats
+
+
+class SERes2NetBlock(nn.Layer):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ res2net_scale=8,
+ se_channels=128,
+ kernel_size=1,
+ dilation=1,
+ activation=nn.ReLU,
+ ):
+ super(SERes2NetBlock, self).__init__()
+ self.out_channels = out_channels
+ self.tdnn1 = TDNNBlock(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ dilation=1,
+ activation=activation,
+ )
+ self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, dilation)
+ self.tdnn2 = TDNNBlock(
+ out_channels,
+ out_channels,
+ kernel_size=1,
+ dilation=1,
+ activation=activation,
+ )
+ self.se_block = SEBlock(out_channels, se_channels, out_channels)
+
+ self.shortcut = None
+ if in_channels != out_channels:
+ self.shortcut = Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x, lengths=None):
+ residual = x
+ if self.shortcut:
+ residual = self.shortcut(x)
+
+ x = self.tdnn1(x)
+ x = self.res2net_block(x)
+ x = self.tdnn2(x)
+ x = self.se_block(x, lengths)
+
+ return x + residual
+
+
+class ECAPA_TDNN(nn.Layer):
+ def __init__(
+ self,
+ input_size,
+ lin_neurons=192,
+ activation=nn.ReLU,
+ channels=[512, 512, 512, 512, 1536],
+ kernel_sizes=[5, 3, 3, 3, 1],
+ dilations=[1, 2, 3, 4, 1],
+ attention_channels=128,
+ res2net_scale=8,
+ se_channels=128,
+ global_context=True,
+ ):
+
+ super(ECAPA_TDNN, self).__init__()
+ assert len(channels) == len(kernel_sizes)
+ assert len(channels) == len(dilations)
+ self.channels = channels
+ self.blocks = nn.LayerList()
+ self.emb_size = lin_neurons
+
+ # The initial TDNN layer
+ self.blocks.append(TDNNBlock(
+ input_size,
+ channels[0],
+ kernel_sizes[0],
+ dilations[0],
+ activation,
+ ))
+
+ # SE-Res2Net layers
+ for i in range(1, len(channels) - 1):
+ self.blocks.append(
+ SERes2NetBlock(
+ channels[i - 1],
+ channels[i],
+ res2net_scale=res2net_scale,
+ se_channels=se_channels,
+ kernel_size=kernel_sizes[i],
+ dilation=dilations[i],
+ activation=activation,
+ ))
+
+ # Multi-layer feature aggregation
+ self.mfa = TDNNBlock(
+ channels[-1],
+ channels[-1],
+ kernel_sizes[-1],
+ dilations[-1],
+ activation,
+ )
+
+ # Attentive Statistical Pooling
+ self.asp = AttentiveStatisticsPooling(
+ channels[-1],
+ attention_channels=attention_channels,
+ global_context=global_context,
+ )
+ self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
+
+ # Final linear transformation
+ self.fc = Conv1d(
+ in_channels=channels[-1] * 2,
+ out_channels=self.emb_size,
+ kernel_size=1,
+ )
+
+ def forward(self, x, lengths=None):
+ xl = []
+ for layer in self.blocks:
+ try:
+ x = layer(x, lengths=lengths)
+ except TypeError:
+ x = layer(x)
+ xl.append(x)
+
+ # Multi-layer feature aggregation
+ x = paddle.concat(xl[1:], axis=1)
+ x = self.mfa(x)
+
+ # Attentive Statistical Pooling
+ x = self.asp(x, lengths=lengths)
+ x = self.asp_bn(x)
+
+ # Final linear transformation
+ x = self.fc(x)
+
+ return x
diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
new file mode 100644
index 00000000..5feee8f2
--- /dev/null
+++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
@@ -0,0 +1,99 @@
+import paddle
+import paddleaudio
+from paddleaudio.features.spectrum import hz_to_mel
+from paddleaudio.features.spectrum import mel_to_hz
+from paddleaudio.features.spectrum import power_to_db
+from paddleaudio.features.spectrum import Spectrogram
+from paddleaudio.features.window import get_window
+
+
+def compute_fbank_matrix(sample_rate: int = 16000,
+ n_fft: int = 400,
+ n_mels: int = 80,
+ f_min: int = 0.0,
+ f_max: int = 8000.0):
+ mel = paddle.linspace(hz_to_mel(f_min, htk=True), hz_to_mel(f_max, htk=True), n_mels + 2, dtype=paddle.float32)
+ hz = mel_to_hz(mel, htk=True)
+
+ band = hz[1:] - hz[:-1]
+ band = band[:-1]
+ f_central = hz[1:-1]
+
+ n_stft = n_fft // 2 + 1
+ all_freqs = paddle.linspace(0, sample_rate // 2, n_stft)
+ all_freqs_mat = all_freqs.tile([f_central.shape[0], 1])
+
+ f_central_mat = f_central.tile([all_freqs_mat.shape[1], 1]).transpose([1, 0])
+ band_mat = band.tile([all_freqs_mat.shape[1], 1]).transpose([1, 0])
+
+ slope = (all_freqs_mat - f_central_mat) / band_mat
+ left_side = slope + 1.0
+ right_side = -slope + 1.0
+
+ fbank_matrix = paddle.maximum(paddle.zeros_like(left_side), paddle.minimum(left_side, right_side))
+
+ return fbank_matrix
+
+
+def compute_log_fbank(
+ x: paddle.Tensor,
+ sample_rate: int = 16000,
+ n_fft: int = 400,
+ hop_length: int = 160,
+ win_length: int = 400,
+ n_mels: int = 80,
+ window: str = 'hamming',
+ center: bool = True,
+ pad_mode: str = 'constant',
+ f_min: float = 0.0,
+ f_max: float = None,
+ top_db: float = 80.0,
+):
+
+ if f_max is None:
+ f_max = sample_rate / 2
+
+ spect = Spectrogram(
+ n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode)(x)
+
+ fbank_matrix = compute_fbank_matrix(
+ sample_rate=sample_rate,
+ n_fft=n_fft,
+ n_mels=n_mels,
+ f_min=f_min,
+ f_max=f_max,
+ )
+ fbank = paddle.matmul(fbank_matrix, spect)
+ log_fbank = power_to_db(fbank, top_db=top_db).transpose([0, 2, 1])
+ return log_fbank
+
+
+def compute_stats(x: paddle.Tensor, mean_norm: bool = True, std_norm: bool = False, eps: float = 1e-10):
+ if mean_norm:
+ current_mean = paddle.mean(x, axis=0)
+ else:
+ current_mean = paddle.to_tensor([0.0])
+
+ if std_norm:
+ current_std = paddle.std(x, axis=0)
+ else:
+ current_std = paddle.to_tensor([1.0])
+
+ current_std = paddle.maximum(current_std, eps * paddle.ones_like(current_std))
+
+ return current_mean, current_std
+
+
+def normalize(
+ x: paddle.Tensor,
+ global_mean: paddle.Tensor = None,
+ global_std: paddle.Tensor = None,
+):
+
+ for i in range(x.shape[0]): # (B, ...)
+ if global_mean is None and global_std is None:
+ mean, std = compute_stats(x[i])
+ x[i] = (x[i] - mean) / std
+ else:
+ x[i] = (x[i] - global_mean) / global_std
+ return x
diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
new file mode 100644
index 00000000..11f7121a
--- /dev/null
+++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+from typing import List
+from typing import Union
+
+import numpy as np
+import paddle
+import paddleaudio
+
+from .ecapa_tdnn import ECAPA_TDNN
+from .feature import compute_log_fbank
+from .feature import normalize
+from paddlehub.module.module import moduleinfo
+from paddlehub.utils.log import logger
+
+
+@moduleinfo(
+ name="ecapa_tdnn_voxceleb",
+ version="1.0.0",
+ summary="",
+ author="paddlepaddle",
+ author_email="",
+ type="audio/speaker_recognition")
+class SpeakerRecognition(paddle.nn.Layer):
+ def __init__(self, threshold=0.25):
+ super(SpeakerRecognition, self).__init__()
+ global_stats_path = os.path.join(self.directory, 'assets', 'global_embedding_stats.npy')
+ ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams')
+
+ self.sr = 16000
+ self.threshold = threshold
+ model_conf = {
+ 'input_size': 80,
+ 'channels': [1024, 1024, 1024, 1024, 3072],
+ 'kernel_sizes': [5, 3, 3, 3, 1],
+ 'dilations': [1, 2, 3, 4, 1],
+ 'attention_channels': 128,
+ 'lin_neurons': 192
+ }
+ self.model = ECAPA_TDNN(**model_conf)
+ self.model.set_state_dict(paddle.load(ckpt_path))
+ self.model.eval()
+
+ global_embedding_stats = np.load(global_stats_path, allow_pickle=True)
+ self.global_emb_mean = paddle.to_tensor(global_embedding_stats.item().get('global_emb_mean'))
+ self.global_emb_std = paddle.to_tensor(global_embedding_stats.item().get('global_emb_std'))
+
+ self.similarity = paddle.nn.CosineSimilarity(axis=-1, eps=1e-6)
+
+ def load_audio(self, wav):
+ wav = os.path.abspath(os.path.expanduser(wav))
+ assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav)
+ waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False)
+ return waveform
+
+ def speaker_embedding(self, wav):
+ waveform = self.load_audio(wav)
+ embedding = self(paddle.to_tensor(waveform)).reshape([-1])
+ return embedding.numpy()
+
+ def speaker_verify(self, wav1, wav2):
+ waveform1 = self.load_audio(wav1)
+ embedding1 = self(paddle.to_tensor(waveform1)).reshape([-1])
+
+ waveform2 = self.load_audio(wav2)
+ embedding2 = self(paddle.to_tensor(waveform2)).reshape([-1])
+
+ score = self.similarity(embedding1, embedding2).numpy()
+ return score, score > self.threshold
+
+ def forward(self, x):
+ if len(x.shape) == 1:
+ x = x.unsqueeze(0)
+
+ fbank = compute_log_fbank(x) # x: waveform tensors with (B, T) shape
+ norm_fbank = normalize(fbank)
+ embedding = self.model(norm_fbank.transpose([0, 2, 1])).transpose([0, 2, 1])
+ norm_embedding = normalize(x=embedding, global_mean=self.global_emb_mean, global_std=self.global_emb_std)
+
+ return norm_embedding
diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
new file mode 100644
index 00000000..defe617f
--- /dev/null
+++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
@@ -0,0 +1 @@
+paddleaudio==0.1.0
--
GitLab