From 62e3a25275fd127ea6acf3ae556ec36bfb5e2102 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Fri, 21 Aug 2020 14:48:30 +0800 Subject: [PATCH] update SyncBatchNorm --- dygraph/models/hrnet.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dygraph/models/hrnet.py b/dygraph/models/hrnet.py index 3c3d7cd4..2019900d 100644 --- a/dygraph/models/hrnet.py +++ b/dygraph/models/hrnet.py @@ -216,26 +216,25 @@ class ConvBNLayer(fluid.dygraph.Layer): stride=stride, padding=(filter_size - 1) // 2, groups=groups, - act=None, param_attr=ParamAttr( initializer=Normal(scale=0.001), name=name + "_weights"), bias_attr=False) bn_name = name + '_bn' self._batch_norm = BatchNorm( num_filters, - act=act, - param_attr=ParamAttr( + weight_attr=ParamAttr( name=bn_name + '_scale', initializer=fluid.initializer.Constant(1.0)), bias_attr=ParamAttr( bn_name + '_offset', - initializer=fluid.initializer.Constant(0.0)), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') + initializer=fluid.initializer.Constant(0.0))) + self.act = act def forward(self, input): y = self._conv(input) y = self._batch_norm(y) + if self.act == 'relu': + y = fluid.layers.relu(y) return y -- GitLab