提交 894429c9 编写于 作者: S songyouwei 提交者: hong

update build_once for mnist (#4103)

test=develop
上级 66f6039f
......@@ -21,7 +21,7 @@ import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable
......@@ -41,7 +41,6 @@ def parse_args():
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
......@@ -58,10 +57,10 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope)
super(SimpleImgConvPool, self).__init__()
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
......@@ -74,7 +73,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
......@@ -89,20 +87,19 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
class MNIST(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(MNIST, self).__init__(name_scope)
def __init__(self):
super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
self.pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(),
10,
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(self.pool_2_shape, 10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
......@@ -111,6 +108,7 @@ class MNIST(fluid.dygraph.Layer):
def forward(self, inputs, label=None):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
......@@ -148,7 +146,7 @@ def inference_mnist():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist")
mnist_infer = MNIST()
# load checkpoint
model_dict, _ = fluid.load_dygraph("save_temp")
mnist_infer.set_dict(model_dict)
......@@ -188,7 +186,7 @@ def train_mnist(args):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
mnist = MNIST("mnist")
mnist = MNIST()
adam = AdamOptimizer(learning_rate=0.001)
if args.use_data_parallel:
mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册