未验证 提交 a7aa1452 编写于 作者: L littletomatodonkey 提交者: GitHub

fix repvgg eval (#677)

* fix repvgg eval

* fix dp training

* fix single card train
上级 2e62e2e2
......@@ -4,19 +4,39 @@ import numpy as np
__all__ = [
'RepVGG',
'RepVGG_A0', 'RepVGG_A1', 'RepVGG_A2',
'RepVGG_B0', 'RepVGG_B1', 'RepVGG_B2', 'RepVGG_B3',
'RepVGG_B1g2', 'RepVGG_B1g4',
'RepVGG_B2g2', 'RepVGG_B2g4',
'RepVGG_B3g2', 'RepVGG_B3g4',
'RepVGG_A0',
'RepVGG_A1',
'RepVGG_A2',
'RepVGG_B0',
'RepVGG_B1',
'RepVGG_B2',
'RepVGG_B3',
'RepVGG_B1g2',
'RepVGG_B1g4',
'RepVGG_B2g2',
'RepVGG_B2g4',
'RepVGG_B3g2',
'RepVGG_B3g4',
]
class ConvBN(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1):
super(ConvBN, self).__init__()
self.conv = nn.Conv2D(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias_attr=False)
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=False)
self.bn = nn.BatchNorm2D(num_features=out_channels)
def forward(self, x):
......@@ -26,9 +46,15 @@ class ConvBN(nn.Layer):
class RepVGGBlock(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros'):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros'):
super(RepVGGBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -47,11 +73,22 @@ class RepVGGBlock(nn.Layer):
self.nonlinearity = nn.ReLU()
self.rbr_identity = nn.BatchNorm2D(
num_features=in_channels) if out_channels == in_channels and stride == 1 else None
self.rbr_dense = ConvBN(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
self.rbr_1x1 = ConvBN(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, stride=stride, padding=padding_11, groups=groups)
num_features=in_channels
) if out_channels == in_channels and stride == 1 else None
self.rbr_dense = ConvBN(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups)
self.rbr_1x1 = ConvBN(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding_11,
groups=groups)
def forward(self, inputs):
if not self.training:
......@@ -61,12 +98,20 @@ class RepVGGBlock(nn.Layer):
id_out = 0
else:
id_out = self.rbr_identity(inputs)
return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
return self.nonlinearity(
self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
def eval(self):
if not hasattr(self, 'rbr_reparam'):
self.rbr_reparam = nn.Conv2D(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride,
padding=self.padding, dilation=self.dilation, groups=self.groups, padding_mode=self.padding_mode)
self.rbr_reparam = nn.Conv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
padding_mode=self.padding_mode)
self.training = False
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam.weight.set_value(kernel)
......@@ -78,7 +123,8 @@ class RepVGGBlock(nn.Layer):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
......@@ -117,8 +163,11 @@ class RepVGGBlock(nn.Layer):
class RepVGG(nn.Layer):
def __init__(self, num_blocks, width_multiplier=None, override_groups_map=None, class_dim=1000):
def __init__(self,
num_blocks,
width_multiplier=None,
override_groups_map=None,
class_dim=1000):
super(RepVGG, self).__init__()
assert len(width_multiplier) == 4
......@@ -129,7 +178,11 @@ class RepVGG(nn.Layer):
self.in_planes = min(64, int(64 * width_multiplier[0]))
self.stage0 = RepVGGBlock(
in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1)
in_channels=3,
out_channels=self.in_planes,
kernel_size=3,
stride=2,
padding=1)
self.cur_layer_idx = 1
self.stage1 = self._make_stage(
int(64 * width_multiplier[0]), num_blocks[0], stride=2)
......@@ -143,16 +196,28 @@ class RepVGG(nn.Layer):
self.linear = nn.Linear(int(512 * width_multiplier[3]), class_dim)
def _make_stage(self, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
strides = [stride] + [1] * (num_blocks - 1)
blocks = []
for stride in strides:
cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,
stride=stride, padding=1, groups=cur_groups))
blocks.append(
RepVGGBlock(
in_channels=self.in_planes,
out_channels=planes,
kernel_size=3,
stride=stride,
padding=1,
groups=cur_groups))
self.in_planes = planes
self.cur_layer_idx += 1
return nn.Sequential(*blocks)
def eval(self):
self.training = False
for layer in self.sublayers():
layer.training = False
layer.eval()
def forward(self, x):
out = self.stage0(x)
out = self.stage1(out)
......@@ -171,65 +236,104 @@ g4_map = {l: 4 for l in optional_groupwise_layers}
def RepVGG_A0(**kwargs):
return RepVGG(num_blocks=[2, 4, 14, 1],
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, **kwargs)
return RepVGG(
num_blocks=[2, 4, 14, 1],
width_multiplier=[0.75, 0.75, 0.75, 2.5],
override_groups_map=None,
**kwargs)
def RepVGG_A1(**kwargs):
return RepVGG(num_blocks=[2, 4, 14, 1],
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, **kwargs)
return RepVGG(
num_blocks=[2, 4, 14, 1],
width_multiplier=[1, 1, 1, 2.5],
override_groups_map=None,
**kwargs)
def RepVGG_A2(**kwargs):
return RepVGG(num_blocks=[2, 4, 14, 1],
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, **kwargs)
return RepVGG(
num_blocks=[2, 4, 14, 1],
width_multiplier=[1.5, 1.5, 1.5, 2.75],
override_groups_map=None,
**kwargs)
def RepVGG_B0(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[1, 1, 1, 2.5],
override_groups_map=None,
**kwargs)
def RepVGG_B1(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[2, 2, 2, 4], override_groups_map=None, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[2, 2, 2, 4],
override_groups_map=None,
**kwargs)
def RepVGG_B1g2(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[2, 2, 2, 4],
override_groups_map=g2_map,
**kwargs)
def RepVGG_B1g4(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[2, 2, 2, 4],
override_groups_map=g4_map,
**kwargs)
def RepVGG_B2(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[2.5, 2.5, 2.5, 5],
override_groups_map=None,
**kwargs)
def RepVGG_B2g2(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[2.5, 2.5, 2.5, 5],
override_groups_map=g2_map,
**kwargs)
def RepVGG_B2g4(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[2.5, 2.5, 2.5, 5],
override_groups_map=g4_map,
**kwargs)
def RepVGG_B3(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[3, 3, 3, 5], override_groups_map=None, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[3, 3, 3, 5],
override_groups_map=None,
**kwargs)
def RepVGG_B3g2(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[3, 3, 3, 5],
override_groups_map=g2_map,
**kwargs)
def RepVGG_B3g4(**kwargs):
return RepVGG(num_blocks=[4, 6, 16, 1],
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, **kwargs)
return RepVGG(
num_blocks=[4, 6, 16, 1],
width_multiplier=[3, 3, 3, 5],
override_groups_map=g4_map,
**kwargs)
......@@ -69,24 +69,23 @@ def main(args, return_dict={}):
paddle.distributed.init_parallel_env()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
if config["use_data_parallel"]:
net = paddle.DataParallel(net)
init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)()
net.eval()
with paddle.no_grad():
if not multilabel:
top1_acc = program.run(valid_dataloader, config, net, None, None, 0,
'valid')
top1_acc = program.run(valid_dataloader, config, net, None, None,
0, 'valid')
return_dict["top1_acc"] = top1_acc
return top1_acc
else:
all_outs = []
targets = []
for idx, batch in enumerate(valid_dataloader()):
feeds = program.create_feeds(batch, False, config.classes_num, multilabel)
for _, batch in enumerate(valid_dataloader()):
feeds = program.create_feeds(batch, False, config.classes_num,
multilabel)
out = net(feeds["image"])
out = F.sigmoid(out)
......
......@@ -69,9 +69,10 @@ def main(args):
optimizer, lr_scheduler = program.create_optimizer(
config, parameter_list=net.parameters())
dp_net = net
if config["use_data_parallel"]:
find_unused_parameters = config.get("find_unused_parameters", False)
net = paddle.DataParallel(
dp_net = paddle.DataParallel(
net, find_unused_parameters=find_unused_parameters)
# load model from checkpoint or pretrained model
......@@ -96,8 +97,8 @@ def main(args):
for epoch_id in range(last_epoch_id + 1, config.epochs):
net.train()
# 1. train with train dataset
program.run(train_dataloader, config, net, optimizer, lr_scheduler,
epoch_id, 'train', vdl_writer)
program.run(train_dataloader, config, dp_net, optimizer,
lr_scheduler, epoch_id, 'train', vdl_writer)
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册