未验证 提交 df55cb9b 编写于 作者: G Guanghua Yu 提交者: GitHub

update PicoDet and GFL post_process (#5101)

上级 a3bc6d5b
architecture: PicoDet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ESNet_x1_0_pretrained.pdparams
export_post_process: False # Whether post-processing is included in the network
export_post_process: False # Whether post-processing is included in the network when export model.
PicoDet:
backbone: ESNet
......
......@@ -631,9 +631,12 @@ class Trainer(object):
im_shape = [image_shape[0], 2]
scale_factor = [image_shape[0], 2]
export_post_process = self.cfg.get('export_post_process', False)
if hasattr(self.model, 'deploy') and not export_post_process:
if hasattr(self.model, 'deploy'):
self.model.deploy = True
export_post_process = self.cfg.get('export_post_process', False)
if hasattr(self.model, 'export_post_process'):
self.model.export_post_process = export_post_process
image_shape = [None] + image_shape[1:]
if hasattr(self.model, 'fuse_norm'):
self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
False)
......
......@@ -41,7 +41,7 @@ class PicoDet(BaseArch):
self.backbone = backbone
self.neck = neck
self.head = head
self.deploy = False
self.export_post_process = True
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......@@ -62,8 +62,8 @@ class PicoDet(BaseArch):
def _forward(self):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
head_outs = self.head(fpn_feats, self.deploy)
if self.training or self.deploy:
head_outs = self.head(fpn_feats, self.export_post_process)
if self.training or not self.export_post_process:
return head_outs, None
else:
im_shape = self.inputs['im_shape']
......@@ -83,7 +83,7 @@ class PicoDet(BaseArch):
return loss
def get_pred(self):
if self.deploy:
if not self.export_post_process:
return {'picodet': self._forward()[0]}
else:
bbox_pred, bbox_num = self._forward()
......
......@@ -756,20 +756,22 @@ def bbox_center(boxes):
def batch_distance2bbox(points, distance, max_shapes=None):
"""Decode distance prediction to bounding box for batch.
Args:
points (Tensor): [B, ..., 2]
distance (Tensor): [B, ..., 4]
max_shapes (tuple): [B, 2], "h,w" format, Shape of the image.
points (Tensor): [B, ..., 2], "xy" format
distance (Tensor): [B, ..., 4], "ltrb" format
max_shapes (Tensor): [B, 2], "h,w" format, Shape of the image.
Returns:
Tensor: Decoded bboxes.
Tensor: Decoded bboxes, "x1y1x2y2" format.
"""
x1 = points[..., 0] - distance[..., 0]
y1 = points[..., 1] - distance[..., 1]
x2 = points[..., 0] + distance[..., 2]
y2 = points[..., 1] + distance[..., 3]
lt, rb = paddle.split(distance, 2, -1)
x1y1 = points - lt
x2y2 = points + rb
out_bbox = paddle.concat([x1y1, x2y2], -1)
if max_shapes is not None:
for i, max_shape in enumerate(max_shapes):
x1[i] = x1[i].clip(min=0, max=max_shape[1])
y1[i] = y1[i].clip(min=0, max=max_shape[0])
x2[i] = x2[i].clip(min=0, max=max_shape[1])
y2[i] = y2[i].clip(min=0, max=max_shape[0])
return paddle.stack([x1, y1, x2, y2], -1)
max_shapes = max_shapes.flip(-1).tile([1, 2])
delta_dim = out_bbox.ndim - max_shapes.ndim
for _ in range(delta_dim):
max_shapes.unsqueeze_(1)
out_bbox = paddle.where(out_bbox < max_shapes, out_bbox, max_shapes)
out_bbox = paddle.where(out_bbox > 0, out_bbox,
paddle.zeros_like(out_bbox))
return out_bbox
......@@ -29,7 +29,7 @@ from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance
from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance, batch_distance2bbox
from ppdet.data.transform.atss_assigner import bbox_overlaps
......@@ -241,18 +241,34 @@ class GFLHead(nn.Layer):
), "The size of fpn_feats is not equal to size of fpn_stride"
cls_logits_list = []
bboxes_reg_list = []
for scale_reg, fpn_feat in zip(self.scales_regs, fpn_feats):
for stride, scale_reg, fpn_feat in zip(self.fpn_stride,
self.scales_regs, fpn_feats):
conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat)
cls_logits = self.gfl_head_cls(conv_cls_feat)
bbox_reg = scale_reg(self.gfl_head_reg(conv_reg_feat))
cls_score = self.gfl_head_cls(conv_cls_feat)
bbox_pred = scale_reg(self.gfl_head_reg(conv_reg_feat))
if self.dgqp_module:
quality_score = self.dgqp_module(bbox_reg)
cls_logits = F.sigmoid(cls_logits) * quality_score
quality_score = self.dgqp_module(bbox_pred)
cls_score = F.sigmoid(cls_score) * quality_score
if not self.training:
cls_logits = F.sigmoid(cls_logits.transpose([0, 2, 3, 1]))
bbox_reg = bbox_reg.transpose([0, 2, 3, 1])
cls_logits_list.append(cls_logits)
bboxes_reg_list.append(bbox_reg)
cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
b, cell_h, cell_w, _ = paddle.shape(cls_score)
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
cls_score = cls_score.reshape([b, -1, self.cls_out_channels])
bbox_pred = self.distribution_project(bbox_pred) * stride
bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
# NOTE: If keep_ratio=False and image shape value that
# multiples of 32, distance2bbox not set max_shapes parameter
# to speed up model prediction. If need to set max_shapes,
# please use inputs['im_shape'].
bbox_pred = batch_distance2bbox(
center_points, bbox_pred, max_shapes=None)
cls_logits_list.append(cls_score)
bboxes_reg_list.append(bbox_pred)
return (cls_logits_list, bboxes_reg_list)
......@@ -410,71 +426,15 @@ class GFLHead(nn.Layer):
x = x.flatten()
return y, x
def get_bboxes_single(self,
cls_scores,
bbox_preds,
img_shape,
scale_factor,
rescale=True,
cell_offset=0):
assert len(cls_scores) == len(bbox_preds)
mlvl_bboxes = []
mlvl_scores = []
for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores,
bbox_preds):
featmap_size = [
paddle.shape(cls_score)[0], paddle.shape(cls_score)[1]
]
y, x = self.get_single_level_center_point(
featmap_size, stride, cell_offset=cell_offset)
center_points = paddle.stack([x, y], axis=-1)
scores = cls_score.reshape([-1, self.cls_out_channels])
bbox_pred = self.distribution_project(bbox_pred) * stride
if scores.shape[0] > self.nms_pre:
max_scores = scores.max(axis=1)
_, topk_inds = max_scores.topk(self.nms_pre)
center_points = center_points.gather(topk_inds)
bbox_pred = bbox_pred.gather(topk_inds)
scores = scores.gather(topk_inds)
bboxes = distance2bbox(
center_points, bbox_pred, max_shape=img_shape)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_bboxes = paddle.concat(mlvl_bboxes)
if rescale:
# [h_scale, w_scale] to [w_scale, h_scale, w_scale, h_scale]
im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]])
mlvl_bboxes /= im_scale
mlvl_scores = paddle.concat(mlvl_scores)
mlvl_scores = mlvl_scores.transpose([1, 0])
return mlvl_bboxes, mlvl_scores
def decode(self, cls_scores, bbox_preds, im_shape, scale_factor,
cell_offset):
batch_bboxes = []
batch_scores = []
for img_id in range(cls_scores[0].shape[0]):
num_levels = len(cls_scores)
cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)]
bbox_pred_list = [bbox_preds[i][img_id] for i in range(num_levels)]
bboxes, scores = self.get_bboxes_single(
cls_score_list,
bbox_pred_list,
im_shape[img_id],
scale_factor[img_id],
cell_offset=cell_offset)
batch_bboxes.append(bboxes)
batch_scores.append(scores)
batch_bboxes = paddle.stack(batch_bboxes, axis=0)
batch_scores = paddle.stack(batch_scores, axis=0)
return batch_bboxes, batch_scores
def post_process(self, gfl_head_outs, im_shape, scale_factor):
cls_scores, bboxes_reg = gfl_head_outs
bboxes, score = self.decode(cls_scores, bboxes_reg, im_shape,
scale_factor, self.cell_offset)
bbox_pred, bbox_num, _ = self.nms(bboxes, score)
bboxes = paddle.concat(bboxes_reg, axis=1)
# rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
im_scale = paddle.concat(
[scale_factor[:, ::-1], scale_factor[:, ::-1]],
axis=-1).unsqueeze(1)
bboxes /= im_scale
mlvl_scores = paddle.concat(cls_scores, axis=1)
mlvl_scores = mlvl_scores.transpose([0, 2, 1])
bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores)
return bbox_pred, bbox_num
......@@ -26,6 +26,7 @@ from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.bbox_utils import batch_distance2bbox
from .simota_head import OTAVFLHead
......@@ -238,7 +239,7 @@ class PicoHead(OTAVFLHead):
bias_attr=ParamAttr(initializer=Constant(value=0))))
self.head_reg_list.append(head_reg)
def forward(self, fpn_feats, deploy=False):
def forward(self, fpn_feats, export_post_process=True):
assert len(fpn_feats) == len(
self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride"
......@@ -260,7 +261,7 @@ class PicoHead(OTAVFLHead):
quality_score = self.dgqp_module(bbox_pred)
cls_score = F.sigmoid(cls_score) * quality_score
if deploy:
if not export_post_process:
# Now only supports batch size = 1 in deploy
# TODO(ygh): support batch size > 1
cls_score = F.sigmoid(cls_score).reshape(
......@@ -270,6 +271,21 @@ class PicoHead(OTAVFLHead):
elif not self.training:
cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
stride = self.fpn_stride[i]
b, cell_h, cell_w, _ = paddle.shape(cls_score)
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
cls_score = cls_score.reshape([b, -1, self.cls_out_channels])
bbox_pred = self.distribution_project(bbox_pred) * stride
bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
# NOTE: If keep_ratio=False and image shape value that
# multiples of 32, distance2bbox not set max_shapes parameter
# to speed up model prediction. If need to set max_shapes,
# please use inputs['im_shape'].
bbox_pred = batch_distance2bbox(
center_points, bbox_pred, max_shapes=None)
cls_logits_list.append(cls_score)
bboxes_reg_list.append(bbox_pred)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册