Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
ba874db5
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ba874db5
编写于
5月 30, 2023
作者:
J
jiamingkong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed the transpose usages ignored before
上级
0e2068e2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
7 addition
and
6 deletion
+7
-6
paddlespeech/s2t/models/wavlm/modules/modules.py
paddlespeech/s2t/models/wavlm/modules/modules.py
+7
-6
未找到文件。
paddlespeech/s2t/models/wavlm/modules/modules.py
浏览文件 @
ba874db5
...
...
@@ -555,19 +555,19 @@ class MultiheadAttention(nn.Layer):
q
=
(
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
q_head_dim
)
.
transpose
(
0
,
1
)
.
transpose
(
[
1
,
0
,
2
]
)
)
if
k
is
not
None
:
k
=
(
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
k_head_dim
)
.
transpose
(
0
,
1
)
.
transpose
(
[
1
,
0
,
2
]
)
)
if
v
is
not
None
:
v
=
(
v
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
.
transpose
(
[
1
,
0
,
2
]
)
)
if
saved_state
is
not
None
:
...
...
@@ -643,7 +643,8 @@ class MultiheadAttention(nn.Layer):
)
attn_weights
=
paddle
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn_weights
=
paddle
.
matmul
(
q
,
k
.
transpose
([
0
,
2
,
1
]))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
assert
list
(
attn_weights
.
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
...
...
@@ -687,13 +688,13 @@ class MultiheadAttention(nn.Layer):
assert
v
is
not
None
attn
=
paddle
.
bmm
(
attn_probs
,
v
)
assert
list
(
attn
.
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
attn
.
transpose
(
[
1
,
0
,
2
]).
reshape
([
tgt_len
,
bsz
,
embed_dim
]
)
attn
=
self
.
out_proj
(
attn
)
attn_weights
:
Optional
[
Tensor
]
=
None
if
need_weights
:
attn_weights
=
attn_weights_float
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
).
transpose
(
1
,
0
)
).
transpose
(
[
1
,
0
,
2
,
3
]
)
if
not
need_head_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
mean
(
dim
=
0
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录