diff --git a/dygraph/paddleseg/models/backbones/hrnet.py b/dygraph/paddleseg/models/backbones/hrnet.py index 6a1d75040b9286af63d06a259c8b885d71338275..9b626a89e2417cec162ff7adc3c5284b79fc97fb 100644 --- a/dygraph/paddleseg/models/backbones/hrnet.py +++ b/dygraph/paddleseg/models/backbones/hrnet.py @@ -174,16 +174,12 @@ class HRNet(nn.Layer): return [x] def init_weight(self): - params = self.parameters() - for param in params: - param_name = param.name - if 'batch_norm' in param_name: - if 'w_0' in param_name: - param_init.constant_init(param, value=1.0) - elif 'b_0' in param_name: - param_init.constant_init(param, value=0.0) - if 'conv' in param_name and 'w_0' in param_name: - param_init.normal_init(param, scale=0.001) + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2d): + param_init.normal_init(layer.weight, scale=0.001) + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) if self.pretrained is not None: utils.load_pretrained_model(self, self.pretrained) diff --git a/dygraph/paddleseg/models/fcn.py b/dygraph/paddleseg/models/fcn.py index 348d5f0767d9983a3fe33b45ff302569b1cdff65..6ba694c90fa1b616d905990c824b9c49b8a42855 100644 --- a/dygraph/paddleseg/models/fcn.py +++ b/dygraph/paddleseg/models/fcn.py @@ -110,16 +110,12 @@ class FCNHead(nn.Layer): return logit_list def init_weight(self): - params = self.parameters() - for param in params: - param_name = param.name - if 'batch_norm' in param_name: - if 'w_0' in param_name: - param_init.constant_init(param, value=1.0) - elif 'b_0' in param_name: - param_init.constant_init(param, value=0.0) - if 'conv' in param_name and 'w_0' in param_name: - param_init.normal_init(param, scale=0.001) + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2d): + param_init.normal_init(layer.weight, scale=0.001) + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) @manager.MODELS.add_component