未验证 提交 8dd6418a 编写于 作者: L littletomatodonkey 提交者: GitHub

fix reshape to flatten (#513)

* fix reshape to flatten

* fix reshape
上级 b401e10b
......@@ -278,7 +278,7 @@ class DenseNet(nn.Layer):
conv = self.batch_norm(conv)
y = self.pool2d_avg(conv)
y = paddle.reshape(y, shape=[-1, y.shape[1]])
y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y)
return y
......
......@@ -328,7 +328,7 @@ class DPN(nn.Layer):
conv5_x_x = self.conv5_x_x_bn(conv5_x_x)
y = self.pool2d_avg(conv5_x_x)
y = paddle.reshape(y, shape=[0, -1])
y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y)
return y
......
......@@ -241,7 +241,7 @@ class MobileNet(nn.Layer):
for block in self.block_list:
y = block(y)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, int(1024 * self.scale)])
y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y)
return y
......
......@@ -213,7 +213,7 @@ class MobileNet(nn.Layer):
y = block(y)
y = self.conv9(y, if_act=True)
y = self.pool2d_avg(y)
y = paddle.reshape(y, shape=[-1, self.out_c])
y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self.out(y)
return y
......
......@@ -169,7 +169,7 @@ class MobileNetV3(nn.Layer):
x = self.last_conv(x)
x = hard_swish(x)
x = self.dropout(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
x = self.out(x)
return x
......
......@@ -279,7 +279,7 @@ class ShuffleNet(Layer):
y = inv(y)
y = self._last_conv(y)
y = self._pool2d_avg(y)
y = reshape(y, shape=[-1, self._out_c])
y = paddle.flatten(y, start_axis=1, stop_axis=-1)
y = self._fc(y)
return y
......
......@@ -112,9 +112,7 @@ class VGGNet(nn.Layer):
x = self._conv_block_3(x)
x = self._conv_block_4(x)
x = self._conv_block_5(x)
_, c, h, w = list(x.shape)
x = paddle.reshape(x, [-1, c * h * w])
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
x = self._fc1(x)
x = F.relu(x)
x = self._drop(x)
......
......@@ -305,7 +305,7 @@ class ExitFlow(nn.Layer):
conv2 = self._conv_2(conv1)
conv2 = F.relu(conv2)
pool = self._pool(conv2)
pool = paddle.reshape(pool, [-1, pool.shape[1]])
pool = paddle.flatten(pool, start_axis=1, stop_axis=-1)
out = self._out(pool)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册