Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
52a8b2f3
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
52a8b2f3
编写于
1月 11, 2022
作者:
K
KP
提交者:
GitHub
1月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ECAPA_TDNN. (#1301)
上级
010aa65b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
18 addition
and
26 deletion
+18
-26
paddlespeech/vector/models/ecapa_tdnn.py
paddlespeech/vector/models/ecapa_tdnn.py
+18
-26
未找到文件。
paddlespeech/vector/models/ecapa_tdnn.py
浏览文件 @
52a8b2f3
...
...
@@ -47,7 +47,7 @@ class Conv1d(nn.Layer):
groups
=
1
,
bias
=
True
,
padding_mode
=
"reflect"
,
):
super
(
Conv1d
,
self
).
__init__
()
super
().
__init__
()
self
.
kernel_size
=
kernel_size
self
.
stride
=
stride
...
...
@@ -110,7 +110,7 @@ class BatchNorm1d(nn.Layer):
bias_attr
=
None
,
data_format
=
'NCL'
,
use_global_stats
=
None
,
):
super
(
BatchNorm1d
,
self
).
__init__
()
super
().
__init__
()
self
.
norm
=
nn
.
BatchNorm1D
(
input_size
,
...
...
@@ -134,7 +134,7 @@ class TDNNBlock(nn.Layer):
kernel_size
,
dilation
,
activation
=
nn
.
ReLU
,
):
super
(
TDNNBlock
,
self
).
__init__
()
super
().
__init__
()
self
.
conv
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
...
...
@@ -149,7 +149,7 @@ class TDNNBlock(nn.Layer):
class
Res2NetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
scale
=
8
,
dilation
=
1
):
super
(
Res2NetBlock
,
self
).
__init__
()
super
().
__init__
()
assert
in_channels
%
scale
==
0
assert
out_channels
%
scale
==
0
...
...
@@ -179,7 +179,7 @@ class Res2NetBlock(nn.Layer):
class
SEBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
se_channels
,
out_channels
):
super
(
SEBlock
,
self
).
__init__
()
super
().
__init__
()
self
.
conv1
=
Conv1d
(
in_channels
=
in_channels
,
out_channels
=
se_channels
,
kernel_size
=
1
)
...
...
@@ -275,7 +275,7 @@ class SERes2NetBlock(nn.Layer):
kernel_size
=
1
,
dilation
=
1
,
activation
=
nn
.
ReLU
,
):
super
(
SERes2NetBlock
,
self
).
__init__
()
super
().
__init__
()
self
.
out_channels
=
out_channels
self
.
tdnn1
=
TDNNBlock
(
in_channels
,
...
...
@@ -313,7 +313,7 @@ class SERes2NetBlock(nn.Layer):
return
x
+
residual
class
E
CAPA_TDNN
(
nn
.
Layer
):
class
E
capaTdnn
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
...
...
@@ -327,7 +327,7 @@ class ECAPA_TDNN(nn.Layer):
se_channels
=
128
,
global_context
=
True
,
):
super
(
ECAPA_TDNN
,
self
).
__init__
()
super
().
__init__
()
assert
len
(
channels
)
==
len
(
kernel_sizes
)
assert
len
(
channels
)
==
len
(
dilations
)
self
.
channels
=
channels
...
...
@@ -377,6 +377,16 @@ class ECAPA_TDNN(nn.Layer):
kernel_size
=
1
,
)
def
forward
(
self
,
x
,
lengths
=
None
):
"""
Compute embeddings.
Args:
x (paddle.Tensor): Input log-fbanks with shape (N, n_mels, T).
lengths (paddle.Tensor, optional): Length proportions of batch length with shape (N). Defaults to None.
Returns:
paddle.Tensor: Output embeddings with shape (N, self.emb_size, 1)
"""
xl
=
[]
for
layer
in
self
.
blocks
:
try
:
...
...
@@ -397,21 +407,3 @@ class ECAPA_TDNN(nn.Layer):
x
=
self
.
fc
(
x
)
return
x
class
Classifier
(
nn
.
Layer
):
def
__init__
(
self
,
backbone
,
num_class
,
dtype
=
paddle
.
float32
):
super
(
Classifier
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
params
=
nn
.
ParameterList
([
paddle
.
create_parameter
(
shape
=
[
num_class
,
self
.
backbone
.
emb_size
],
dtype
=
dtype
)
])
def
forward
(
self
,
x
):
emb
=
self
.
backbone
(
x
.
transpose
([
0
,
2
,
1
])).
transpose
([
0
,
2
,
1
])
logits
=
F
.
linear
(
F
.
normalize
(
emb
.
squeeze
(
1
)),
F
.
normalize
(
self
.
params
[
0
]).
transpose
([
1
,
0
]))
return
logits
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录