提交 b91b1c9b 编写于 作者: H Hui Zhang

support position interpolation for langer attention context windown length.

上级 3b6b6807
......@@ -167,3 +167,40 @@ class RelPositionalEncoding(PositionalEncoding):
x = x * self.xscale
pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)
# RotaryRelPositionalEncoding is same to RelPositionalEncoding
class ScaledRotaryRelPositionalEncoding(RelPositionalEncoding):
"""Scaled Rotary Relative positional encoding module.
POSITION INTERPOLATION: : https://arxiv.org/pdf/2306.15595v2.pdf
"""
def __init__(self,
d_model: int,
dropout_rate: float,
max_len: int=5000,
scale=1):
"""
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
scale (int): Interpolation max input length to `scale * max_len` positions.
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
self.scale = scale
self.max_len = max_len * scale
position = paddle.arange(
0, self.max_len, dtype=paddle.float32).unsqueeze(1) #[T, 1]
# position interpoloation
position *= 1.0 / self.scale
# base^{-2(i-1)/d)}, i \in (1,2...,d/2)
div_term = paddle.exp(
-paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
(math.log(self.base) / self.d_model))
# [B,T,D]
self.pe[:, :, 0::2] = paddle.sin(position * div_term)
self.pe[:, :, 1::2] = paddle.cos(position * div_term)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册