提交 1bc4acfd 编写于 作者: H Hui Zhang

x.shape bug when is -1

上级 70142d0c
......@@ -67,10 +67,10 @@ class PositionalEncoding(nn.Layer):
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...)
"""
T = x.shape[1]
assert offset + x.shape[1] < self.max_len
#TODO(Hui Zhang): using T = x.size(1), __getitem__ not support Tensor
pos_emb = self.pe[:, offset:offset + T]
# when frozen graph, the x.shape[1] is -1, offset=0
# result pos_emb is [1, 4999, D] not [1, 5000, D]
pos_emb = self.pe[:, offset:offset + x.shape[1]]
x = x * self.xscale + pos_emb
return self.dropout(x), self.dropout(pos_emb)
......@@ -116,6 +116,7 @@ class RelPositionalEncoding(PositionalEncoding):
"""
assert offset + x.shape[1] < self.max_len
x = x * self.xscale
#TODO(Hui Zhang): using x.size(1), __getitem__ not support Tensor
# when frozen graph, the x.shape[1] is -1, offset=0
# result pos_emb is [1, 4999, D] not [1, 5000, D]
pos_emb = self.pe[:, offset:offset + x.shape[1]]
return self.dropout(x), self.dropout(pos_emb)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册