Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
d94db47f
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
9 个月 前同步成功
通知
200
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
d94db47f
编写于
7月 17, 2023
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix rotary embeding
上级
596f7140
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
8 addition
and
8 deletion
+8
-8
paddlespeech/s2t/modules/attention.py
paddlespeech/s2t/modules/attention.py
+8
-8
未找到文件。
paddlespeech/s2t/modules/attention.py
浏览文件 @
d94db47f
...
...
@@ -459,6 +459,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
cache
:
paddle
.
Tensor
=
paddle
.
zeros
([
0
,
0
,
0
,
0
])
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
...
...
@@ -476,10 +477,16 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t
=
q
.
shape
[
2
]
q
=
self
.
apply_rotary_position_embeddings
(
pos_emb
[:,
-
q_t
:,
:],
q
)
# k will increase when in streaming decoding.
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
[:,
-
q_t
:,
:],
k
)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
...
...
@@ -504,13 +511,6 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
# non-trivial to calculate `next_cache_start` here.
new_cache
=
paddle
.
concat
((
k
,
v
),
axis
=-
1
)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t
=
q
.
shape
[
2
]
q
=
self
.
apply_rotary_position_embeddings
(
pos_emb
[:,
-
q_t
:,
:],
q
)
# k will increase when in streaming decoding.
k
=
self
.
apply_rotary_position_embeddings
(
pos_emb
,
k
)
# dot(q, k)
scores
=
paddle
.
matmul
(
q
,
k
,
transpose_y
=
True
)
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录