提交 9db7262f 编写于 作者: C cuicheng01

fix LeViT_384 train bugs

上级 f6c5625f
...@@ -128,10 +128,10 @@ class Residual(nn.Layer): ...@@ -128,10 +128,10 @@ class Residual(nn.Layer):
def forward(self, x): def forward(self, x):
if self.training and self.drop > 0: if self.training and self.drop > 0:
return paddle.add(x, y = paddle.rand(
self.m(x) * paddle.rand( shape=[x.shape[0], 1, 1]).__ge__(self.drop).astype("float32")
x.size(0), 1, 1, device=x.device).ge_( y = y.divide(paddle.full_like(y, 1 - self.drop))
self.drop).div(1 - self.drop).detach()) return paddle.add(x, y)
else: else:
return paddle.add(x, self.m(x)) return paddle.add(x, self.m(x))
...@@ -221,8 +221,6 @@ class Subsample(nn.Layer): ...@@ -221,8 +221,6 @@ class Subsample(nn.Layer):
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
#x = paddle.reshape(x, [B, self.resolution, self.resolution,
# C])[:, ::self.stride, ::self.stride]
x = paddle.reshape(x, [B, self.resolution, self.resolution, C]) x = paddle.reshape(x, [B, self.resolution, self.resolution, C])
end1, end2 = x.shape[1], x.shape[2] end1, end2 = x.shape[1], x.shape[2]
x = x[:, 0:end1:self.stride, 0:end2:self.stride] x = x[:, 0:end1:self.stride, 0:end2:self.stride]
...@@ -428,7 +426,7 @@ class LeViT(nn.Layer): ...@@ -428,7 +426,7 @@ class LeViT(nn.Layer):
x = paddle.transpose(x, perm=[0, 2, 1]) x = paddle.transpose(x, perm=[0, 2, 1])
x = self.blocks(x) x = self.blocks(x)
x = x.mean(1) x = x.mean(1)
x = paddle.reshape(x, [-1, 384]) x = paddle.reshape(x, [-1, x.shape[-1]])
if self.distillation: if self.distillation:
x = self.head(x), self.head_dist(x) x = self.head(x), self.head_dist(x)
if not self.training: if not self.training:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册