未验证 提交 8dd6418a 编写于 作者: L littletomatodonkey 提交者: GitHub

fix reshape to flatten (#513)

* fix reshape to flatten

* fix reshape
上级 b401e10b
...@@ -278,7 +278,7 @@ class DenseNet(nn.Layer): ...@@ -278,7 +278,7 @@ class DenseNet(nn.Layer):
conv = self.batch_norm(conv) conv = self.batch_norm(conv)
y = self.pool2d_avg(conv) y = self.pool2d_avg(conv)
y = paddle.reshape(y, shape=[-1, y.shape[1]]) y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y) y = self.out(y)
return y return y
......
...@@ -328,7 +328,7 @@ class DPN(nn.Layer): ...@@ -328,7 +328,7 @@ class DPN(nn.Layer):
conv5_x_x = self.conv5_x_x_bn(conv5_x_x) conv5_x_x = self.conv5_x_x_bn(conv5_x_x)
y = self.pool2d_avg(conv5_x_x) y = self.pool2d_avg(conv5_x_x)
y = paddle.reshape(y, shape=[0, -1]) y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y) y = self.out(y)
return y return y
......
...@@ -241,7 +241,7 @@ class MobileNet(nn.Layer): ...@@ -241,7 +241,7 @@ class MobileNet(nn.Layer):
for block in self.block_list: for block in self.block_list:
y = block(y) y = block(y)
y = self.pool2d_avg(y) y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, int(1024 * self.scale)]) y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y) y = self.out(y)
return y return y
......
...@@ -213,7 +213,7 @@ class MobileNet(nn.Layer): ...@@ -213,7 +213,7 @@ class MobileNet(nn.Layer):
y = block(y) y = block(y)
y = self.conv9(y, if_act=True) y = self.conv9(y, if_act=True)
y = self.pool2d_avg(y) y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self.out_c]) y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y) y = self.out(y)
return y return y
......
...@@ -169,7 +169,7 @@ class MobileNetV3(nn.Layer): ...@@ -169,7 +169,7 @@ class MobileNetV3(nn.Layer):
x = self.last_conv(x) x = self.last_conv(x)
x = hard_swish(x) x = hard_swish(x)
x = self.dropout(x) x = self.dropout(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) x = paddle.flatten(x, start_axis=1, stop_axis=-1)
x = self.out(x) x = self.out(x)
return x return x
......
...@@ -279,7 +279,7 @@ class ShuffleNet(Layer): ...@@ -279,7 +279,7 @@ class ShuffleNet(Layer):
y = inv(y) y = inv(y)
y = self._last_conv(y) y = self._last_conv(y)
y = self._pool2d_avg(y) y = self._pool2d_avg(y)
y = reshape(y, shape=[-1, self._out_c]) y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self._fc(y) y = self._fc(y)
return y return y
......
...@@ -112,9 +112,7 @@ class VGGNet(nn.Layer): ...@@ -112,9 +112,7 @@ class VGGNet(nn.Layer):
x = self._conv_block_3(x) x = self._conv_block_3(x)
x = self._conv_block_4(x) x = self._conv_block_4(x)
x = self._conv_block_5(x) x = self._conv_block_5(x)
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
_, c, h, w = list(x.shape)
x = paddle.reshape(x, [-1, c * h * w])
x = self._fc1(x) x = self._fc1(x)
x = F.relu(x) x = F.relu(x)
x = self._drop(x) x = self._drop(x)
......
...@@ -305,7 +305,7 @@ class ExitFlow(nn.Layer): ...@@ -305,7 +305,7 @@ class ExitFlow(nn.Layer):
conv2 = self._conv_2(conv1) conv2 = self._conv_2(conv1)
conv2 = F.relu(conv2) conv2 = F.relu(conv2)
pool = self._pool(conv2) pool = self._pool(conv2)
pool = paddle.reshape(pool, [-1, pool.shape[1]]) pool = paddle.flatten(pool, start_axis=1, stop_axis=-1)
out = self._out(pool) out = self._out(pool)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册