From 7a64fd68bfe81d675259a3cc8c1c49647afa08dd Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 31 Dec 2019 14:08:57 +0800 Subject: [PATCH] Replace `FC` with `Linear` --- mnist.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mnist.py b/mnist.py index 29094f6..787499a 100644 --- a/mnist.py +++ b/mnist.py @@ -19,7 +19,7 @@ import numpy as np import paddle from paddle import fluid from paddle.fluid.optimizer import MomentumOptimizer -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from model import Model, shape_hints @@ -83,21 +83,24 @@ class MNIST(Model): pool_2_shape = 50 * 4 * 4 SIZE = 10 scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 - self._fc = FC(self.full_name(), - 10, - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.NormalInitializer( - loc=0.0, scale=scale)), - act="softmax") + self._fc = Linear(800, + 10, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale)), + act="softmax") @shape_hints(inputs=[None, 1, 28, 28]) def forward(self, inputs): if self.mode == 'test': # XXX demo purpose x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_2(x) + x = fluid.layers.flatten(x, axis=1) + x = self._fc(x) else: x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_2(x) + x = fluid.layers.flatten(x, axis=1) x = self._fc(x) return x -- GitLab