diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3e3357a09341435e312e2c12314e1e85b30cff53 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/README.md @@ -0,0 +1,98 @@ +# kwmlp_speech_commands + +|模型名称|kwmlp_speech_commands| +| :--- | :---: | +|类别|语音-语言识别| +|网络|Keyword-MLP| +|数据集|Google Speech Commands V2| +|是否支持Fine-tuning|否| +|模型大小|1.6MB| +|最新更新日期|2022-01-04| +|数据指标|ACC 97.56%| + +## 一、模型基本信息 + +### 模型介绍 + +kwmlp_speech_commands采用了 [Keyword-MLP](https://arxiv.org/pdf/2110.07749v1.pdf) 的轻量级模型结构,并在 [Google Speech Commands V2](https://arxiv.org/abs/1804.03209) 数据集上进行了预训练,在其测试集的测试结果为 ACC 97.56%。 + +

+
+

+ + +更多详情请参考 +- [Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition](https://arxiv.org/abs/1804.03209) +- [ATTENTION-FREE KEYWORD SPOTTING](https://arxiv.org/pdf/2110.07749v1.pdf) +- [Keyword-MLP](https://github.com/AI-Research-BD/Keyword-MLP) + + +## 二、安装 + +- ### 1、环境依赖 + + - paddlepaddle >= 2.2.0 + + - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) + +- ### 2、安装 + + - ```shell + $ hub install kwmlp_speech_commands + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + + +## 三、模型API预测 + +- ### 1、预测代码示例 + + ```python + import paddlehub as hub + + model = hub.Module( + name='kwmlp_speech_commands', + version='1.0.0') + + # 通过下列链接可下载示例音频 + # https://paddlehub.bj.bcebos.com/paddlehub_dev/go.wav + + # Keyword spotting + score, label = model.keyword_recognize('no.wav') + print(score, label) + # [0.89498246] no + score, label = model.keyword_recognize('go.wav') + print(score, label) + # [0.8997176] go + score, label = model.keyword_recognize('one.wav') + print(score, label) + # [0.88598305] one + ``` + +- ### 2、API + - ```python + def keyword_recognize( + wav: os.PathLike, + ) + ``` + - 检测音频中包含的关键词。 + + - **参数** + + - `wav`:输入的包含关键词的音频文件,格式为`*.wav`。 + + - **返回** + + - 输出结果的得分和对应的关键词标签。 + + +## 四、更新历史 + +* 1.0.0 + + 初始发布 + + ```shell + $ hub install kwmlp_speech_commands + ``` diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py new file mode 100644 index 0000000000000000000000000000000000000000..900a2eab26e4414b487d6d7858381ee302a107e8 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import paddle +import paddleaudio + + +def create_dct(n_mfcc: int, n_mels: int, norm: str = 'ortho'): + n = paddle.arange(float(n_mels)) + k = paddle.arange(float(n_mfcc)).unsqueeze(1) + dct = paddle.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels) + if norm is None: + dct *= 2.0 + else: + assert norm == "ortho" + dct[0] *= 1.0 / math.sqrt(2.0) + dct *= math.sqrt(2.0 / float(n_mels)) + return dct.t() + + +def compute_mfcc( + x: paddle.Tensor, + sr: int = 16000, + n_mels: int = 40, + n_fft: int = 480, + win_length: int = 480, + hop_length: int = 160, + f_min: float = 0.0, + f_max: float = None, + center: bool = False, + top_db: float = 80.0, + norm: str = 'ortho', +): + fbank = paddleaudio.features.spectrum.MelSpectrogram( + sr=sr, + n_mels=n_mels, + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + f_min=0.0, + f_max=f_max, + center=center)(x) # waveforms batch ~ (B, T) + log_fbank = paddleaudio.features.spectrum.power_to_db(fbank, top_db=top_db) + dct_matrix = create_dct(n_mfcc=n_mels, n_mels=n_mels, norm=norm) + mfcc = paddle.matmul(log_fbank.transpose((0, 2, 1)), dct_matrix).transpose((0, 2, 1)) # (B, n_mels, L) + return mfcc diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py new file mode 100644 index 0000000000000000000000000000000000000000..df8c37e6fb14d1f5c43b0080410d3288406cfa77 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class Residual(nn.Layer): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class PreNorm(nn.Layer): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class PostNorm(nn.Layer): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.norm(self.fn(x, **kwargs)) + + +class SpatialGatingUnit(nn.Layer): + def __init__(self, dim, dim_seq, act=nn.Identity(), init_eps=1e-3): + super().__init__() + dim_out = dim // 2 + + self.norm = nn.LayerNorm(dim_out) + self.proj = nn.Conv1D(dim_seq, dim_seq, 1) + + self.act = act + + init_eps /= dim_seq + + def forward(self, x): + res, gate = x.split(2, axis=-1) + gate = self.norm(gate) + + weight, bias = self.proj.weight, self.proj.bias + gate = F.conv1d(gate, weight, bias) + + return self.act(gate) * res + + +class gMLPBlock(nn.Layer): + def __init__(self, *, dim, dim_ff, seq_len, act=nn.Identity()): + super().__init__() + self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff), nn.GELU()) + + self.sgu = SpatialGatingUnit(dim_ff, seq_len, act) + self.proj_out = nn.Linear(dim_ff // 2, dim) + + def forward(self, x): + x = self.proj_in(x) + x = self.sgu(x) + x = self.proj_out(x) + return x + + +class Rearrange(nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.transpose([0, 1, 3, 2]).squeeze(1) + return x + + +class Reduce(nn.Layer): + def __init__(self, axis=1): + super().__init__() + self.axis = axis + + def forward(self, x): + x = x.mean(axis=self.axis, keepdim=False) + return x + + +class KW_MLP(nn.Layer): + """Keyword-MLP.""" + + def __init__(self, + input_res=[40, 98], + patch_res=[40, 1], + num_classes=35, + dim=64, + depth=12, + ff_mult=4, + channels=1, + prob_survival=0.9, + pre_norm=False, + **kwargs): + super().__init__() + image_height, image_width = input_res + patch_height, patch_width = patch_res + assert (image_height % patch_height) == 0 and ( + image_width % patch_width) == 0, 'image height and width must be divisible by patch size' + num_patches = (image_height // patch_height) * (image_width // patch_width) + + P_Norm = PreNorm if pre_norm else PostNorm + + dim_ff = dim * ff_mult + + self.to_patch_embed = nn.Sequential(Rearrange(), nn.Linear(channels * patch_height * patch_width, dim)) + + self.prob_survival = prob_survival + + self.layers = nn.LayerList( + [Residual(P_Norm(dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches))) for i in range(depth)]) + + self.to_logits = nn.Sequential(nn.LayerNorm(dim), Reduce(axis=1), nn.Linear(dim, num_classes)) + + def forward(self, x): + x = self.to_patch_embed(x) + layers = self.layers + x = nn.Sequential(*layers)(x) + return self.to_logits(x) diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py new file mode 100644 index 0000000000000000000000000000000000000000..34342de360f2927236429baaa41789993038bd5a --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/module.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import paddle +import paddleaudio + +from .feature import compute_mfcc +from .kwmlp import KW_MLP +from paddlehub.module.module import moduleinfo +from paddlehub.utils.log import logger + + +@moduleinfo( + name="kwmlp_speech_commands", + version="1.0.0", + summary="", + author="paddlepaddle", + author_email="", + type="audio/language_identification") +class KWS(paddle.nn.Layer): + def __init__(self): + super(KWS, self).__init__() + ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams') + label_path = os.path.join(self.directory, 'assets', 'label.txt') + + self.label_list = [] + with open(label_path, 'r') as f: + for l in f: + self.label_list.append(l.strip()) + + self.sr = 16000 + model_conf = { + 'input_res': [40, 98], + 'patch_res': [40, 1], + 'num_classes': 35, + 'channels': 1, + 'dim': 64, + 'depth': 12, + 'pre_norm': False, + 'prob_survival': 0.9, + } + self.model = KW_MLP(**model_conf) + self.model.set_state_dict(paddle.load(ckpt_path)) + self.model.eval() + + def load_audio(self, wav): + wav = os.path.abspath(os.path.expanduser(wav)) + assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav) + waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False) + return waveform + + def keyword_recognize(self, wav): + waveform = self.load_audio(wav) + + # fix_length to 1s + if len(waveform) > self.sr: + waveform = waveform[:self.sr] + else: + waveform = np.pad(waveform, (0, self.sr - len(waveform))) + + logits = self(paddle.to_tensor(waveform)).reshape([-1]) + probs = paddle.nn.functional.softmax(logits) + idx = paddle.argmax(probs) + return probs[idx].numpy(), self.label_list[idx] + + def forward(self, x): + if len(x.shape) == 1: # x: waveform tensors with (B, T) shape + x = x.unsqueeze(0) + + mfcc = compute_mfcc(x).unsqueeze(1) # (B, C, n_mels, L) + logits = self.model(mfcc).squeeze(1) + + return logits diff --git a/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..defe617fa36bc5ab7b72438034c785ee2b3ac3c9 --- /dev/null +++ b/modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt @@ -0,0 +1 @@ +paddleaudio==0.1.0 diff --git a/modules/audio/language_identification/ecapa_tdnn_common_language/README.md b/modules/audio/language_identification/ecapa_tdnn_common_language/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f648202e4c97c1a1707bb1f0a0d98949735f047d --- /dev/null +++ b/modules/audio/language_identification/ecapa_tdnn_common_language/README.md @@ -0,0 +1,100 @@ +# ecapa_tdnn_common_language + +|模型名称|ecapa_tdnn_common_language| +| :--- | :---: | +|类别|语音-语言识别| +|网络|ECAPA-TDNN| +|数据集|CommonLanguage| +|是否支持Fine-tuning|否| +|模型大小|79MB| +|最新更新日期|2021-12-30| +|数据指标|ACC 84.9%| + +## 一、模型基本信息 + +### 模型介绍 + +ecapa_tdnn_common_language采用了[ECAPA-TDNN](https://arxiv.org/abs/2005.07143)的模型结构,并在[CommonLanguage](https://zenodo.org/record/5036977/)数据集上进行了预训练,在其测试集的测试结果为 ACC 84.9%。 + +

+
+

+ + +更多详情请参考 +- [CommonLanguage](https://zenodo.org/record/5036977#.Yc19b5Mzb0o) +- [ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification](https://arxiv.org/pdf/2005.07143.pdf) +- [The SpeechBrain Toolkit](https://github.com/speechbrain/speechbrain) + + +## 二、安装 + +- ### 1、环境依赖 + + - paddlepaddle >= 2.2.0 + + - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) + +- ### 2、安装 + + - ```shell + $ hub install ecapa_tdnn_common_language + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + + +## 三、模型API预测 + +- ### 1、预测代码示例 + + ```python + import paddlehub as hub + + model = hub.Module( + name='ecapa_tdnn_common_language', + version='1.0.0') + + # 通过下列链接可下载示例音频 + # https://paddlehub.bj.bcebos.com/paddlehub_dev/zh.wav + # https://paddlehub.bj.bcebos.com/paddlehub_dev/en.wav + # https://paddlehub.bj.bcebos.com/paddlehub_dev/it.wav + + # Language Identification + score, label = model.speaker_verify('zh.wav') + print(score, label) + # array([0.6214552], dtype=float32), 'Chinese_China' + score, label = model.speaker_verify('en.wav') + print(score, label) + # array([0.37193954], dtype=float32), 'English' + score, label = model.speaker_verify('it.wav') + print(score, label) + # array([0.46913534], dtype=float32), 'Italian' + ``` + +- ### 2、API + - ```python + def language_identify( + wav: os.PathLike, + ) + ``` + - 判断输入人声音频的语言类别。 + + - **参数** + + - `wav`:输入的说话人的音频文件,格式为`*.wav`。 + + - **返回** + + - 输出结果的得分和对应的语言类别。 + + +## 四、更新历史 + +* 1.0.0 + + 初始发布 + + ```shell + $ hub install ecapa_tdnn_common_language + ``` diff --git a/modules/audio/language_identification/ecapa_tdnn_common_language/__init__.py b/modules/audio/language_identification/ecapa_tdnn_common_language/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/modules/audio/language_identification/ecapa_tdnn_common_language/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/modules/audio/language_identification/ecapa_tdnn_common_language/ecapa_tdnn.py b/modules/audio/language_identification/ecapa_tdnn_common_language/ecapa_tdnn.py new file mode 100644 index 0000000000000000000000000000000000000000..950a9df7dd465abf56b30b5594e9b16adb49e573 --- /dev/null +++ b/modules/audio/language_identification/ecapa_tdnn_common_language/ecapa_tdnn.py @@ -0,0 +1,406 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def length_to_mask(length, max_len=None, dtype=None): + assert len(length.shape) == 1 + + if max_len is None: + max_len = length.max().astype('int').item() # using arange to generate mask + mask = paddle.arange(max_len, dtype=length.dtype).expand((len(length), max_len)) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + mask = paddle.to_tensor(mask, dtype=dtype) + return mask + + +class Conv1d(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding="same", + dilation=1, + groups=1, + bias=True, + padding_mode="reflect", + ): + super(Conv1d, self).__init__() + + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = padding + self.padding_mode = padding_mode + + self.conv = nn.Conv1D( + in_channels, + out_channels, + self.kernel_size, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=groups, + bias_attr=bias, + ) + + def forward(self, x): + if self.padding == "same": + x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) + else: + raise ValueError("Padding must be 'same'. Got {self.padding}") + + return self.conv(x) + + def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): + L_in = x.shape[-1] # Detecting input shape + padding = self._get_padding_elem(L_in, stride, kernel_size, dilation) # Time padding + x = F.pad(x, padding, mode=self.padding_mode, data_format="NCL") # Applying padding + return x + + def _get_padding_elem(self, L_in: int, stride: int, kernel_size: int, dilation: int): + if stride > 1: + n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) + L_out = stride * (n_steps - 1) + kernel_size * dilation + padding = [kernel_size // 2, kernel_size // 2] + else: + L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 + + padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] + + return padding + + +class BatchNorm1d(nn.Layer): + def __init__( + self, + input_size, + eps=1e-05, + momentum=0.9, + weight_attr=None, + bias_attr=None, + data_format='NCL', + use_global_stats=None, + ): + super(BatchNorm1d, self).__init__() + + self.norm = nn.BatchNorm1D( + input_size, + epsilon=eps, + momentum=momentum, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format, + use_global_stats=use_global_stats, + ) + + def forward(self, x): + x_n = self.norm(x) + return x_n + + +class TDNNBlock(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + activation=nn.ReLU, + ): + super(TDNNBlock, self).__init__() + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + ) + self.activation = activation() + self.norm = BatchNorm1d(input_size=out_channels) + + def forward(self, x): + return self.norm(self.activation(self.conv(x))) + + +class Res2NetBlock(nn.Layer): + def __init__(self, in_channels, out_channels, scale=8, dilation=1): + super(Res2NetBlock, self).__init__() + assert in_channels % scale == 0 + assert out_channels % scale == 0 + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.LayerList( + [TDNNBlock(in_channel, hidden_channel, kernel_size=3, dilation=dilation) for i in range(scale - 1)]) + self.scale = scale + + def forward(self, x): + y = [] + for i, x_i in enumerate(paddle.chunk(x, self.scale, axis=1)): + if i == 0: + y_i = x_i + elif i == 1: + y_i = self.blocks[i - 1](x_i) + else: + y_i = self.blocks[i - 1](x_i + y_i) + y.append(y_i) + y = paddle.concat(y, axis=1) + return y + + +class SEBlock(nn.Layer): + def __init__(self, in_channels, se_channels, out_channels): + super(SEBlock, self).__init__() + + self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1) + self.relu = paddle.nn.ReLU() + self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1) + self.sigmoid = paddle.nn.Sigmoid() + + def forward(self, x, lengths=None): + L = x.shape[-1] + if lengths is not None: + mask = length_to_mask(lengths * L, max_len=L) + mask = mask.unsqueeze(1) + total = mask.sum(axis=2, keepdim=True) + s = (x * mask).sum(axis=2, keepdim=True) / total + else: + s = x.mean(axis=2, keepdim=True) + + s = self.relu(self.conv1(s)) + s = self.sigmoid(self.conv2(s)) + + return s * x + + +class AttentiveStatisticsPooling(nn.Layer): + def __init__(self, channels, attention_channels=128, global_context=True): + super().__init__() + + self.eps = 1e-12 + self.global_context = global_context + if global_context: + self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) + else: + self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1) + + def forward(self, x, lengths=None): + C, L = x.shape[1], x.shape[2] # KP: (N, C, L) + + def _compute_statistics(x, m, axis=2, eps=self.eps): + mean = (m * x).sum(axis) + std = paddle.sqrt((m * (x - mean.unsqueeze(axis)).pow(2)).sum(axis).clip(eps)) + return mean, std + + if lengths is None: + lengths = paddle.ones([x.shape[0]]) + + # Make binary mask of shape [N, 1, L] + mask = length_to_mask(lengths * L, max_len=L) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + if self.global_context: + total = mask.sum(axis=2, keepdim=True).astype('float32') + mean, std = _compute_statistics(x, mask / total) + mean = mean.unsqueeze(2).tile((1, 1, L)) + std = std.unsqueeze(2).tile((1, 1, L)) + attn = paddle.concat([x, mean, std], axis=1) + else: + attn = x + + # Apply layers + attn = self.conv(self.tanh(self.tdnn(attn))) + + # Filter out zero-paddings + attn = paddle.where(mask.tile((1, C, 1)) == 0, paddle.ones_like(attn) * float("-inf"), attn) + + attn = F.softmax(attn, axis=2) + mean, std = _compute_statistics(x, attn) + + # Append mean and std of the batch + pooled_stats = paddle.concat((mean, std), axis=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SERes2NetBlock(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + activation=nn.ReLU, + ): + super(SERes2NetBlock, self).__init__() + self.out_channels = out_channels + self.tdnn1 = TDNNBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, dilation) + self.tdnn2 = TDNNBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + ) + self.se_block = SEBlock(out_channels, se_channels, out_channels) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x, lengths=None): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.tdnn1(x) + x = self.res2net_block(x) + x = self.tdnn2(x) + x = self.se_block(x, lengths) + + return x + residual + + +class ECAPA_TDNN(nn.Layer): + def __init__( + self, + input_size, + lin_neurons=192, + activation=nn.ReLU, + channels=[512, 512, 512, 512, 1536], + kernel_sizes=[5, 3, 3, 3, 1], + dilations=[1, 2, 3, 4, 1], + attention_channels=128, + res2net_scale=8, + se_channels=128, + global_context=True, + ): + + super(ECAPA_TDNN, self).__init__() + assert len(channels) == len(kernel_sizes) + assert len(channels) == len(dilations) + self.channels = channels + self.blocks = nn.LayerList() + self.emb_size = lin_neurons + + # The initial TDNN layer + self.blocks.append(TDNNBlock( + input_size, + channels[0], + kernel_sizes[0], + dilations[0], + activation, + )) + + # SE-Res2Net layers + for i in range(1, len(channels) - 1): + self.blocks.append( + SERes2NetBlock( + channels[i - 1], + channels[i], + res2net_scale=res2net_scale, + se_channels=se_channels, + kernel_size=kernel_sizes[i], + dilation=dilations[i], + activation=activation, + )) + + # Multi-layer feature aggregation + self.mfa = TDNNBlock( + channels[-1], + channels[-1], + kernel_sizes[-1], + dilations[-1], + activation, + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + channels[-1], + attention_channels=attention_channels, + global_context=global_context, + ) + self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2) + + # Final linear transformation + self.fc = Conv1d( + in_channels=channels[-1] * 2, + out_channels=self.emb_size, + kernel_size=1, + ) + + def forward(self, x, lengths=None): + xl = [] + for layer in self.blocks: + try: + x = layer(x, lengths=lengths) + except TypeError: + x = layer(x) + xl.append(x) + + # Multi-layer feature aggregation + x = paddle.concat(xl[1:], axis=1) + x = self.mfa(x) + + # Attentive Statistical Pooling + x = self.asp(x, lengths=lengths) + x = self.asp_bn(x) + + # Final linear transformation + x = self.fc(x) + + return x + + +class Classifier(nn.Layer): + def __init__(self, backbone, num_class, dtype=paddle.float32): + super(Classifier, self).__init__() + self.backbone = backbone + self.params = nn.ParameterList( + [paddle.create_parameter(shape=[num_class, self.backbone.emb_size], dtype=dtype)]) + + def forward(self, x): + emb = self.backbone(x.transpose([0, 2, 1])).transpose([0, 2, 1]) + logits = F.linear(F.normalize(emb.squeeze(1)), F.normalize(self.params[0]).transpose([1, 0])) + + return logits diff --git a/modules/audio/language_identification/ecapa_tdnn_common_language/feature.py b/modules/audio/language_identification/ecapa_tdnn_common_language/feature.py new file mode 100644 index 0000000000000000000000000000000000000000..09b930ebfd4cd56c9be1bc107f4ca6fc5f948027 --- /dev/null +++ b/modules/audio/language_identification/ecapa_tdnn_common_language/feature.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddleaudio +from paddleaudio.features.spectrum import hz_to_mel +from paddleaudio.features.spectrum import mel_to_hz +from paddleaudio.features.spectrum import power_to_db +from paddleaudio.features.spectrum import Spectrogram +from paddleaudio.features.window import get_window + + +def compute_fbank_matrix(sample_rate: int = 16000, + n_fft: int = 400, + n_mels: int = 80, + f_min: int = 0.0, + f_max: int = 8000.0): + mel = paddle.linspace(hz_to_mel(f_min, htk=True), hz_to_mel(f_max, htk=True), n_mels + 2, dtype=paddle.float32) + hz = mel_to_hz(mel, htk=True) + + band = hz[1:] - hz[:-1] + band = band[:-1] + f_central = hz[1:-1] + + n_stft = n_fft // 2 + 1 + all_freqs = paddle.linspace(0, sample_rate // 2, n_stft) + all_freqs_mat = all_freqs.tile([f_central.shape[0], 1]) + + f_central_mat = f_central.tile([all_freqs_mat.shape[1], 1]).transpose([1, 0]) + band_mat = band.tile([all_freqs_mat.shape[1], 1]).transpose([1, 0]) + + slope = (all_freqs_mat - f_central_mat) / band_mat + left_side = slope + 1.0 + right_side = -slope + 1.0 + + fbank_matrix = paddle.maximum(paddle.zeros_like(left_side), paddle.minimum(left_side, right_side)) + + return fbank_matrix + + +def compute_log_fbank( + x: paddle.Tensor, + sample_rate: int = 16000, + n_fft: int = 400, + hop_length: int = 160, + win_length: int = 400, + n_mels: int = 80, + window: str = 'hamming', + center: bool = True, + pad_mode: str = 'constant', + f_min: float = 0.0, + f_max: float = None, + top_db: float = 80.0, +): + + if f_max is None: + f_max = sample_rate / 2 + + spect = Spectrogram( + n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode)(x) + + fbank_matrix = compute_fbank_matrix( + sample_rate=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + f_min=f_min, + f_max=f_max, + ) + fbank = paddle.matmul(fbank_matrix, spect) + log_fbank = power_to_db(fbank, top_db=top_db).transpose([0, 2, 1]) + return log_fbank + + +def compute_stats(x: paddle.Tensor, mean_norm: bool = True, std_norm: bool = False, eps: float = 1e-10): + if mean_norm: + current_mean = paddle.mean(x, axis=0) + else: + current_mean = paddle.to_tensor([0.0]) + + if std_norm: + current_std = paddle.std(x, axis=0) + else: + current_std = paddle.to_tensor([1.0]) + + current_std = paddle.maximum(current_std, eps * paddle.ones_like(current_std)) + + return current_mean, current_std + + +def normalize( + x: paddle.Tensor, + global_mean: paddle.Tensor = None, + global_std: paddle.Tensor = None, +): + + for i in range(x.shape[0]): # (B, ...) + if global_mean is None and global_std is None: + mean, std = compute_stats(x[i]) + x[i] = (x[i] - mean) / std + else: + x[i] = (x[i] - global_mean) / global_std + return x diff --git a/modules/audio/language_identification/ecapa_tdnn_common_language/module.py b/modules/audio/language_identification/ecapa_tdnn_common_language/module.py new file mode 100644 index 0000000000000000000000000000000000000000..1950deaf1b5843c5f69269bb6982691739b0332e --- /dev/null +++ b/modules/audio/language_identification/ecapa_tdnn_common_language/module.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from typing import List +from typing import Union + +import numpy as np +import paddle +import paddleaudio + +from .ecapa_tdnn import Classifier +from .ecapa_tdnn import ECAPA_TDNN +from .feature import compute_log_fbank +from .feature import normalize +from paddlehub.module.module import moduleinfo +from paddlehub.utils.log import logger + + +@moduleinfo( + name="ecapa_tdnn_common_language", + version="1.0.0", + summary="", + author="paddlepaddle", + author_email="", + type="audio/language_identification") +class LanguageIdentification(paddle.nn.Layer): + def __init__(self): + super(LanguageIdentification, self).__init__() + ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams') + label_path = os.path.join(self.directory, 'assets', 'label.txt') + + self.label_list = [] + with open(label_path, 'r') as f: + for l in f: + self.label_list.append(l.strip()) + + self.sr = 16000 + model_conf = { + 'input_size': 80, + 'channels': [1024, 1024, 1024, 1024, 3072], + 'kernel_sizes': [5, 3, 3, 3, 1], + 'dilations': [1, 2, 3, 4, 1], + 'attention_channels': 128, + 'lin_neurons': 192 + } + self.model = Classifier( + backbone=ECAPA_TDNN(**model_conf), + num_class=45, + ) + self.model.set_state_dict(paddle.load(ckpt_path)) + self.model.eval() + + def load_audio(self, wav): + wav = os.path.abspath(os.path.expanduser(wav)) + assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav) + waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False) + return waveform + + def language_identify(self, wav): + waveform = self.load_audio(wav) + logits = self(paddle.to_tensor(waveform)).reshape([-1]) + idx = paddle.argmax(logits) + return logits[idx].numpy(), self.label_list[idx] + + def forward(self, x): + if len(x.shape) == 1: + x = x.unsqueeze(0) + + fbank = compute_log_fbank(x) # x: waveform tensors with (B, T) shape + norm_fbank = normalize(fbank) + logits = self.model(norm_fbank).squeeze(1) + + return logits diff --git a/modules/audio/language_identification/ecapa_tdnn_common_language/requirements.txt b/modules/audio/language_identification/ecapa_tdnn_common_language/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..defe617fa36bc5ab7b72438034c785ee2b3ac3c9 --- /dev/null +++ b/modules/audio/language_identification/ecapa_tdnn_common_language/requirements.txt @@ -0,0 +1 @@ +paddleaudio==0.1.0 diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md new file mode 100644 index 0000000000000000000000000000000000000000..70da7371cc411e535a4b53fd74a46c9a2521a016 --- /dev/null +++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md @@ -0,0 +1,128 @@ +# ecapa_tdnn_voxceleb + +|模型名称|ecapa_tdnn_voxceleb| +| :--- | :---: | +|类别|语音-声纹识别| +|网络|ECAPA-TDNN| +|数据集|VoxCeleb| +|是否支持Fine-tuning|否| +|模型大小|79MB| +|最新更新日期|2021-12-30| +|数据指标|EER 0.69%| + +## 一、模型基本信息 + +### 模型介绍 + +ecapa_tdnn_voxceleb采用了[ECAPA-TDNN](https://arxiv.org/abs/2005.07143)的模型结构,并在[VoxCeleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/)数据集上进行了预训练,在VoxCeleb1的声纹识别测试集([veri_test.txt](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt))上的测试结果为 EER 0.69%,达到了该数据集的SOTA。 + +

+
+

+ + + +更多详情请参考 +- [VoxCeleb: a large-scale speaker identification dataset](https://www.robots.ox.ac.uk/~vgg/publications/2017/Nagrani17/nagrani17.pdf) +- [ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification](https://arxiv.org/pdf/2005.07143.pdf) +- [The SpeechBrain Toolkit](https://github.com/speechbrain/speechbrain) + + +## 二、安装 + +- ### 1、环境依赖 + + - paddlepaddle >= 2.2.0 + + - paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) + +- ### 2、安装 + + - ```shell + $ hub install ecapa_tdnn_voxceleb + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + + +## 三、模型API预测 + +- ### 1、预测代码示例 + + ```python + import paddlehub as hub + + model = hub.Module( + name='ecapa_tdnn_voxceleb', + threshold=0.25, + version='1.0.0') + + # 通过下列链接可下载示例音频 + # https://paddlehub.bj.bcebos.com/paddlehub_dev/sv1.wav + # https://paddlehub.bj.bcebos.com/paddlehub_dev/sv2.wav + + # Speaker Embedding + embedding = model.speaker_embedding('sv1.wav') + print(embedding.shape) + # (192,) + + # Speaker Verification + score, pred = model.speaker_verify('sv1.wav', 'sv2.wav') + print(score, pred) + # [0.16354457], [False] + ``` + +- ### 2、API + - ```python + def __init__( + threshold: float, + ) + ``` + - 初始化声纹模型,确定判别阈值。 + + - **参数** + + - `threshold`:设定模型判别声纹相似度的得分阈值,默认为 0.25。 + + - ```python + def speaker_embedding( + wav: os.PathLike, + ) + ``` + - 获取输入音频的声纹特征 + + - **参数** + + - `wav`:输入的说话人的音频文件,格式为`*.wav`。 + + - **返回** + + - 输出纬度为 (192,) 的声纹特征向量。 + + - ```python + def speaker_verify( + wav1: os.PathLike, + wav2: os.PathLike, + ) + ``` + - 对比两段音频,分别计算其声纹特征的相似度得分,并判断是否为同一说话人。 + + - **参数** + + - `wav1`:输入的说话人1的音频文件,格式为`*.wav`。 + - `wav2`:输入的说话人2的音频文件,格式为`*.wav`。 + + - **返回** + + - 返回声纹相似度得分[-1, 1]和预测结果。 + + +## 四、更新历史 + +* 1.0.0 + + 初始发布 + + ```shell + $ hub install ecapa_tdnn_voxceleb + ``` diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..185a92b8d94d3426d616c0624f0f2ee04339349e --- /dev/null +++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py new file mode 100644 index 0000000000000000000000000000000000000000..59950860985414aaca3a46657cd11cd9645c223c --- /dev/null +++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py @@ -0,0 +1,392 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import os + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def length_to_mask(length, max_len=None, dtype=None): + assert len(length.shape) == 1 + + if max_len is None: + max_len = length.max().astype('int').item() # using arange to generate mask + mask = paddle.arange(max_len, dtype=length.dtype).expand((len(length), max_len)) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + mask = paddle.to_tensor(mask, dtype=dtype) + return mask + + +class Conv1d(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding="same", + dilation=1, + groups=1, + bias=True, + padding_mode="reflect", + ): + super(Conv1d, self).__init__() + + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.padding = padding + self.padding_mode = padding_mode + + self.conv = nn.Conv1D( + in_channels, + out_channels, + self.kernel_size, + stride=self.stride, + padding=0, + dilation=self.dilation, + groups=groups, + bias_attr=bias, + ) + + def forward(self, x): + if self.padding == "same": + x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) + else: + raise ValueError("Padding must be 'same'. Got {self.padding}") + + return self.conv(x) + + def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int): + L_in = x.shape[-1] # Detecting input shape + padding = self._get_padding_elem(L_in, stride, kernel_size, dilation) # Time padding + x = F.pad(x, padding, mode=self.padding_mode, data_format="NCL") # Applying padding + return x + + def _get_padding_elem(self, L_in: int, stride: int, kernel_size: int, dilation: int): + if stride > 1: + n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) + L_out = stride * (n_steps - 1) + kernel_size * dilation + padding = [kernel_size // 2, kernel_size // 2] + else: + L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 + + padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] + + return padding + + +class BatchNorm1d(nn.Layer): + def __init__( + self, + input_size, + eps=1e-05, + momentum=0.9, + weight_attr=None, + bias_attr=None, + data_format='NCL', + use_global_stats=None, + ): + super(BatchNorm1d, self).__init__() + + self.norm = nn.BatchNorm1D( + input_size, + epsilon=eps, + momentum=momentum, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format, + use_global_stats=use_global_stats, + ) + + def forward(self, x): + x_n = self.norm(x) + return x_n + + +class TDNNBlock(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + activation=nn.ReLU, + ): + super(TDNNBlock, self).__init__() + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + ) + self.activation = activation() + self.norm = BatchNorm1d(input_size=out_channels) + + def forward(self, x): + return self.norm(self.activation(self.conv(x))) + + +class Res2NetBlock(nn.Layer): + def __init__(self, in_channels, out_channels, scale=8, dilation=1): + super(Res2NetBlock, self).__init__() + assert in_channels % scale == 0 + assert out_channels % scale == 0 + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.LayerList( + [TDNNBlock(in_channel, hidden_channel, kernel_size=3, dilation=dilation) for i in range(scale - 1)]) + self.scale = scale + + def forward(self, x): + y = [] + for i, x_i in enumerate(paddle.chunk(x, self.scale, axis=1)): + if i == 0: + y_i = x_i + elif i == 1: + y_i = self.blocks[i - 1](x_i) + else: + y_i = self.blocks[i - 1](x_i + y_i) + y.append(y_i) + y = paddle.concat(y, axis=1) + return y + + +class SEBlock(nn.Layer): + def __init__(self, in_channels, se_channels, out_channels): + super(SEBlock, self).__init__() + + self.conv1 = Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1) + self.relu = paddle.nn.ReLU() + self.conv2 = Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1) + self.sigmoid = paddle.nn.Sigmoid() + + def forward(self, x, lengths=None): + L = x.shape[-1] + if lengths is not None: + mask = length_to_mask(lengths * L, max_len=L) + mask = mask.unsqueeze(1) + total = mask.sum(axis=2, keepdim=True) + s = (x * mask).sum(axis=2, keepdim=True) / total + else: + s = x.mean(axis=2, keepdim=True) + + s = self.relu(self.conv1(s)) + s = self.sigmoid(self.conv2(s)) + + return s * x + + +class AttentiveStatisticsPooling(nn.Layer): + def __init__(self, channels, attention_channels=128, global_context=True): + super().__init__() + + self.eps = 1e-12 + self.global_context = global_context + if global_context: + self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1) + else: + self.tdnn = TDNNBlock(channels, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = Conv1d(in_channels=attention_channels, out_channels=channels, kernel_size=1) + + def forward(self, x, lengths=None): + C, L = x.shape[1], x.shape[2] # KP: (N, C, L) + + def _compute_statistics(x, m, axis=2, eps=self.eps): + mean = (m * x).sum(axis) + std = paddle.sqrt((m * (x - mean.unsqueeze(axis)).pow(2)).sum(axis).clip(eps)) + return mean, std + + if lengths is None: + lengths = paddle.ones([x.shape[0]]) + + # Make binary mask of shape [N, 1, L] + mask = length_to_mask(lengths * L, max_len=L) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + if self.global_context: + total = mask.sum(axis=2, keepdim=True).astype('float32') + mean, std = _compute_statistics(x, mask / total) + mean = mean.unsqueeze(2).tile((1, 1, L)) + std = std.unsqueeze(2).tile((1, 1, L)) + attn = paddle.concat([x, mean, std], axis=1) + else: + attn = x + + # Apply layers + attn = self.conv(self.tanh(self.tdnn(attn))) + + # Filter out zero-paddings + attn = paddle.where(mask.tile((1, C, 1)) == 0, paddle.ones_like(attn) * float("-inf"), attn) + + attn = F.softmax(attn, axis=2) + mean, std = _compute_statistics(x, attn) + + # Append mean and std of the batch + pooled_stats = paddle.concat((mean, std), axis=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SERes2NetBlock(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + activation=nn.ReLU, + ): + super(SERes2NetBlock, self).__init__() + self.out_channels = out_channels + self.tdnn1 = TDNNBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, dilation) + self.tdnn2 = TDNNBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + activation=activation, + ) + self.se_block = SEBlock(out_channels, se_channels, out_channels) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x, lengths=None): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.tdnn1(x) + x = self.res2net_block(x) + x = self.tdnn2(x) + x = self.se_block(x, lengths) + + return x + residual + + +class ECAPA_TDNN(nn.Layer): + def __init__( + self, + input_size, + lin_neurons=192, + activation=nn.ReLU, + channels=[512, 512, 512, 512, 1536], + kernel_sizes=[5, 3, 3, 3, 1], + dilations=[1, 2, 3, 4, 1], + attention_channels=128, + res2net_scale=8, + se_channels=128, + global_context=True, + ): + + super(ECAPA_TDNN, self).__init__() + assert len(channels) == len(kernel_sizes) + assert len(channels) == len(dilations) + self.channels = channels + self.blocks = nn.LayerList() + self.emb_size = lin_neurons + + # The initial TDNN layer + self.blocks.append(TDNNBlock( + input_size, + channels[0], + kernel_sizes[0], + dilations[0], + activation, + )) + + # SE-Res2Net layers + for i in range(1, len(channels) - 1): + self.blocks.append( + SERes2NetBlock( + channels[i - 1], + channels[i], + res2net_scale=res2net_scale, + se_channels=se_channels, + kernel_size=kernel_sizes[i], + dilation=dilations[i], + activation=activation, + )) + + # Multi-layer feature aggregation + self.mfa = TDNNBlock( + channels[-1], + channels[-1], + kernel_sizes[-1], + dilations[-1], + activation, + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + channels[-1], + attention_channels=attention_channels, + global_context=global_context, + ) + self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2) + + # Final linear transformation + self.fc = Conv1d( + in_channels=channels[-1] * 2, + out_channels=self.emb_size, + kernel_size=1, + ) + + def forward(self, x, lengths=None): + xl = [] + for layer in self.blocks: + try: + x = layer(x, lengths=lengths) + except TypeError: + x = layer(x) + xl.append(x) + + # Multi-layer feature aggregation + x = paddle.concat(xl[1:], axis=1) + x = self.mfa(x) + + # Attentive Statistical Pooling + x = self.asp(x, lengths=lengths) + x = self.asp_bn(x) + + # Final linear transformation + x = self.fc(x) + + return x diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py new file mode 100644 index 0000000000000000000000000000000000000000..09b930ebfd4cd56c9be1bc107f4ca6fc5f948027 --- /dev/null +++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddleaudio +from paddleaudio.features.spectrum import hz_to_mel +from paddleaudio.features.spectrum import mel_to_hz +from paddleaudio.features.spectrum import power_to_db +from paddleaudio.features.spectrum import Spectrogram +from paddleaudio.features.window import get_window + + +def compute_fbank_matrix(sample_rate: int = 16000, + n_fft: int = 400, + n_mels: int = 80, + f_min: int = 0.0, + f_max: int = 8000.0): + mel = paddle.linspace(hz_to_mel(f_min, htk=True), hz_to_mel(f_max, htk=True), n_mels + 2, dtype=paddle.float32) + hz = mel_to_hz(mel, htk=True) + + band = hz[1:] - hz[:-1] + band = band[:-1] + f_central = hz[1:-1] + + n_stft = n_fft // 2 + 1 + all_freqs = paddle.linspace(0, sample_rate // 2, n_stft) + all_freqs_mat = all_freqs.tile([f_central.shape[0], 1]) + + f_central_mat = f_central.tile([all_freqs_mat.shape[1], 1]).transpose([1, 0]) + band_mat = band.tile([all_freqs_mat.shape[1], 1]).transpose([1, 0]) + + slope = (all_freqs_mat - f_central_mat) / band_mat + left_side = slope + 1.0 + right_side = -slope + 1.0 + + fbank_matrix = paddle.maximum(paddle.zeros_like(left_side), paddle.minimum(left_side, right_side)) + + return fbank_matrix + + +def compute_log_fbank( + x: paddle.Tensor, + sample_rate: int = 16000, + n_fft: int = 400, + hop_length: int = 160, + win_length: int = 400, + n_mels: int = 80, + window: str = 'hamming', + center: bool = True, + pad_mode: str = 'constant', + f_min: float = 0.0, + f_max: float = None, + top_db: float = 80.0, +): + + if f_max is None: + f_max = sample_rate / 2 + + spect = Spectrogram( + n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=center, pad_mode=pad_mode)(x) + + fbank_matrix = compute_fbank_matrix( + sample_rate=sample_rate, + n_fft=n_fft, + n_mels=n_mels, + f_min=f_min, + f_max=f_max, + ) + fbank = paddle.matmul(fbank_matrix, spect) + log_fbank = power_to_db(fbank, top_db=top_db).transpose([0, 2, 1]) + return log_fbank + + +def compute_stats(x: paddle.Tensor, mean_norm: bool = True, std_norm: bool = False, eps: float = 1e-10): + if mean_norm: + current_mean = paddle.mean(x, axis=0) + else: + current_mean = paddle.to_tensor([0.0]) + + if std_norm: + current_std = paddle.std(x, axis=0) + else: + current_std = paddle.to_tensor([1.0]) + + current_std = paddle.maximum(current_std, eps * paddle.ones_like(current_std)) + + return current_mean, current_std + + +def normalize( + x: paddle.Tensor, + global_mean: paddle.Tensor = None, + global_std: paddle.Tensor = None, +): + + for i in range(x.shape[0]): # (B, ...) + if global_mean is None and global_std is None: + mean, std = compute_stats(x[i]) + x[i] = (x[i] - mean) / std + else: + x[i] = (x[i] - global_mean) / global_std + return x diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py new file mode 100644 index 0000000000000000000000000000000000000000..11f7121a5f0a7eb2b330ffeedec821171bb30bef --- /dev/null +++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from typing import List +from typing import Union + +import numpy as np +import paddle +import paddleaudio + +from .ecapa_tdnn import ECAPA_TDNN +from .feature import compute_log_fbank +from .feature import normalize +from paddlehub.module.module import moduleinfo +from paddlehub.utils.log import logger + + +@moduleinfo( + name="ecapa_tdnn_voxceleb", + version="1.0.0", + summary="", + author="paddlepaddle", + author_email="", + type="audio/speaker_recognition") +class SpeakerRecognition(paddle.nn.Layer): + def __init__(self, threshold=0.25): + super(SpeakerRecognition, self).__init__() + global_stats_path = os.path.join(self.directory, 'assets', 'global_embedding_stats.npy') + ckpt_path = os.path.join(self.directory, 'assets', 'model.pdparams') + + self.sr = 16000 + self.threshold = threshold + model_conf = { + 'input_size': 80, + 'channels': [1024, 1024, 1024, 1024, 3072], + 'kernel_sizes': [5, 3, 3, 3, 1], + 'dilations': [1, 2, 3, 4, 1], + 'attention_channels': 128, + 'lin_neurons': 192 + } + self.model = ECAPA_TDNN(**model_conf) + self.model.set_state_dict(paddle.load(ckpt_path)) + self.model.eval() + + global_embedding_stats = np.load(global_stats_path, allow_pickle=True) + self.global_emb_mean = paddle.to_tensor(global_embedding_stats.item().get('global_emb_mean')) + self.global_emb_std = paddle.to_tensor(global_embedding_stats.item().get('global_emb_std')) + + self.similarity = paddle.nn.CosineSimilarity(axis=-1, eps=1e-6) + + def load_audio(self, wav): + wav = os.path.abspath(os.path.expanduser(wav)) + assert os.path.isfile(wav), 'Please check wav file: {}'.format(wav) + waveform, _ = paddleaudio.load(wav, sr=self.sr, mono=True, normal=False) + return waveform + + def speaker_embedding(self, wav): + waveform = self.load_audio(wav) + embedding = self(paddle.to_tensor(waveform)).reshape([-1]) + return embedding.numpy() + + def speaker_verify(self, wav1, wav2): + waveform1 = self.load_audio(wav1) + embedding1 = self(paddle.to_tensor(waveform1)).reshape([-1]) + + waveform2 = self.load_audio(wav2) + embedding2 = self(paddle.to_tensor(waveform2)).reshape([-1]) + + score = self.similarity(embedding1, embedding2).numpy() + return score, score > self.threshold + + def forward(self, x): + if len(x.shape) == 1: + x = x.unsqueeze(0) + + fbank = compute_log_fbank(x) # x: waveform tensors with (B, T) shape + norm_fbank = normalize(fbank) + embedding = self.model(norm_fbank.transpose([0, 2, 1])).transpose([0, 2, 1]) + norm_embedding = normalize(x=embedding, global_mean=self.global_emb_mean, global_std=self.global_emb_std) + + return norm_embedding diff --git a/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..defe617fa36bc5ab7b72438034c785ee2b3ac3c9 --- /dev/null +++ b/modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt @@ -0,0 +1 @@ +paddleaudio==0.1.0