From a01113c3380c3d21ea483f58ea4ab345ad962541 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Thu, 18 Jun 2020 12:09:57 +0800 Subject: [PATCH] Add relu layer for lenet (#24874) * add relu for lenet, test=develop * fix test model, test=develop --- python/paddle/incubate/hapi/tests/test_model.py | 5 +++-- python/paddle/incubate/hapi/vision/models/lenet.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/incubate/hapi/tests/test_model.py b/python/paddle/incubate/hapi/tests/test_model.py index e49ec5651f..9753c1838d 100644 --- a/python/paddle/incubate/hapi/tests/test_model.py +++ b/python/paddle/incubate/hapi/tests/test_model.py @@ -23,8 +23,7 @@ import shutil import tempfile from paddle import fluid -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear -from paddle.fluid.dygraph.container import Sequential +from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential from paddle.fluid.dygraph.base import to_variable from paddle.incubate.hapi.model import Model, Input, set_device @@ -42,9 +41,11 @@ class LeNetDygraph(fluid.dygraph.Layer): self.features = Sequential( Conv2D( 1, 6, 3, stride=1, padding=1), + ReLU(), Pool2D(2, 'max', 2), Conv2D( 6, 16, 5, stride=1, padding=0), + ReLU(), Pool2D(2, 'max', 2)) if num_classes > 0: diff --git a/python/paddle/incubate/hapi/vision/models/lenet.py b/python/paddle/incubate/hapi/vision/models/lenet.py index c49addcb1f..45094119f0 100644 --- a/python/paddle/incubate/hapi/vision/models/lenet.py +++ b/python/paddle/incubate/hapi/vision/models/lenet.py @@ -13,8 +13,7 @@ #limitations under the License. import paddle.fluid as fluid -from paddle.fluid.dygraph.nn import Conv2D, BatchNorm, Pool2D, Linear -from paddle.fluid.dygraph.container import Sequential +from paddle.nn import Conv2D, Pool2D, Linear, ReLU, Sequential from ...model import Model @@ -44,9 +43,11 @@ class LeNet(Model): self.features = Sequential( Conv2D( 1, 6, 3, stride=1, padding=1), + ReLU(), Pool2D(2, 'max', 2), Conv2D( 6, 16, 5, stride=1, padding=0), + ReLU(), Pool2D(2, 'max', 2)) if num_classes > 0: -- GitLab