未验证 提交 80b2627b 编写于 作者: F Feng Ni 提交者: GitHub

refine fcos codes (#6949)

* refine fcos codes

* refine fcos postprocess

* fix fcos deploy

* fix fcos deploy
上级 92078713
......@@ -5,22 +5,21 @@ FCOS:
backbone: ResNet
neck: FPN
fcos_head: FCOSHead
fcos_post_process: FCOSPostProcess
ResNet:
# index 0 stands for res2
depth: 50
variant: 'b'
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
freeze_at: 0 # res2
return_idx: [1, 2, 3]
num_stages: 4
FPN:
out_channel: 256
spatial_scales: [0.125, 0.0625, 0.03125]
extra_stage: 2
has_extra_convs: true
use_c5: false
has_extra_convs: True
use_c5: False
FCOSHead:
fcos_feat:
......@@ -29,22 +28,18 @@ FCOSHead:
feat_out: 256
num_convs: 4
norm_type: "gn"
use_dcn: false
use_dcn: False
fpn_stride: [8, 16, 32, 64, 128]
prior_prob: 0.01
fcos_loss: FCOSLoss
norm_reg_targets: true
centerness_on_reg: true
FCOSLoss:
norm_reg_targets: True
centerness_on_reg: True
num_shift: 0.5
fcos_loss:
name: FCOSLoss
loss_alpha: 0.25
loss_gamma: 2.0
iou_loss_type: "giou"
reg_weights: 1.0
FCOSPostProcess:
decode:
name: FCOSBox
nms:
name: MultiClassNMS
nms_top_k: 1000
......
......@@ -2,11 +2,11 @@ worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {target_size: [800, 1333], keep_ratio: true, interp: 1}
- Permute: {}
- Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- RandomFlip: {}
batch_transforms:
- Permute: {}
- PadBatch: {pad_to_stride: 128}
- Gt2FCOSTarget:
object_sizes_boundary: [64, 128, 256, 512]
......@@ -14,29 +14,28 @@ TrainReader:
downsample_ratios: [8, 16, 32, 64, 128]
norm_reg_targets: True
batch_size: 2
shuffle: true
drop_last: true
shuffle: True
drop_last: True
EvalReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 128}
batch_size: 1
shuffle: false
TestReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Resize: {interp: 1, target_size: [800, 1333], keep_ratio: True}
- Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 128}
batch_size: 1
shuffle: false
fuse_normalize: True
......@@ -9,24 +9,8 @@ _BASE_: [
weights: output/fcos_dcn_r50_fpn_1x_coco/model_final
ResNet:
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
dcn_v2_stages: [1,2,3]
dcn_v2_stages: [1, 2, 3]
FCOSHead:
fcos_feat:
name: FCOSFeat
feat_in: 256
feat_out: 256
num_convs: 4
norm_type: "gn"
use_dcn: true
num_classes: 80
fpn_stride: [8, 16, 32, 64, 128]
prior_prob: 0.01
fcos_loss: FCOSLoss
norm_reg_targets: true
centerness_on_reg: true
use_dcn: True
......@@ -11,11 +11,11 @@ weights: output/fcos_r50_fpn_multiscale_2x_coco/model_final
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: true, interp: 1}
- Permute: {}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- RandomFlip: {}
batch_transforms:
- Permute: {}
- PadBatch: {pad_to_stride: 128}
- Gt2FCOSTarget:
object_sizes_boundary: [64, 128, 256, 512]
......@@ -23,8 +23,8 @@ TrainReader:
downsample_ratios: [8, 16, 32, 64, 128]
norm_reg_targets: True
batch_size: 2
shuffle: true
drop_last: true
shuffle: True
drop_last: True
epoch: 24
......
......@@ -32,22 +32,15 @@ class FCOS(BaseArch):
backbone (object): backbone instance
neck (object): 'FPN' instance
fcos_head (object): 'FCOSHead' instance
post_process (object): 'FCOSPostProcess' instance
"""
__category__ = 'architecture'
__inject__ = ['fcos_post_process']
def __init__(self,
backbone,
neck,
fcos_head='FCOSHead',
fcos_post_process='FCOSPostProcess'):
def __init__(self, backbone, neck='FPN', fcos_head='FCOSHead'):
super(FCOS, self).__init__()
self.backbone = backbone
self.neck = neck
self.fcos_head = fcos_head
self.fcos_post_process = fcos_post_process
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -68,38 +61,18 @@ class FCOS(BaseArch):
def _forward(self):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
fcos_head_outs = self.fcos_head(fpn_feats, self.training)
if not self.training:
scale_factor = self.inputs['scale_factor']
bboxes = self.fcos_post_process(fcos_head_outs, scale_factor)
return bboxes
if self.training:
losses = self.fcos_head(fpn_feats, self.inputs)
return losses
else:
return fcos_head_outs
def get_loss(self, ):
loss = {}
tag_labels, tag_bboxes, tag_centerness = [], [], []
for i in range(len(self.fcos_head.fpn_stride)):
# labels, reg_target, centerness
k_lbl = 'labels{}'.format(i)
if k_lbl in self.inputs:
tag_labels.append(self.inputs[k_lbl])
k_box = 'reg_target{}'.format(i)
if k_box in self.inputs:
tag_bboxes.append(self.inputs[k_box])
k_ctn = 'centerness{}'.format(i)
if k_ctn in self.inputs:
tag_centerness.append(self.inputs[k_ctn])
fcos_head_outs = self._forward()
loss_fcos = self.fcos_head.get_loss(fcos_head_outs, tag_labels,
tag_bboxes, tag_centerness)
loss.update(loss_fcos)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
return loss
fcos_head_outs = self.fcos_head(fpn_feats)
bbox_pred, bbox_num = self.fcos_head.post_process(
fcos_head_outs, self.inputs['scale_factor'])
return {'bbox': bbox_pred, 'bbox_num': bbox_num}
def get_loss(self):
return self._forward()
def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
return self._forward()
......@@ -24,7 +24,9 @@ from paddle import ParamAttr
from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.layers import ConvNormLayer, MultiClassNMS
__all__ = ['FCOSFeat', 'FCOSHead']
class ScaleReg(nn.Layer):
......@@ -115,25 +117,31 @@ class FCOSHead(nn.Layer):
"""
FCOSHead
Args:
fcos_feat (object): Instance of 'FCOSFeat'
num_classes (int): Number of classes
fcos_feat (object): Instance of 'FCOSFeat'
fpn_stride (list): The stride of each FPN Layer
prior_prob (float): Used to set the bias init for the class prediction layer
fcos_loss (object): Instance of 'FCOSLoss'
norm_reg_targets (bool): Normalization the regression target if true
centerness_on_reg (bool): The prediction of centerness on regression or clssification branch
num_shift (float): Relative offset between the center of the first shift and the top-left corner of img
fcos_loss (object): Instance of 'FCOSLoss'
nms (object): Instance of 'MultiClassNMS'
trt (bool): Whether to use trt in nms of deploy
"""
__inject__ = ['fcos_feat', 'fcos_loss']
__shared__ = ['num_classes']
__inject__ = ['fcos_feat', 'fcos_loss', 'nms']
__shared__ = ['num_classes', 'trt']
def __init__(self,
fcos_feat,
num_classes=80,
fcos_feat='FCOSFeat',
fpn_stride=[8, 16, 32, 64, 128],
prior_prob=0.01,
fcos_loss='FCOSLoss',
norm_reg_targets=True,
centerness_on_reg=True):
centerness_on_reg=True,
num_shift=0.5,
fcos_loss='FCOSLoss',
nms='MultiClassNMS',
trt=False):
super(FCOSHead, self).__init__()
self.fcos_feat = fcos_feat
self.num_classes = num_classes
......@@ -142,6 +150,10 @@ class FCOSHead(nn.Layer):
self.fcos_loss = fcos_loss
self.norm_reg_targets = norm_reg_targets
self.centerness_on_reg = centerness_on_reg
self.num_shift = num_shift
self.nms = nms
if isinstance(self.nms, MultiClassNMS) and trt:
self.nms.trt = trt
conv_cls_name = "fcos_head_cls"
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
......@@ -191,7 +203,7 @@ class FCOSHead(nn.Layer):
scale_reg = self.add_sublayer(feat_name, ScaleReg())
self.scales_regs.append(scale_reg)
def _compute_locations_by_level(self, fpn_stride, feature):
def _compute_locations_by_level(self, fpn_stride, feature, num_shift=0.5):
"""
Compute locations of anchor points of each FPN layer
Args:
......@@ -200,25 +212,21 @@ class FCOSHead(nn.Layer):
Return:
Anchor points locations of current FPN feature map
"""
shape_fm = paddle.shape(feature)
shape_fm.stop_gradient = True
h, w = shape_fm[2], shape_fm[3]
h, w = feature.shape[2], feature.shape[3]
shift_x = paddle.arange(0, w * fpn_stride, fpn_stride)
shift_y = paddle.arange(0, h * fpn_stride, fpn_stride)
shift_x = paddle.unsqueeze(shift_x, axis=0)
shift_y = paddle.unsqueeze(shift_y, axis=1)
shift_x = paddle.expand(shift_x, shape=[h, w])
shift_y = paddle.expand(shift_y, shape=[h, w])
shift_x.stop_gradient = True
shift_y.stop_gradient = True
shift_x = paddle.reshape(shift_x, shape=[-1])
shift_y = paddle.reshape(shift_y, shape=[-1])
location = paddle.stack(
[shift_x, shift_y], axis=-1) + float(fpn_stride) / 2
location.stop_gradient = True
[shift_x, shift_y], axis=-1) + float(fpn_stride * num_shift)
return location
def forward(self, fpn_feats, is_training):
def forward(self, fpn_feats, targets=None):
assert len(fpn_feats) == len(
self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride"
......@@ -236,7 +244,8 @@ class FCOSHead(nn.Layer):
centerness = self.fcos_head_centerness(fcos_cls_feat)
if self.norm_reg_targets:
bbox_reg = F.relu(bbox_reg)
if not is_training:
if not self.training:
# eval or infer
bbox_reg = bbox_reg * fpn_stride
else:
bbox_reg = paddle.exp(bbox_reg)
......@@ -244,17 +253,85 @@ class FCOSHead(nn.Layer):
bboxes_reg_list.append(bbox_reg)
centerness_list.append(centerness)
if not is_training:
if self.training:
losses = {}
fcos_head_outs = [cls_logits_list, bboxes_reg_list, centerness_list]
losses_fcos = self.get_loss(fcos_head_outs, targets)
losses.update(losses_fcos)
total_loss = paddle.add_n(list(losses.values()))
losses.update({'loss': total_loss})
return losses
else:
# eval or infer
locations_list = []
for fpn_stride, feature in zip(self.fpn_stride, fpn_feats):
location = self._compute_locations_by_level(fpn_stride, feature)
location = self._compute_locations_by_level(fpn_stride, feature,
self.num_shift)
locations_list.append(location)
return locations_list, cls_logits_list, bboxes_reg_list, centerness_list
else:
return cls_logits_list, bboxes_reg_list, centerness_list
fcos_head_outs = [
locations_list, cls_logits_list, bboxes_reg_list,
centerness_list
]
return fcos_head_outs
def get_loss(self, fcos_head_outs, tag_labels, tag_bboxes, tag_centerness):
def get_loss(self, fcos_head_outs, targets):
cls_logits, bboxes_reg, centerness = fcos_head_outs
return self.fcos_loss(cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_centerness)
# get labels,reg_target,centerness
tag_labels, tag_bboxes, tag_centerness = [], [], []
for i in range(len(self.fpn_stride)):
k_lbl = 'labels{}'.format(i)
if k_lbl in targets:
tag_labels.append(targets[k_lbl])
k_box = 'reg_target{}'.format(i)
if k_box in targets:
tag_bboxes.append(targets[k_box])
k_ctn = 'centerness{}'.format(i)
if k_ctn in targets:
tag_centerness.append(targets[k_ctn])
losses_fcos = self.fcos_loss(cls_logits, bboxes_reg, centerness,
tag_labels, tag_bboxes, tag_centerness)
return losses_fcos
def _post_process_by_level(self, locations, box_cls, box_reg, box_ctn):
box_scores = F.sigmoid(box_cls).flatten(2).transpose([0, 2, 1])
box_centerness = F.sigmoid(box_ctn).flatten(2).transpose([0, 2, 1])
pred_scores = box_scores * box_centerness
box_reg_ch_last = box_reg.flatten(2).transpose([0, 2, 1])
box_reg_decoding = paddle.stack(
[
locations[:, 0] - box_reg_ch_last[:, :, 0],
locations[:, 1] - box_reg_ch_last[:, :, 1],
locations[:, 0] + box_reg_ch_last[:, :, 2],
locations[:, 1] + box_reg_ch_last[:, :, 3]
],
axis=1)
pred_boxes = box_reg_decoding.transpose([0, 2, 1])
return pred_scores, pred_boxes
def post_process(self, fcos_head_outs, scale_factor):
locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
pred_bboxes, pred_scores = [], []
for pts, cls, reg, ctn in zip(locations, cls_logits, bboxes_reg,
centerness):
scores, boxes = self._post_process_by_level(pts, cls, reg, ctn)
pred_scores.append(scores)
pred_bboxes.append(boxes)
pred_bboxes = paddle.concat(pred_bboxes, axis=1)
pred_scores = paddle.concat(pred_scores, axis=1)
# scale bbox to origin
scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
scale_factor = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4])
pred_bboxes /= scale_factor
pred_scores = pred_scores.transpose([0, 2, 1])
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
......@@ -703,98 +703,6 @@ class SSDBox(object):
return output_boxes, output_scores
@register
@serializable
class FCOSBox(object):
__shared__ = ['num_classes']
def __init__(self, num_classes=80):
super(FCOSBox, self).__init__()
self.num_classes = num_classes
def _merge_hw(self, inputs, ch_type="channel_first"):
"""
Merge h and w of the feature map into one dimension.
Args:
inputs (Tensor): Tensor of the input feature map
ch_type (str): "channel_first" or "channel_last" style
Return:
new_shape (Tensor): The new shape after h and w merged
"""
shape_ = paddle.shape(inputs)
bs, ch, hi, wi = shape_[0], shape_[1], shape_[2], shape_[3]
img_size = hi * wi
img_size.stop_gradient = True
if ch_type == "channel_first":
new_shape = paddle.concat([bs, ch, img_size])
elif ch_type == "channel_last":
new_shape = paddle.concat([bs, img_size, ch])
else:
raise KeyError("Wrong ch_type %s" % ch_type)
new_shape.stop_gradient = True
return new_shape
def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
scale_factor):
"""
Postprocess each layer of the output with corresponding locations.
Args:
locations (Tensor): anchor points for current layer, [H*W, 2]
box_cls (Tensor): categories prediction, [N, C, H, W],
C is the number of classes
box_reg (Tensor): bounding box prediction, [N, 4, H, W]
box_ctn (Tensor): centerness prediction, [N, 1, H, W]
scale_factor (Tensor): [h_scale, w_scale] for input images
Return:
box_cls_ch_last (Tensor): score for each category, in [N, C, M]
C is the number of classes and M is the number of anchor points
box_reg_decoding (Tensor): decoded bounding box, in [N, M, 4]
last dimension is [x1, y1, x2, y2]
"""
act_shape_cls = self._merge_hw(box_cls)
box_cls_ch_last = paddle.reshape(x=box_cls, shape=act_shape_cls)
box_cls_ch_last = F.sigmoid(box_cls_ch_last)
act_shape_reg = self._merge_hw(box_reg)
box_reg_ch_last = paddle.reshape(x=box_reg, shape=act_shape_reg)
box_reg_ch_last = paddle.transpose(box_reg_ch_last, perm=[0, 2, 1])
box_reg_decoding = paddle.stack(
[
locations[:, 0] - box_reg_ch_last[:, :, 0],
locations[:, 1] - box_reg_ch_last[:, :, 1],
locations[:, 0] + box_reg_ch_last[:, :, 2],
locations[:, 1] + box_reg_ch_last[:, :, 3]
],
axis=1)
box_reg_decoding = paddle.transpose(box_reg_decoding, perm=[0, 2, 1])
act_shape_ctn = self._merge_hw(box_ctn)
box_ctn_ch_last = paddle.reshape(x=box_ctn, shape=act_shape_ctn)
box_ctn_ch_last = F.sigmoid(box_ctn_ch_last)
# recover the location to original image
im_scale = paddle.concat([scale_factor, scale_factor], axis=1)
im_scale = paddle.expand(im_scale, [box_reg_decoding.shape[0], 4])
im_scale = paddle.reshape(im_scale, [box_reg_decoding.shape[0], -1, 4])
box_reg_decoding = box_reg_decoding / im_scale
box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
return box_cls_ch_last, box_reg_decoding
def __call__(self, locations, cls_logits, bboxes_reg, centerness,
scale_factor):
pred_boxes_ = []
pred_scores_ = []
for pts, cls, box, ctn in zip(locations, cls_logits, bboxes_reg,
centerness):
pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
pts, cls, box, ctn, scale_factor)
pred_boxes_.append(pred_boxes_lvl)
pred_scores_.append(pred_scores_lvl)
pred_boxes = paddle.concat(pred_boxes_, axis=1)
pred_scores = paddle.concat(pred_scores_, axis=2)
return pred_boxes, pred_scores
@register
class TTFBox(object):
__shared__ = ['down_ratio']
......
......@@ -26,9 +26,8 @@ except Exception:
from collections import Sequence
__all__ = [
'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess',
'JDEBBoxPostProcess', 'CenterNetPostProcess', 'DETRBBoxPostProcess',
'SparsePostProcess'
'BBoxPostProcess', 'MaskPostProcess', 'JDEBBoxPostProcess',
'CenterNetPostProcess', 'DETRBBoxPostProcess', 'SparsePostProcess'
]
......@@ -37,8 +36,12 @@ class BBoxPostProcess(object):
__shared__ = ['num_classes', 'export_onnx', 'export_eb']
__inject__ = ['decode', 'nms']
def __init__(self, num_classes=80, decode=None, nms=None,
export_onnx=False, export_eb=False):
def __init__(self,
num_classes=80,
decode=None,
nms=None,
export_onnx=False,
export_eb=False):
super(BBoxPostProcess, self).__init__()
self.num_classes = num_classes
self.decode = decode
......@@ -279,26 +282,6 @@ class MaskPostProcess(object):
return pred_result
@register
class FCOSPostProcess(object):
__inject__ = ['decode', 'nms']
def __init__(self, decode=None, nms=None):
super(FCOSPostProcess, self).__init__()
self.decode = decode
self.nms = nms
def __call__(self, fcos_head_outs, scale_factor):
"""
Decode the bbox and do NMS in FCOS.
"""
locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
centerness, scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, score)
return bbox_pred, bbox_num
@register
class JDEBBoxPostProcess(nn.Layer):
__shared__ = ['num_classes']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册