未验证 提交 ef83ab8a 编写于 作者: S shangliang Xu 提交者: GitHub

Add PP-YOLOv3 code (#5281)

* [ppyolov3] add ppyolov3 base code

* add ppyolov3 s/m/x

* modify ema

* modify code to convert onnx successfully

* support arbitrary shape

* update config to use amp default

* refine ppyolo_head code

* modify reparameter code

* refine act layer

* adapter pico_head and tood_head code

* remove ppyolov3 yaml

* fix codestyle
Co-authored-by: Nwangxinxin08 <wangxinxin08@baidu.com>
上级 629e1533
......@@ -1747,7 +1747,7 @@ class Mixup(BaseOperator):
gt_score2 = np.ones_like(sample[1]['gt_class'])
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
result['gt_score'] = gt_score
result['gt_score'] = gt_score.astype('float32')
if 'is_crowd' in sample[0]:
is_crowd1 = sample[0]['is_crowd']
is_crowd2 = sample[1]['is_crowd']
......
......@@ -32,7 +32,10 @@ from ppdet.metrics import get_infer_results
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
__all__ = ['Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer', 'VisualDLWriter', 'SniperProposalsGenerator']
__all__ = [
'Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer',
'VisualDLWriter', 'SniperProposalsGenerator'
]
class Callback(object):
......@@ -202,8 +205,14 @@ class Checkpointer(Callback):
logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap))
if weight:
save_model(weight, self.model.optimizer, self.save_dir,
save_name, epoch_id + 1)
if self.model.use_ema:
save_model(status['weight'], self.save_dir, save_name,
epoch_id + 1, self.model.optimizer)
save_model(weight, self.save_dir,
'{}_ema'.format(save_name), epoch_id + 1)
else:
save_model(weight, self.save_dir, save_name, epoch_id + 1,
self.model.optimizer)
class WiferFaceEval(Callback):
......
......@@ -339,7 +339,8 @@ class Trainer(object):
self.start_epoch = load_weight(self.model.student_model, weights,
self.optimizer)
else:
self.start_epoch = load_weight(self.model, weights, self.optimizer)
self.start_epoch = load_weight(self.model, weights, self.optimizer,
self.ema if self.use_ema else None)
logger.debug("Resume weights of epoch {}".format(self.start_epoch))
def train(self, validate=False):
......@@ -432,21 +433,23 @@ class Trainer(object):
self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status)
if self.use_ema:
self.ema.update(self.model)
self.ema.update()
iter_tic = time.time()
# apply ema weight on model
if self.use_ema:
weight = copy.deepcopy(self.model.state_dict())
self.model.set_dict(self.ema.apply())
if self.cfg.get('unstructured_prune'):
self.pruner.update_params()
is_snapshot = (self._nranks < 2 or self._local_rank == 0) \
and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
if is_snapshot and self.use_ema:
# apply ema weight on model
weight = copy.deepcopy(self.model.state_dict())
self.model.set_dict(self.ema.apply())
self.status['weight'] = weight
self._compose_callback.on_epoch_end(self.status)
if validate and (self._nranks < 2 or self._local_rank == 0) \
and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
or epoch_id == self.end_epoch - 1):
if validate and is_snapshot:
if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset
......@@ -467,13 +470,15 @@ class Trainer(object):
Init_mark = True
self._init_metrics(validate=validate)
self._reset_metrics()
with paddle.no_grad():
self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader)
# restore origin weight on model
if self.use_ema:
if is_snapshot and self.use_ema:
# reset original weight
self.model.set_dict(weight)
self.status.pop('weight')
self._compose_callback.on_train_end(self.status)
......@@ -634,6 +639,11 @@ class Trainer(object):
if hasattr(self.model, 'deploy'):
self.model.deploy = True
for layer in self.model.sublayers():
if hasattr(layer, 'convert_to_deploy'):
layer.convert_to_deploy()
export_post_process = self.cfg.get('export_post_process', False)
if hasattr(self.model, 'export_post_process'):
self.model.export_post_process = export_post_process
......
......@@ -109,10 +109,14 @@ class YOLOv3(BaseArch):
if self.return_idx:
_, bbox, bbox_num, _ = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors)
else:
elif self.post_process is not None:
bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
else:
bbox, bbox_num = self.yolo_head.post_process(
yolo_head_outs, self.inputs['im_shape'],
self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
......
......@@ -107,11 +107,15 @@ def gather_topk_anchors(metrics, topk, largest=True, topk_mask=None, eps=1e-9):
return is_in_topk.astype(metrics.dtype)
def check_points_inside_bboxes(points, bboxes, eps=1e-9):
def check_points_inside_bboxes(points,
bboxes,
center_radius_tensor=None,
eps=1e-9):
r"""
Args:
points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors
bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format
center_radius_tensor (Tensor, float32): shape [L, 1] Default: None.
eps (float): Default: 1e-9
Returns:
is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected
......@@ -119,6 +123,19 @@ def check_points_inside_bboxes(points, bboxes, eps=1e-9):
points = points.unsqueeze([0, 1])
x, y = points.chunk(2, axis=-1)
xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1)
if center_radius_tensor is not None:
center_radius_tensor = center_radius_tensor.unsqueeze([0, 1])
bboxes_cx = (xmin + xmax) / 2.
bboxes_cy = (ymin + ymax) / 2.
xmin_sampling = bboxes_cx - center_radius_tensor
ymin_sampling = bboxes_cy - center_radius_tensor
xmax_sampling = bboxes_cx + center_radius_tensor
ymax_sampling = bboxes_cy + center_radius_tensor
xmin = paddle.maximum(xmin, xmin_sampling)
ymin = paddle.maximum(ymin, ymin_sampling)
xmax = paddle.minimum(xmax, xmax_sampling)
ymax = paddle.minimum(ymax, ymax_sampling)
l = x - xmin
t = y - ymin
r = xmax - x
......@@ -167,14 +184,16 @@ def generate_anchors_for_grid_cell(feats,
grid_cell_size (float): anchor size
grid_cell_offset (float): The range is between 0 and 1.
Returns:
anchors (List[Tensor]): shape[s, (l, 4)]
num_anchors_list (List[int]): shape[s]
stride_tensor_list (List[Tensor]): shape[s, (l, 1)]
anchors (Tensor): shape[l, 4], "xmin, ymin, xmax, ymax" format.
anchor_points (Tensor): shape[l, 2], "x, y" format.
num_anchors_list (List[int]): shape[s], contains [s_1, s_2, ...].
stride_tensor (Tensor): shape[l, 1], contains the stride for each scale.
"""
assert len(feats) == len(fpn_strides)
anchors = []
anchor_points = []
num_anchors_list = []
stride_tensor_list = []
stride_tensor = []
for feat, stride in zip(feats, fpn_strides):
_, _, h, w = feat.shape
cell_half_size = grid_cell_size * stride * 0.5
......@@ -187,8 +206,19 @@ def generate_anchors_for_grid_cell(feats,
shift_x + cell_half_size, shift_y + cell_half_size
],
axis=-1).astype(feat.dtype)
anchor_point = paddle.stack(
[shift_x, shift_y], axis=-1).astype(feat.dtype)
anchors.append(anchor.reshape([-1, 4]))
anchor_points.append(anchor_point.reshape([-1, 2]))
num_anchors_list.append(len(anchors[-1]))
stride_tensor_list.append(
paddle.full([num_anchors_list[-1], 1], stride))
return anchors, num_anchors_list, stride_tensor_list
stride_tensor.append(
paddle.full(
[num_anchors_list[-1], 1], stride, dtype=feat.dtype))
anchors = paddle.concat(anchors)
anchors.stop_gradient = True
anchor_points = paddle.concat(anchor_points)
anchor_points.stop_gradient = True
stride_tensor = paddle.concat(stride_tensor)
stride_tensor.stop_gradient = True
return anchors, anchor_points, num_anchors_list, stride_tensor
......@@ -29,6 +29,7 @@ from . import swin_transformer
from . import lcnet
from . import hardnet
from . import esnet
from . import cspresnet
from .vgg import *
from .resnet import *
......@@ -47,3 +48,4 @@ from .swin_transformer import *
from .lcnet import *
from .hardnet import *
from .esnet import *
from .cspresnet import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.modeling.ops import get_act_fn
from ppdet.core.workspace import register, serializable
from ..shape_spec import ShapeSpec
__all__ = ['CSPResNet', 'BasicBlock', 'EffectiveSELayer', 'ConvBNLayer']
class ConvBNLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size=3,
stride=1,
groups=1,
padding=0,
act=None):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=False)
self.bn = nn.BatchNorm2D(
ch_out,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
return x
class RepVggBlock(nn.Layer):
def __init__(self, ch_in, ch_out, act='relu'):
super(RepVggBlock, self).__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = ConvBNLayer(
ch_in, ch_out, 3, stride=1, padding=1, act=None)
self.conv2 = ConvBNLayer(
ch_in, ch_out, 1, stride=1, padding=0, act=None)
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
def forward(self, x):
if hasattr(self, 'conv'):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
y = self.act(y)
return y
def convert_to_deploy(self):
if not hasattr(self, 'conv'):
self.conv = nn.Conv2D(
in_channels=self.ch_in,
out_channels=self.ch_out,
kernel_size=3,
stride=1,
padding=1,
groups=1)
kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.set_value(kernel)
self.conv.bias.set_value(bias)
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1), bias3x3 + bias1x1
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
class BasicBlock(nn.Layer):
def __init__(self, ch_in, ch_out, act='relu', shortcut=True):
super(BasicBlock, self).__init__()
assert ch_in == ch_out
self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act)
self.conv2 = RepVggBlock(ch_out, ch_out, act=act)
self.shortcut = shortcut
def forward(self, x):
y = self.conv1(x)
y = self.conv2(y)
if self.shortcut:
return paddle.add(x, y)
else:
return y
class EffectiveSELayer(nn.Layer):
""" Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, act='hardsigmoid'):
super(EffectiveSELayer, self).__init__()
self.fc = nn.Conv2D(channels, channels, kernel_size=1, padding=0)
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.act(x_se)
class CSPResStage(nn.Layer):
def __init__(self,
block_fn,
ch_in,
ch_out,
n,
stride,
act='relu',
attn='eca'):
super(CSPResStage, self).__init__()
ch_mid = (ch_in + ch_out) // 2
if stride == 2:
self.conv_down = ConvBNLayer(
ch_in, ch_mid, 3, stride=2, padding=1, act=act)
else:
self.conv_down = None
self.conv1 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
self.blocks = nn.Sequential(* [
block_fn(
ch_mid // 2, ch_mid // 2, act=act, shortcut=True)
for i in range(n)
])
if attn:
self.attn = EffectiveSELayer(ch_mid, act='hardsigmoid')
else:
self.attn = None
self.conv3 = ConvBNLayer(ch_mid, ch_out, 1, act=act)
def forward(self, x):
if self.conv_down is not None:
x = self.conv_down(x)
y1 = self.conv1(x)
y2 = self.blocks(self.conv2(x))
y = paddle.concat([y1, y2], axis=1)
if self.attn is not None:
y = self.attn(y)
y = self.conv3(y)
return y
@register
@serializable
class CSPResNet(nn.Layer):
__shared__ = ['width_mult', 'depth_mult', 'trt']
def __init__(self,
layers=[3, 6, 6, 3],
channels=[64, 128, 256, 512, 1024],
act='swish',
return_idx=[0, 1, 2, 3, 4],
depth_wise=False,
use_large_stem=False,
width_mult=1.0,
depth_mult=1.0,
trt=False):
super(CSPResNet, self).__init__()
channels = [max(round(c * width_mult), 1) for c in channels]
layers = [max(round(l * depth_mult), 1) for l in layers]
act = get_act_fn(
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
if use_large_stem:
self.stem = nn.Sequential(
('conv1', ConvBNLayer(
3, channels[0] // 2, 3, stride=2, padding=1, act=act)),
('conv2', ConvBNLayer(
channels[0] // 2,
channels[0] // 2,
3,
stride=1,
padding=1,
act=act)), ('conv3', ConvBNLayer(
channels[0] // 2,
channels[0],
3,
stride=1,
padding=1,
act=act)))
else:
self.stem = nn.Sequential(
('conv1', ConvBNLayer(
3, channels[0] // 2, 3, stride=2, padding=1, act=act)),
('conv2', ConvBNLayer(
channels[0] // 2,
channels[0],
3,
stride=1,
padding=1,
act=act)))
n = len(channels) - 1
self.stages = nn.Sequential(* [(str(i), CSPResStage(
BasicBlock, channels[i], channels[i + 1], layers[i], 2, act=act))
for i in range(n)])
self._out_channels = channels[1:]
self._out_strides = [4, 8, 16, 32]
self.return_idx = return_idx
def forward(self, inputs):
x = inputs['image']
x = self.stem(x)
outs = []
for idx, stage in enumerate(self.stages):
x = stage(x)
if idx in self.return_idx:
outs.append(x)
return outs
@property
def out_shape(self):
return [
ShapeSpec(
channels=self._out_channels[i], stride=self._out_strides[i])
for i in self.return_idx
]
......@@ -77,8 +77,8 @@ class ConvBNLayer(nn.Layer):
out = self.batch_norm(out)
if self.act == 'leaky':
out = F.leaky_relu(out, 0.1)
elif self.act == 'mish':
out = mish(out)
else:
out = getattr(F, self.act)(out)
return out
......@@ -156,7 +156,7 @@ class BasicBlock(nn.Layer):
# channel route: 10-->5 --> 5-->10
self.conv1 = ConvBNLayer(
ch_in=ch_in,
ch_out=int(ch_out/2),
ch_out=int(ch_out / 2),
filter_size=1,
stride=1,
padding=0,
......@@ -165,8 +165,8 @@ class BasicBlock(nn.Layer):
freeze_norm=freeze_norm,
data_format=data_format)
self.conv2 = ConvBNLayer(
ch_in=int(ch_out/2),
ch_out=ch_out ,
ch_in=int(ch_out / 2),
ch_out=ch_out,
filter_size=3,
stride=1,
padding=1,
......@@ -317,7 +317,7 @@ class DarkNet(nn.Layer):
down_name,
DownSample(
ch_in=int(ch_in[i]),
ch_out=int(ch_in[i+1]),
ch_out=int(ch_in[i + 1]),
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
......
......@@ -744,9 +744,9 @@ def distance2bbox(points, distance, max_shape=None):
def bbox_center(boxes):
"""Get bbox centers from boxes.
Args:
boxes (Tensor): boxes with shape (N, 4), "xmin, ymin, xmax, ymax" format.
boxes (Tensor): boxes with shape (..., 4), "xmin, ymin, xmax, ymax" format.
Returns:
Tensor: boxes centers with shape (N, 2), "cx, cy" format.
Tensor: boxes centers with shape (..., 2), "cx, cy" format.
"""
boxes_cx = (boxes[..., 0] + boxes[..., 2]) / 2
boxes_cy = (boxes[..., 1] + boxes[..., 3]) / 2
......@@ -782,7 +782,7 @@ def delta2bbox_v2(rois,
means=(0.0, 0.0, 0.0, 0.0),
stds=(1.0, 1.0, 1.0, 1.0),
max_shape=None,
wh_ratio_clip=16.0/1000.0,
wh_ratio_clip=16.0 / 1000.0,
ctr_clip=None):
"""Transform network output(delta) to bboxes.
Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/
......@@ -860,7 +860,7 @@ def bbox2delta_v2(src_boxes,
dw = paddle.log(tgt_w / src_w)
dh = paddle.log(tgt_h / src_h)
deltas = paddle.stack((dx, dy, dw, dh), axis=1) # [n, 4]
deltas = paddle.stack((dx, dy, dw, dh), axis=1) # [n, 4]
means = paddle.to_tensor(means, place=src_boxes.place)
stds = paddle.to_tensor(stds, place=src_boxes.place)
deltas = (deltas - means) / stds
......
......@@ -32,6 +32,7 @@ from . import detr_head
from . import sparsercnn_head
from . import tood_head
from . import retina_head
from . import ppyolo_head
from .bbox_head import *
from .mask_head import *
......@@ -53,3 +54,4 @@ from .detr_head import *
from .sparsercnn_head import *
from .tood_head import *
from .retina_head import *
from .ppyolo_head import *
......@@ -464,12 +464,13 @@ class PicoHeadV2(GFLHead):
assert len(fpn_feats) == len(
self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride"
anchors, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell(
anchors, _, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell(
fpn_feats, self.fpn_stride, self.grid_cell_scale, self.cell_offset)
anchors_split = paddle.split(anchors, num_anchors_list)
cls_score_list, reg_list, box_list = [], [], []
for i, fpn_feat, anchor, stride, align_cls in zip(
range(len(self.fpn_stride)), fpn_feats, anchors,
range(len(self.fpn_stride)), fpn_feats, anchors_split,
self.fpn_stride, self.cls_align):
b, _, h, w = get_static_shape(fpn_feat)
# task decomposition
......@@ -507,10 +508,6 @@ class PicoHeadV2(GFLHead):
cls_score_list = paddle.concat(cls_score_list, axis=1)
box_list = paddle.concat(box_list, axis=1)
reg_list = paddle.concat(reg_list, axis=1)
anchors = paddle.concat(anchors)
anchors.stop_gradient = True
stride_tensor_list = paddle.concat(stride_tensor_list)
stride_tensor_list.stop_gradient = True
return cls_score_list, reg_list, box_list, anchors, num_anchors_list, stride_tensor_list
def get_loss(self, head_outs, gt_meta):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import batch_distance2bbox
from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, paddle_distributed_is_initialized, get_act_fn
__all__ = ['PPYOLOHead']
class ESEAttn(nn.Layer):
def __init__(self, feat_channels, act='swish'):
super(ESEAttn, self).__init__()
self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
self._init_weights()
def _init_weights(self):
normal_(self.fc.weight, std=0.001)
def forward(self, feat, avg_feat):
weight = F.sigmoid(self.fc(avg_feat))
return self.conv(feat * weight)
@register
class PPYOLOHead(nn.Layer):
__shared__ = ['num_classes', 'trt']
__inject__ = ['static_assigner', 'assigner', 'nms']
def __init__(self,
in_channels=[1024, 512, 256],
num_classes=80,
act='swish',
fpn_strides=(32, 16, 8),
grid_cell_scale=5.0,
grid_cell_offset=0.5,
reg_max=16,
static_assigner_epoch=4,
use_varifocal_loss=True,
static_assigner='ATSSAssigner',
assigner='TaskAlignedAssigner',
nms='MultiClassNMS',
eval_input_size=[],
loss_weight={
'class': 1.0,
'iou': 2.5,
'dfl': 0.5,
},
trt=False):
super(PPYOLOHead, self).__init__()
assert len(in_channels) > 0, "len(in_channels) should > 0"
self.in_channels = in_channels
self.num_classes = num_classes
self.fpn_strides = fpn_strides
self.grid_cell_scale = grid_cell_scale
self.grid_cell_offset = grid_cell_offset
self.reg_max = reg_max
self.iou_loss = GIoULoss()
self.loss_weight = loss_weight
self.use_varifocal_loss = use_varifocal_loss
self.eval_input_size = eval_input_size
self.static_assigner_epoch = static_assigner_epoch
self.static_assigner = static_assigner
self.assigner = assigner
self.nms = nms
# stem
self.stem_cls = nn.LayerList()
self.stem_reg = nn.LayerList()
act = get_act_fn(
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
for in_c in self.in_channels:
self.stem_cls.append(ESEAttn(in_c, act=act))
self.stem_reg.append(ESEAttn(in_c, act=act))
# pred head
self.pred_cls = nn.LayerList()
self.pred_reg = nn.LayerList()
for in_c in self.in_channels:
self.pred_cls.append(
nn.Conv2D(
in_c, self.num_classes, 3, padding=1))
self.pred_reg.append(
nn.Conv2D(
in_c, 4 * (self.reg_max + 1), 3, padding=1))
# projection conv
self.proj_conv = nn.Conv2D(self.reg_max + 1, 1, 1, bias_attr=False)
self._init_weights()
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
def _init_weights(self):
bias_cls = bias_init_with_prob(0.01)
for cls_, reg_ in zip(self.pred_cls, self.pred_reg):
constant_(cls_.weight)
constant_(cls_.bias, bias_cls)
constant_(reg_.weight)
constant_(reg_.bias, 1.0)
self.proj = paddle.linspace(0, self.reg_max, self.reg_max + 1)
self.proj_conv.weight.set_value(
self.proj.reshape([1, self.reg_max + 1, 1, 1]))
self.proj_conv.weight.stop_gradient = True
if self.eval_input_size:
anchor_points, stride_tensor = self._generate_anchors()
self.register_buffer('anchor_points', anchor_points)
self.register_buffer('stride_tensor', stride_tensor)
def forward_train(self, feats, targets):
anchors, anchor_points, num_anchors_list, stride_tensor = \
generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale,
self.grid_cell_offset)
cls_score_list, reg_distri_list = [], []
for i, feat in enumerate(feats):
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
feat)
reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
# cls and reg
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1]))
cls_score_list = paddle.concat(cls_score_list, axis=1)
reg_distri_list = paddle.concat(reg_distri_list, axis=1)
return self.get_loss([
cls_score_list, reg_distri_list, anchors, anchor_points,
num_anchors_list, stride_tensor
], targets)
def _generate_anchors(self, feats=None):
# just use in eval time
anchor_points = []
stride_tensor = []
for i, stride in enumerate(self.fpn_strides):
if feats is not None:
_, _, h, w = feats[i].shape
else:
h = int(self.eval_input_size[0] / stride)
w = int(self.eval_input_size[1] / stride)
shift_x = paddle.arange(end=w) + self.grid_cell_offset
shift_y = paddle.arange(end=h) + self.grid_cell_offset
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
anchor_point = paddle.cast(
paddle.stack(
[shift_x, shift_y], axis=-1), dtype='float32')
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(
paddle.full(
[h * w, 1], stride, dtype='float32'))
anchor_points = paddle.concat(anchor_points)
stride_tensor = paddle.concat(stride_tensor)
return anchor_points, stride_tensor
def forward_eval(self, feats):
if self.eval_input_size:
anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
else:
anchor_points, stride_tensor = self._generate_anchors(feats)
cls_score_list, reg_dist_list = [], []
for i, feat in enumerate(feats):
b, _, h, w = feat.shape
l = h * w
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
feat)
reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose(
[0, 2, 1, 3])
reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1))
# cls and reg
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.reshape([b, self.num_classes, l]))
reg_dist_list.append(reg_dist.reshape([b, 4, l]))
cls_score_list = paddle.concat(cls_score_list, axis=-1)
reg_dist_list = paddle.concat(reg_dist_list, axis=-1)
return cls_score_list, reg_dist_list, anchor_points, stride_tensor
def forward(self, feats, targets=None):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
if self.training:
return self.forward_train(feats, targets)
else:
return self.forward_eval(feats)
@staticmethod
def _focal_loss(score, label, alpha=0.25, gamma=2.0):
weight = (score - label).pow(gamma)
if alpha > 0:
alpha_t = alpha * label + (1 - alpha) * (1 - label)
weight *= alpha_t
loss = F.binary_cross_entropy(
score, label, weight=weight, reduction='sum')
return loss
@staticmethod
def _varifocal_loss(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
loss = F.binary_cross_entropy(
pred_score, gt_score, weight=weight, reduction='sum')
return loss
def _bbox_decode(self, anchor_points, pred_dist):
b, l, _ = get_static_shape(pred_dist)
pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1
])).matmul(self.proj)
return batch_distance2bbox(anchor_points, pred_dist)
def _bbox2distance(self, points, bbox):
x1y1, x2y2 = paddle.split(bbox, 2, -1)
lt = points - x1y1
rb = x2y2 - points
return paddle.concat([lt, rb], -1).clip(0, self.reg_max - 0.01)
def _df_loss(self, pred_dist, target):
target_left = paddle.cast(target, 'int64')
target_right = target_left + 1
weight_left = target_right.astype('float32') - target
weight_right = 1 - weight_left
loss_left = F.cross_entropy(
pred_dist, target_left, reduction='none') * weight_left
loss_right = F.cross_entropy(
pred_dist, target_right, reduction='none') * weight_right
return (loss_left + loss_right).mean(-1, keepdim=True)
def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels,
assigned_bboxes, assigned_scores, assigned_scores_sum):
# select positive samples mask
mask_positive = (assigned_labels != self.num_classes)
num_pos = mask_positive.sum()
# pos/neg loss
if num_pos > 0:
# l1 + iou
bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
pred_bboxes_pos = paddle.masked_select(pred_bboxes,
bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = paddle.masked_select(
assigned_bboxes, bbox_mask).reshape([-1, 4])
bbox_weight = paddle.masked_select(
assigned_scores.sum(-1), mask_positive).unsqueeze(-1)
loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos)
loss_iou = self.iou_loss(pred_bboxes_pos,
assigned_bboxes_pos) * bbox_weight
loss_iou = loss_iou.sum() / assigned_scores_sum
dist_mask = mask_positive.unsqueeze(-1).tile(
[1, 1, (self.reg_max + 1) * 4])
pred_dist_pos = paddle.masked_select(
pred_dist, dist_mask).reshape([-1, 4, self.reg_max + 1])
assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes)
assigned_ltrb_pos = paddle.masked_select(
assigned_ltrb, bbox_mask).reshape([-1, 4])
loss_dfl = self._df_loss(pred_dist_pos,
assigned_ltrb_pos) * bbox_weight
loss_dfl = loss_dfl.sum() / assigned_scores_sum
else:
loss_l1 = paddle.zeros([1])
loss_iou = paddle.zeros([1])
loss_dfl = paddle.zeros([1])
return loss_l1, loss_iou, loss_dfl
def get_loss(self, head_outs, gt_meta):
pred_scores, pred_distri, anchors,\
anchor_points, num_anchors_list, stride_tensor = head_outs
anchor_points_s = anchor_points / stride_tensor
pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \
self.static_assigner(
anchors,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes,
pred_bboxes=pred_bboxes.detach() * stride_tensor)
alpha_l = 0.25
else:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
alpha_l = -1
# rescale bbox
assigned_bboxes /= stride_tensor
# cls loss
if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels, self.num_classes)
loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
one_hot_label)
else:
loss_cls = self._focal_loss(
pred_scores, assigned_scores, alpha=alpha_l)
assigned_scores_sum = assigned_scores.sum()
if paddle_distributed_is_initialized():
paddle.distributed.all_reduce(assigned_scores_sum)
assigned_scores_sum = paddle.clip(
assigned_scores_sum / paddle.distributed.get_world_size(),
min=1)
loss_cls /= assigned_scores_sum
loss_l1, loss_iou, loss_dfl = \
self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores,
assigned_scores_sum)
loss = self.loss_weight['class'] * loss_cls + \
self.loss_weight['iou'] * loss_iou + \
self.loss_weight['dfl'] * loss_dfl
out_dict = {
'loss': loss,
'loss_cls': loss_cls,
'loss_iou': loss_iou,
'loss_dfl': loss_dfl,
'loss_l1': loss_l1,
}
return out_dict
def post_process(self, head_outs, img_shape, scale_factor):
pred_scores, pred_dist, anchor_points, stride_tensor = head_outs
pred_bboxes = batch_distance2bbox(anchor_points,
pred_dist.transpose([0, 2, 1]))
pred_bboxes *= stride_tensor
# scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
scale_factor = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4])
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
......@@ -218,13 +218,17 @@ class TOODHead(nn.Layer):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
anchors, num_anchors_list, stride_tensor_list = generate_anchors_for_grid_cell(
anchors, anchor_points, num_anchors_list, stride_tensor =\
generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale,
self.grid_cell_offset)
anchor_centers_split = paddle.split(anchor_points / stride_tensor,
num_anchors_list)
cls_score_list, bbox_pred_list = [], []
for feat, scale_reg, anchor, stride in zip(feats, self.scales_regs,
anchors, self.fpn_strides):
for feat, scale_reg, anchor_centers, stride in zip(
feats, self.scales_regs, anchor_centers_split,
self.fpn_strides):
b, _, h, w = get_static_shape(feat)
inter_feats = []
for inter_conv in self.inter_convs:
......@@ -250,8 +254,8 @@ class TOODHead(nn.Layer):
# reg prediction and alignment
reg_dist = scale_reg(self.tood_reg(reg_feat).exp())
reg_dist = reg_dist.flatten(2).transpose([0, 2, 1])
anchor_centers = bbox_center(anchor).unsqueeze(0) / stride
reg_bbox = batch_distance2bbox(anchor_centers, reg_dist)
reg_bbox = batch_distance2bbox(
anchor_centers.unsqueeze(0), reg_dist)
if self.use_align_head:
reg_offset = F.relu(self.reg_offset_conv1(feat))
reg_offset = self.reg_offset_conv2(reg_offset)
......@@ -268,12 +272,8 @@ class TOODHead(nn.Layer):
bbox_pred_list.append(bbox_pred)
cls_score_list = paddle.concat(cls_score_list, axis=1)
bbox_pred_list = paddle.concat(bbox_pred_list, axis=1)
anchors = paddle.concat(anchors)
anchors.stop_gradient = True
stride_tensor_list = paddle.concat(stride_tensor_list).unsqueeze(0)
stride_tensor_list.stop_gradient = True
return cls_score_list, bbox_pred_list, anchors, num_anchors_list, stride_tensor_list
return cls_score_list, bbox_pred_list, anchors, num_anchors_list, stride_tensor
@staticmethod
def _focal_loss(score, label, alpha=0.25, gamma=2.0):
......@@ -287,7 +287,7 @@ class TOODHead(nn.Layer):
def get_loss(self, head_outs, gt_meta):
pred_scores, pred_bboxes, anchors, \
num_anchors_list, stride_tensor_list = head_outs
num_anchors_list, stride_tensor = head_outs
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask']
......@@ -304,7 +304,7 @@ class TOODHead(nn.Layer):
else:
assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list,
pred_bboxes.detach() * stride_tensor,
bbox_center(anchors),
num_anchors_list,
gt_labels,
......@@ -314,7 +314,7 @@ class TOODHead(nn.Layer):
alpha_l = -1
# rescale bbox
assigned_bboxes /= stride_tensor_list
assigned_bboxes /= stride_tensor
# classification loss
loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha=alpha_l)
# select positive samples mask
......
......@@ -251,7 +251,7 @@ class LiteConv(nn.Layer):
class DropBlock(nn.Layer):
def __init__(self, block_size, keep_prob, name, data_format='NCHW'):
def __init__(self, block_size, keep_prob, name=None, data_format='NCHW'):
"""
DropBlock layer, see https://arxiv.org/abs/1810.12890
......
......@@ -20,6 +20,7 @@ from . import centernet_fpn
from . import bifpn
from . import csp_pan
from . import es_pan
from . import custom_pan
from .fpn import *
from .yolo_fpn import *
......@@ -30,3 +31,4 @@ from .blazeface_fpn import *
from .bifpn import *
from .csp_pan import *
from .es_pan import *
from .custom_pan import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import DropBlock
from ppdet.modeling.ops import get_act_fn
from ..backbones.cspresnet import ConvBNLayer, BasicBlock
from ..shape_spec import ShapeSpec
__all__ = ['CustomCSPPAN']
class SPP(nn.Layer):
def __init__(self,
ch_in,
ch_out,
k,
pool_size,
act='swish',
data_format='NCHW'):
super(SPP, self).__init__()
self.pool = []
self.data_format = data_format
for i, size in enumerate(pool_size):
pool = self.add_sublayer(
'pool{}'.format(i),
nn.MaxPool2D(
kernel_size=size,
stride=1,
padding=size // 2,
data_format=data_format,
ceil_mode=False))
self.pool.append(pool)
self.conv = ConvBNLayer(ch_in, ch_out, k, padding=k // 2, act=act)
def forward(self, x):
outs = [x]
for pool in self.pool:
outs.append(pool(x))
if self.data_format == 'NCHW':
y = paddle.concat(outs, axis=1)
else:
y = paddle.concat(outs, axis=-1)
y = self.conv(y)
return y
class CSPStage(nn.Layer):
def __init__(self, block_fn, ch_in, ch_out, n, act='swish', spp=False):
super(CSPStage, self).__init__()
ch_mid = int(ch_out // 2)
self.conv1 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
self.conv2 = ConvBNLayer(ch_in, ch_mid, 1, act=act)
self.convs = nn.Sequential()
next_ch_in = ch_mid
for i in range(n):
self.convs.add_sublayer(
str(i),
eval(block_fn)(next_ch_in, ch_mid, act=act, shortcut=False))
if i == (n - 1) // 2 and spp:
self.convs.add_sublayer(
'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
next_ch_in = ch_mid
self.conv3 = ConvBNLayer(ch_mid * 2, ch_out, 1, act=act)
def forward(self, x):
y1 = self.conv1(x)
y2 = self.conv2(x)
y2 = self.convs(y2)
y = paddle.concat([y1, y2], axis=1)
y = self.conv3(y)
return y
@register
@serializable
class CustomCSPPAN(nn.Layer):
__shared__ = ['norm_type', 'data_format', 'width_mult', 'depth_mult', 'trt']
def __init__(self,
in_channels=[256, 512, 1024],
out_channels=[1024, 512, 256],
norm_type='bn',
act='leaky',
stage_fn='CSPStage',
block_fn='BasicBlock',
stage_num=1,
block_num=3,
drop_block=False,
block_size=3,
keep_prob=0.9,
spp=False,
data_format='NCHW',
width_mult=1.0,
depth_mult=1.0,
trt=False):
super(CustomCSPPAN, self).__init__()
out_channels = [max(round(c * width_mult), 1) for c in out_channels]
block_num = max(round(block_num * depth_mult), 1)
act = get_act_fn(
act, trt=trt) if act is None or isinstance(act,
(str, dict)) else act
self.num_blocks = len(in_channels)
self.data_format = data_format
self._out_channels = out_channels
in_channels = in_channels[::-1]
fpn_stages = []
fpn_routes = []
for i, (ch_in, ch_out) in enumerate(zip(in_channels, out_channels)):
if i > 0:
ch_in += ch_pre // 2
stage = nn.Sequential()
for j in range(stage_num):
stage.add_sublayer(
str(j),
eval(stage_fn)(block_fn,
ch_in if j == 0 else ch_out,
ch_out,
block_num,
act=act,
spp=(spp and i == 0)))
if drop_block:
stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
fpn_stages.append(stage)
if i < self.num_blocks - 1:
fpn_routes.append(
ConvBNLayer(
ch_in=ch_out,
ch_out=ch_out // 2,
filter_size=1,
stride=1,
padding=0,
act=act))
ch_pre = ch_out
self.fpn_stages = nn.LayerList(fpn_stages)
self.fpn_routes = nn.LayerList(fpn_routes)
pan_stages = []
pan_routes = []
for i in reversed(range(self.num_blocks - 1)):
pan_routes.append(
ConvBNLayer(
ch_in=out_channels[i + 1],
ch_out=out_channels[i + 1],
filter_size=3,
stride=2,
padding=1,
act=act))
ch_in = out_channels[i] + out_channels[i + 1]
ch_out = out_channels[i]
stage = nn.Sequential()
for j in range(stage_num):
stage.add_sublayer(
str(j),
eval(stage_fn)(block_fn,
ch_in if j == 0 else ch_out,
ch_out,
block_num,
act=act,
spp=False))
if drop_block:
stage.add_sublayer('drop', DropBlock(block_size, keep_prob))
pan_stages.append(stage)
self.pan_stages = nn.LayerList(pan_stages[::-1])
self.pan_routes = nn.LayerList(pan_routes[::-1])
def forward(self, blocks, for_mot=False):
blocks = blocks[::-1]
fpn_feats = []
for i, block in enumerate(blocks):
if i > 0:
block = paddle.concat([route, block], axis=1)
route = self.fpn_stages[i](block)
fpn_feats.append(route)
if i < self.num_blocks - 1:
route = self.fpn_routes[i](route)
route = F.interpolate(
route, scale_factor=2., data_format=self.data_format)
pan_feats = [fpn_feats[-1], ]
route = fpn_feats[-1]
for i in reversed(range(self.num_blocks - 1)):
block = fpn_feats[i]
route = self.pan_routes[i](route)
block = paddle.concat([route, block], axis=1)
route = self.pan_stages[i](block)
pan_feats.append(route)
return pan_feats[::-1]
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
......@@ -114,7 +114,7 @@ class SPP(nn.Layer):
ch_out,
k,
pool_size,
norm_type,
norm_type='bn',
freeze_norm=False,
name='',
act='leaky',
......
......@@ -20,28 +20,57 @@ from paddle.regularizer import L2Decay
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype
__all__ = [
'roi_pool',
'roi_align',
'prior_box',
'generate_proposals',
'iou_similarity',
'box_coder',
'yolo_box',
'multiclass_nms',
'distribute_fpn_proposals',
'collect_fpn_proposals',
'matrix_nms',
'batch_norm',
'mish',
'roi_pool', 'roi_align', 'prior_box', 'generate_proposals',
'iou_similarity', 'box_coder', 'yolo_box', 'multiclass_nms',
'distribute_fpn_proposals', 'collect_fpn_proposals', 'matrix_nms',
'batch_norm', 'mish', 'swish', 'identity'
]
def identity(x):
return x
def mish(x):
return x * paddle.tanh(F.softplus(x))
return F.mish(x) if hasattr(F, mish) else x * F.tanh(F.softplus(x))
def swish(x):
return x * F.sigmoid(x)
TRT_ACT_SPEC = {'swish': swish}
ACT_SPEC = {'mish': mish}
def get_act_fn(act=None, trt=False):
assert act is None or isinstance(act, (
str, dict)), 'name of activation should be str, dict or None'
if not act:
return identity
if isinstance(act, dict):
name = act['name']
act.pop('name')
kwargs = act
else:
name = act
kwargs = dict()
if trt and name in TRT_ACT_SPEC:
fn = TRT_ACT_SPEC[name]
elif name in ACT_SPEC:
fn = ACT_SPEC[name]
else:
fn = getattr(F, name)
return lambda x: fn(x, **kwargs)
def batch_norm(ch,
......@@ -1602,3 +1631,8 @@ def get_static_shape(tensor):
shape = paddle.shape(tensor)
shape.stop_gradient = True
return shape
def paddle_distributed_is_initialized():
return core.is_compiled_with_dist(
) and parallel_helper._is_parallel_ctx_initialized()
......@@ -43,7 +43,7 @@ class CosineDecay(object):
the max_iters is much larger than the warmup iter
"""
def __init__(self, max_epochs=1000, use_warmup=True, eta_min=0):
def __init__(self, max_epochs=1000, use_warmup=True, eta_min=0.):
self.max_epochs = max_epochs
self.use_warmup = use_warmup
self.eta_min = eta_min
......@@ -65,6 +65,7 @@ class CosineDecay(object):
decayed_lr = base_lr * 0.5 * (math.cos(
(i - warmup_iters) * math.pi /
(max_iters - warmup_iters)) + 1)
decayed_lr = decayed_lr if decayed_lr > self.eta_min else self.eta_min
value.append(decayed_lr)
return optimizer.lr.PiecewiseDecay(boundary, value)
......@@ -110,7 +111,7 @@ class PiecewiseDecay(object):
boundary = [int(step_per_epoch) * i for i in self.milestones]
value = [base_lr] # during step[0, boundary[0]] is base_lr
# self.values is setted directly in config
# self.values is setted directly in config
if self.values is not None:
assert len(self.milestones) + 1 == len(self.values)
return optimizer.lr.PiecewiseDecay(boundary, self.values)
......@@ -201,7 +202,7 @@ class LearningRate(object):
return self.schedulers[0](base_lr=self.base_lr,
step_per_epoch=step_per_epoch)
# TODO: split warmup & decay
# TODO: split warmup & decay
# warmup
boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
# decay
......@@ -299,10 +300,10 @@ class ModelEMA(object):
Ema's parameter are updated with the formula:
`ema_param = decay * ema_param + (1 - decay) * cur_param`.
Defaults is 0.9998.
use_thres_step (bool): Whether set decay by thres_step or not
cycle_epoch (int): The epoch of interval to reset ema_param and
use_thres_step (bool): Whether set decay by thres_step or not
cycle_epoch (int): The epoch of interval to reset ema_param and
step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience
add a regular effect to ema, which is set according to experience
and is effective when the total training epoch is large.
"""
......@@ -331,6 +332,11 @@ class ModelEMA(object):
for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v)
def resume(self, state_dict, step):
for k, v in state_dict.items():
self.state_dict[k] = v
self.step = step
def update(self, model=None):
if self.use_thres_step:
decay = min(self.decay, (1 + self.step) / (10 + self.step))
......
......@@ -62,7 +62,7 @@ def _strip_postfix(path):
return path
def load_weight(model, weight, optimizer=None):
def load_weight(model, weight, optimizer=None, ema=None):
if is_url(weight):
weight = get_weights_path(weight)
......@@ -102,6 +102,10 @@ def load_weight(model, weight, optimizer=None):
last_epoch = optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict)
if ema is not None and os.path.exists(path + '_ema.pdparams'):
ema_state_dict = paddle.load(path + '_ema.pdparams')
ema.resume(ema_state_dict,
optim_state_dict['LR_Scheduler']['last_epoch'])
return last_epoch
......@@ -112,9 +116,9 @@ def match_state_dict(model_state_dict, weight_state_dict):
The method supposes that all the names in pretrained weight state dict are
subclass of the names in models`, if the prefix 'backbone.' in pretrained weight
keys is stripped. And we could get the candidates for each model key. Then we
keys is stripped. And we could get the candidates for each model key. Then we
select the name with the longest matched size as the final match result. For
example, the model state dict has the name of
example, the model state dict has the name of
'backbone.res2.res2a.branch2a.conv.weight' and the pretrained weight as
name of 'res2.res2a.branch2a.conv.weight' and 'branch2a.conv.weight'. We
match the 'res2.res2a.branch2a.conv.weight' to the model key.
......@@ -125,7 +129,7 @@ def match_state_dict(model_state_dict, weight_state_dict):
def match(a, b):
if b.startswith('backbone.res5'):
# In Faster RCNN, res5 pretrained weights have prefix of backbone,
# In Faster RCNN, res5 pretrained weights have prefix of backbone,
# however, the corresponding model weights have difficult prefix,
# bbox_head.
b = b[9:]
......@@ -201,7 +205,7 @@ def load_pretrain_weight(model, pretrain_weight):
logger.info('Finish loading model weights: {}'.format(weights_path))
def save_model(model, optimizer, save_dir, save_name, last_epoch):
def save_model(model, save_dir, save_name, last_epoch, optimizer=None):
"""
save model into disk.
......@@ -224,7 +228,8 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch):
assert isinstance(model,
dict), 'model is not a instance of nn.layer or dict'
paddle.save(model, save_path + ".pdparams")
state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt")
if optimizer is not None:
state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt")
logger.info("Save checkpoint: {}".format(save_dir))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册