提交 da1b5884 编写于 作者: C cuicheng01

Update levit.py

上级 18592f5b
...@@ -426,7 +426,8 @@ class LeViT(nn.Layer): ...@@ -426,7 +426,8 @@ class LeViT(nn.Layer):
x = paddle.transpose(x, perm=[0, 2, 1]) x = paddle.transpose(x, perm=[0, 2, 1])
x = self.blocks(x) x = self.blocks(x)
x = x.mean(1) x = x.mean(1)
x = paddle.reshape(x, [-1, x.shape[-1]])
x = paddle.reshape(x, [-1, self.embed_dim[-1]])
if self.distillation: if self.distillation:
x = self.head(x), self.head_dist(x) x = self.head(x), self.head_dist(x)
if not self.training: if not self.training:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册