未验证 提交 0a3d768c 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] update assigner and tood_head (#5169)

上级 6ee18c2b
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- Resize: {target_size: [800, 1333], keep_ratio: true}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
- Decode: {}
- RandomFlip: {prob: 0.5}
- Resize: {target_size: [800, 1333], keep_ratio: true}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
- PadBatch: {pad_to_stride: 32}
- PadGT: {}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
collate_batch: true
use_shared_memory: true
......
......@@ -47,6 +47,7 @@ __all__ = [
'PadMaskBatch',
'Gt2GFLTarget',
'Gt2CenterNetTarget',
'PadGT',
]
......@@ -72,13 +73,15 @@ class PadBatch(BaseOperator):
coarsest_stride = self.pad_to_stride
# multi scale input is nested list
if isinstance(samples, typing.Sequence) and len(samples) > 0 and isinstance(samples[0], typing.Sequence):
if isinstance(samples,
typing.Sequence) and len(samples) > 0 and isinstance(
samples[0], typing.Sequence):
inner_samples = samples[0]
else:
inner_samples = samples
max_shape = np.array([data['image'].shape for data in inner_samples]).max(
axis=0)
max_shape = np.array(
[data['image'].shape for data in inner_samples]).max(axis=0)
if coarsest_stride > 0:
max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
......@@ -1066,3 +1069,56 @@ class Gt2CenterNetTarget(BaseOperator):
sample['size'] = wh
sample['offset'] = reg
return sample
@register_op
class PadGT(BaseOperator):
"""
Pad 0 to `gt_class`, `gt_bbox`, `gt_score`...
The num_max_boxes is the largest for batch.
Args:
return_gt_mask (bool): If true, return `pad_gt_mask`,
1 means bbox, 0 means no bbox.
"""
def __init__(self, return_gt_mask=True):
super(PadGT, self).__init__()
self.return_gt_mask = return_gt_mask
def __call__(self, samples, context=None):
num_max_boxes = max([len(s['gt_bbox']) for s in samples])
for sample in samples:
if self.return_gt_mask:
sample['pad_gt_mask'] = np.zeros(
(num_max_boxes, 1), dtype=np.float32)
if num_max_boxes == 0:
continue
num_gt = len(sample['gt_bbox'])
pad_gt_class = np.zeros((num_max_boxes, 1), dtype=np.int32)
pad_gt_bbox = np.zeros((num_max_boxes, 4), dtype=np.float32)
if num_gt > 0:
pad_gt_class[:num_gt] = sample['gt_class']
pad_gt_bbox[:num_gt] = sample['gt_bbox']
sample['gt_class'] = pad_gt_class
sample['gt_bbox'] = pad_gt_bbox
# pad_gt_mask
if 'pad_gt_mask' in sample:
sample['pad_gt_mask'][:num_gt] = 1
# gt_score
if 'gt_score' in sample:
pad_gt_score = np.zeros((num_max_boxes, 1), dtype=np.float32)
if num_gt > 0:
pad_gt_score[:num_gt] = sample['gt_score']
sample['gt_score'] = pad_gt_score
if 'is_crowd' in sample:
pad_is_crowd = np.zeros((num_max_boxes, 1), dtype=np.int32)
if num_gt > 0:
pad_is_crowd[:num_gt] = sample['is_crowd']
sample['is_crowd'] = pad_is_crowd
if 'difficult' in sample:
pad_diff = np.zeros((num_max_boxes, 1), dtype=np.int32)
if num_gt > 0:
pad_diff[:num_gt] = sample['difficult']
sample['difficult'] = pad_diff
return samples
......@@ -23,10 +23,13 @@ import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..ops import iou_similarity
from ..bbox_utils import iou_similarity as batch_iou_similarity
from ..bbox_utils import bbox_center
from .utils import (pad_gt, check_points_inside_bboxes, compute_max_iou_anchor,
from .utils import (check_points_inside_bboxes, compute_max_iou_anchor,
compute_max_iou_gt)
__all__ = ['ATSSAssigner']
@register
class ATSSAssigner(nn.Layer):
......@@ -77,8 +80,10 @@ class ATSSAssigner(nn.Layer):
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index,
gt_scores=None):
gt_scores=None,
pred_bboxes=None):
r"""This code is based on
https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
......@@ -99,18 +104,18 @@ class ATSSAssigner(nn.Layer):
anchor_bboxes (Tensor, float32): pre-defined anchors, shape(L, 4),
"xmin, xmax, ymin, ymax" format
num_anchors_list (List): num of anchors in each level
gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, shape(B, n, 1)
gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, shape(B, n, 4)
gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
bg_index (int): background index
gt_scores (Tensor|List[Tensor]|None, float32) Score of gt_bboxes,
gt_scores (Tensor|None, float32) Score of gt_bboxes,
shape(B, n, 1), if None, then it will initialize with one_hot label
pred_bboxes (Tensor, float32, optional): predicted bounding boxes, shape(B, L, 4)
Returns:
assigned_labels (Tensor): (B, L)
assigned_bboxes (Tensor): (B, L, 4)
assigned_scores (Tensor): (B, L, C)
assigned_scores (Tensor): (B, L, C), if pred_bboxes is not None, then output ious
"""
gt_labels, gt_bboxes, pad_gt_scores, pad_gt_mask = pad_gt(
gt_labels, gt_bboxes, gt_scores)
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
......@@ -198,9 +203,14 @@ class ATSSAssigner(nn.Layer):
assigned_bboxes = assigned_bboxes.reshape([batch_size, num_anchors, 4])
assigned_scores = F.one_hot(assigned_labels, self.num_classes)
if gt_scores is not None:
if pred_bboxes is not None:
# assigned iou
ious = batch_iou_similarity(gt_bboxes, pred_bboxes) * mask_positive
ious = ious.max(axis=-2).unsqueeze(-1)
assigned_scores *= ious
elif gt_scores is not None:
gather_scores = paddle.gather(
pad_gt_scores.flatten(), assigned_gt_index.flatten(), axis=0)
gt_scores.flatten(), assigned_gt_index.flatten(), axis=0)
gather_scores = gather_scores.reshape([batch_size, num_anchors])
gather_scores = paddle.where(mask_positive_sum > 0, gather_scores,
paddle.zeros_like(gather_scores))
......
......@@ -22,9 +22,11 @@ import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import iou_similarity
from .utils import (pad_gt, gather_topk_anchors, check_points_inside_bboxes,
from .utils import (gather_topk_anchors, check_points_inside_bboxes,
compute_max_iou_anchor)
__all__ = ['TaskAlignedAssigner']
@register
class TaskAlignedAssigner(nn.Layer):
......@@ -43,8 +45,10 @@ class TaskAlignedAssigner(nn.Layer):
pred_scores,
pred_bboxes,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index,
gt_scores=None):
r"""This code is based on
......@@ -61,20 +65,18 @@ class TaskAlignedAssigner(nn.Layer):
pred_scores (Tensor, float32): predicted class probability, shape(B, L, C)
pred_bboxes (Tensor, float32): predicted bounding boxes, shape(B, L, 4)
anchor_points (Tensor, float32): pre-defined anchors, shape(L, 2), "cxcy" format
gt_labels (Tensor|List[Tensor], int64): Label of gt_bboxes, shape(B, n, 1)
gt_bboxes (Tensor|List[Tensor], float32): Ground truth bboxes, shape(B, n, 4)
num_anchors_list (List): num of anchors in each level, shape(L)
gt_labels (Tensor, int64|int32): Label of gt_bboxes, shape(B, n, 1)
gt_bboxes (Tensor, float32): Ground truth bboxes, shape(B, n, 4)
pad_gt_mask (Tensor, float32): 1 means bbox, 0 means no bbox, shape(B, n, 1)
bg_index (int): background index
gt_scores (Tensor|List[Tensor]|None, float32) Score of gt_bboxes,
shape(B, n, 1), if None, then it will initialize with one_hot label
gt_scores (Tensor|None, float32) Score of gt_bboxes, shape(B, n, 1)
Returns:
assigned_labels (Tensor): (B, L)
assigned_bboxes (Tensor): (B, L, 4)
assigned_scores (Tensor): (B, L, C)
"""
assert pred_scores.ndim == pred_bboxes.ndim
gt_labels, gt_bboxes, pad_gt_scores, pad_gt_mask = pad_gt(
gt_labels, gt_bboxes, gt_scores)
assert gt_labels.ndim == gt_bboxes.ndim and \
gt_bboxes.ndim == 3
......
......@@ -286,9 +286,11 @@ class TOODHead(nn.Layer):
return loss
def get_loss(self, head_outs, gt_meta):
pred_scores, pred_bboxes, anchors, num_anchors_list, stride_tensor_list = head_outs
pred_scores, pred_bboxes, anchors, \
num_anchors_list, stride_tensor_list = head_outs
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
......@@ -296,6 +298,7 @@ class TOODHead(nn.Layer):
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
alpha_l = 0.25
else:
......@@ -303,8 +306,10 @@ class TOODHead(nn.Layer):
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor_list,
bbox_center(anchors),
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
alpha_l = -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册