From 70c0212a3f3a8221003715ec72d0c2871e552ffb Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 23 Sep 2020 12:05:24 +0800 Subject: [PATCH] update param_init --- dygraph/paddleseg/models/backbones/hrnet.py | 16 ++++++---------- dygraph/paddleseg/models/fcn.py | 16 ++++++---------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/dygraph/paddleseg/models/backbones/hrnet.py b/dygraph/paddleseg/models/backbones/hrnet.py index 6a1d7504..9b626a89 100644 --- a/dygraph/paddleseg/models/backbones/hrnet.py +++ b/dygraph/paddleseg/models/backbones/hrnet.py @@ -174,16 +174,12 @@ class HRNet(nn.Layer): return [x] def init_weight(self): - params = self.parameters() - for param in params: - param_name = param.name - if 'batch_norm' in param_name: - if 'w_0' in param_name: - param_init.constant_init(param, value=1.0) - elif 'b_0' in param_name: - param_init.constant_init(param, value=0.0) - if 'conv' in param_name and 'w_0' in param_name: - param_init.normal_init(param, scale=0.001) + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2d): + param_init.normal_init(layer.weight, scale=0.001) + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) if self.pretrained is not None: utils.load_pretrained_model(self, self.pretrained) diff --git a/dygraph/paddleseg/models/fcn.py b/dygraph/paddleseg/models/fcn.py index 348d5f07..6ba694c9 100644 --- a/dygraph/paddleseg/models/fcn.py +++ b/dygraph/paddleseg/models/fcn.py @@ -110,16 +110,12 @@ class FCNHead(nn.Layer): return logit_list def init_weight(self): - params = self.parameters() - for param in params: - param_name = param.name - if 'batch_norm' in param_name: - if 'w_0' in param_name: - param_init.constant_init(param, value=1.0) - elif 'b_0' in param_name: - param_init.constant_init(param, value=0.0) - if 'conv' in param_name and 'w_0' in param_name: - param_init.normal_init(param, scale=0.001) + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2d): + param_init.normal_init(layer.weight, scale=0.001) + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) @manager.MODELS.add_component -- GitLab