提交 9db7262f 编写于 作者: C cuicheng01

fix LeViT_384 train bugs

上级 f6c5625f
......@@ -128,10 +128,10 @@ class Residual(nn.Layer):
def forward(self, x):
if self.training and self.drop > 0:
return paddle.add(x,
self.m(x) * paddle.rand(
x.size(0), 1, 1, device=x.device).ge_(
self.drop).div(1 - self.drop).detach())
y = paddle.rand(
shape=[x.shape[0], 1, 1]).__ge__(self.drop).astype("float32")
y = y.divide(paddle.full_like(y, 1 - self.drop))
return paddle.add(x, y)
else:
return paddle.add(x, self.m(x))
......@@ -221,8 +221,6 @@ class Subsample(nn.Layer):
def forward(self, x):
B, N, C = x.shape
#x = paddle.reshape(x, [B, self.resolution, self.resolution,
# C])[:, ::self.stride, ::self.stride]
x = paddle.reshape(x, [B, self.resolution, self.resolution, C])
end1, end2 = x.shape[1], x.shape[2]
x = x[:, 0:end1:self.stride, 0:end2:self.stride]
......@@ -428,7 +426,7 @@ class LeViT(nn.Layer):
x = paddle.transpose(x, perm=[0, 2, 1])
x = self.blocks(x)
x = x.mean(1)
x = paddle.reshape(x, [-1, 384])
x = paddle.reshape(x, [-1, x.shape[-1]])
if self.distillation:
x = self.head(x), self.head_dist(x)
if not self.training:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册