Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
600276b3
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
600276b3
编写于
3月 04, 2022
作者:
K
KP
提交者:
GitHub
3月 04, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1753 from KPatr1ck/speech
Add 3 speech models.
上级
c66be094
415ea430
变更
18
显示空白变更内容
内联
并排
Showing
18 changed file
with
1856 addition
and
0 deletion
+1856
-0
modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
...es/audio/keyword_spotting/kwmlp_speech_commands/README.md
+98
-0
modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
.../audio/keyword_spotting/kwmlp_speech_commands/__init__.py
+13
-0
modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
...s/audio/keyword_spotting/kwmlp_speech_commands/feature.py
+59
-0
modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
...les/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
+143
-0
modules/audio/keyword_spotting/kwmlp_speech_commands/module.py
...es/audio/keyword_spotting/kwmlp_speech_commands/module.py
+86
-0
modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt
...o/keyword_spotting/kwmlp_speech_commands/requirements.txt
+1
-0
modules/audio/language_identification/ecapa_tdnn_common_language/README.md
...guage_identification/ecapa_tdnn_common_language/README.md
+100
-0
modules/audio/language_identification/ecapa_tdnn_common_language/__init__.py
...age_identification/ecapa_tdnn_common_language/__init__.py
+13
-0
modules/audio/language_identification/ecapa_tdnn_common_language/ecapa_tdnn.py
...e_identification/ecapa_tdnn_common_language/ecapa_tdnn.py
+406
-0
modules/audio/language_identification/ecapa_tdnn_common_language/feature.py
...uage_identification/ecapa_tdnn_common_language/feature.py
+112
-0
modules/audio/language_identification/ecapa_tdnn_common_language/module.py
...guage_identification/ecapa_tdnn_common_language/module.py
+85
-0
modules/audio/language_identification/ecapa_tdnn_common_language/requirements.txt
...dentification/ecapa_tdnn_common_language/requirements.txt
+1
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
...s/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
+128
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
...audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
+13
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
...dio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
+392
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
.../audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
+112
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
...s/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
+93
-0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
.../speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
+1
-0
未找到文件。
modules/audio/keyword_spotting/kwmlp_speech_commands/README.md
0 → 100644
浏览文件 @
600276b3
# kwmlp_speech_commands
|模型名称|kwmlp_speech_commands|
| :--- | :---: |
|类别|语音-语言识别|
|网络|Keyword-MLP|
|数据集|Google Speech Commands V2|
|是否支持Fine-tuning|否|
|模型大小|1.6MB|
|最新更新日期|2022-01-04|
|数据指标|ACC 97.56%|
## 一、模型基本信息
### 模型介绍
kwmlp_speech_commands采用了
[
Keyword-MLP
](
https://arxiv.org/pdf/2110.07749v1.pdf
)
的轻量级模型结构,并在
[
Google Speech Commands V2
](
https://arxiv.org/abs/1804.03209
)
数据集上进行了预训练,在其测试集的测试结果为 ACC 97.56%。
<p
align=
"center"
>
<img
src=
"https://d3i71xaburhd42.cloudfront.net/fa690a97f76ba119ca08fb02fa524a546c47f031/2-Figure1-1.png"
hspace=
'10'
height=
"550"
/>
<br
/>
</p>
更多详情请参考
-
[
Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition
](
https://arxiv.org/abs/1804.03209
)
-
[
ATTENTION-FREE KEYWORD SPOTTING
](
https://arxiv.org/pdf/2110.07749v1.pdf
)
-
[
Keyword-MLP
](
https://github.com/AI-Research-BD/Keyword-MLP
)
## 二、安装
-
### 1、环境依赖
-
paddlepaddle >= 2.2.0
-
paddlehub >= 2.2.0 |
[
如何安装PaddleHub
](
../../../../docs/docs_ch/get_start/installation.rst
)
-
### 2、安装
-
```shell
$ hub install kwmlp_speech_commands
```
-
如您安装时遇到问题,可参考:
[
零基础windows安装
](
../../../../docs/docs_ch/get_start/windows_quickstart.md
)
|
[
零基础Linux安装
](
../../../../docs/docs_ch/get_start/linux_quickstart.md
)
|
[
零基础MacOS安装
](
../../../../docs/docs_ch/get_start/mac_quickstart.md
)
## 三、模型API预测
-
### 1、预测代码示例
```python
import paddlehub as hub
model = hub.Module(
name='kwmlp_speech_commands',
version='1.0.0')
# 通过下列链接可下载示例音频
# https://paddlehub.bj.bcebos.com/paddlehub_dev/go.wav
# Keyword spotting
score, label = model.keyword_recognize('no.wav')
print(score, label)
# [0.89498246] no
score, label = model.keyword_recognize('go.wav')
print(score, label)
# [0.8997176] go
score, label = model.keyword_recognize('one.wav')
print(score, label)
# [0.88598305] one
```
-
### 2、API
-
```python
def keyword_recognize(
wav: os.PathLike,
)
```
-
检测音频中包含的关键词。
- **参数**
- `wav`:输入的包含关键词的音频文件,格式为`*.wav`。
- **返回**
- 输出结果的得分和对应的关键词标签。
## 四、更新历史
*
1.0.0
初始发布
```
shell
$
hub
install
kwmlp_speech_commands
```
modules/audio/keyword_spotting/kwmlp_speech_commands/__init__.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
modules/audio/keyword_spotting/kwmlp_speech_commands/feature.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
numpy
as
np
import
paddle
import
paddleaudio
def
create_dct
(
n_mfcc
:
int
,
n_mels
:
int
,
norm
:
str
=
'ortho'
):
n
=
paddle
.
arange
(
float
(
n_mels
))
k
=
paddle
.
arange
(
float
(
n_mfcc
)).
unsqueeze
(
1
)
dct
=
paddle
.
cos
(
math
.
pi
/
float
(
n_mels
)
*
(
n
+
0.5
)
*
k
)
# size (n_mfcc, n_mels)
if
norm
is
None
:
dct
*=
2.0
else
:
assert
norm
==
"ortho"
dct
[
0
]
*=
1.0
/
math
.
sqrt
(
2.0
)
dct
*=
math
.
sqrt
(
2.0
/
float
(
n_mels
))
return
dct
.
t
()
def
compute_mfcc
(
x
:
paddle
.
Tensor
,
sr
:
int
=
16000
,
n_mels
:
int
=
40
,
n_fft
:
int
=
480
,
win_length
:
int
=
480
,
hop_length
:
int
=
160
,
f_min
:
float
=
0.0
,
f_max
:
float
=
None
,
center
:
bool
=
False
,
top_db
:
float
=
80.0
,
norm
:
str
=
'ortho'
,
):
fbank
=
paddleaudio
.
features
.
spectrum
.
MelSpectrogram
(
sr
=
sr
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
win_length
=
win_length
,
hop_length
=
hop_length
,
f_min
=
0.0
,
f_max
=
f_max
,
center
=
center
)(
x
)
# waveforms batch ~ (B, T)
log_fbank
=
paddleaudio
.
features
.
spectrum
.
power_to_db
(
fbank
,
top_db
=
top_db
)
dct_matrix
=
create_dct
(
n_mfcc
=
n_mels
,
n_mels
=
n_mels
,
norm
=
norm
)
mfcc
=
paddle
.
matmul
(
log_fbank
.
transpose
((
0
,
2
,
1
)),
dct_matrix
).
transpose
((
0
,
2
,
1
))
# (B, n_mels, L)
return
mfcc
modules/audio/keyword_spotting/kwmlp_speech_commands/kwmlp.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
Residual
(
nn
.
Layer
):
def
__init__
(
self
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
def
forward
(
self
,
x
):
return
self
.
fn
(
x
)
+
x
class
PreNorm
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
fn
):
super
().
__init__
()
self
.
fn
=
fn
self
.
norm
=
nn
.
LayerNorm
(
dim
)
def
forward
(
self
,
x
,
**
kwargs
):
x
=
self
.
norm
(
x
)
return
self
.
fn
(
x
,
**
kwargs
)
class
PostNorm
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
fn
):
super
().
__init__
()
self
.
norm
=
nn
.
LayerNorm
(
dim
)
self
.
fn
=
fn
def
forward
(
self
,
x
,
**
kwargs
):
return
self
.
norm
(
self
.
fn
(
x
,
**
kwargs
))
class
SpatialGatingUnit
(
nn
.
Layer
):
def
__init__
(
self
,
dim
,
dim_seq
,
act
=
nn
.
Identity
(),
init_eps
=
1e-3
):
super
().
__init__
()
dim_out
=
dim
//
2
self
.
norm
=
nn
.
LayerNorm
(
dim_out
)
self
.
proj
=
nn
.
Conv1D
(
dim_seq
,
dim_seq
,
1
)
self
.
act
=
act
init_eps
/=
dim_seq
def
forward
(
self
,
x
):
res
,
gate
=
x
.
split
(
2
,
axis
=-
1
)
gate
=
self
.
norm
(
gate
)
weight
,
bias
=
self
.
proj
.
weight
,
self
.
proj
.
bias
gate
=
F
.
conv1d
(
gate
,
weight
,
bias
)
return
self
.
act
(
gate
)
*
res
class
gMLPBlock
(
nn
.
Layer
):
def
__init__
(
self
,
*
,
dim
,
dim_ff
,
seq_len
,
act
=
nn
.
Identity
()):
super
().
__init__
()
self
.
proj_in
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
dim_ff
),
nn
.
GELU
())
self
.
sgu
=
SpatialGatingUnit
(
dim_ff
,
seq_len
,
act
)
self
.
proj_out
=
nn
.
Linear
(
dim_ff
//
2
,
dim
)
def
forward
(
self
,
x
):
x
=
self
.
proj_in
(
x
)
x
=
self
.
sgu
(
x
)
x
=
self
.
proj_out
(
x
)
return
x
class
Rearrange
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
):
x
=
x
.
transpose
([
0
,
1
,
3
,
2
]).
squeeze
(
1
)
return
x
class
Reduce
(
nn
.
Layer
):
def
__init__
(
self
,
axis
=
1
):
super
().
__init__
()
self
.
axis
=
axis
def
forward
(
self
,
x
):
x
=
x
.
mean
(
axis
=
self
.
axis
,
keepdim
=
False
)
return
x
class
KW_MLP
(
nn
.
Layer
):
"""Keyword-MLP."""
def
__init__
(
self
,
input_res
=
[
40
,
98
],
patch_res
=
[
40
,
1
],
num_classes
=
35
,
dim
=
64
,
depth
=
12
,
ff_mult
=
4
,
channels
=
1
,
prob_survival
=
0.9
,
pre_norm
=
False
,
**
kwargs
):
super
().
__init__
()
image_height
,
image_width
=
input_res
patch_height
,
patch_width
=
patch_res
assert
(
image_height
%
patch_height
)
==
0
and
(
image_width
%
patch_width
)
==
0
,
'image height and width must be divisible by patch size'
num_patches
=
(
image_height
//
patch_height
)
*
(
image_width
//
patch_width
)
P_Norm
=
PreNorm
if
pre_norm
else
PostNorm
dim_ff
=
dim
*
ff_mult
self
.
to_patch_embed
=
nn
.
Sequential
(
Rearrange
(),
nn
.
Linear
(
channels
*
patch_height
*
patch_width
,
dim
))
self
.
prob_survival
=
prob_survival
self
.
layers
=
nn
.
LayerList
(
[
Residual
(
P_Norm
(
dim
,
gMLPBlock
(
dim
=
dim
,
dim_ff
=
dim_ff
,
seq_len
=
num_patches
)))
for
i
in
range
(
depth
)])
self
.
to_logits
=
nn
.
Sequential
(
nn
.
LayerNorm
(
dim
),
Reduce
(
axis
=
1
),
nn
.
Linear
(
dim
,
num_classes
))
def
forward
(
self
,
x
):
x
=
self
.
to_patch_embed
(
x
)
layers
=
self
.
layers
x
=
nn
.
Sequential
(
*
layers
)(
x
)
return
self
.
to_logits
(
x
)
modules/audio/keyword_spotting/kwmlp_speech_commands/module.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
import
paddle
import
paddleaudio
from
.feature
import
compute_mfcc
from
.kwmlp
import
KW_MLP
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.utils.log
import
logger
@
moduleinfo
(
name
=
"kwmlp_speech_commands"
,
version
=
"1.0.0"
,
summary
=
""
,
author
=
"paddlepaddle"
,
author_email
=
""
,
type
=
"audio/language_identification"
)
class
KWS
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
KWS
,
self
).
__init__
()
ckpt_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'model.pdparams'
)
label_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'label.txt'
)
self
.
label_list
=
[]
with
open
(
label_path
,
'r'
)
as
f
:
for
l
in
f
:
self
.
label_list
.
append
(
l
.
strip
())
self
.
sr
=
16000
model_conf
=
{
'input_res'
:
[
40
,
98
],
'patch_res'
:
[
40
,
1
],
'num_classes'
:
35
,
'channels'
:
1
,
'dim'
:
64
,
'depth'
:
12
,
'pre_norm'
:
False
,
'prob_survival'
:
0.9
,
}
self
.
model
=
KW_MLP
(
**
model_conf
)
self
.
model
.
set_state_dict
(
paddle
.
load
(
ckpt_path
))
self
.
model
.
eval
()
def
load_audio
(
self
,
wav
):
wav
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
wav
))
assert
os
.
path
.
isfile
(
wav
),
'Please check wav file: {}'
.
format
(
wav
)
waveform
,
_
=
paddleaudio
.
load
(
wav
,
sr
=
self
.
sr
,
mono
=
True
,
normal
=
False
)
return
waveform
def
keyword_recognize
(
self
,
wav
):
waveform
=
self
.
load_audio
(
wav
)
# fix_length to 1s
if
len
(
waveform
)
>
self
.
sr
:
waveform
=
waveform
[:
self
.
sr
]
else
:
waveform
=
np
.
pad
(
waveform
,
(
0
,
self
.
sr
-
len
(
waveform
)))
logits
=
self
(
paddle
.
to_tensor
(
waveform
)).
reshape
([
-
1
])
probs
=
paddle
.
nn
.
functional
.
softmax
(
logits
)
idx
=
paddle
.
argmax
(
probs
)
return
probs
[
idx
].
numpy
(),
self
.
label_list
[
idx
]
def
forward
(
self
,
x
):
if
len
(
x
.
shape
)
==
1
:
# x: waveform tensors with (B, T) shape
x
=
x
.
unsqueeze
(
0
)
mfcc
=
compute_mfcc
(
x
).
unsqueeze
(
1
)
# (B, C, n_mels, L)
logits
=
self
.
model
(
mfcc
).
squeeze
(
1
)
return
logits
modules/audio/keyword_spotting/kwmlp_speech_commands/requirements.txt
0 → 100644
浏览文件 @
600276b3
paddleaudio==0.1.0
modules/audio/language_identification/ecapa_tdnn_common_language/README.md
0 → 100644
浏览文件 @
600276b3
# ecapa_tdnn_common_language
|模型名称|ecapa_tdnn_common_language|
| :--- | :---: |
|类别|语音-语言识别|
|网络|ECAPA-TDNN|
|数据集|CommonLanguage|
|是否支持Fine-tuning|否|
|模型大小|79MB|
|最新更新日期|2021-12-30|
|数据指标|ACC 84.9%|
## 一、模型基本信息
### 模型介绍
ecapa_tdnn_common_language采用了
[
ECAPA-TDNN
](
https://arxiv.org/abs/2005.07143
)
的模型结构,并在
[
CommonLanguage
](
https://zenodo.org/record/5036977/
)
数据集上进行了预训练,在其测试集的测试结果为 ACC 84.9%。
<p
align=
"center"
>
<img
src=
"https://d3i71xaburhd42.cloudfront.net/9609f4817a7e769f5e3e07084db35e46696e82cd/3-Figure2-1.png"
hspace=
'10'
height=
"550"
/>
<br
/>
</p>
更多详情请参考
-
[
CommonLanguage
](
https://zenodo.org/record/5036977#.Yc19b5Mzb0o
)
-
[
ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification
](
https://arxiv.org/pdf/2005.07143.pdf
)
-
[
The SpeechBrain Toolkit
](
https://github.com/speechbrain/speechbrain
)
## 二、安装
-
### 1、环境依赖
-
paddlepaddle >= 2.2.0
-
paddlehub >= 2.2.0 |
[
如何安装PaddleHub
](
../../../../docs/docs_ch/get_start/installation.rst
)
-
### 2、安装
-
```shell
$ hub install ecapa_tdnn_common_language
```
-
如您安装时遇到问题,可参考:
[
零基础windows安装
](
../../../../docs/docs_ch/get_start/windows_quickstart.md
)
|
[
零基础Linux安装
](
../../../../docs/docs_ch/get_start/linux_quickstart.md
)
|
[
零基础MacOS安装
](
../../../../docs/docs_ch/get_start/mac_quickstart.md
)
## 三、模型API预测
-
### 1、预测代码示例
```python
import paddlehub as hub
model = hub.Module(
name='ecapa_tdnn_common_language',
version='1.0.0')
# 通过下列链接可下载示例音频
# https://paddlehub.bj.bcebos.com/paddlehub_dev/zh.wav
# https://paddlehub.bj.bcebos.com/paddlehub_dev/en.wav
# https://paddlehub.bj.bcebos.com/paddlehub_dev/it.wav
# Language Identification
score, label = model.speaker_verify('zh.wav')
print(score, label)
# array([0.6214552], dtype=float32), 'Chinese_China'
score, label = model.speaker_verify('en.wav')
print(score, label)
# array([0.37193954], dtype=float32), 'English'
score, label = model.speaker_verify('it.wav')
print(score, label)
# array([0.46913534], dtype=float32), 'Italian'
```
-
### 2、API
-
```python
def language_identify(
wav: os.PathLike,
)
```
-
判断输入人声音频的语言类别。
- **参数**
- `wav`:输入的说话人的音频文件,格式为`*.wav`。
- **返回**
- 输出结果的得分和对应的语言类别。
## 四、更新历史
*
1.0.0
初始发布
```
shell
$
hub
install
ecapa_tdnn_common_language
```
modules/audio/language_identification/ecapa_tdnn_common_language/__init__.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
modules/audio/language_identification/ecapa_tdnn_common_language/ecapa_tdnn.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
def
length_to_mask
(
length
,
max_len
=
None
,
dtype
=
None
):
assert
len
(
length
.
shape
)
==
1
if
max_len
is
None
:
max_len
=
length
.
max
().
astype
(
'int'
).
item
()
# using arange to generate mask
mask
=
paddle
.
arange
(
max_len
,
dtype
=
length
.
dtype
).
expand
((
len
(
length
),
max_len
))
<
length
.
unsqueeze
(
1
)
if
dtype
is
None
:
dtype
=
length
.
dtype
mask
=
paddle
.
to_tensor
(
mask
,
dtype
=
dtype
)
return
mask
class
Conv1d
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
"same"
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
padding_mode
=
"reflect"
,
):
super
(
Conv1d
,
self
).
__init__
()
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
padding
=
padding
self
.
padding_mode
=
padding_mode
self
.
conv
=
nn
.
Conv1D
(
in_channels
,
out_channels
,
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
0
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias_attr
=
bias
,
)
def
forward
(
self
,
x
):
if
self
.
padding
==
"same"
:
x
=
self
.
_manage_padding
(
x
,
self
.
kernel_size
,
self
.
dilation
,
self
.
stride
)
else
:
raise
ValueError
(
"Padding must be 'same'. Got {self.padding}"
)
return
self
.
conv
(
x
)
def
_manage_padding
(
self
,
x
,
kernel_size
:
int
,
dilation
:
int
,
stride
:
int
):
L_in
=
x
.
shape
[
-
1
]
# Detecting input shape
padding
=
self
.
_get_padding_elem
(
L_in
,
stride
,
kernel_size
,
dilation
)
# Time padding
x
=
F
.
pad
(
x
,
padding
,
mode
=
self
.
padding_mode
,
data_format
=
"NCL"
)
# Applying padding
return
x
def
_get_padding_elem
(
self
,
L_in
:
int
,
stride
:
int
,
kernel_size
:
int
,
dilation
:
int
):
if
stride
>
1
:
n_steps
=
math
.
ceil
(((
L_in
-
kernel_size
*
dilation
)
/
stride
)
+
1
)
L_out
=
stride
*
(
n_steps
-
1
)
+
kernel_size
*
dilation
padding
=
[
kernel_size
//
2
,
kernel_size
//
2
]
else
:
L_out
=
(
L_in
-
dilation
*
(
kernel_size
-
1
)
-
1
)
//
stride
+
1
padding
=
[(
L_in
-
L_out
)
//
2
,
(
L_in
-
L_out
)
//
2
]
return
padding
class
BatchNorm1d
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
eps
=
1e-05
,
momentum
=
0.9
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCL'
,
use_global_stats
=
None
,
):
super
(
BatchNorm1d
,
self
).
__init__
()
self
.
norm
=
nn
.
BatchNorm1D
(
input_size
,
epsilon
=
eps
,
momentum
=
momentum
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
data_format
=
data_format
,
use_global_stats
=
use_global_stats
,
)
def
forward
(
self
,
x
):
x_n
=
self
.
norm
(
x
)
return
x_n
class
TDNNBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
dilation
,
activation
=
nn
.
ReLU
,
):
super
(
TDNNBlock
,
self
).
__init__
()
self
.
conv
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
dilation
=
dilation
,
)
self
.
activation
=
activation
()
self
.
norm
=
BatchNorm1d
(
input_size
=
out_channels
)
def
forward
(
self
,
x
):
return
self
.
norm
(
self
.
activation
(
self
.
conv
(
x
)))
class
Res2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
scale
=
8
,
dilation
=
1
):
super
(
Res2NetBlock
,
self
).
__init__
()
assert
in_channels
%
scale
==
0
assert
out_channels
%
scale
==
0
in_channel
=
in_channels
//
scale
hidden_channel
=
out_channels
//
scale
self
.
blocks
=
nn
.
LayerList
(
[
TDNNBlock
(
in_channel
,
hidden_channel
,
kernel_size
=
3
,
dilation
=
dilation
)
for
i
in
range
(
scale
-
1
)])
self
.
scale
=
scale
def
forward
(
self
,
x
):
y
=
[]
for
i
,
x_i
in
enumerate
(
paddle
.
chunk
(
x
,
self
.
scale
,
axis
=
1
)):
if
i
==
0
:
y_i
=
x_i
elif
i
==
1
:
y_i
=
self
.
blocks
[
i
-
1
](
x_i
)
else
:
y_i
=
self
.
blocks
[
i
-
1
](
x_i
+
y_i
)
y
.
append
(
y_i
)
y
=
paddle
.
concat
(
y
,
axis
=
1
)
return
y
class
SEBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
se_channels
,
out_channels
):
super
(
SEBlock
,
self
).
__init__
()
self
.
conv1
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
se_channels
,
kernel_size
=
1
)
self
.
relu
=
paddle
.
nn
.
ReLU
()
self
.
conv2
=
Conv1d
(
in_channels
=
se_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
)
self
.
sigmoid
=
paddle
.
nn
.
Sigmoid
()
def
forward
(
self
,
x
,
lengths
=
None
):
L
=
x
.
shape
[
-
1
]
if
lengths
is
not
None
:
mask
=
length_to_mask
(
lengths
*
L
,
max_len
=
L
)
mask
=
mask
.
unsqueeze
(
1
)
total
=
mask
.
sum
(
axis
=
2
,
keepdim
=
True
)
s
=
(
x
*
mask
).
sum
(
axis
=
2
,
keepdim
=
True
)
/
total
else
:
s
=
x
.
mean
(
axis
=
2
,
keepdim
=
True
)
s
=
self
.
relu
(
self
.
conv1
(
s
))
s
=
self
.
sigmoid
(
self
.
conv2
(
s
))
return
s
*
x
class
AttentiveStatisticsPooling
(
nn
.
Layer
):
def
__init__
(
self
,
channels
,
attention_channels
=
128
,
global_context
=
True
):
super
().
__init__
()
self
.
eps
=
1e-12
self
.
global_context
=
global_context
if
global_context
:
self
.
tdnn
=
TDNNBlock
(
channels
*
3
,
attention_channels
,
1
,
1
)
else
:
self
.
tdnn
=
TDNNBlock
(
channels
,
attention_channels
,
1
,
1
)
self
.
tanh
=
nn
.
Tanh
()
self
.
conv
=
Conv1d
(
in_channels
=
attention_channels
,
out_channels
=
channels
,
kernel_size
=
1
)
def
forward
(
self
,
x
,
lengths
=
None
):
C
,
L
=
x
.
shape
[
1
],
x
.
shape
[
2
]
# KP: (N, C, L)
def
_compute_statistics
(
x
,
m
,
axis
=
2
,
eps
=
self
.
eps
):
mean
=
(
m
*
x
).
sum
(
axis
)
std
=
paddle
.
sqrt
((
m
*
(
x
-
mean
.
unsqueeze
(
axis
)).
pow
(
2
)).
sum
(
axis
).
clip
(
eps
))
return
mean
,
std
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
x
.
shape
[
0
]])
# Make binary mask of shape [N, 1, L]
mask
=
length_to_mask
(
lengths
*
L
,
max_len
=
L
)
mask
=
mask
.
unsqueeze
(
1
)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if
self
.
global_context
:
total
=
mask
.
sum
(
axis
=
2
,
keepdim
=
True
).
astype
(
'float32'
)
mean
,
std
=
_compute_statistics
(
x
,
mask
/
total
)
mean
=
mean
.
unsqueeze
(
2
).
tile
((
1
,
1
,
L
))
std
=
std
.
unsqueeze
(
2
).
tile
((
1
,
1
,
L
))
attn
=
paddle
.
concat
([
x
,
mean
,
std
],
axis
=
1
)
else
:
attn
=
x
# Apply layers
attn
=
self
.
conv
(
self
.
tanh
(
self
.
tdnn
(
attn
)))
# Filter out zero-paddings
attn
=
paddle
.
where
(
mask
.
tile
((
1
,
C
,
1
))
==
0
,
paddle
.
ones_like
(
attn
)
*
float
(
"-inf"
),
attn
)
attn
=
F
.
softmax
(
attn
,
axis
=
2
)
mean
,
std
=
_compute_statistics
(
x
,
attn
)
# Append mean and std of the batch
pooled_stats
=
paddle
.
concat
((
mean
,
std
),
axis
=
1
)
pooled_stats
=
pooled_stats
.
unsqueeze
(
2
)
return
pooled_stats
class
SERes2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
res2net_scale
=
8
,
se_channels
=
128
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
nn
.
ReLU
,
):
super
(
SERes2NetBlock
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
tdnn1
=
TDNNBlock
(
in_channels
,
out_channels
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
activation
,
)
self
.
res2net_block
=
Res2NetBlock
(
out_channels
,
out_channels
,
res2net_scale
,
dilation
)
self
.
tdnn2
=
TDNNBlock
(
out_channels
,
out_channels
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
activation
,
)
self
.
se_block
=
SEBlock
(
out_channels
,
se_channels
,
out_channels
)
self
.
shortcut
=
None
if
in_channels
!=
out_channels
:
self
.
shortcut
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
,
lengths
=
None
):
residual
=
x
if
self
.
shortcut
:
residual
=
self
.
shortcut
(
x
)
x
=
self
.
tdnn1
(
x
)
x
=
self
.
res2net_block
(
x
)
x
=
self
.
tdnn2
(
x
)
x
=
self
.
se_block
(
x
,
lengths
)
return
x
+
residual
class
ECAPA_TDNN
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
lin_neurons
=
192
,
activation
=
nn
.
ReLU
,
channels
=
[
512
,
512
,
512
,
512
,
1536
],
kernel_sizes
=
[
5
,
3
,
3
,
3
,
1
],
dilations
=
[
1
,
2
,
3
,
4
,
1
],
attention_channels
=
128
,
res2net_scale
=
8
,
se_channels
=
128
,
global_context
=
True
,
):
super
(
ECAPA_TDNN
,
self
).
__init__
()
assert
len
(
channels
)
==
len
(
kernel_sizes
)
assert
len
(
channels
)
==
len
(
dilations
)
self
.
channels
=
channels
self
.
blocks
=
nn
.
LayerList
()
self
.
emb_size
=
lin_neurons
# The initial TDNN layer
self
.
blocks
.
append
(
TDNNBlock
(
input_size
,
channels
[
0
],
kernel_sizes
[
0
],
dilations
[
0
],
activation
,
))
# SE-Res2Net layers
for
i
in
range
(
1
,
len
(
channels
)
-
1
):
self
.
blocks
.
append
(
SERes2NetBlock
(
channels
[
i
-
1
],
channels
[
i
],
res2net_scale
=
res2net_scale
,
se_channels
=
se_channels
,
kernel_size
=
kernel_sizes
[
i
],
dilation
=
dilations
[
i
],
activation
=
activation
,
))
# Multi-layer feature aggregation
self
.
mfa
=
TDNNBlock
(
channels
[
-
1
],
channels
[
-
1
],
kernel_sizes
[
-
1
],
dilations
[
-
1
],
activation
,
)
# Attentive Statistical Pooling
self
.
asp
=
AttentiveStatisticsPooling
(
channels
[
-
1
],
attention_channels
=
attention_channels
,
global_context
=
global_context
,
)
self
.
asp_bn
=
BatchNorm1d
(
input_size
=
channels
[
-
1
]
*
2
)
# Final linear transformation
self
.
fc
=
Conv1d
(
in_channels
=
channels
[
-
1
]
*
2
,
out_channels
=
self
.
emb_size
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
,
lengths
=
None
):
xl
=
[]
for
layer
in
self
.
blocks
:
try
:
x
=
layer
(
x
,
lengths
=
lengths
)
except
TypeError
:
x
=
layer
(
x
)
xl
.
append
(
x
)
# Multi-layer feature aggregation
x
=
paddle
.
concat
(
xl
[
1
:],
axis
=
1
)
x
=
self
.
mfa
(
x
)
# Attentive Statistical Pooling
x
=
self
.
asp
(
x
,
lengths
=
lengths
)
x
=
self
.
asp_bn
(
x
)
# Final linear transformation
x
=
self
.
fc
(
x
)
return
x
class
Classifier
(
nn
.
Layer
):
def
__init__
(
self
,
backbone
,
num_class
,
dtype
=
paddle
.
float32
):
super
(
Classifier
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
params
=
nn
.
ParameterList
(
[
paddle
.
create_parameter
(
shape
=
[
num_class
,
self
.
backbone
.
emb_size
],
dtype
=
dtype
)])
def
forward
(
self
,
x
):
emb
=
self
.
backbone
(
x
.
transpose
([
0
,
2
,
1
])).
transpose
([
0
,
2
,
1
])
logits
=
F
.
linear
(
F
.
normalize
(
emb
.
squeeze
(
1
)),
F
.
normalize
(
self
.
params
[
0
]).
transpose
([
1
,
0
]))
return
logits
modules/audio/language_identification/ecapa_tdnn_common_language/feature.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddleaudio
from
paddleaudio.features.spectrum
import
hz_to_mel
from
paddleaudio.features.spectrum
import
mel_to_hz
from
paddleaudio.features.spectrum
import
power_to_db
from
paddleaudio.features.spectrum
import
Spectrogram
from
paddleaudio.features.window
import
get_window
def
compute_fbank_matrix
(
sample_rate
:
int
=
16000
,
n_fft
:
int
=
400
,
n_mels
:
int
=
80
,
f_min
:
int
=
0.0
,
f_max
:
int
=
8000.0
):
mel
=
paddle
.
linspace
(
hz_to_mel
(
f_min
,
htk
=
True
),
hz_to_mel
(
f_max
,
htk
=
True
),
n_mels
+
2
,
dtype
=
paddle
.
float32
)
hz
=
mel_to_hz
(
mel
,
htk
=
True
)
band
=
hz
[
1
:]
-
hz
[:
-
1
]
band
=
band
[:
-
1
]
f_central
=
hz
[
1
:
-
1
]
n_stft
=
n_fft
//
2
+
1
all_freqs
=
paddle
.
linspace
(
0
,
sample_rate
//
2
,
n_stft
)
all_freqs_mat
=
all_freqs
.
tile
([
f_central
.
shape
[
0
],
1
])
f_central_mat
=
f_central
.
tile
([
all_freqs_mat
.
shape
[
1
],
1
]).
transpose
([
1
,
0
])
band_mat
=
band
.
tile
([
all_freqs_mat
.
shape
[
1
],
1
]).
transpose
([
1
,
0
])
slope
=
(
all_freqs_mat
-
f_central_mat
)
/
band_mat
left_side
=
slope
+
1.0
right_side
=
-
slope
+
1.0
fbank_matrix
=
paddle
.
maximum
(
paddle
.
zeros_like
(
left_side
),
paddle
.
minimum
(
left_side
,
right_side
))
return
fbank_matrix
def
compute_log_fbank
(
x
:
paddle
.
Tensor
,
sample_rate
:
int
=
16000
,
n_fft
:
int
=
400
,
hop_length
:
int
=
160
,
win_length
:
int
=
400
,
n_mels
:
int
=
80
,
window
:
str
=
'hamming'
,
center
:
bool
=
True
,
pad_mode
:
str
=
'constant'
,
f_min
:
float
=
0.0
,
f_max
:
float
=
None
,
top_db
:
float
=
80.0
,
):
if
f_max
is
None
:
f_max
=
sample_rate
/
2
spect
=
Spectrogram
(
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
)(
x
)
fbank_matrix
=
compute_fbank_matrix
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
f_min
=
f_min
,
f_max
=
f_max
,
)
fbank
=
paddle
.
matmul
(
fbank_matrix
,
spect
)
log_fbank
=
power_to_db
(
fbank
,
top_db
=
top_db
).
transpose
([
0
,
2
,
1
])
return
log_fbank
def
compute_stats
(
x
:
paddle
.
Tensor
,
mean_norm
:
bool
=
True
,
std_norm
:
bool
=
False
,
eps
:
float
=
1e-10
):
if
mean_norm
:
current_mean
=
paddle
.
mean
(
x
,
axis
=
0
)
else
:
current_mean
=
paddle
.
to_tensor
([
0.0
])
if
std_norm
:
current_std
=
paddle
.
std
(
x
,
axis
=
0
)
else
:
current_std
=
paddle
.
to_tensor
([
1.0
])
current_std
=
paddle
.
maximum
(
current_std
,
eps
*
paddle
.
ones_like
(
current_std
))
return
current_mean
,
current_std
def
normalize
(
x
:
paddle
.
Tensor
,
global_mean
:
paddle
.
Tensor
=
None
,
global_std
:
paddle
.
Tensor
=
None
,
):
for
i
in
range
(
x
.
shape
[
0
]):
# (B, ...)
if
global_mean
is
None
and
global_std
is
None
:
mean
,
std
=
compute_stats
(
x
[
i
])
x
[
i
]
=
(
x
[
i
]
-
mean
)
/
std
else
:
x
[
i
]
=
(
x
[
i
]
-
global_mean
)
/
global_std
return
x
modules/audio/language_identification/ecapa_tdnn_common_language/module.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
from
typing
import
List
from
typing
import
Union
import
numpy
as
np
import
paddle
import
paddleaudio
from
.ecapa_tdnn
import
Classifier
from
.ecapa_tdnn
import
ECAPA_TDNN
from
.feature
import
compute_log_fbank
from
.feature
import
normalize
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.utils.log
import
logger
@
moduleinfo
(
name
=
"ecapa_tdnn_common_language"
,
version
=
"1.0.0"
,
summary
=
""
,
author
=
"paddlepaddle"
,
author_email
=
""
,
type
=
"audio/language_identification"
)
class
LanguageIdentification
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
LanguageIdentification
,
self
).
__init__
()
ckpt_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'model.pdparams'
)
label_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'label.txt'
)
self
.
label_list
=
[]
with
open
(
label_path
,
'r'
)
as
f
:
for
l
in
f
:
self
.
label_list
.
append
(
l
.
strip
())
self
.
sr
=
16000
model_conf
=
{
'input_size'
:
80
,
'channels'
:
[
1024
,
1024
,
1024
,
1024
,
3072
],
'kernel_sizes'
:
[
5
,
3
,
3
,
3
,
1
],
'dilations'
:
[
1
,
2
,
3
,
4
,
1
],
'attention_channels'
:
128
,
'lin_neurons'
:
192
}
self
.
model
=
Classifier
(
backbone
=
ECAPA_TDNN
(
**
model_conf
),
num_class
=
45
,
)
self
.
model
.
set_state_dict
(
paddle
.
load
(
ckpt_path
))
self
.
model
.
eval
()
def
load_audio
(
self
,
wav
):
wav
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
wav
))
assert
os
.
path
.
isfile
(
wav
),
'Please check wav file: {}'
.
format
(
wav
)
waveform
,
_
=
paddleaudio
.
load
(
wav
,
sr
=
self
.
sr
,
mono
=
True
,
normal
=
False
)
return
waveform
def
language_identify
(
self
,
wav
):
waveform
=
self
.
load_audio
(
wav
)
logits
=
self
(
paddle
.
to_tensor
(
waveform
)).
reshape
([
-
1
])
idx
=
paddle
.
argmax
(
logits
)
return
logits
[
idx
].
numpy
(),
self
.
label_list
[
idx
]
def
forward
(
self
,
x
):
if
len
(
x
.
shape
)
==
1
:
x
=
x
.
unsqueeze
(
0
)
fbank
=
compute_log_fbank
(
x
)
# x: waveform tensors with (B, T) shape
norm_fbank
=
normalize
(
fbank
)
logits
=
self
.
model
(
norm_fbank
).
squeeze
(
1
)
return
logits
modules/audio/language_identification/ecapa_tdnn_common_language/requirements.txt
0 → 100644
浏览文件 @
600276b3
paddleaudio==0.1.0
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/README.md
0 → 100644
浏览文件 @
600276b3
# ecapa_tdnn_voxceleb
|模型名称|ecapa_tdnn_voxceleb|
| :--- | :---: |
|类别|语音-声纹识别|
|网络|ECAPA-TDNN|
|数据集|VoxCeleb|
|是否支持Fine-tuning|否|
|模型大小|79MB|
|最新更新日期|2021-12-30|
|数据指标|EER 0.69%|
## 一、模型基本信息
### 模型介绍
ecapa_tdnn_voxceleb采用了
[
ECAPA-TDNN
](
https://arxiv.org/abs/2005.07143
)
的模型结构,并在
[
VoxCeleb
](
http://www.robots.ox.ac.uk/~vgg/data/voxceleb/
)
数据集上进行了预训练,在VoxCeleb1的声纹识别测试集(
[
veri_test.txt
](
https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt
)
)上的测试结果为 EER 0.69%,达到了该数据集的SOTA。
<p
align=
"center"
>
<img
src=
"https://d3i71xaburhd42.cloudfront.net/9609f4817a7e769f5e3e07084db35e46696e82cd/3-Figure2-1.png"
hspace=
'10'
height=
"550"
/>
<br
/>
</p>
更多详情请参考
-
[
VoxCeleb: a large-scale speaker identification dataset
](
https://www.robots.ox.ac.uk/~vgg/publications/2017/Nagrani17/nagrani17.pdf
)
-
[
ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification
](
https://arxiv.org/pdf/2005.07143.pdf
)
-
[
The SpeechBrain Toolkit
](
https://github.com/speechbrain/speechbrain
)
## 二、安装
-
### 1、环境依赖
-
paddlepaddle >= 2.2.0
-
paddlehub >= 2.2.0 |
[
如何安装PaddleHub
](
../../../../docs/docs_ch/get_start/installation.rst
)
-
### 2、安装
-
```shell
$ hub install ecapa_tdnn_voxceleb
```
-
如您安装时遇到问题,可参考:
[
零基础windows安装
](
../../../../docs/docs_ch/get_start/windows_quickstart.md
)
|
[
零基础Linux安装
](
../../../../docs/docs_ch/get_start/linux_quickstart.md
)
|
[
零基础MacOS安装
](
../../../../docs/docs_ch/get_start/mac_quickstart.md
)
## 三、模型API预测
-
### 1、预测代码示例
```python
import paddlehub as hub
model = hub.Module(
name='ecapa_tdnn_voxceleb',
threshold=0.25,
version='1.0.0')
# 通过下列链接可下载示例音频
# https://paddlehub.bj.bcebos.com/paddlehub_dev/sv1.wav
# https://paddlehub.bj.bcebos.com/paddlehub_dev/sv2.wav
# Speaker Embedding
embedding = model.speaker_embedding('sv1.wav')
print(embedding.shape)
# (192,)
# Speaker Verification
score, pred = model.speaker_verify('sv1.wav', 'sv2.wav')
print(score, pred)
# [0.16354457], [False]
```
-
### 2、API
-
```python
def __init__(
threshold: float,
)
```
-
初始化声纹模型,确定判别阈值。
- **参数**
- `threshold`:设定模型判别声纹相似度的得分阈值,默认为 0.25。
-
```python
def speaker_embedding(
wav: os.PathLike,
)
```
-
获取输入音频的声纹特征
- **参数**
- `wav`:输入的说话人的音频文件,格式为`*.wav`。
- **返回**
- 输出纬度为 (192,) 的声纹特征向量。
-
```python
def speaker_verify(
wav1: os.PathLike,
wav2: os.PathLike,
)
```
-
对比两段音频,分别计算其声纹特征的相似度得分,并判断是否为同一说话人。
- **参数**
- `wav1`:输入的说话人1的音频文件,格式为`*.wav`。
- `wav2`:输入的说话人2的音频文件,格式为`*.wav`。
- **返回**
- 返回声纹相似度得分[-1, 1]和预测结果。
## 四、更新历史
*
1.0.0
初始发布
```
shell
$
hub
install
ecapa_tdnn_voxceleb
```
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/__init__.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/ecapa_tdnn.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
def
length_to_mask
(
length
,
max_len
=
None
,
dtype
=
None
):
assert
len
(
length
.
shape
)
==
1
if
max_len
is
None
:
max_len
=
length
.
max
().
astype
(
'int'
).
item
()
# using arange to generate mask
mask
=
paddle
.
arange
(
max_len
,
dtype
=
length
.
dtype
).
expand
((
len
(
length
),
max_len
))
<
length
.
unsqueeze
(
1
)
if
dtype
is
None
:
dtype
=
length
.
dtype
mask
=
paddle
.
to_tensor
(
mask
,
dtype
=
dtype
)
return
mask
class
Conv1d
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
padding
=
"same"
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
padding_mode
=
"reflect"
,
):
super
(
Conv1d
,
self
).
__init__
()
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
padding
=
padding
self
.
padding_mode
=
padding_mode
self
.
conv
=
nn
.
Conv1D
(
in_channels
,
out_channels
,
self
.
kernel_size
,
stride
=
self
.
stride
,
padding
=
0
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias_attr
=
bias
,
)
def
forward
(
self
,
x
):
if
self
.
padding
==
"same"
:
x
=
self
.
_manage_padding
(
x
,
self
.
kernel_size
,
self
.
dilation
,
self
.
stride
)
else
:
raise
ValueError
(
"Padding must be 'same'. Got {self.padding}"
)
return
self
.
conv
(
x
)
def
_manage_padding
(
self
,
x
,
kernel_size
:
int
,
dilation
:
int
,
stride
:
int
):
L_in
=
x
.
shape
[
-
1
]
# Detecting input shape
padding
=
self
.
_get_padding_elem
(
L_in
,
stride
,
kernel_size
,
dilation
)
# Time padding
x
=
F
.
pad
(
x
,
padding
,
mode
=
self
.
padding_mode
,
data_format
=
"NCL"
)
# Applying padding
return
x
def
_get_padding_elem
(
self
,
L_in
:
int
,
stride
:
int
,
kernel_size
:
int
,
dilation
:
int
):
if
stride
>
1
:
n_steps
=
math
.
ceil
(((
L_in
-
kernel_size
*
dilation
)
/
stride
)
+
1
)
L_out
=
stride
*
(
n_steps
-
1
)
+
kernel_size
*
dilation
padding
=
[
kernel_size
//
2
,
kernel_size
//
2
]
else
:
L_out
=
(
L_in
-
dilation
*
(
kernel_size
-
1
)
-
1
)
//
stride
+
1
padding
=
[(
L_in
-
L_out
)
//
2
,
(
L_in
-
L_out
)
//
2
]
return
padding
class
BatchNorm1d
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
eps
=
1e-05
,
momentum
=
0.9
,
weight_attr
=
None
,
bias_attr
=
None
,
data_format
=
'NCL'
,
use_global_stats
=
None
,
):
super
(
BatchNorm1d
,
self
).
__init__
()
self
.
norm
=
nn
.
BatchNorm1D
(
input_size
,
epsilon
=
eps
,
momentum
=
momentum
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
data_format
=
data_format
,
use_global_stats
=
use_global_stats
,
)
def
forward
(
self
,
x
):
x_n
=
self
.
norm
(
x
)
return
x_n
class
TDNNBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
dilation
,
activation
=
nn
.
ReLU
,
):
super
(
TDNNBlock
,
self
).
__init__
()
self
.
conv
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
dilation
=
dilation
,
)
self
.
activation
=
activation
()
self
.
norm
=
BatchNorm1d
(
input_size
=
out_channels
)
def
forward
(
self
,
x
):
return
self
.
norm
(
self
.
activation
(
self
.
conv
(
x
)))
class
Res2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
scale
=
8
,
dilation
=
1
):
super
(
Res2NetBlock
,
self
).
__init__
()
assert
in_channels
%
scale
==
0
assert
out_channels
%
scale
==
0
in_channel
=
in_channels
//
scale
hidden_channel
=
out_channels
//
scale
self
.
blocks
=
nn
.
LayerList
(
[
TDNNBlock
(
in_channel
,
hidden_channel
,
kernel_size
=
3
,
dilation
=
dilation
)
for
i
in
range
(
scale
-
1
)])
self
.
scale
=
scale
def
forward
(
self
,
x
):
y
=
[]
for
i
,
x_i
in
enumerate
(
paddle
.
chunk
(
x
,
self
.
scale
,
axis
=
1
)):
if
i
==
0
:
y_i
=
x_i
elif
i
==
1
:
y_i
=
self
.
blocks
[
i
-
1
](
x_i
)
else
:
y_i
=
self
.
blocks
[
i
-
1
](
x_i
+
y_i
)
y
.
append
(
y_i
)
y
=
paddle
.
concat
(
y
,
axis
=
1
)
return
y
class
SEBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
se_channels
,
out_channels
):
super
(
SEBlock
,
self
).
__init__
()
self
.
conv1
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
se_channels
,
kernel_size
=
1
)
self
.
relu
=
paddle
.
nn
.
ReLU
()
self
.
conv2
=
Conv1d
(
in_channels
=
se_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
)
self
.
sigmoid
=
paddle
.
nn
.
Sigmoid
()
def
forward
(
self
,
x
,
lengths
=
None
):
L
=
x
.
shape
[
-
1
]
if
lengths
is
not
None
:
mask
=
length_to_mask
(
lengths
*
L
,
max_len
=
L
)
mask
=
mask
.
unsqueeze
(
1
)
total
=
mask
.
sum
(
axis
=
2
,
keepdim
=
True
)
s
=
(
x
*
mask
).
sum
(
axis
=
2
,
keepdim
=
True
)
/
total
else
:
s
=
x
.
mean
(
axis
=
2
,
keepdim
=
True
)
s
=
self
.
relu
(
self
.
conv1
(
s
))
s
=
self
.
sigmoid
(
self
.
conv2
(
s
))
return
s
*
x
class
AttentiveStatisticsPooling
(
nn
.
Layer
):
def
__init__
(
self
,
channels
,
attention_channels
=
128
,
global_context
=
True
):
super
().
__init__
()
self
.
eps
=
1e-12
self
.
global_context
=
global_context
if
global_context
:
self
.
tdnn
=
TDNNBlock
(
channels
*
3
,
attention_channels
,
1
,
1
)
else
:
self
.
tdnn
=
TDNNBlock
(
channels
,
attention_channels
,
1
,
1
)
self
.
tanh
=
nn
.
Tanh
()
self
.
conv
=
Conv1d
(
in_channels
=
attention_channels
,
out_channels
=
channels
,
kernel_size
=
1
)
def
forward
(
self
,
x
,
lengths
=
None
):
C
,
L
=
x
.
shape
[
1
],
x
.
shape
[
2
]
# KP: (N, C, L)
def
_compute_statistics
(
x
,
m
,
axis
=
2
,
eps
=
self
.
eps
):
mean
=
(
m
*
x
).
sum
(
axis
)
std
=
paddle
.
sqrt
((
m
*
(
x
-
mean
.
unsqueeze
(
axis
)).
pow
(
2
)).
sum
(
axis
).
clip
(
eps
))
return
mean
,
std
if
lengths
is
None
:
lengths
=
paddle
.
ones
([
x
.
shape
[
0
]])
# Make binary mask of shape [N, 1, L]
mask
=
length_to_mask
(
lengths
*
L
,
max_len
=
L
)
mask
=
mask
.
unsqueeze
(
1
)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if
self
.
global_context
:
total
=
mask
.
sum
(
axis
=
2
,
keepdim
=
True
).
astype
(
'float32'
)
mean
,
std
=
_compute_statistics
(
x
,
mask
/
total
)
mean
=
mean
.
unsqueeze
(
2
).
tile
((
1
,
1
,
L
))
std
=
std
.
unsqueeze
(
2
).
tile
((
1
,
1
,
L
))
attn
=
paddle
.
concat
([
x
,
mean
,
std
],
axis
=
1
)
else
:
attn
=
x
# Apply layers
attn
=
self
.
conv
(
self
.
tanh
(
self
.
tdnn
(
attn
)))
# Filter out zero-paddings
attn
=
paddle
.
where
(
mask
.
tile
((
1
,
C
,
1
))
==
0
,
paddle
.
ones_like
(
attn
)
*
float
(
"-inf"
),
attn
)
attn
=
F
.
softmax
(
attn
,
axis
=
2
)
mean
,
std
=
_compute_statistics
(
x
,
attn
)
# Append mean and std of the batch
pooled_stats
=
paddle
.
concat
((
mean
,
std
),
axis
=
1
)
pooled_stats
=
pooled_stats
.
unsqueeze
(
2
)
return
pooled_stats
class
SERes2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
res2net_scale
=
8
,
se_channels
=
128
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
nn
.
ReLU
,
):
super
(
SERes2NetBlock
,
self
).
__init__
()
self
.
out_channels
=
out_channels
self
.
tdnn1
=
TDNNBlock
(
in_channels
,
out_channels
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
activation
,
)
self
.
res2net_block
=
Res2NetBlock
(
out_channels
,
out_channels
,
res2net_scale
,
dilation
)
self
.
tdnn2
=
TDNNBlock
(
out_channels
,
out_channels
,
kernel_size
=
1
,
dilation
=
1
,
activation
=
activation
,
)
self
.
se_block
=
SEBlock
(
out_channels
,
se_channels
,
out_channels
)
self
.
shortcut
=
None
if
in_channels
!=
out_channels
:
self
.
shortcut
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
,
lengths
=
None
):
residual
=
x
if
self
.
shortcut
:
residual
=
self
.
shortcut
(
x
)
x
=
self
.
tdnn1
(
x
)
x
=
self
.
res2net_block
(
x
)
x
=
self
.
tdnn2
(
x
)
x
=
self
.
se_block
(
x
,
lengths
)
return
x
+
residual
class
ECAPA_TDNN
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
lin_neurons
=
192
,
activation
=
nn
.
ReLU
,
channels
=
[
512
,
512
,
512
,
512
,
1536
],
kernel_sizes
=
[
5
,
3
,
3
,
3
,
1
],
dilations
=
[
1
,
2
,
3
,
4
,
1
],
attention_channels
=
128
,
res2net_scale
=
8
,
se_channels
=
128
,
global_context
=
True
,
):
super
(
ECAPA_TDNN
,
self
).
__init__
()
assert
len
(
channels
)
==
len
(
kernel_sizes
)
assert
len
(
channels
)
==
len
(
dilations
)
self
.
channels
=
channels
self
.
blocks
=
nn
.
LayerList
()
self
.
emb_size
=
lin_neurons
# The initial TDNN layer
self
.
blocks
.
append
(
TDNNBlock
(
input_size
,
channels
[
0
],
kernel_sizes
[
0
],
dilations
[
0
],
activation
,
))
# SE-Res2Net layers
for
i
in
range
(
1
,
len
(
channels
)
-
1
):
self
.
blocks
.
append
(
SERes2NetBlock
(
channels
[
i
-
1
],
channels
[
i
],
res2net_scale
=
res2net_scale
,
se_channels
=
se_channels
,
kernel_size
=
kernel_sizes
[
i
],
dilation
=
dilations
[
i
],
activation
=
activation
,
))
# Multi-layer feature aggregation
self
.
mfa
=
TDNNBlock
(
channels
[
-
1
],
channels
[
-
1
],
kernel_sizes
[
-
1
],
dilations
[
-
1
],
activation
,
)
# Attentive Statistical Pooling
self
.
asp
=
AttentiveStatisticsPooling
(
channels
[
-
1
],
attention_channels
=
attention_channels
,
global_context
=
global_context
,
)
self
.
asp_bn
=
BatchNorm1d
(
input_size
=
channels
[
-
1
]
*
2
)
# Final linear transformation
self
.
fc
=
Conv1d
(
in_channels
=
channels
[
-
1
]
*
2
,
out_channels
=
self
.
emb_size
,
kernel_size
=
1
,
)
def
forward
(
self
,
x
,
lengths
=
None
):
xl
=
[]
for
layer
in
self
.
blocks
:
try
:
x
=
layer
(
x
,
lengths
=
lengths
)
except
TypeError
:
x
=
layer
(
x
)
xl
.
append
(
x
)
# Multi-layer feature aggregation
x
=
paddle
.
concat
(
xl
[
1
:],
axis
=
1
)
x
=
self
.
mfa
(
x
)
# Attentive Statistical Pooling
x
=
self
.
asp
(
x
,
lengths
=
lengths
)
x
=
self
.
asp_bn
(
x
)
# Final linear transformation
x
=
self
.
fc
(
x
)
return
x
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/feature.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddleaudio
from
paddleaudio.features.spectrum
import
hz_to_mel
from
paddleaudio.features.spectrum
import
mel_to_hz
from
paddleaudio.features.spectrum
import
power_to_db
from
paddleaudio.features.spectrum
import
Spectrogram
from
paddleaudio.features.window
import
get_window
def
compute_fbank_matrix
(
sample_rate
:
int
=
16000
,
n_fft
:
int
=
400
,
n_mels
:
int
=
80
,
f_min
:
int
=
0.0
,
f_max
:
int
=
8000.0
):
mel
=
paddle
.
linspace
(
hz_to_mel
(
f_min
,
htk
=
True
),
hz_to_mel
(
f_max
,
htk
=
True
),
n_mels
+
2
,
dtype
=
paddle
.
float32
)
hz
=
mel_to_hz
(
mel
,
htk
=
True
)
band
=
hz
[
1
:]
-
hz
[:
-
1
]
band
=
band
[:
-
1
]
f_central
=
hz
[
1
:
-
1
]
n_stft
=
n_fft
//
2
+
1
all_freqs
=
paddle
.
linspace
(
0
,
sample_rate
//
2
,
n_stft
)
all_freqs_mat
=
all_freqs
.
tile
([
f_central
.
shape
[
0
],
1
])
f_central_mat
=
f_central
.
tile
([
all_freqs_mat
.
shape
[
1
],
1
]).
transpose
([
1
,
0
])
band_mat
=
band
.
tile
([
all_freqs_mat
.
shape
[
1
],
1
]).
transpose
([
1
,
0
])
slope
=
(
all_freqs_mat
-
f_central_mat
)
/
band_mat
left_side
=
slope
+
1.0
right_side
=
-
slope
+
1.0
fbank_matrix
=
paddle
.
maximum
(
paddle
.
zeros_like
(
left_side
),
paddle
.
minimum
(
left_side
,
right_side
))
return
fbank_matrix
def
compute_log_fbank
(
x
:
paddle
.
Tensor
,
sample_rate
:
int
=
16000
,
n_fft
:
int
=
400
,
hop_length
:
int
=
160
,
win_length
:
int
=
400
,
n_mels
:
int
=
80
,
window
:
str
=
'hamming'
,
center
:
bool
=
True
,
pad_mode
:
str
=
'constant'
,
f_min
:
float
=
0.0
,
f_max
:
float
=
None
,
top_db
:
float
=
80.0
,
):
if
f_max
is
None
:
f_max
=
sample_rate
/
2
spect
=
Spectrogram
(
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
window
=
window
,
center
=
center
,
pad_mode
=
pad_mode
)(
x
)
fbank_matrix
=
compute_fbank_matrix
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
f_min
=
f_min
,
f_max
=
f_max
,
)
fbank
=
paddle
.
matmul
(
fbank_matrix
,
spect
)
log_fbank
=
power_to_db
(
fbank
,
top_db
=
top_db
).
transpose
([
0
,
2
,
1
])
return
log_fbank
def
compute_stats
(
x
:
paddle
.
Tensor
,
mean_norm
:
bool
=
True
,
std_norm
:
bool
=
False
,
eps
:
float
=
1e-10
):
if
mean_norm
:
current_mean
=
paddle
.
mean
(
x
,
axis
=
0
)
else
:
current_mean
=
paddle
.
to_tensor
([
0.0
])
if
std_norm
:
current_std
=
paddle
.
std
(
x
,
axis
=
0
)
else
:
current_std
=
paddle
.
to_tensor
([
1.0
])
current_std
=
paddle
.
maximum
(
current_std
,
eps
*
paddle
.
ones_like
(
current_std
))
return
current_mean
,
current_std
def
normalize
(
x
:
paddle
.
Tensor
,
global_mean
:
paddle
.
Tensor
=
None
,
global_std
:
paddle
.
Tensor
=
None
,
):
for
i
in
range
(
x
.
shape
[
0
]):
# (B, ...)
if
global_mean
is
None
and
global_std
is
None
:
mean
,
std
=
compute_stats
(
x
[
i
])
x
[
i
]
=
(
x
[
i
]
-
mean
)
/
std
else
:
x
[
i
]
=
(
x
[
i
]
-
global_mean
)
/
global_std
return
x
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/module.py
0 → 100644
浏览文件 @
600276b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
from
typing
import
List
from
typing
import
Union
import
numpy
as
np
import
paddle
import
paddleaudio
from
.ecapa_tdnn
import
ECAPA_TDNN
from
.feature
import
compute_log_fbank
from
.feature
import
normalize
from
paddlehub.module.module
import
moduleinfo
from
paddlehub.utils.log
import
logger
@
moduleinfo
(
name
=
"ecapa_tdnn_voxceleb"
,
version
=
"1.0.0"
,
summary
=
""
,
author
=
"paddlepaddle"
,
author_email
=
""
,
type
=
"audio/speaker_recognition"
)
class
SpeakerRecognition
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
threshold
=
0.25
):
super
(
SpeakerRecognition
,
self
).
__init__
()
global_stats_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'global_embedding_stats.npy'
)
ckpt_path
=
os
.
path
.
join
(
self
.
directory
,
'assets'
,
'model.pdparams'
)
self
.
sr
=
16000
self
.
threshold
=
threshold
model_conf
=
{
'input_size'
:
80
,
'channels'
:
[
1024
,
1024
,
1024
,
1024
,
3072
],
'kernel_sizes'
:
[
5
,
3
,
3
,
3
,
1
],
'dilations'
:
[
1
,
2
,
3
,
4
,
1
],
'attention_channels'
:
128
,
'lin_neurons'
:
192
}
self
.
model
=
ECAPA_TDNN
(
**
model_conf
)
self
.
model
.
set_state_dict
(
paddle
.
load
(
ckpt_path
))
self
.
model
.
eval
()
global_embedding_stats
=
np
.
load
(
global_stats_path
,
allow_pickle
=
True
)
self
.
global_emb_mean
=
paddle
.
to_tensor
(
global_embedding_stats
.
item
().
get
(
'global_emb_mean'
))
self
.
global_emb_std
=
paddle
.
to_tensor
(
global_embedding_stats
.
item
().
get
(
'global_emb_std'
))
self
.
similarity
=
paddle
.
nn
.
CosineSimilarity
(
axis
=-
1
,
eps
=
1e-6
)
def
load_audio
(
self
,
wav
):
wav
=
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
wav
))
assert
os
.
path
.
isfile
(
wav
),
'Please check wav file: {}'
.
format
(
wav
)
waveform
,
_
=
paddleaudio
.
load
(
wav
,
sr
=
self
.
sr
,
mono
=
True
,
normal
=
False
)
return
waveform
def
speaker_embedding
(
self
,
wav
):
waveform
=
self
.
load_audio
(
wav
)
embedding
=
self
(
paddle
.
to_tensor
(
waveform
)).
reshape
([
-
1
])
return
embedding
.
numpy
()
def
speaker_verify
(
self
,
wav1
,
wav2
):
waveform1
=
self
.
load_audio
(
wav1
)
embedding1
=
self
(
paddle
.
to_tensor
(
waveform1
)).
reshape
([
-
1
])
waveform2
=
self
.
load_audio
(
wav2
)
embedding2
=
self
(
paddle
.
to_tensor
(
waveform2
)).
reshape
([
-
1
])
score
=
self
.
similarity
(
embedding1
,
embedding2
).
numpy
()
return
score
,
score
>
self
.
threshold
def
forward
(
self
,
x
):
if
len
(
x
.
shape
)
==
1
:
x
=
x
.
unsqueeze
(
0
)
fbank
=
compute_log_fbank
(
x
)
# x: waveform tensors with (B, T) shape
norm_fbank
=
normalize
(
fbank
)
embedding
=
self
.
model
(
norm_fbank
.
transpose
([
0
,
2
,
1
])).
transpose
([
0
,
2
,
1
])
norm_embedding
=
normalize
(
x
=
embedding
,
global_mean
=
self
.
global_emb_mean
,
global_std
=
self
.
global_emb_std
)
return
norm_embedding
modules/audio/speaker_recognition/ecapa_tdnn_voxceleb/requirements.txt
0 → 100644
浏览文件 @
600276b3
paddleaudio==0.1.0
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录