提交 0b3b621a 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix concat error when fp16

上级 4e988692
......@@ -105,7 +105,7 @@ class Mlp(nn.Layer):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
......@@ -319,7 +319,7 @@ class PatchEmbed(nn.Layer):
fan_out = self.out_chans
fan_in = self.patch_size[0] * self.patch_size[1] * self.in_chans
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.XavierUniform(fan_in, fan_out)) # MAE
initializer=nn.initializer.XavierUniform(fan_in, fan_out)) # MAE
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
......@@ -566,9 +566,9 @@ class VisionTransformer(nn.Layer):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.shape
cls_tokens = self.cls_token.expand(
[batch_size, -1,
-1]) # stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand([
batch_size, -1, -1
]).astype(x.dtype) # stole cls_tokens impl from Phil Wang, thanks
x = paddle.concat((cls_tokens, x), axis=1)
if self.pos_embed is not None:
if self.use_abs_pos_emb:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册