提交 894429c9 编写于 作者: S songyouwei 提交者: hong

update build_once for mnist (#4103)

test=develop
上级 66f6039f
...@@ -21,7 +21,7 @@ import os ...@@ -21,7 +21,7 @@ import os
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
...@@ -41,7 +41,6 @@ def parse_args(): ...@@ -41,7 +41,6 @@ def parse_args():
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
name_scope,
num_channels, num_channels,
num_filters, num_filters,
filter_size, filter_size,
...@@ -58,10 +57,10 @@ class SimpleImgConvPool(fluid.dygraph.Layer): ...@@ -58,10 +57,10 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
use_cudnn=False, use_cudnn=False,
param_attr=None, param_attr=None,
bias_attr=None): bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope) super(SimpleImgConvPool, self).__init__()
self._conv2d = Conv2D( self._conv2d = Conv2D(
self.full_name(), num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=conv_stride, stride=conv_stride,
...@@ -74,7 +73,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer): ...@@ -74,7 +73,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
use_cudnn=use_cudnn) use_cudnn=use_cudnn)
self._pool2d = Pool2D( self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size, pool_size=pool_size,
pool_type=pool_type, pool_type=pool_type,
pool_stride=pool_stride, pool_stride=pool_stride,
...@@ -89,20 +87,19 @@ class SimpleImgConvPool(fluid.dygraph.Layer): ...@@ -89,20 +87,19 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
class MNIST(fluid.dygraph.Layer): class MNIST(fluid.dygraph.Layer):
def __init__(self, name_scope): def __init__(self):
super(MNIST, self).__init__(name_scope) super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConvPool( self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu") 1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool( self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu") 20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4 self.pool_2_shape = 50 * 4 * 4
SIZE = 10 SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(), self._fc = Linear(self.pool_2_shape, 10,
10,
param_attr=fluid.param_attr.ParamAttr( param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer( initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)), loc=0.0, scale=scale)),
...@@ -111,6 +108,7 @@ class MNIST(fluid.dygraph.Layer): ...@@ -111,6 +108,7 @@ class MNIST(fluid.dygraph.Layer):
def forward(self, inputs, label=None): def forward(self, inputs, label=None):
x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x) x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x) x = self._fc(x)
if label is not None: if label is not None:
acc = fluid.layers.accuracy(input=x, label=label) acc = fluid.layers.accuracy(input=x, label=label)
...@@ -148,7 +146,7 @@ def inference_mnist(): ...@@ -148,7 +146,7 @@ def inference_mnist():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0) if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist") mnist_infer = MNIST()
# load checkpoint # load checkpoint
model_dict, _ = fluid.load_dygraph("save_temp") model_dict, _ = fluid.load_dygraph("save_temp")
mnist_infer.set_dict(model_dict) mnist_infer.set_dict(model_dict)
...@@ -188,7 +186,7 @@ def train_mnist(args): ...@@ -188,7 +186,7 @@ def train_mnist(args):
if args.use_data_parallel: if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context() strategy = fluid.dygraph.parallel.prepare_context()
mnist = MNIST("mnist") mnist = MNIST()
adam = AdamOptimizer(learning_rate=0.001) adam = AdamOptimizer(learning_rate=0.001)
if args.use_data_parallel: if args.use_data_parallel:
mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy) mnist = fluid.dygraph.parallel.DataParallel(mnist, strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册