未验证 提交 ef83ab8a 编写于 作者: S shangliang Xu 提交者: GitHub

Add PP-YOLOv3 code (#5281)

* [ppyolov3] add ppyolov3 base code

* add ppyolov3 s/m/x

* modify ema

* modify code to convert onnx successfully

* support arbitrary shape

* update config to use amp default

* refine ppyolo_head code

* modify reparameter code

* refine act layer

* adapter pico_head and tood_head code

* remove ppyolov3 yaml

* fix codestyle
Co-authored-by: Nwangxinxin08 <wangxinxin08@baidu.com>
上级 629e1533
...@@ -1747,7 +1747,7 @@ class Mixup(BaseOperator): ...@@ -1747,7 +1747,7 @@ class Mixup(BaseOperator):
gt_score2 = np.ones_like(sample[1]['gt_class']) gt_score2 = np.ones_like(sample[1]['gt_class'])
gt_score = np.concatenate( gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
result['gt_score'] = gt_score result['gt_score'] = gt_score.astype('float32')
if 'is_crowd' in sample[0]: if 'is_crowd' in sample[0]:
is_crowd1 = sample[0]['is_crowd'] is_crowd1 = sample[0]['is_crowd']
is_crowd2 = sample[1]['is_crowd'] is_crowd2 = sample[1]['is_crowd']
......
...@@ -32,7 +32,10 @@ from ppdet.metrics import get_infer_results ...@@ -32,7 +32,10 @@ from ppdet.metrics import get_infer_results
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine') logger = setup_logger('ppdet.engine')
__all__ = ['Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer', 'VisualDLWriter', 'SniperProposalsGenerator'] __all__ = [
'Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer',
'VisualDLWriter', 'SniperProposalsGenerator'
]
class Callback(object): class Callback(object):
...@@ -202,8 +205,14 @@ class Checkpointer(Callback): ...@@ -202,8 +205,14 @@ class Checkpointer(Callback):
logger.info("Best test {} ap is {:0.3f}.".format( logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap)) key, self.best_ap))
if weight: if weight:
save_model(weight, self.model.optimizer, self.save_dir, if self.model.use_ema:
save_name, epoch_id + 1) save_model(status['weight'], self.save_dir, save_name,
epoch_id + 1, self.model.optimizer)
save_model(weight, self.save_dir,
'{}_ema'.format(save_name), epoch_id + 1)
else:
save_model(weight, self.save_dir, save_name, epoch_id + 1,
self.model.optimizer)
class WiferFaceEval(Callback): class WiferFaceEval(Callback):
......
...@@ -339,7 +339,8 @@ class Trainer(object): ...@@ -339,7 +339,8 @@ class Trainer(object):
self.start_epoch = load_weight(self.model.student_model, weights, self.start_epoch = load_weight(self.model.student_model, weights,
self.optimizer) self.optimizer)
else: else:
self.start_epoch = load_weight(self.model, weights, self.optimizer) self.start_epoch = load_weight(self.model, weights, self.optimizer,
self.ema if self.use_ema else None)
logger.debug("Resume weights of epoch {}".format(self.start_epoch)) logger.debug("Resume weights of epoch {}".format(self.start_epoch))
def train(self, validate=False): def train(self, validate=False):
...@@ -432,21 +433,23 @@ class Trainer(object): ...@@ -432,21 +433,23 @@ class Trainer(object):
self.status['batch_time'].update(time.time() - iter_tic) self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status) self._compose_callback.on_step_end(self.status)
if self.use_ema: if self.use_ema:
self.ema.update(self.model) self.ema.update()
iter_tic = time.time() iter_tic = time.time()
if self.cfg.get('unstructured_prune'):
self.pruner.update_params()
is_snapshot = (self._nranks < 2 or self._local_rank == 0) \
and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
if is_snapshot and self.use_ema:
# apply ema weight on model # apply ema weight on model
if self.use_ema:
weight = copy.deepcopy(self.model.state_dict()) weight = copy.deepcopy(self.model.state_dict())
self.model.set_dict(self.ema.apply()) self.model.set_dict(self.ema.apply())
if self.cfg.get('unstructured_prune'): self.status['weight'] = weight
self.pruner.update_params()
self._compose_callback.on_epoch_end(self.status) self._compose_callback.on_epoch_end(self.status)
if validate and (self._nranks < 2 or self._local_rank == 0) \ if validate and is_snapshot:
and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
or epoch_id == self.end_epoch - 1):
if not hasattr(self, '_eval_loader'): if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader # build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset self._eval_dataset = self.cfg.EvalDataset
...@@ -467,13 +470,15 @@ class Trainer(object): ...@@ -467,13 +470,15 @@ class Trainer(object):
Init_mark = True Init_mark = True
self._init_metrics(validate=validate) self._init_metrics(validate=validate)
self._reset_metrics() self._reset_metrics()
with paddle.no_grad(): with paddle.no_grad():
self.status['save_best_model'] = True self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader) self._eval_with_loader(self._eval_loader)
# restore origin weight on model if is_snapshot and self.use_ema:
if self.use_ema: # reset original weight
self.model.set_dict(weight) self.model.set_dict(weight)
self.status.pop('weight')
self._compose_callback.on_train_end(self.status) self._compose_callback.on_train_end(self.status)
...@@ -634,6 +639,11 @@ class Trainer(object): ...@@ -634,6 +639,11 @@ class Trainer(object):
if hasattr(self.model, 'deploy'): if hasattr(self.model, 'deploy'):
self.model.deploy = True self.model.deploy = True
for layer in self.model.sublayers():
if hasattr(layer, 'convert_to_deploy'):
layer.convert_to_deploy()
export_post_process = self.cfg.get('export_post_process', False) export_post_process = self.cfg.get('export_post_process', False)
if hasattr(self.model, 'export_post_process'): if hasattr(self.model, 'export_post_process'):
self.model.export_post_process = export_post_process self.model.export_post_process = export_post_process
......
...@@ -109,10 +109,14 @@ class YOLOv3(BaseArch): ...@@ -109,10 +109,14 @@ class YOLOv3(BaseArch):
if self.return_idx: if self.return_idx:
_, bbox, bbox_num, _ = self.post_process( _, bbox, bbox_num, _ = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors) yolo_head_outs, self.yolo_head.mask_anchors)
else: elif self.post_process is not None:
bbox, bbox_num = self.post_process( bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors, yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor']) self.inputs['im_shape'], self.inputs['scale_factor'])
else:
bbox, bbox_num = self.yolo_head.post_process(
yolo_head_outs, self.inputs['im_shape'],
self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num} output = {'bbox': bbox, 'bbox_num': bbox_num}
return output return output
......
...@@ -107,11 +107,15 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9): ...@@ -107,11 +107,15 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
return is_in_topk.astype(metrics.dtype) return is_in_topk.astype(metrics.dtype)
def check_points_inside_bboxes(points, bboxes, eps=1e-9): def check_points_inside_bboxes(points,
bboxes,
center_radius_tensor=None,
eps=1e-9):
r""" r"""
Args: Args:
points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors
bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format
center_radius_tensor (Tensor, float32): shape [L, 1] Default: None.
eps (float): Default: 1e-9 eps (float): Default: 1e-9
Returns: Returns:
is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected
...@@ -119,6 +123,19 @@ def check_points_inside_bboxes(points, bboxes, eps=1e-9): ...@@ -119,6 +123,19 @@ def check_points_inside_bboxes(points, bboxes, eps=1e-9):
points = points.unsqueeze([0, 1]) points = points.unsqueeze([0, 1])
x, y = points.chunk(2, axis=-1) x, y = points.chunk(2, axis=-1)
xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1) xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1)
if center_radius_tensor is not None:
center_radius_tensor = center_radius_tensor.unsqueeze([0, 1])
bboxes_cx = (xmin + xmax) / 2.
bboxes_cy = (ymin + ymax) / 2.
xmin_sampling = bboxes_cx - center_radius_tensor
ymin_sampling = bboxes_cy - center_radius_tensor
xmax_sampling = bboxes_cx + center_radius_tensor
ymax_sampling = bboxes_cy + center_radius_tensor
xmin = paddle.maximum(xmin, xmin_sampling)
ymin = paddle.maximum(ymin, ymin_sampling)
xmax = paddle.minimum(xmax, xmax_sampling)
ymax = paddle.minimum(ymax, ymax_sampling)
l = x - xmin l = x - xmin
t = y - ymin t = y - ymin
r = xmax - x r = xmax - x
...@@ -167,14 +184,16 @@ def generate_anchors_for_grid_cell(feats, ...@@ -167,14 +184,16 @@ def generate_anchors_for_grid_cell(feats,
grid_cell_size (float): anchor size grid_cell_size (float): anchor size
grid_cell_offset (float): The range is between 0 and 1. grid_cell_offset (float): The range is between 0 and 1.
Returns: Returns:
anchors (List[Tensor]): shape[s, (l, 4)] anchors (Tensor): shape[l, 4], "xmin, ymin, xmax, ymax" format.
num_anchors_list (List[int]): shape[s] anchor_points (Tensor): shape[l, 2], "x, y" format.
stride_tensor_list (List[Tensor]): shape[s, (l, 1)] num_anchors_list (List[int]): shape[s], contains [s_1, s_2, ...].
stride_tensor (Tensor): shape[l, 1], contains the stride for each scale.
""" """
assert len(feats) == len(fpn_strides) assert len(feats) == len(fpn_strides)
anchors = [] anchors = []
anchor_points = []
num_anchors_list = [] num_anchors_list = []
stride_tensor_list = [] stride_tensor = []
for feat, stride in zip(feats, fpn_strides): for feat, stride in zip(feats, fpn_strides):
_, _, h, w = feat.shape _, _, h, w = feat.shape
cell_half_size = grid_cell_size * stride * 0.5 cell_half_size = grid_cell_size * stride * 0.5
...@@ -187,8 +206,19 @@ def generate_anchors_for_grid_cell(feats, ...@@ -187,8 +206,19 @@ def generate_anchors_for_grid_cell(feats,
shift_x + cell_half_size, shift_y + cell_half_size shift_x + cell_half_size, shift_y + cell_half_size
], ],
axis=-1).astype(feat.dtype) axis=-1).astype(feat.dtype)
anchor_point = paddle.stack(
[shift_x, shift_y], axis=-1).astype(feat.dtype)
anchors.append(anchor.reshape([-1, 4])) anchors.append(anchor.reshape([-1, 4]))
anchor_points.append(anchor_point.reshape([-1, 2]))
num_anchors_list.append(len(anchors[-1])) num_anchors_list.append(len(anchors[-1]))
stride_tensor_list.append( stride_tensor.append(
paddle.full([num_anchors_list[-1], 1], stride)) paddle.full(
return anchors, num_anchors_list, stride_tensor_list [num_anchors_list[-1], 1], stride, dtype=feat.dtype))
anchors = paddle.concat(anchors)
anchors.stop_gradient = True
anchor_points = paddle.concat(anchor_points)
anchor_points.stop_gradient = True
stride_tensor = paddle.concat(stride_tensor)
stride_tensor.stop_gradient = True
return anchors, anchor_points, num_anchors_list, stride_tensor
...@@ -29,6 +29,7 @@ from . import swin_transformer ...@@ -29,6 +29,7 @@ from . import swin_transformer
from . import lcnet from . import lcnet
from . import hardnet from . import hardnet
from . import esnet from . import esnet
from . import cspresnet
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
...@@ -47,3 +48,4 @@ from .swin_transformer import * ...@@ -47,3 +48,4 @@ from .swin_transformer import *
from .lcnet import * from .lcnet import *
from .hardnet import * from .hardnet import *
from .esnet import * from .esnet import *
from .cspresnet import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.modeling.ops import get_act_fn
from ppdet.core.workspace import register, serializable
from ..shape_spec import ShapeSpec
__all__ = ['CSPResNet', 'BasicBlock', 'EffectiveSELayer', 'ConvBNLayer']
class ConvBNLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act=None):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=False)
self.bn = nn.BatchNorm2D(
ch_out,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
return x
class RepVggBlock(nn.Layer):
def __init__(self, ch_in, ch_out, act='relu'):
super(RepVggBlock, self).__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = ConvBNLayer(
ch_in, ch_out, 3, stride=1, padding=1, act=None)
self.conv2 = ConvBNLayer(
ch_in, ch_out, 1, stride=1, padding=0, act=None)
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
def forward(self, x):
if hasattr(self, 'conv'):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
y = self.act(y)
return y
def convert_to_deploy(self):
if not hasattr(self, 'conv'):
self.conv = nn.Conv2D(
in_channels=self.ch_in,
out_channels=self.ch_out,
kernel_size=3,
stride=1,
padding=1,
groups=1)
kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.set_value(kernel)
self.conv.bias.set_value(bias)
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1), bias3x3 + bias1x1
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
class BasicBlock(nn.Layer):
def __init__(self, ch_in, ch_out, act='relu', shortcut=True):
super(BasicBlock, self).__init__()
assert ch_in == ch_out
self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act)
self.conv2 = RepVggBlock(ch_out, ch_out, act=act)
self.shortcut = shortcut
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
if self.shortcut:
return paddle.add(x, y)
else:
return y
class EffectiveSELayer(nn.Layer):
""" Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, act='hardsigmoid'):
super(EffectiveSELayer, self).__init__()
self.fc = nn.Conv2D(channels, channels, kernel_size=1, padding=0)
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.act(x_se)
class CSPResStage(nn.Layer):
def __init__(self,
block_fn,
ch_in,
ch_out,
n,
stride,
act='relu',
attn='eca'):
super(CSPResStage, self).__init__()
ch_mid = (ch_in + ch_out) // 2
if stride == 2:
self.conv_down = ConvBNLayer(
ch_in, ch_mid, 3, stride=2, padding=1, act=act)
else:
self.conv_down = None
self.conv1 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
self.blocks = nn.Sequential(* [
block_fn(
ch_mid // 2, ch_mid // 2, act=act, shortcut=True)
for i in range(n)
])
if attn:
self.attn = EffectiveSELayer(ch_mid, act='hardsigmoid')
else:
self.attn = None
self.conv3 = ConvBNLayer(ch_mid, ch_out, 1, act=act)
def forward(self, x):
if self.conv_down is not None:
x = self.conv_down(x)
y1 = self.conv1(x)
y2 = self.blocks(self.conv2(x))
y = paddle.concat([y1, y2], axis=1)
if self.attn is not None:
y = self.attn(y)
y = self.conv3(y)
return y
@register
@serializable
class CSPResNet(nn.Layer):
__shared__ = ['width_mult', 'depth_mult', 'trt']
def __init__(self,
layers=[3, 6, 6, 3],
channels=[64, 128, 256, 512, 1024],
act='swish',
return_idx=[0, 1, 2, 3, 4],
depth_wise=False,
use_large_stem=False,
width_mult=1.0,
depth_mult=1.0,
trt=False):
super(CSPResNet, self).__init__()
channels = [max(round(c * width_mult), 1) for c in channels]
layers = [max(round(l * depth_mult), 1) for l in layers]
act = get_act_fn(
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
if use_large_stem:
self.stem = nn.Sequential(
('conv1', ConvBNLayer(
3, channels[0] // 2, 3, stride=2, padding=1, act=act)),
('conv2', ConvBNLayer(
channels[0] // 2,
channels[0] // 2,
3,
stride=1,
padding=1,
act=act)), ('conv3', ConvBNLayer(
channels[0] // 2,
channels[0],
3,
stride=1,
padding=1,
act=act)))
else:
self.stem = nn.Sequential(
('conv1', ConvBNLayer(
3, channels[0] // 2, 3, stride=2, padding=1, act=act)),
('conv2', ConvBNLayer(
channels[0] // 2,
channels[0],
3,
stride=1,
padding=1,
act=act)))
n = len(channels) - 1
self.stages = nn.Sequential(* [(str(i), CSPResStage(
BasicBlock, channels[i], channels[i + 1], layers[i], 2, act=act))
for i in range(n)])
self._out_channels = channels[1:]
self._out_strides = [4, 8, 16, 32]
self.return_idx = return_idx
def forward(self, inputs):
x = inputs['image']
x = self.stem(x)
outs = []
for idx, stage in enumerate(self.stages):
x = stage(x)
if idx in self.return_idx:
outs.append(x)
return outs
@property
def out_shape(self):
return [
ShapeSpec(
channels=self._out_channels[i], stride=self._out_strides[i])
for i in self.return_idx
]
...@@ -77,8 +77,8 @@ class ConvBNLayer(nn.Layer): ...@@ -77,8 +77,8 @@ class ConvBNLayer(nn.Layer):
out = self.batch_norm(out) out = self.batch_norm(out)
if self.act == 'leaky': if self.act == 'leaky':
out = F.leaky_relu(out, 0.1) out = F.leaky_relu(out, 0.1)
elif self.act == 'mish': else:
out = mish(out) out = getattr(F, self.act)(out)
return out return out
...@@ -156,7 +156,7 @@ class BasicBlock(nn.Layer): ...@@ -156,7 +156,7 @@ class BasicBlock(nn.Layer):
# channel route: 10-->5 --> 5-->10 # channel route: 10-->5 --> 5-->10
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=int(ch_out/2), ch_out=int(ch_out / 2),
filter_size=1, filter_size=1,
stride=1, stride=1,
padding=0, padding=0,
...@@ -165,8 +165,8 @@ class BasicBlock(nn.Layer): ...@@ -165,8 +165,8 @@ class BasicBlock(nn.Layer):
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
data_format=data_format) data_format=data_format)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
ch_in=int(ch_out/2), ch_in=int(ch_out / 2),
ch_out=ch_out , ch_out=ch_out,
filter_size=3, filter_size=3,
stride=1, stride=1,
padding=1, padding=1,
...@@ -317,7 +317,7 @@ class DarkNet(nn.Layer): ...@@ -317,7 +317,7 @@ class DarkNet(nn.Layer):
down_name, down_name,
DownSample( DownSample(
ch_in=int(ch_in[i]), ch_in=int(ch_in[i]),
ch_out=int(ch_in[i+1]), ch_out=int(ch_in[i + 1]),
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
......
...@@ -744,9 +744,9 @@ def distance2bbox(points, distance, max_shape=None): ...@@ -744,9 +744,9 @@ def distance2bbox(points, distance, max_shape=None):
def bbox_center(boxes): def bbox_center(boxes):
"""Get bbox centers from boxes. """Get bbox centers from boxes.
Args: Args:
boxes (Tensor): boxes with shape (N, 4), "xmin, ymin, xmax, ymax" format. boxes (Tensor): boxes with shape (..., 4), "xmin, ymin, xmax, ymax" format.
Returns: Returns:
Tensor: boxes centers with shape (N, 2), "cx, cy" format. Tensor: boxes centers with shape (..., 2), "cx, cy" format.
""" """
boxes_cx = (boxes[..., 0] + boxes[..., 2]) / 2 boxes_cx = (boxes[..., 0] + boxes[..., 2]) / 2
boxes_cy = (boxes[..., 1] + boxes[..., 3]) / 2 boxes_cy = (boxes[..., 1] + boxes[..., 3]) / 2
...@@ -782,7 +782,7 @@ def delta2bbox_v2(rois, ...@@ -782,7 +782,7 @@ def delta2bbox_v2(rois,
means=(0.0, 0.0, 0.0, 0.0), means=(0.0, 0.0, 0.0, 0.0),
stds=(1.0, 1.0, 1.0, 1.0), stds=(1.0, 1.0, 1.0, 1.0),
max_shape=None, max_shape=None,
wh_ratio_clip=16.0/1000.0, wh_ratio_clip=16.0 / 1000.0,
ctr_clip=None): ctr_clip=None):
"""Transform network output(delta) to bboxes. """Transform network output(delta) to bboxes.
Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/ Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/
......
...@@ -32,6 +32,7 @@ from . import detr_head ...@@ -32,6 +32,7 @@ from . import detr_head
from . import sparsercnn_head from . import sparsercnn_head
from . import tood_head from . import tood_head
from . import retina_head from . import retina_head
from . import ppyolo_head
from .bbox_head import * from .bbox_head import *
from .mask_head import * from .mask_head import *
...@@ -53,3 +54,4 @@ from .detr_head import * ...@@ -53,3 +54,4 @@ from .detr_head import *
from .sparsercnn_head import * from .sparsercnn_head import *
from .tood_head import * from .tood_head import *
from .retina_head import * from .retina_head import *
from .ppyolo_head import *
...@@ -464,12 +464,13 @@ class PicoHeadV2(GFLHead): ...@@ -464,12 +464,13 @@ class PicoHeadV2(GFLHead):
assert len(fpn_feats) == len( assert len(fpn_feats) == len(
self.fpn_stride self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride" ), "The size of fpn_feats is not equal to size of fpn_stride"
anchors, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell( anchors, _, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell(
fpn_feats, self.fpn_stride, self.grid_cell_scale, self.cell_offset) fpn_feats, self.fpn_stride, self.grid_cell_scale, self.cell_offset)
anchors_split = paddle.split(anchors, num_anchors_list)
cls_score_list, reg_list, box_list = [], [], [] cls_score_list, reg_list, box_list = [], [], []
for i, fpn_feat, anchor, stride, align_cls in zip( for i, fpn_feat, anchor, stride, align_cls in zip(
range(len(self.fpn_stride)), fpn_feats, anchors, range(len(self.fpn_stride)), fpn_feats, anchors_split,
self.fpn_stride, self.cls_align): self.fpn_stride, self.cls_align):
b, _, h, w = get_static_shape(fpn_feat) b, _, h, w = get_static_shape(fpn_feat)
# task decomposition # task decomposition
...@@ -507,10 +508,6 @@ class PicoHeadV2(GFLHead): ...@@ -507,10 +508,6 @@ class PicoHeadV2(GFLHead):
cls_score_list = paddle.concat(cls_score_list, axis=1) cls_score_list = paddle.concat(cls_score_list, axis=1)
box_list = paddle.concat(box_list, axis=1) box_list = paddle.concat(box_list, axis=1)
reg_list = paddle.concat(reg_list, axis=1) reg_list = paddle.concat(reg_list, axis=1)
anchors = paddle.concat(anchors)
anchors.stop_gradient = True
stride_tensor_list = paddle.concat(stride_tensor_list)
stride_tensor_list.stop_gradient = True
return cls_score_list, reg_list, box_list, anchors, num_anchors_list, stride_tensor_list return cls_score_list, reg_list, box_list, anchors, num_anchors_list, stride_tensor_list
def get_loss(self, head_outs, gt_meta): def get_loss(self, head_outs, gt_meta):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import batch_distance2bbox
from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn
__all__ = ['PPYOLOHead']
class ESEAttn(nn.Layer):
def __init__(self, feat_channels, act='swish'):
super(ESEAttn, self).__init__()
self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
self._init_weights()
def _init_weights(self):
normal_(self.fc.weight, std=0.001)
def forward(self, feat, avg_feat):
weight = F.sigmoid(self.fc(avg_feat))
return self.conv(feat * weight)
@register
class PPYOLOHead(nn.Layer):
__shared__ = ['num_classes', 'trt']
__inject__ = ['static_assigner', 'assigner', 'nms']
def __init__(self,
in_channels=[1024, 512, 256],
num_classes=80,
act='swish',
fpn_strides=(32, 16, 8),
grid_cell_scale=5.0,
grid_cell_offset=0.5,
reg_max=16,
static_assigner_epoch=4,
use_varifocal_loss=True,
static_assigner='ATSSAssigner',
assigner='TaskAlignedAssigner',
nms='MultiClassNMS',
eval_input_size=[],
loss_weight={
'class': 1.0,
'iou': 2.5,
'dfl': 0.5,
},
trt=False):
super(PPYOLOHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels
self.num_classes = num_classes
self.fpn_strides = fpn_strides
self.grid_cell_scale = grid_cell_scale
self.grid_cell_offset = grid_cell_offset
self.reg_max = reg_max
self.iou_loss = GIoULoss()
self.loss_weight = loss_weight
self.use_varifocal_loss = use_varifocal_loss
self.eval_input_size = eval_input_size
self.static_assigner_epoch = static_assigner_epoch
self.static_assigner = static_assigner
self.assigner = assigner
self.nms = nms
# stem
self.stem_cls = nn.LayerList()
self.stem_reg = nn.LayerList()
act = get_act_fn(
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
for in_c in self.in_channels:
self.stem_cls.append(ESEAttn(in_c, act=act))
self.stem_reg.append(ESEAttn(in_c, act=act))
# pred head
self.pred_cls = nn.LayerList()
self.pred_reg = nn.LayerList()
for in_c in self.in_channels:
self.pred_cls.append(
nn.Conv2D(
in_c, self.num_classes, 3, padding=1))
self.pred_reg.append(
nn.Conv2D(
in_c, 4 * (self.reg_max + 1), 3, padding=1))
# projection conv
self.proj_conv = nn.Conv2D(self.reg_max + 1, 1, 1, bias_attr=False)
self._init_weights()
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
def _init_weights(self):
bias_cls = bias_init_with_prob(0.01)
for cls_, reg_ in zip(self.pred_cls, self.pred_reg):
constant_(cls_.weight)
constant_(cls_.bias, bias_cls)
constant_(reg_.weight)
constant_(reg_.bias, 1.0)
self.proj = paddle.linspace(0, self.reg_max, self.reg_max + 1)
self.proj_conv.weight.set_value(
self.proj.reshape([1, self.reg_max + 1, 1, 1]))
self.proj_conv.weight.stop_gradient = True
if self.eval_input_size:
anchor_points, stride_tensor = self._generate_anchors()
self.register_buffer('anchor_points', anchor_points)
self.register_buffer('stride_tensor', stride_tensor)
def forward_train(self, feats, targets):
anchors, anchor_points, num_anchors_list, stride_tensor = \
generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale,
self.grid_cell_offset)
cls_score_list, reg_distri_list = [], []
for i, feat in enumerate(feats):
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
feat)
reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
# cls and reg
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1]))
cls_score_list = paddle.concat(cls_score_list, axis=1)
reg_distri_list = paddle.concat(reg_distri_list, axis=1)
return self.get_loss([
cls_score_list, reg_distri_list, anchors, anchor_points,
num_anchors_list, stride_tensor
], targets)
def _generate_anchors(self, feats=None):
# just use in eval time
anchor_points = []
stride_tensor = []
for i, stride in enumerate(self.fpn_strides):
if feats is not None:
_, _, h, w = feats[i].shape
else:
h = int(self.eval_input_size[0] / stride)
w = int(self.eval_input_size[1] / stride)
shift_x = paddle.arange(end=w) + self.grid_cell_offset
shift_y = paddle.arange(end=h) + self.grid_cell_offset
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
anchor_point = paddle.cast(
paddle.stack(
[shift_x, shift_y], axis=-1), dtype='float32')
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(
paddle.full(
[h * w, 1], stride, dtype='float32'))
anchor_points = paddle.concat(anchor_points)
stride_tensor = paddle.concat(stride_tensor)
return anchor_points, stride_tensor
def forward_eval(self, feats):
if self.eval_input_size:
anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
else:
anchor_points, stride_tensor = self._generate_anchors(feats)
cls_score_list, reg_dist_list = [], []
for i, feat in enumerate(feats):
b, _, h, w = feat.shape
l = h * w
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
feat)
reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose(
[0, 2, 1, 3])
reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1))
# cls and reg
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.reshape([b, self.num_classes, l]))
reg_dist_list.append(reg_dist.reshape([b, 4, l]))
cls_score_list = paddle.concat(cls_score_list, axis=-1)
reg_dist_list = paddle.concat(reg_dist_list, axis=-1)
return cls_score_list, reg_dist_list, anchor_points, stride_tensor
def forward(self, feats, targets=None):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
if self.training:
return self.forward_train(feats, targets)
else:
return self.forward_eval(feats)
@staticmethod
def _focal_loss(score, label, alpha=0.25, gamma=2.0):
weight = (score - label).pow(gamma)
if alpha > 0:
alpha_t = alpha * label + (1 - alpha) * (1 - label)
weight *= alpha_t
loss = F.binary_cross_entropy(
score, label, weight=weight, reduction='sum')
return loss
@staticmethod
def _varifocal_loss(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
loss = F.binary_cross_entropy(
pred_score, gt_score, weight=weight, reduction='sum')
return loss
def _bbox_decode(self, anchor_points, pred_dist):
b, l, _ = get_static_shape(pred_dist)
pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1
])).matmul(self.proj)
return batch_distance2bbox(anchor_points, pred_dist)
def _bbox2distance(self, points, bbox):
x1y1, x2y2 = paddle.split(bbox, 2, -1)
lt = points - x1y1
rb = x2y2 - points
return paddle.concat([lt, rb], -1).clip(0, self.reg_max - 0.01)
def _df_loss(self, pred_dist, target):
target_left = paddle.cast(target, 'int64')
target_right = target_left + 1
weight_left = target_right.astype('float32') - target
weight_right = 1 - weight_left
loss_left = F.cross_entropy(
pred_dist, target_left, reduction='none') * weight_left
loss_right = F.cross_entropy(
pred_dist, target_right, reduction='none') * weight_right
return (loss_left + loss_right).mean(-1, keepdim=True)
def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels,
assigned_bboxes, assigned_scores, assigned_scores_sum):
# select positive samples mask
mask_positive = (assigned_labels != self.num_classes)
num_pos = mask_positive.sum()
# pos/neg loss
if num_pos > 0:
# l1 + iou
bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
pred_bboxes_pos = paddle.masked_select(pred_bboxes,
bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = paddle.masked_select(
assigned_bboxes, bbox_mask).reshape([-1, 4])
bbox_weight = paddle.masked_select(
assigned_scores.sum(-1), mask_positive).unsqueeze(-1)
loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos)
loss_iou = self.iou_loss(pred_bboxes_pos,
assigned_bboxes_pos) * bbox_weight
loss_iou = loss_iou.sum() / assigned_scores_sum
dist_mask = mask_positive.unsqueeze(-1).tile(
[1, 1, (self.reg_max + 1) * 4])
pred_dist_pos = paddle.masked_select(
pred_dist, dist_mask).reshape([-1, 4, self.reg_max + 1])
assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes)
assigned_ltrb_pos = paddle.masked_select(
assigned_ltrb, bbox_mask).reshape([-1, 4])
loss_dfl = self._df_loss(pred_dist_pos,
assigned_ltrb_pos) * bbox_weight
loss_dfl = loss_dfl.sum() / assigned_scores_sum
else:
loss_l1 = paddle.zeros([1])
loss_iou = paddle.zeros([1])
loss_dfl = paddle.zeros([1])
return loss_l1, loss_iou, loss_dfl
def get_loss(self, head_outs, gt_meta):
pred_scores, pred_distri, anchors,\
anchor_points, num_anchors_list, stride_tensor = head_outs
anchor_points_s = anchor_points / stride_tensor
pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \
self.static_assigner(
anchors,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes,
pred_bboxes=pred_bboxes.detach() * stride_tensor)
alpha_l = 0.25
else:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
alpha_l = -1
# rescale bbox
assigned_bboxes /= stride_tensor
# cls loss
if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels, self.num_classes)
loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
one_hot_label)
else:
loss_cls = self._focal_loss(
pred_scores, assigned_scores, alpha=alpha_l)
assigned_scores_sum = assigned_scores.sum()
if paddle_distributed_is_initialized():
paddle.distributed.all_reduce(assigned_scores_sum)
assigned_scores_sum = paddle.clip(
assigned_scores_sum / paddle.distributed.get_world_size(),
min=1)
loss_cls /= assigned_scores_sum
loss_l1, loss_iou, loss_dfl = \
self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores,
assigned_scores_sum)
loss = self.loss_weight['class'] * loss_cls + \
self.loss_weight['iou'] * loss_iou + \
self.loss_weight['dfl'] * loss_dfl
out_dict = {
'loss': loss,
'loss_cls': loss_cls,
'loss_iou': loss_iou,
'loss_dfl': loss_dfl,
'loss_l1': loss_l1,
}
return out_dict
def post_process(self, head_outs, img_shape, scale_factor):
pred_scores, pred_dist, anchor_points, stride_tensor = head_outs
pred_bboxes = batch_distance2bbox(anchor_points,
pred_dist.transpose([0, 2, 1]))
pred_bboxes *= stride_tensor
# scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
scale_factor = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4])
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
...@@ -218,13 +218,17 @@ class TOODHead(nn.Layer): ...@@ -218,13 +218,17 @@ class TOODHead(nn.Layer):
assert len(feats) == len(self.fpn_strides), \ assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides" "The size of feats is not equal to size of fpn_strides"
anchors, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell( anchors, anchor_points, num_anchors_list, stride_tensor =\
generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale, feats, self.fpn_strides, self.grid_cell_scale,
self.grid_cell_offset) self.grid_cell_offset)
anchor_centers_split = paddle.split(anchor_points / stride_tensor,
num_anchors_list)
cls_score_list, bbox_pred_list = [], [] cls_score_list, bbox_pred_list = [], []
for feat, scale_reg, anchor, stride in zip(feats, self.scales_regs, for feat, scale_reg, anchor_centers, stride in zip(
anchors, self.fpn_strides): feats, self.scales_regs, anchor_centers_split,
self.fpn_strides):
b, _, h, w = get_static_shape(feat) b, _, h, w = get_static_shape(feat)
inter_feats = [] inter_feats = []
for inter_conv in self.inter_convs: for inter_conv in self.inter_convs:
...@@ -250,8 +254,8 @@ class TOODHead(nn.Layer): ...@@ -250,8 +254,8 @@ class TOODHead(nn.Layer):
# reg prediction and alignment # reg prediction and alignment
reg_dist = scale_reg(self.tood_reg(reg_feat).exp()) reg_dist = scale_reg(self.tood_reg(reg_feat).exp())
reg_dist = reg_dist.flatten(2).transpose([0, 2, 1]) reg_dist = reg_dist.flatten(2).transpose([0, 2, 1])
anchor_centers = bbox_center(anchor).unsqueeze(0) / stride reg_bbox = batch_distance2bbox(
reg_bbox = batch_distance2bbox(anchor_centers, reg_dist) anchor_centers.unsqueeze(0), reg_dist)
if self.use_align_head: if self.use_align_head:
reg_offset = F.relu(self.reg_offset_conv1(feat)) reg_offset = F.relu(self.reg_offset_conv1(feat))
reg_offset = self.reg_offset_conv2(reg_offset) reg_offset = self.reg_offset_conv2(reg_offset)
...@@ -268,12 +272,8 @@ class TOODHead(nn.Layer): ...@@ -268,12 +272,8 @@ class TOODHead(nn.Layer):
bbox_pred_list.append(bbox_pred) bbox_pred_list.append(bbox_pred)
cls_score_list = paddle.concat(cls_score_list, axis=1) cls_score_list = paddle.concat(cls_score_list, axis=1)
bbox_pred_list = paddle.concat(bbox_pred_list, axis=1) bbox_pred_list = paddle.concat(bbox_pred_list, axis=1)
anchors = paddle.concat(anchors)
anchors.stop_gradient = True
stride_tensor_list = paddle.concat(stride_tensor_list).unsqueeze(0)
stride_tensor_list.stop_gradient = True
return cls_score_list, bbox_pred_list, anchors, num_anchors_list, stride_tensor_list return cls_score_list, bbox_pred_list, anchors, num_anchors_list, stride_tensor
@staticmethod @staticmethod
def _focal_loss(score, label, alpha=0.25, gamma=2.0): def _focal_loss(score, label, alpha=0.25, gamma=2.0):
...@@ -287,7 +287,7 @@ class TOODHead(nn.Layer): ...@@ -287,7 +287,7 @@ class TOODHead(nn.Layer):
def get_loss(self, head_outs, gt_meta): def get_loss(self, head_outs, gt_meta):
pred_scores, pred_bboxes, anchors, \ pred_scores, pred_bboxes, anchors, \
num_anchors_list, stride_tensor_list = head_outs num_anchors_list, stride_tensor = head_outs
gt_labels = gt_meta['gt_class'] gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox'] gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask'] pad_gt_mask = gt_meta['pad_gt_mask']
...@@ -304,7 +304,7 @@ class TOODHead(nn.Layer): ...@@ -304,7 +304,7 @@ class TOODHead(nn.Layer):
else: else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner( assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
pred_scores.detach(), pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list, pred_bboxes.detach() * stride_tensor,
bbox_center(anchors), bbox_center(anchors),
num_anchors_list, num_anchors_list,
gt_labels, gt_labels,
...@@ -314,7 +314,7 @@ class TOODHead(nn.Layer): ...@@ -314,7 +314,7 @@ class TOODHead(nn.Layer):
alpha_l = -1 alpha_l = -1
# rescale bbox # rescale bbox
assigned_bboxes /= stride_tensor_list assigned_bboxes /= stride_tensor
# classification loss # classification loss
loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha=alpha_l) loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha=alpha_l)
# select positive samples mask # select positive samples mask
......
...@@ -251,7 +251,7 @@ class LiteConv(nn.Layer): ...@@ -251,7 +251,7 @@ class LiteConv(nn.Layer):
class DropBlock(nn.Layer): class DropBlock(nn.Layer):
def __init__(self, block_size, keep_prob, name, data_format='NCHW'): def __init__(self, block_size, keep_prob, name=None, data_format='NCHW'):
""" """
DropBlock layer, see https://arxiv.org/abs/1810.12890 DropBlock layer, see https://arxiv.org/abs/1810.12890
......
...@@ -20,6 +20,7 @@ from . import centernet_fpn ...@@ -20,6 +20,7 @@ from . import centernet_fpn
from . import bifpn from . import bifpn
from . import csp_pan from . import csp_pan
from . import es_pan from . import es_pan
from . import custom_pan
from .fpn import * from .fpn import *
from .yolo_fpn import * from .yolo_fpn import *
...@@ -30,3 +31,4 @@ from .blazeface_fpn import * ...@@ -30,3 +31,4 @@ from .blazeface_fpn import *
from .bifpn import * from .bifpn import *
from .csp_pan import * from .csp_pan import *
from .es_pan import * from .es_pan import *
from .custom_pan import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import DropBlock
from ppdet.modeling.ops import get_act_fn
from ..backbones.cspresnet import ConvBNLayer, BasicBlock
from ..shape_spec import ShapeSpec
__all__ = ['CustomCSPPAN']
class SPP(nn.Layer):
def __init__(self,
ch_in,
ch_out,
k,
pool_size,
act='swish',
data_format='NCHW'):
super(SPP, self).__init__()
self.pool = []
self.data_format = data_format
for i, size in enumerate(pool_size):
pool = self.add_sublayer(
'pool{}'.format(i),
nn.MaxPool2D(
kernel_size=size,
stride=1,
padding=size // 2,
data_format=data_format,
ceil_mode=False))
self.pool.append(pool)
self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act)
def forward(self, x):
outs = [x]
for pool in self.pool:
outs.append(pool(x))
if self.data_format == 'NCHW':
y = paddle.concat(outs, axis=1)
else:
y = paddle.concat(outs, axis=-1)
y = self.conv(y)
return y
class CSPStage(nn.Layer):
def __init__(self, block_fn, ch_in, ch_out, n, act='swish', spp=False):
super(CSPStage, self).__init__()
ch_mid = int(ch_out // 2)
self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
self.convs = nn.Sequential()
next_ch_in = ch_mid
for i in range(n):
self.convs.add_sublayer(
str(i),
eval(block_fn)(next_ch_in, ch_mid, act=act, shortcut=False))
if i == (n - 1) // 2 and spp:
self.convs.add_sublayer(
'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
next_ch_in = ch_mid
self.conv3 = ConvBNLayer(ch_mid * 2, ch_out, 1, act=act)
def forward(self, x):
y1 = self.conv1(x)
y2 = self.conv2(x)
y2 = self.convs(y2)
y = paddle.concat([y1, y2], axis=1)
y = self.conv3(y)
return y
@register
@serializable
class CustomCSPPAN(nn.Layer):
__shared__ = ['norm_type', 'data_format', 'width_mult', 'depth_mult', 'trt']
def __init__(self,
in_channels=[256, 512, 1024],
out_channels=[1024, 512, 256],
norm_type='bn',
act='leaky',
stage_fn='CSPStage',
block_fn='BasicBlock',
stage_num=1,
block_num=3,
drop_block=False,
block_size=3,
keep_prob=0.9,
spp=False,
data_format='NCHW',
width_mult=1.0,
depth_mult=1.0,
trt=False):
super(CustomCSPPAN, self).__init__()
out_channels = [max(round(c * width_mult), 1) for c in out_channels]
block_num = max(round(block_num * depth_mult), 1)
act = get_act_fn(
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
self.num_blocks = len(in_channels)
self.data_format = data_format
self._out_channels = out_channels
in_channels = in_channels[::-1]
fpn_stages = []
fpn_routes = []
for i, (ch_in, ch_out) in enumerate(zip(in_channels, out_channels)):
if i > 0:
ch_in += ch_pre // 2
stage = nn.Sequential()
for j in range(stage_num):
stage.add_sublayer(
str(j),
eval(stage_fn)(block_fn,
ch_in if j == 0 else ch_out,
ch_out,
block_num,
act=act,
spp=(spp and i == 0)))
if drop_block:
stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
fpn_stages.append(stage)
if i < self.num_blocks - 1:
fpn_routes.append(
ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out // 2,
filter_size=1,
stride=1,
padding=0,
act=act))
ch_pre = ch_out
self.fpn_stages = nn.LayerList(fpn_stages)
self.fpn_routes = nn.LayerList(fpn_routes)
pan_stages = []
pan_routes = []
for i in reversed(range(self.num_blocks - 1)):
pan_routes.append(
ConvBNLayer(
ch_in=out_channels[i + 1],
ch_out=out_channels[i + 1],
filter_size=3,
stride=2,
padding=1,
act=act))
ch_in = out_channels[i] + out_channels[i + 1]
ch_out = out_channels[i]
stage = nn.Sequential()
for j in range(stage_num):
stage.add_sublayer(
str(j),
eval(stage_fn)(block_fn,
ch_in if j == 0 else ch_out,
ch_out,
block_num,
act=act,
spp=False))
if drop_block:
stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
pan_stages.append(stage)
self.pan_stages = nn.LayerList(pan_stages[::-1])
self.pan_routes = nn.LayerList(pan_routes[::-1])
def forward(self, blocks, for_mot=False):
blocks = blocks[::-1]
fpn_feats = []
for i, block in enumerate(blocks):
if i > 0:
block = paddle.concat([route, block], axis=1)
route = self.fpn_stages[i](block)
fpn_feats.append(route)
if i < self.num_blocks - 1:
route = self.fpn_routes[i](route)
route = F.interpolate(
route, scale_factor=2., data_format=self.data_format)
pan_feats = [fpn_feats[-1], ]
route = fpn_feats[-1]
for i in reversed(range(self.num_blocks - 1)):
block = fpn_feats[i]
route = self.pan_routes[i](route)
block = paddle.concat([route, block], axis=1)
route = self.pan_stages[i](block)
pan_feats.append(route)
return pan_feats[::-1]
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
...@@ -114,7 +114,7 @@ class SPP(nn.Layer): ...@@ -114,7 +114,7 @@ class SPP(nn.Layer):
ch_out, ch_out,
k, k,
pool_size, pool_size,
norm_type, norm_type='bn',
freeze_norm=False, freeze_norm=False,
name='', name='',
act='leaky', act='leaky',
......
...@@ -20,28 +20,57 @@ from paddle.regularizer import L2Decay ...@@ -20,28 +20,57 @@ from paddle.regularizer import L2Decay
from paddle.fluid.framework import Variable, in_dygraph_mode from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
__all__ = [ __all__ = [
'roi_pool', 'roi_pool', 'roi_align', 'prior_box', 'generate_proposals',
'roi_align', 'iou_similarity', 'box_coder', 'yolo_box', 'multiclass_nms',
'prior_box', 'distribute_fpn_proposals', 'collect_fpn_proposals', 'matrix_nms',
'generate_proposals', 'batch_norm', 'mish', 'swish', 'identity'
'iou_similarity',
'box_coder',
'yolo_box',
'multiclass_nms',
'distribute_fpn_proposals',
'collect_fpn_proposals',
'matrix_nms',
'batch_norm',
'mish',
] ]
def identity(x):
return x
def mish(x): def mish(x):
return x * paddle.tanh(F.softplus(x)) return F.mish(x) if hasattr(F, mish) else x * F.tanh(F.softplus(x))
def swish(x):
return x * F.sigmoid(x)
TRT_ACT_SPEC = {'swish': swish}
ACT_SPEC = {'mish': mish}
def get_act_fn(act=None, trt=False):
assert act is None or isinstance(act, (
str, dict)), 'name of activation should be str, dict or None'
if not act:
return identity
if isinstance(act, dict):
name = act['name']
act.pop('name')
kwargs = act
else:
name = act
kwargs = dict()
if trt and name in TRT_ACT_SPEC:
fn = TRT_ACT_SPEC[name]
elif name in ACT_SPEC:
fn = ACT_SPEC[name]
else:
fn = getattr(F, name)
return lambda x: fn(x, **kwargs)
def batch_norm(ch, def batch_norm(ch,
...@@ -1602,3 +1631,8 @@ def get_static_shape(tensor): ...@@ -1602,3 +1631,8 @@ def get_static_shape(tensor):
shape = paddle.shape(tensor) shape = paddle.shape(tensor)
shape.stop_gradient = True shape.stop_gradient = True
return shape return shape
def paddle_distributed_is_initialized():
return core.is_compiled_with_dist(
) and parallel_helper._is_parallel_ctx_initialized()
...@@ -43,7 +43,7 @@ class CosineDecay(object): ...@@ -43,7 +43,7 @@ class CosineDecay(object):
the max_iters is much larger than the warmup iter the max_iters is much larger than the warmup iter
""" """
def __init__(self, max_epochs=1000, use_warmup=True, eta_min=0): def __init__(self, max_epochs=1000, use_warmup=True, eta_min=0.):
self.max_epochs = max_epochs self.max_epochs = max_epochs
self.use_warmup = use_warmup self.use_warmup = use_warmup
self.eta_min = eta_min self.eta_min = eta_min
...@@ -65,6 +65,7 @@ class CosineDecay(object): ...@@ -65,6 +65,7 @@ class CosineDecay(object):
decayed_lr = base_lr * 0.5 * (math.cos( decayed_lr = base_lr * 0.5 * (math.cos(
(i - warmup_iters) * math.pi / (i - warmup_iters) * math.pi /
(max_iters - warmup_iters)) + 1) (max_iters - warmup_iters)) + 1)
decayed_lr = decayed_lr if decayed_lr > self.eta_min else self.eta_min
value.append(decayed_lr) value.append(decayed_lr)
return optimizer.lr.PiecewiseDecay(boundary, value) return optimizer.lr.PiecewiseDecay(boundary, value)
...@@ -331,6 +332,11 @@ class ModelEMA(object): ...@@ -331,6 +332,11 @@ class ModelEMA(object):
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v) self.state_dict[k] = paddle.zeros_like(v)
def resume(self, state_dict, step):
for k, v in state_dict.items():
self.state_dict[k] = v
self.step = step
def update(self, model=None): def update(self, model=None):
if self.use_thres_step: if self.use_thres_step:
decay = min(self.decay, (1 + self.step) / (10 + self.step)) decay = min(self.decay, (1 + self.step) / (10 + self.step))
......
...@@ -62,7 +62,7 @@ def _strip_postfix(path): ...@@ -62,7 +62,7 @@ def _strip_postfix(path):
return path return path
def load_weight(model, weight, optimizer=None): def load_weight(model, weight, optimizer=None, ema=None):
if is_url(weight): if is_url(weight):
weight = get_weights_path(weight) weight = get_weights_path(weight)
...@@ -102,6 +102,10 @@ def load_weight(model, weight, optimizer=None): ...@@ -102,6 +102,10 @@ def load_weight(model, weight, optimizer=None):
last_epoch = optim_state_dict.pop('last_epoch') last_epoch = optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict) optimizer.set_state_dict(optim_state_dict)
if ema is not None and os.path.exists(path + '_ema.pdparams'):
ema_state_dict = paddle.load(path + '_ema.pdparams')
ema.resume(ema_state_dict,
optim_state_dict['LR_Scheduler']['last_epoch'])
return last_epoch return last_epoch
...@@ -201,7 +205,7 @@ def load_pretrain_weight(model, pretrain_weight): ...@@ -201,7 +205,7 @@ def load_pretrain_weight(model, pretrain_weight):
logger.info('Finish loading model weights: {}'.format(weights_path)) logger.info('Finish loading model weights: {}'.format(weights_path))
def save_model(model, optimizer, save_dir, save_name, last_epoch): def save_model(model, save_dir, save_name, last_epoch, optimizer=None):
""" """
save model into disk. save model into disk.
...@@ -224,6 +228,7 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch): ...@@ -224,6 +228,7 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch):
assert isinstance(model, assert isinstance(model,
dict), 'model is not a instance of nn.layer or dict' dict), 'model is not a instance of nn.layer or dict'
paddle.save(model, save_path + ".pdparams") paddle.save(model, save_path + ".pdparams")
if optimizer is not None:
state_dict = optimizer.state_dict() state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt") paddle.save(state_dict, save_path + ".pdopt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册