提交 d94db47f 编写于 作者: H Hui Zhang

fix rotary embeding

上级 596f7140
......@@ -459,6 +459,7 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Ref: https://github.com/facebookresearch/llama/blob/main/llama/model.py
Args:
query (paddle.Tensor): Query tensor (#batch, time1, size).
key (paddle.Tensor): Key tensor (#batch, time2, size).
......@@ -476,10 +477,16 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
# q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t = q.shape[2]
q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q)
# k will increase when in streaming decoding.
k = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], k)
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
......@@ -504,13 +511,6 @@ class RoPERelPositionMultiHeadedAttention(MultiHeadedAttention):
# non-trivial to calculate `next_cache_start` here.
new_cache = paddle.concat((k, v), axis=-1)
# f{q,k}(x_m, m) = R^d_{\theta, m} W_{q,k} x_m, m is position index
# q_t always is chunk_size
q_t = q.shape[2]
q = self.apply_rotary_position_embeddings(pos_emb[:, -q_t:, :], q)
# k will increase when in streaming decoding.
k = self.apply_rotary_position_embeddings(pos_emb, k)
# dot(q, k)
scores = paddle.matmul(q, k, transpose_y=True) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册