未验证 提交 36b48e9e 编写于 作者: W wangguanzhong 提交者: GitHub

clean param name (#3799)

上级 4ea5b435
...@@ -55,25 +55,14 @@ class ConvBNLayer(nn.Layer): ...@@ -55,25 +55,14 @@ class ConvBNLayer(nn.Layer):
padding=padding, padding=padding,
groups=num_groups, groups=num_groups,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=conv_lr, learning_rate=conv_lr, initializer=KaimingNormal()),
initializer=KaimingNormal(),
name=name + "_weights"),
bias_attr=False) bias_attr=False)
param_attr = ParamAttr(name=name + "_bn_scale")
bias_attr = ParamAttr(name=name + "_bn_offset")
if norm_type == 'sync_bn': if norm_type == 'sync_bn':
self._batch_norm = nn.SyncBatchNorm( self._batch_norm = nn.SyncBatchNorm(out_channels)
out_channels, weight_attr=param_attr, bias_attr=bias_attr)
else: else:
self._batch_norm = nn.BatchNorm( self._batch_norm = nn.BatchNorm(
out_channels, out_channels, act=None, use_global_stats=False)
act=None,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=False,
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
def forward(self, x): def forward(self, x):
x = self._conv(x) x = self._conv(x)
......
...@@ -100,21 +100,15 @@ class SEBlock(nn.Layer): ...@@ -100,21 +100,15 @@ class SEBlock(nn.Layer):
num_channels, num_channels,
med_ch, med_ch,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=lr_mult, learning_rate=lr_mult, initializer=Uniform(-stdv, stdv)),
initializer=Uniform(-stdv, stdv), bias_attr=ParamAttr(learning_rate=lr_mult))
name=name + "_1_weights"),
bias_attr=ParamAttr(
learning_rate=lr_mult, name=name + "_1_offset"))
stdv = 1.0 / math.sqrt(med_ch * 1.0) stdv = 1.0 / math.sqrt(med_ch * 1.0)
self.excitation = Linear( self.excitation = Linear(
med_ch, med_ch,
num_channels, num_channels,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=lr_mult, learning_rate=lr_mult, initializer=Uniform(-stdv, stdv)),
initializer=Uniform(-stdv, stdv), bias_attr=ParamAttr(learning_rate=lr_mult))
name=name + "_2_weights"),
bias_attr=ParamAttr(
learning_rate=lr_mult, name=name + "_2_offset"))
def forward(self, inputs): def forward(self, inputs):
pool = self.pool2d_gap(inputs) pool = self.pool2d_gap(inputs)
......
...@@ -51,31 +51,23 @@ class ConvNormLayer(nn.Layer): ...@@ -51,31 +51,23 @@ class ConvNormLayer(nn.Layer):
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=1, groups=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Normal(
name=name + "_weights", initializer=Normal( mean=0., std=0.01)),
mean=0., std=0.01)),
bias_attr=False) bias_attr=False)
norm_lr = 0. if freeze_norm else 1. norm_lr = 0. if freeze_norm else 1.
norm_name = name + '_bn'
param_attr = ParamAttr( param_attr = ParamAttr(
name=norm_name + "_scale", learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr( bias_attr = ParamAttr(
name=norm_name + "_offset", learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
global_stats = True if freeze_norm else False global_stats = True if freeze_norm else False
if norm_type in ['bn', 'sync_bn']: if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm( self.norm = nn.BatchNorm(
ch_out, ch_out,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
use_global_stats=global_stats, use_global_stats=global_stats)
moving_mean_name=norm_name + '_mean',
moving_variance_name=norm_name + '_variance')
elif norm_type == 'gn': elif norm_type == 'gn':
self.norm = nn.GroupNorm( self.norm = nn.GroupNorm(
num_groups=norm_groups, num_groups=norm_groups,
...@@ -375,17 +367,13 @@ class SELayer(nn.Layer): ...@@ -375,17 +367,13 @@ class SELayer(nn.Layer):
self.squeeze = Linear( self.squeeze = Linear(
num_channels, num_channels,
med_ch, med_ch,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(med_ch * 1.0) stdv = 1.0 / math.sqrt(med_ch * 1.0)
self.excitation = Linear( self.excitation = Linear(
med_ch, med_ch,
num_filters, num_filters,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
bias_attr=ParamAttr(name=name + '_exc_offset'))
def forward(self, input): def forward(self, input):
pool = self.pool2d_gap(input) pool = self.pool2d_gap(input)
......
...@@ -62,21 +62,17 @@ class ConvBNLayer(nn.Layer): ...@@ -62,21 +62,17 @@ class ConvBNLayer(nn.Layer):
padding=padding, padding=padding,
groups=num_groups, groups=num_groups,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=lr_mult, learning_rate=lr_mult, regularizer=L2Decay(conv_decay)),
regularizer=L2Decay(conv_decay),
name=name + "_weights"),
bias_attr=False) bias_attr=False)
norm_lr = 0. if freeze_norm else lr_mult norm_lr = 0. if freeze_norm else lr_mult
param_attr = ParamAttr( param_attr = ParamAttr(
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay), regularizer=L2Decay(norm_decay),
name=name + "_bn_scale",
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
bias_attr = ParamAttr( bias_attr = ParamAttr(
learning_rate=norm_lr, learning_rate=norm_lr,
regularizer=L2Decay(norm_decay), regularizer=L2Decay(norm_decay),
name=name + "_bn_offset",
trainable=False if freeze_norm else True) trainable=False if freeze_norm else True)
global_stats = True if freeze_norm else False global_stats = True if freeze_norm else False
if norm_type == 'sync_bn': if norm_type == 'sync_bn':
...@@ -88,9 +84,7 @@ class ConvBNLayer(nn.Layer): ...@@ -88,9 +84,7 @@ class ConvBNLayer(nn.Layer):
act=None, act=None,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
use_global_stats=global_stats, use_global_stats=global_stats)
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
norm_params = self.bn.parameters() norm_params = self.bn.parameters()
if freeze_norm: if freeze_norm:
for param in norm_params: for param in norm_params:
...@@ -203,13 +197,9 @@ class SEModule(nn.Layer): ...@@ -203,13 +197,9 @@ class SEModule(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=lr_mult, learning_rate=lr_mult, regularizer=L2Decay(conv_decay)),
regularizer=L2Decay(conv_decay),
name=name + "_1_weights"),
bias_attr=ParamAttr( bias_attr=ParamAttr(
learning_rate=lr_mult, learning_rate=lr_mult, regularizer=L2Decay(conv_decay)))
regularizer=L2Decay(conv_decay),
name=name + "_1_offset"))
self.conv2 = nn.Conv2D( self.conv2 = nn.Conv2D(
in_channels=mid_channels, in_channels=mid_channels,
out_channels=channel, out_channels=channel,
...@@ -217,13 +207,9 @@ class SEModule(nn.Layer): ...@@ -217,13 +207,9 @@ class SEModule(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=lr_mult, learning_rate=lr_mult, regularizer=L2Decay(conv_decay)),
regularizer=L2Decay(conv_decay),
name=name + "_2_weights"),
bias_attr=ParamAttr( bias_attr=ParamAttr(
learning_rate=lr_mult, learning_rate=lr_mult, regularizer=L2Decay(conv_decay)))
regularizer=L2Decay(conv_decay),
name=name + "_2_offset"))
def forward(self, inputs): def forward(self, inputs):
outputs = self.avg_pool(inputs) outputs = self.avg_pool(inputs)
......
...@@ -30,9 +30,7 @@ class ConvBlock(nn.Layer): ...@@ -30,9 +30,7 @@ class ConvBlock(nn.Layer):
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1)
weight_attr=ParamAttr(name=name + "1_weights"),
bias_attr=ParamAttr(name=name + "1_bias"))
self.conv_out_list = [] self.conv_out_list = []
for i in range(1, groups): for i in range(1, groups):
conv_out = self.add_sublayer( conv_out = self.add_sublayer(
...@@ -42,10 +40,7 @@ class ConvBlock(nn.Layer): ...@@ -42,10 +40,7 @@ class ConvBlock(nn.Layer):
out_channels=out_channels, out_channels=out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1))
weight_attr=ParamAttr(
name=name + "{}_weights".format(i + 1)),
bias_attr=ParamAttr(name=name + "{}_bias".format(i + 1))))
self.conv_out_list.append(conv_out) self.conv_out_list.append(conv_out)
self.pool = MaxPool2D( self.pool = MaxPool2D(
......
...@@ -151,12 +151,9 @@ class FCOSHead(nn.Layer): ...@@ -151,12 +151,9 @@ class FCOSHead(nn.Layer):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Normal(
name=conv_cls_name + "_weights", mean=0., std=0.01)),
initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=conv_cls_name + "_bias",
initializer=Constant(value=bias_init_value)))) initializer=Constant(value=bias_init_value))))
conv_reg_name = "fcos_head_reg" conv_reg_name = "fcos_head_reg"
...@@ -168,13 +165,9 @@ class FCOSHead(nn.Layer): ...@@ -168,13 +165,9 @@ class FCOSHead(nn.Layer):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Normal(
name=conv_reg_name + "_weights", mean=0., std=0.01)),
initializer=Normal( bias_attr=ParamAttr(initializer=Constant(value=0))))
mean=0., std=0.01)),
bias_attr=ParamAttr(
name=conv_reg_name + "_bias",
initializer=Constant(value=0))))
conv_centerness_name = "fcos_head_centerness" conv_centerness_name = "fcos_head_centerness"
self.fcos_head_centerness = self.add_sublayer( self.fcos_head_centerness = self.add_sublayer(
...@@ -185,13 +178,9 @@ class FCOSHead(nn.Layer): ...@@ -185,13 +178,9 @@ class FCOSHead(nn.Layer):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Normal(
name=conv_centerness_name + "_weights", mean=0., std=0.01)),
initializer=Normal( bias_attr=ParamAttr(initializer=Constant(value=0))))
mean=0., std=0.01)),
bias_attr=ParamAttr(
name=conv_centerness_name + "_bias",
initializer=Constant(value=0))))
self.scales_regs = [] self.scales_regs = []
for i in range(len(self.fpn_stride)): for i in range(len(self.fpn_stride)):
......
...@@ -51,25 +51,14 @@ class ConvBNLayer(nn.Layer): ...@@ -51,25 +51,14 @@ class ConvBNLayer(nn.Layer):
padding=padding, padding=padding,
groups=num_groups, groups=num_groups,
weight_attr=ParamAttr( weight_attr=ParamAttr(
learning_rate=conv_lr, learning_rate=conv_lr, initializer=KaimingNormal()),
initializer=KaimingNormal(),
name=name + "_weights"),
bias_attr=False) bias_attr=False)
param_attr = ParamAttr(name=name + "_bn_scale")
bias_attr = ParamAttr(name=name + "_bn_offset")
if norm_type == 'sync_bn': if norm_type == 'sync_bn':
self._batch_norm = nn.SyncBatchNorm( self._batch_norm = nn.SyncBatchNorm(out_channels)
out_channels, weight_attr=param_attr, bias_attr=bias_attr)
else: else:
self._batch_norm = nn.BatchNorm( self._batch_norm = nn.BatchNorm(
out_channels, out_channels, act=None, use_global_stats=False)
act=None,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=False,
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
def forward(self, x): def forward(self, x):
x = self._conv(x) x = self._conv(x)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ..shape_spec import ShapeSpec from ..shape_spec import ShapeSpec
...@@ -53,7 +52,6 @@ class HRFPN(nn.Layer): ...@@ -53,7 +52,6 @@ class HRFPN(nn.Layer):
in_channels=in_channel, in_channels=in_channel,
out_channels=out_channel, out_channels=out_channel,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr(name='hrfpn_reduction_weights'),
bias_attr=False) bias_attr=False)
if share_conv: if share_conv:
...@@ -62,7 +60,6 @@ class HRFPN(nn.Layer): ...@@ -62,7 +60,6 @@ class HRFPN(nn.Layer):
out_channels=out_channel, out_channels=out_channel,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr(name='fpn_conv_weights'),
bias_attr=False) bias_attr=False)
else: else:
self.fpn_conv = [] self.fpn_conv = []
...@@ -75,7 +72,6 @@ class HRFPN(nn.Layer): ...@@ -75,7 +72,6 @@ class HRFPN(nn.Layer):
out_channels=out_channel, out_channels=out_channel,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr(name=conv_name + "_weights"),
bias_attr=False)) bias_attr=False))
self.fpn_conv.append(conv) self.fpn_conv.append(conv)
......
...@@ -92,9 +92,7 @@ class JDEEmbeddingHead(nn.Layer): ...@@ -92,9 +92,7 @@ class JDEEmbeddingHead(nn.Layer):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
weight_attr=ParamAttr(name=name + '.conv.weights'), bias_attr=ParamAttr(regularizer=L2Decay(0.))))
bias_attr=ParamAttr(
name=name + '.conv.bias', regularizer=L2Decay(0.))))
self.identify_outputs.append(identify_output) self.identify_outputs.append(identify_output)
loss_p_cls = self.add_sublayer('cls.{}'.format(i), LossParam(-4.15)) loss_p_cls = self.add_sublayer('cls.{}'.format(i), LossParam(-4.15))
......
...@@ -89,16 +89,12 @@ class PCBPyramid(nn.Layer): ...@@ -89,16 +89,12 @@ class PCBPyramid(nn.Layer):
if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]): if idx_branches >= sum(self.num_in_each_level[0:idx_levels + 1]):
idx_levels += 1 idx_levels += 1
name = "Linear_branch_id_{}".format(idx_branches)
fc = nn.Linear( fc = nn.Linear(
in_features=num_conv_out_channels, in_features=num_conv_out_channels,
out_features=self.num_classes, out_features=self.num_classes,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=Normal(
name=name + "_weights", mean=0., std=0.001)),
initializer=Normal( bias_attr=ParamAttr(initializer=Constant(value=0.)))
mean=0., std=0.001)),
bias_attr=ParamAttr(
name=name + "_bias", initializer=Constant(value=0.)))
pyramid_fc_list.append(fc) pyramid_fc_list.append(fc)
return pyramid_conv_list, pyramid_fc_list return pyramid_conv_list, pyramid_fc_list
......
...@@ -50,23 +50,13 @@ class ConvBNLayer(nn.Layer): ...@@ -50,23 +50,13 @@ class ConvBNLayer(nn.Layer):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name + "_weights",
learning_rate=lr_mult, learning_rate=lr_mult,
initializer=Normal(0, math.sqrt(2. / conv_stdv))), initializer=Normal(0, math.sqrt(2. / conv_stdv))),
bias_attr=False, bias_attr=False,
data_format=data_format) data_format=data_format)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm( self._batch_norm = nn.BatchNorm(
num_filters, num_filters, act=act, data_layout=data_format)
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(bn_name + "_offset"),
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance",
data_layout=data_format)
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册