Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
9ba49968
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9ba49968
编写于
12月 30, 2021
作者:
K
KP
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ecapa_tdnn_voxceleb.
上级
96cb5498
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
702 addition
and
0 deletion
+702
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
...s/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
+117
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
...audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
+0
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
...dio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
+392
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
.../audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
+99
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
...s/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
+93
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
.../speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
+1
-0
未找到文件。
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
0 → 100644
浏览文件 @
9ba49968
# ecapa_tdnn_voxceleb
|模型名称|ecapa_tdnn_voxceleb|
| :--- | :---: |
|类别|语音-声纹识别|
|网络|ECAPA-TDNN|
|数据集|VoxCeleb|
|是否支持Fine-tuning|否|
|模型大小|79MB|
|最新更新日期|2021-12-30|
|数据指标|EER 0.69%|
## 一、模型基本信息
### 模型介绍
ecapa_tdnn_voxceleb采用了
[
ECAPA-TDNN
](
https://arxiv.org/abs/2005.07143
)
的模型结构,并在
[
VoxCeleb
](
http://www.robots.ox.ac.uk/~vgg/data/voxceleb/
)
数据集上进行了预训练,在VoxCeleb1的声纹识别测试集(
[
veri_test.txt
](
https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt
)
)上的测试结果为 EER 0.69%,达到了该数据集的SOTA。
<p
align=
"center"
>
<img
src=
"https://d3i71xaburhd42.cloudfront.net/9609f4817a7e769f5e3e07084db35e46696e82cd/3-Figure2-1.png"
hspace=
'10'
height=
"550"
/>
<br
/>
</p>
更多详情请参考
-
[
VoxCeleb: a large-scale speaker identification dataset
](
https://www.robots.ox.ac.uk/~vgg/publications/2017/Nagrani17/nagrani17.pdf
)
-
[
ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification
](
https://arxiv.org/pdf/2005.07143.pdf
)
-
[
The SpeechBrain Toolkit
](
https://github.com/speechbrain/speechbrain
)
## 二、安装
-
### 1、环境依赖
-
paddlepaddle >= 2.2.0
-
paddlehub >= 2.2.0 |
[
如何安装PaddleHub
](
../../../../docs/docs_ch/get_start/installation.rst
)
-
### 2、安装
-
```shell
$ hub install ecapa_tdnn_voxceleb
```
-
如您安装时遇到问题,可参考:
[
零基础windows安装
](
../../../../docs/docs_ch/get_start/windows_quickstart.md
)
|
[
零基础Linux安装
](
../../../../docs/docs_ch/get_start/linux_quickstart.md
)
|
[
零基础MacOS安装
](
../../../../docs/docs_ch/get_start/mac_quickstart.md
)
## 三、模型API预测
-
### 1、预测代码示例
```python
import paddlehub as hub
model = hub.Module(
name='ecapa_tdnn_voxceleb',
threshold=0.25,
version='1.0.0')
# 通过下列链接可下载示例音频
# https://paddlehub.bj.bcebos.com/hub_dev/sv1.wav
# https://paddlehub.bj.bcebos.com/hub_dev/sv2.wav
# Speaker Embedding
embedding = model.speaker_embedding('sv1.wav')
print(embedding.shape)
# (192,)
# Speaker Verification
score, pred = model.speaker_verify('sv1.wav', 'sv2.wav')
print(score, pred)
# [0.16354457], [False]
```
-
### 2、API
-
```python
def speaker_embedding(
wav: os.PathLike,
)
```
-
获取输入音频的声纹特征
- **参数**
- `wav`:输入的说话人的音频文件,格式为`*.wav`。
- **返回**
- 输出纬度为 (192,) 的声纹特征向量。
-
```python
def speaker_verify(
wav1: os.PathLike,
wav2: os.PathLike,
)
```
-
对比两段音频,分别计算其声纹特征的相似度得分,并判断是否为同一说话人。
- **参数**
- `wav1`:输入的说话人1的音频文件,格式为`*.wav`。
- `wav2`:输入的说话人2的音频文件,格式为`*.wav`。
- **返回**
- 返回声纹相似度得分[-1, 1]和预测结果。
## 四、更新历史
*
1.0.0
初始发布
```
shell
$
hub
install
ecapa_tdnn_voxceleb
```
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
0 → 100644
浏览文件 @
9ba49968
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
0 → 100644
浏览文件 @
9ba49968
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
def
length_to_mask
(
length
,
max_len
=
None
,
dtype
=
None
):
assert
len
(
length
.
shape
)
==
1
if
max_len
is
None
:
max_len
=
length
.
max
().
astype
(
'int'
).
item
()
# using arange to generate mask
mask
=
paddle
.
arange
(
max_len
,
dtype
=
length
.
dtype
).
expand
((
len
(
length
),
max_len
))
<
length
.
unsqueeze
(
1
)
if
dtype
is
None
:
dtype
=
length
.
dtype
mask
=
paddle
.
to_tensor
(
mask
,
dtype
=
dtype
)
return
mask
class
Conv1d
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
"same"
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
padding_mode
=
"reflect"
,
):
super
(
Conv1d
,
self
).
__init__
()
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
padding
=
padding
self
.
padding_mode
=
padding_mode
self
.
conv
=
nn
.
Conv1D
(
in_channels
,
out_channels
,
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
0
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias_attr
=
bias
,
)
def
forward
(
self
,
x
):
if
self
.
padding
==
"same"
:
x
=
self
.
_manage_padding
(
x
,
self
.
kernel_size
,
self
.
dilation
,
self
.
stride
)
else
:
raise
ValueError
(
"Padding must be 'same'. Got {self.padding}"
)
return
self
.
conv
(
x
)
def
_manage_padding
(
self
,
x
,
kernel_size
:
int
,
dilation
:
int
,
stride
:
int
):
L_in
=
x
.
shape
[
-
1
]
# Detecting input shape
padding
=
self
.
_get_padding_elem
(
L_in
,
stride
,
kernel_size
,
dilation
)
# Time padding
x
=
F
.
pad
(
x
,
padding
,
mode
=
self
.
padding_mode
,
data_format
=
"NCL"
)
# Applying padding
return
x
def
_get_padding_elem
(
self
,
L_in
:
int
,
stride
:
int
,
kernel_size
:
int
,
dilation
:
int
):
if
stride
>
1
:
n_steps
=
math
.
ceil
(((
L_in
-
kernel_size
*
dilation
)
/
stride
)
+
1
)
L_out
=
stride
*
(
n_steps
-
1
)
+
kernel_size
*
dilation
padding
=
[
kernel_size
//
2
,
kernel_size
//
2
]
else
:
L_out
=
(
L_in
-
dilation
*
(
kernel_size
-
1
)
-
1
)
//
stride
+
1
padding
=
[(
L_in
-
L_out
)
//
2
,
(
L_in
-
L_out
)
//
2
]
return
padding
class
BatchNorm1d
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
eps
=
1e-05
,
momentum
=
0.9
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCL'
,
use_global_stats
=
None
,
):
super
(
BatchNorm1d
,
self
).
__init__
()
self
.
norm
=
nn
.
BatchNorm1D
(
input_size
,
epsilon
=
eps
,
momentum
=
momentum
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
data_format
=
data_format
,
use_global_stats
=
use_global_stats
,
)
def
forward
(
self
,
x
):
x_n
=
self
.
norm
(
x
)
return
x_n
class
TDNNBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
dilation
,
activation
=
nn
.
ReLU
,
):
super
(
TDNNBlock
,
self
).
__init__
()
self
.
conv
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
dilation
=
dilation
,
)
self
.
activation
=
activation
()
self
.
norm
=
BatchNorm1d
(
input_size
=
out_channels
)
def
forward
(
self
,
x
):
return
self
.
norm
(
self
.
activation
(
self
.
conv
(
x
)))
class
Res2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
scale
=
8
,
dilation
=
1
):
super
(
Res2NetBlock
,
self
).
__init__
()
assert
in_channels
%
scale
==
0
assert
out_channels
%
scale
==
0
in_channel
=
in_channels
//
scale
hidden_channel
=
out_channels
//
scale
self
.
blocks
=
nn
.
LayerList
(
[
TDNNBlock
(
in_channel
,
hidden_channel
,
kernel_size
=
3
,
dilation
=
dilation
)
for
i
in
range
(
scale
-
1
)])
self
.
scale
=
scale
def
forward
(
self
,
x
):
y
=
[]
for
i
,
x_i
in
enumerate
(
paddle
.
chunk
(
x
,
self
.
scale
,
axis
=
1
)):
if
i
==
0
:
y_i
=
x_i
elif
i
==
1
:
y_i
=
self
.
blocks
[
i
-
1
](
x_i
)
else
:
y_i
=
self
.
blocks
[
i
-
1
](
x_i
+
y_i
)
y
.
append
(
y_i
)
y
=
paddle
.
concat
(
y
,
axis
=
1
)
return
y
class
SEBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
se_channels
,
out_channels
):
super
(
SEBlock
,
self
).
__init__
()
self
.
conv1
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
se_channels
,
kernel_size
=
1
)
self
.
relu
=
paddle
.
nn
.
ReLU
()
self
.
conv2
=
Conv1d
(
in_channels
=
se_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
)
self
.
sigmoid
=
paddle
.
nn
.
Sigmoid
()
def
forward
(
self
,
x
,
lengths
=
None
):
L
=
x
.
shape
[
-
1
]
if
lengths
is
not
None
:
mask
=
length_to_mask
(
lengths
*
L
,
max_len
=
L
)
mask
=
mask
.
unsqueeze
(
1
)
total
=
mask
.
sum
(
axis
=
2
,
keepdim
=
True
)
s
=
(
x
*
mask
).
sum
(
axis
=
2
,
keepdim
=
True
)
/
total
else
:
s
=
x
.
mean
(
axis
=
2
,
keepdim
=
True
)
s
=
self
.
relu
(
self
.
conv1
(
s
))
s
=
self
.
sigmoid
(
self
.
conv2
(
s
))
return
s
*
x
class
AttentiveStatisticsPooling
(
nn
.
Layer
):
def
__init__
(
self
,
channels
,
attention_channels
=
128
,
global_context
=
True
):
super
().
__init__
()
self
.
eps
=
1e-12
self
.
global_context
=
global_context
if
global_context
:
self
.
tdnn
=
TDNNBlock
(
channels
*
3
,
attention_channels
,
1
,
1
)
else
:
self
.
tdnn
=
TDNNBlock
(
channels
,
attention_channels
,
1
,
1
)
self
.
tanh
=
nn
.
Tanh
()
self
.
conv
=
Conv1d
(
in_channels
=
attention_channels
,
out_channels
=
channels
,
kernel_size
=
1
)
def
forward
(
self
,
x
,
lengths
=
None
):
C
,
L
=
x
.
shape
[
1
],
x
.
shape
[
2
]
# KP: (N, C, L)
def
_compute_statistics
(
x
,
m
,
axis
=
2
,
eps
=
self
.
eps
):
mean
=
(
m
*
x
).
sum
(
axis
)
std
=
paddle
.
sqrt
((
m
*
(
x
-
mean
.
unsqueeze
(
axis
)).
pow
(
2
)).
sum
(
axis
).
clip
(
eps
))
return
mean
,
std
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
x
.
shape
[
0
]])
# Make binary mask of shape [N, 1, L]
mask
=
length_to_mask
(
lengths
*
L
,
max_len
=
L
)
mask
=
mask
.
unsqueeze
(
1
)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if
self
.
global_context
:
total
=
mask
.
sum
(
axis
=
2
,
keepdim
=
True
).
astype
(
'float32'
)
mean
,
std
=
_compute_statistics
(
x
,
mask
/
total
)
mean
=
mean
.
unsqueeze
(
2
).
tile
((
1
,
1
,
L
))
std
=
std
.
unsqueeze
(
2
).
tile
((
1
,
1
,
L
))
attn
=
paddle
.
concat
([
x
,
mean
,
std
],
axis
=
1
)
else
:
attn
=
x
# Apply layers
attn
=
self
.
conv
(
self
.
tanh
(
self
.
tdnn
(
attn
)))
# Filter out zero-paddings
attn
=
paddle
.
where
(
mask
.
tile
((
1
,
C
,
1
))
==
0
,
paddle
.
ones_like
(
attn
)
*
float
(
"-inf"
),
attn
)
attn
=
F
.
softmax
(
attn
,
axis
=
2
)
mean
,
std
=
_compute_statistics
(
x
,
attn
)
# Append mean and std of the batch
pooled_stats
=
paddle
.
concat
((
mean
,
std
),
axis
=
1
)
pooled_stats
=
pooled_stats
.
unsqueeze
(
2
)
return
pooled_stats
class
SERes2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
res2net_scale
=
8
,
se_channels
=
128
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
nn
.
ReLU
,
):
super
(
SERes2NetBlock
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
tdnn1
=
TDNNBlock
(
in_channels
,
out_channels
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
activation
,
)
self
.
res2net_block
=
Res2NetBlock
(
out_channels
,
out_channels
,
res2net_scale
,
dilation
)
self
.
tdnn2
=
TDNNBlock
(
out_channels
,
out_channels
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
activation
,
)
self
.
se_block
=
SEBlock
(
out_channels
,
se_channels
,
out_channels
)
self
.
shortcut
=
None
if
in_channels
!=
out_channels
:
self
.
shortcut
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
,
lengths
=
None
):
residual
=
x
if
self
.
shortcut
:
residual
=
self
.
shortcut
(
x
)
x
=
self
.
tdnn1
(
x
)
x
=
self
.
res2net_block
(
x
)
x
=
self
.
tdnn2
(
x
)
x
=
self
.
se_block
(
x
,
lengths
)
return
x
+
residual
class
ECAPA_TDNN
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
lin_neurons
=
192
,
activation
=
nn
.
ReLU
,
channels
=
[
512
,
512
,
512
,
512
,
1536
],
kernel_sizes
=
[
5
,
3
,
3
,
3
,
1
],
dilations
=
[
1
,
2
,
3
,
4
,
1
],
attention_channels
=
128
,
res2net_scale
=
8
,
se_channels
=
128
,
global_context
=
True
,
):
super
(
ECAPA_TDNN
,
self
).
__init__
()
assert
len
(
channels
)
==
len
(
kernel_sizes
)
assert
len
(
channels
)
==
len
(
dilations
)
self
.
channels
=
channels
self
.
blocks
=
nn
.
LayerList
()
self
.
emb_size
=
lin_neurons
# The initial TDNN layer
self
.
blocks
.
append
(
TDNNBlock
(
input_size
,
channels
[
0
],
kernel_sizes
[
0
],
dilations
[
0
],
activation
,
))
# SE-Res2Net layers
for
i
in
range
(
1
,
len
(
channels
)
-
1
):
self
.
blocks
.
append
(
SERes2NetBlock
(
channels
[
i
-
1
],
channels
[
i
],
res2net_scale
=
res2net_scale
,
se_channels
=
se_channels
,
kernel_size
=
kernel_sizes
[
i
],
dilation
=
dilations
[
i
],
activation
=
activation
,
))
# Multi-layer feature aggregation
self
.
mfa
=
TDNNBlock
(
channels
[
-
1
],
channels
[
-
1
],
kernel_sizes
[
-
1
],
dilations
[
-
1
],
activation
,
)
# Attentive Statistical Pooling
self
.
asp
=
AttentiveStatisticsPooling
(
channels
[
-
1
],
attention_channels
=
attention_channels
,
global_context
=
global_context
,
)
self
.
asp_bn
=
BatchNorm1d
(
input_size
=
channels
[
-
1
]
*
2
)
# Final linear transformation
self
.
fc
=
Conv1d
(
in_channels
=
channels
[
-
1
]
*
2
,
out_channels
=
self
.
emb_size
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
,
lengths
=
None
):
xl
=
[]
for
layer
in
self
.
blocks
:
try
:
x
=
layer
(
x
,
lengths
=
lengths
)
except
TypeError
:
x
=
layer
(
x
)
xl
.
append
(
x
)
# Multi-layer feature aggregation
x
=
paddle
.
concat
(
xl
[
1
:],
axis
=
1
)
x
=
self
.
mfa
(
x
)
# Attentive Statistical Pooling
x
=
self
.
asp
(
x
,
lengths
=
lengths
)
x
=
self
.
asp_bn
(
x
)
# Final linear transformation
x
=
self
.
fc
(
x
)
return
x
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
0 → 100644
浏览文件 @
9ba49968
import
paddle
import
paddleaudio
from
paddleaudio.features.spectrum
import
hz_to_mel
from
paddleaudio.features.spectrum
import
mel_to_hz
from
paddleaudio.features.spectrum
import
power_to_db
from
paddleaudio.features.spectrum
import
Spectrogram
from
paddleaudio.features.window
import
get_window
def
compute_fbank_matrix
(
sample_rate
:
int
=
16000
,
n_fft
:
int
=
400
,
n_mels
:
int
=
80
,
f_min
:
int
=
0.0
,
f_max
:
int
=
8000.0
):
mel
=
paddle
.
linspace
(
hz_to_mel
(
f_min
,
htk
=
True
),
hz_to_mel
(
f_max
,
htk
=
True
),
n_mels
+
2
,
dtype
=
paddle
.
float32
)
hz
=
mel_to_hz
(
mel
,
htk
=
True
)
band
=
hz
[
1
:]
-
hz
[:
-
1
]
band
=
band
[:
-
1
]
f_central
=
hz
[
1
:
-
1
]
n_stft
=
n_fft
//
2
+
1
all_freqs
=
paddle
.
linspace
(
0
,
sample_rate
//
2
,
n_stft
)
all_freqs_mat
=
all_freqs
.
tile
([
f_central
.
shape
[
0
],
1
])
f_central_mat
=
f_central
.
tile
([
all_freqs_mat
.
shape
[
1
],
1
]).
transpose
([
1
,
0
])
band_mat
=
band
.
tile
([
all_freqs_mat
.
shape
[
1
],
1
]).
transpose
([
1
,
0
])
slope
=
(
all_freqs_mat
-
f_central_mat
)
/
band_mat
left_side
=
slope
+
1.0
right_side
=
-
slope
+
1.0
fbank_matrix
=
paddle
.
maximum
(
paddle
.
zeros_like
(
left_side
),
paddle
.
minimum
(
left_side
,
right_side
))
return
fbank_matrix
def
compute_log_fbank
(
x
:
paddle
.
Tensor
,
sample_rate
:
int
=
16000
,
n_fft
:
int
=
400
,
hop_length
:
int
=
160
,
win_length
:
int
=
400
,
n_mels
:
int
=
80
,
window
:
str
=
'hamming'
,
center
:
bool
=
True
,
pad_mode
:
str
=
'constant'
,
f_min
:
float
=
0.0
,
f_max
:
float
=
None
,
top_db
:
float
=
80.0
,
):
if
f_max
is
None
:
f_max
=
sample_rate
/
2
spect
=
Spectrogram
(
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
)(
x
)
fbank_matrix
=
compute_fbank_matrix
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
f_min
=
f_min
,
f_max
=
f_max
,
)
fbank
=
paddle
.
matmul
(
fbank_matrix
,
spect
)
log_fbank
=
power_to_db
(
fbank
,
top_db
=
top_db
).
transpose
([
0
,
2
,
1
])
return
log_fbank
def
compute_stats
(
x
:
paddle
.
Tensor
,
mean_norm
:
bool
=
True
,
std_norm
:
bool
=
False
,
eps
:
float
=
1e-10
):
if
mean_norm
:
current_mean
=
paddle
.
mean
(
x
,
axis
=
0
)
else
:
current_mean
=
paddle
.
to_tensor
([
0.0
])
if
std_norm
:
current_std
=
paddle
.
std
(
x
,
axis
=
0
)
else
:
current_std
=
paddle
.
to_tensor
([
1.0
])
current_std
=
paddle
.
maximum
(
current_std
,
eps
*
paddle
.
ones_like
(
current_std
))
return
current_mean
,
current_std
def
normalize
(
x
:
paddle
.
Tensor
,
global_mean
:
paddle
.
Tensor
=
None
,
global_std
:
paddle
.
Tensor
=
None
,
):
for
i
in
range
(
x
.
shape
[
0
]):
# (B, ...)
if
global_mean
is
None
and
global_std
is
None
:
mean
,
std
=
compute_stats
(
x
[
i
])
x
[
i
]
=
(
x
[
i
]
-
mean
)
/
std
else
:
x
[
i
]
=
(
x
[
i
]
-
global_mean
)
/
global_std
return
x
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
0 → 100644
浏览文件 @
9ba49968
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
from
typing
import
List
from
typing
import
Union
import
numpy
as
np
import
paddle
import
paddleaudio
from
.ecapa_tdnn
import
ECAPA_TDNN
from
.feature
import
compute_log_fbank
from
.feature
import
normalize
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.utils.log
import
logger
@
moduleinfo
(
name
=
"ecapa_tdnn_voxceleb"
,
version
=
"1.0.0"
,
summary
=
""
,
author
=
"paddlepaddle"
,
author_email
=
""
,
type
=
"audio/speaker_recognition"
)
class
SpeakerRecognition
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
threshold
=
0.25
):
super
(
SpeakerRecognition
,
self
).
__init__
()
global_stats_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'global_embedding_stats.npy'
)
ckpt_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'model.pdparams'
)
self
.
sr
=
16000
self
.
threshold
=
threshold
model_conf
=
{
'input_size'
:
80
,
'channels'
:
[
1024
,
1024
,
1024
,
1024
,
3072
],
'kernel_sizes'
:
[
5
,
3
,
3
,
3
,
1
],
'dilations'
:
[
1
,
2
,
3
,
4
,
1
],
'attention_channels'
:
128
,
'lin_neurons'
:
192
}
self
.
model
=
ECAPA_TDNN
(
**
model_conf
)
self
.
model
.
set_state_dict
(
paddle
.
load
(
ckpt_path
))
self
.
model
.
eval
()
global_embedding_stats
=
np
.
load
(
global_stats_path
,
allow_pickle
=
True
)
self
.
global_emb_mean
=
paddle
.
to_tensor
(
global_embedding_stats
.
item
().
get
(
'global_emb_mean'
))
self
.
global_emb_std
=
paddle
.
to_tensor
(
global_embedding_stats
.
item
().
get
(
'global_emb_std'
))
self
.
similarity
=
paddle
.
nn
.
CosineSimilarity
(
axis
=-
1
,
eps
=
1e-6
)
def
load_audio
(
self
,
wav
):
wav
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
wav
))
assert
os
.
path
.
isfile
(
wav
),
'Please check wav file: {}'
.
format
(
wav
)
waveform
,
_
=
paddleaudio
.
load
(
wav
,
sr
=
self
.
sr
,
mono
=
True
,
normal
=
False
)
return
waveform
def
speaker_embedding
(
self
,
wav
):
waveform
=
self
.
load_audio
(
wav
)
embedding
=
self
(
paddle
.
to_tensor
(
waveform
)).
reshape
([
-
1
])
return
embedding
.
numpy
()
def
speaker_verify
(
self
,
wav1
,
wav2
):
waveform1
=
self
.
load_audio
(
wav1
)
embedding1
=
self
(
paddle
.
to_tensor
(
waveform1
)).
reshape
([
-
1
])
waveform2
=
self
.
load_audio
(
wav2
)
embedding2
=
self
(
paddle
.
to_tensor
(
waveform2
)).
reshape
([
-
1
])
score
=
self
.
similarity
(
embedding1
,
embedding2
).
numpy
()
return
score
,
score
>
self
.
threshold
def
forward
(
self
,
x
):
if
len
(
x
.
shape
)
==
1
:
x
=
x
.
unsqueeze
(
0
)
fbank
=
compute_log_fbank
(
x
)
# x: waveform tensors with (B, T) shape
norm_fbank
=
normalize
(
fbank
)
embedding
=
self
.
model
(
norm_fbank
.
transpose
([
0
,
2
,
1
])).
transpose
([
0
,
2
,
1
])
norm_embedding
=
normalize
(
x
=
embedding
,
global_mean
=
self
.
global_emb_mean
,
global_std
=
self
.
global_emb_std
)
return
norm_embedding
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
0 → 100644
浏览文件 @
9ba49968
paddleaudio==0.1.0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录