提交 70c0212a 编写于 作者: C chenguowei01

update param_init

上级 20e46282
...@@ -174,16 +174,12 @@ class HRNet(nn.Layer): ...@@ -174,16 +174,12 @@ class HRNet(nn.Layer):
return [x] return [x]
def init_weight(self): def init_weight(self):
params = self.parameters() for layer in self.sublayers():
for param in params: if isinstance(layer, nn.Conv2d):
param_name = param.name param_init.normal_init(layer.weight, scale=0.001)
if 'batch_norm' in param_name: elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
if 'w_0' in param_name: param_init.constant_init(layer.weight, value=1.0)
param_init.constant_init(param, value=1.0) param_init.constant_init(layer.bias, value=0.0)
elif 'b_0' in param_name:
param_init.constant_init(param, value=0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if self.pretrained is not None: if self.pretrained is not None:
utils.load_pretrained_model(self, self.pretrained) utils.load_pretrained_model(self, self.pretrained)
......
...@@ -110,16 +110,12 @@ class FCNHead(nn.Layer): ...@@ -110,16 +110,12 @@ class FCNHead(nn.Layer):
return logit_list return logit_list
def init_weight(self): def init_weight(self):
params = self.parameters() for layer in self.sublayers():
for param in params: if isinstance(layer, nn.Conv2d):
param_name = param.name param_init.normal_init(layer.weight, scale=0.001)
if 'batch_norm' in param_name: elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
if 'w_0' in param_name: param_init.constant_init(layer.weight, value=1.0)
param_init.constant_init(param, value=1.0) param_init.constant_init(layer.bias, value=0.0)
elif 'b_0' in param_name:
param_init.constant_init(param, value=0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
@manager.MODELS.add_component @manager.MODELS.add_component
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册