未验证 提交 6eb4784d 编写于 作者: S shangliang Xu 提交者: GitHub

[mask dino] add mask_dino model (#7887)

align torch dn code

merge deformable dino mask-dino same code

reset norm attr

fix dino amp training

fix bbox_pred in detr postprocess
上级 653604c0
......@@ -8,7 +8,7 @@ DETR:
backbone: ResNet
transformer: DeformableTransformer
detr_head: DeformableDETRHead
post_process: DETRBBoxPostProcess
post_process: DETRPostProcess
ResNet:
......@@ -40,7 +40,7 @@ DeformableDETRHead:
DETRLoss:
loss_coeff: {class: 2, bbox: 5, giou: 2, mask: 1, dice: 1}
loss_coeff: {class: 2, bbox: 5, giou: 2}
aux_loss: True
......
......@@ -28,8 +28,6 @@ EvalReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -41,8 +39,6 @@ TestReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -7,7 +7,7 @@ DETR:
backbone: ResNet
transformer: DETRTransformer
detr_head: DETRHead
post_process: DETRBBoxPostProcess
post_process: DETRPostProcess
ResNet:
......@@ -36,7 +36,7 @@ DETRHead:
DETRLoss:
loss_coeff: {class: 1, bbox: 5, giou: 2, no_object: 0.1, mask: 1, dice: 1}
loss_coeff: {class: 1, bbox: 5, giou: 2, no_object: 0.1}
aux_loss: True
......
......@@ -28,8 +28,6 @@ EvalReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -41,8 +39,6 @@ TestReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -8,7 +8,7 @@ DETR:
backbone: FocalNet
transformer: DINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
post_process: DETRPostProcess
FocalNet:
arch: 'focalnet_L_384_22k_fl4'
......@@ -41,5 +41,5 @@ DINOHead:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
DETRPostProcess:
num_top_queries: 300
......@@ -8,7 +8,7 @@ DETR:
backbone: ResNet
transformer: DINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
post_process: DETRPostProcess
ResNet:
# index 0 stands for res2
......@@ -45,5 +45,5 @@ DINOHead:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
DETRPostProcess:
num_top_queries: 300
......@@ -28,8 +28,6 @@ EvalReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
......@@ -39,6 +37,4 @@ TestReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
......@@ -7,7 +7,7 @@ DETR:
backbone: SwinTransformer
transformer: DINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
post_process: DETRPostProcess
SwinTransformer:
arch: 'swin_L_384' # ['swin_T_224', 'swin_S_224', 'swin_B_224', 'swin_L_224', 'swin_B_384', 'swin_L_384']
......@@ -42,5 +42,5 @@ DINOHead:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
DETRPostProcess:
num_top_queries: 300
......@@ -28,8 +28,6 @@ EvalReader:
- Resize: {target_size: [1184, 2000], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -41,8 +39,6 @@ TestReader:
- Resize: {target_size: [1184, 2000], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -28,8 +28,6 @@ EvalReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -41,8 +39,6 @@ TestReader:
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
......@@ -8,7 +8,7 @@ DETR:
backbone: ResNet
transformer: GroupDINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
post_process: DETRPostProcess
ResNet:
# index 0 stands for res2
......@@ -47,7 +47,7 @@ DINOHead:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
DETRPostProcess:
num_top_queries: 300
dual_queries: True
dual_groups: 10
......@@ -8,7 +8,7 @@ DETR:
neck: SimpleFeaturePyramid
transformer: GroupDINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess
post_process: DETRPostProcess
VisionTransformer2D:
patch_size: 16
......@@ -62,7 +62,7 @@ DINOHead:
matcher_coeff: {class: 2, bbox: 5, giou: 2}
DETRBBoxPostProcess:
DETRPostProcess:
num_top_queries: 300
dual_queries: True
dual_groups: 10
......@@ -248,7 +248,7 @@ class EvalReader(BaseDataLoader):
batch_transforms=[],
batch_size=1,
shuffle=False,
drop_last=True,
drop_last=False,
num_classes=80,
**kwargs):
super(EvalReader, self).__init__(sample_transforms, batch_transforms,
......
......@@ -1782,56 +1782,110 @@ class RandomScaledCrop(BaseOperator):
"""Resize image and bbox based on long side (with optional random scaling),
then crop or pad image to target size.
Args:
target_dim (int): target size.
target_size (int|list): target size, "hw" format.
scale_range (list): random scale range.
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
fill_value (float|list|tuple): color value used to fill the canvas,
in RGB order.
"""
def __init__(self,
target_dim=512,
target_size=512,
scale_range=[.1, 2.],
interp=cv2.INTER_LINEAR):
interp=cv2.INTER_LINEAR,
fill_value=(123.675, 116.28, 103.53)):
super(RandomScaledCrop, self).__init__()
self.target_dim = target_dim
assert isinstance(target_size, (
Integral, Sequence)), "target_size must be Integer, List or Tuple"
if isinstance(target_size, Integral):
target_size = [target_size, ] * 2
self.target_size = target_size
self.scale_range = scale_range
self.interp = interp
assert isinstance(fill_value, (Number, Sequence)), \
"fill value must be either float or sequence"
if isinstance(fill_value, Number):
fill_value = (fill_value, ) * 3
if not isinstance(fill_value, tuple):
fill_value = tuple(fill_value)
self.fill_value = fill_value
def apply_image(self, img, output_size, offset_x, offset_y):
th, tw = self.target_size
rh, rw = output_size
img = cv2.resize(
img, (rw, rh), interpolation=self.interp).astype(np.float32)
canvas = np.ones([th, tw, 3], dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[:min(th, rh), :min(tw, rw)] = \
img[offset_y:offset_y + th, offset_x:offset_x + tw]
return canvas
def apply_bbox(self, gt_bbox, gt_class, scale, offset_x, offset_y):
th, tw = self.target_size
shift_array = np.array(
[
offset_x,
offset_y,
] * 2, dtype=np.float32)
boxes = gt_bbox * scale - shift_array
boxes[:, 0::2] = np.clip(boxes[:, 0::2], 0, tw)
boxes[:, 1::2] = np.clip(boxes[:, 1::2], 0, th)
# filter boxes with no area
area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1)
valid = (area > 1.).nonzero()[0]
return boxes[valid], gt_class[valid], valid
def apply_segm(self, segms, output_size, offset_x, offset_y, valid=None):
th, tw = self.target_size
rh, rw = output_size
out_segms = []
for segm in segms:
segm = cv2.resize(segm, (rw, rh), interpolation=cv2.INTER_NEAREST)
segm = segm.astype(np.float32)
canvas = np.zeros([th, tw], dtype=segm.dtype)
canvas[:min(th, rh), :min(tw, rw)] = \
segm[offset_y:offset_y + th, offset_x:offset_x + tw]
out_segms.append(canvas)
out_segms = np.stack(out_segms)
return out_segms if valid is None else out_segms[valid]
def apply(self, sample, context=None):
img = sample['image']
h, w = img.shape[:2]
random_scale = np.random.uniform(*self.scale_range)
dim = self.target_dim
random_dim = int(dim * random_scale)
dim_max = max(h, w)
scale = random_dim / dim_max
resize_w = int(w * scale + 0.5)
resize_h = int(h * scale + 0.5)
offset_x = int(max(0, np.random.uniform(0., resize_w - dim)))
offset_y = int(max(0, np.random.uniform(0., resize_h - dim)))
img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interp)
img = np.array(img)
canvas = np.zeros((dim, dim, 3), dtype=img.dtype)
canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[
offset_y:offset_y + dim, offset_x:offset_x + dim, :]
sample['image'] = canvas
sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32)
scale_factor = sample['sacle_factor']
target_scale_size = [t * random_scale for t in self.target_size]
# Compute actual rescaling applied to image.
scale = min(target_scale_size[0] / h, target_scale_size[1] / w)
output_size = [int(round(h * scale)), int(round(w * scale))]
# get offset
offset_x = int(
max(0, np.random.uniform(0., output_size[1] - self.target_size[1])))
offset_y = int(
max(0, np.random.uniform(0., output_size[0] - self.target_size[0])))
# apply to image
sample['image'] = self.apply_image(img, output_size, offset_x, offset_y)
# apply to bbox
valid = None
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'], sample['gt_class'], valid = self.apply_bbox(
sample['gt_bbox'], sample['gt_class'], scale, offset_x,
offset_y)
# apply to segm
if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
sample['gt_segm'] = self.apply_segm(sample['gt_segm'], output_size,
offset_x, offset_y, valid)
sample['im_shape'] = np.asarray(output_size, dtype=np.float32)
scale_factor = sample['scale_factor']
sample['scale_factor'] = np.asarray(
[scale_factor[0] * scale, scale_factor[1] * scale],
dtype=np.float32)
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
scale_array = np.array([scale, scale] * 2, dtype=np.float32)
shift_array = np.array([offset_x, offset_y] * 2, dtype=np.float32)
boxes = sample['gt_bbox'] * scale_array - shift_array
boxes = np.clip(boxes, 0, dim - 1)
# filter boxes with no area
area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1)
valid = (area > 1.).nonzero()[0]
sample['gt_bbox'] = boxes[valid]
sample['gt_class'] = sample['gt_class'][valid]
return sample
......
......@@ -28,14 +28,15 @@ __all__ = ['DETR']
class DETR(BaseArch):
__category__ = 'architecture'
__inject__ = ['post_process']
__shared__ = ['exclude_post_process']
__shared__ = ['with_mask', 'exclude_post_process']
def __init__(self,
backbone,
transformer='DETRTransformer',
detr_head='DETRHead',
neck=None,
post_process='DETRBBoxPostProcess',
post_process='DETRPostProcess',
with_mask=False,
exclude_post_process=False):
super(DETR, self).__init__()
self.backbone = backbone
......@@ -43,6 +44,7 @@ class DETR(BaseArch):
self.transformer = transformer
self.detr_head = detr_head
self.post_process = post_process
self.with_mask = with_mask
self.exclude_post_process = exclude_post_process
@classmethod
......@@ -95,13 +97,16 @@ class DETR(BaseArch):
else:
preds = self.detr_head(out_transformer, body_feats)
if self.exclude_post_process:
bboxes, logits, masks = preds
return bboxes, logits
bbox, bbox_num, mask = preds
else:
bbox, bbox_num = self.post_process(
preds, self.inputs['im_shape'], self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
bbox, bbox_num, mask = self.post_process(
preds, self.inputs['im_shape'], self.inputs['scale_factor'],
paddle.shape(self.inputs['image'])[2:])
output = {'bbox': bbox, 'bbox_num': bbox_num}
if self.with_mask:
output['mask'] = mask
return output
def get_loss(self):
return self._forward()
......
......@@ -443,7 +443,8 @@ class ResNet(nn.Layer):
return_idx=[0, 1, 2, 3],
dcn_v2_stages=[-1],
num_stages=4,
std_senet=False):
std_senet=False,
freeze_stem_only=False):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
......@@ -558,8 +559,9 @@ class ResNet(nn.Layer):
if freeze_at >= 0:
self._freeze_parameters(self.conv1)
for i in range(min(freeze_at + 1, num_stages)):
self._freeze_parameters(self.res_layers[i])
if not freeze_stem_only:
for i in range(min(freeze_at + 1, num_stages)):
self._freeze_parameters(self.res_layers[i])
def _freeze_parameters(self, m):
for p in m.parameters():
......
......@@ -24,7 +24,7 @@ import pycocotools.mask as mask_util
from ..initializer import linear_init_, constant_
from ..transformers.utils import inverse_sigmoid
__all__ = ['DETRHead', 'DeformableDETRHead', 'DINOHead']
__all__ = ['DETRHead', 'DeformableDETRHead', 'DINOHead', 'MaskDINOHead']
class MLP(nn.Layer):
......@@ -459,3 +459,75 @@ class DINOHead(nn.Layer):
dn_meta=dn_meta)
else:
return (dec_out_bboxes[-1], dec_out_logits[-1], None)
@register
class MaskDINOHead(nn.Layer):
__inject__ = ['loss']
def __init__(self, loss='DINOLoss'):
super(MaskDINOHead, self).__init__()
self.loss = loss
def forward(self, out_transformer, body_feats, inputs=None):
(dec_out_logits, dec_out_bboxes, dec_out_masks, enc_out, init_out,
dn_meta) = out_transformer
if self.training:
assert inputs is not None
assert 'gt_bbox' in inputs and 'gt_class' in inputs
assert 'gt_segm' in inputs
if dn_meta is not None:
dn_out_logits, dec_out_logits = paddle.split(
dec_out_logits, dn_meta['dn_num_split'], axis=2)
dn_out_bboxes, dec_out_bboxes = paddle.split(
dec_out_bboxes, dn_meta['dn_num_split'], axis=2)
dn_out_masks, dec_out_masks = paddle.split(
dec_out_masks, dn_meta['dn_num_split'], axis=2)
if init_out is not None:
init_out_logits, init_out_bboxes, init_out_masks = init_out
init_out_logits_dn, init_out_logits = paddle.split(
init_out_logits, dn_meta['dn_num_split'], axis=1)
init_out_bboxes_dn, init_out_bboxes = paddle.split(
init_out_bboxes, dn_meta['dn_num_split'], axis=1)
init_out_masks_dn, init_out_masks = paddle.split(
init_out_masks, dn_meta['dn_num_split'], axis=1)
dec_out_logits = paddle.concat(
[init_out_logits.unsqueeze(0), dec_out_logits])
dec_out_bboxes = paddle.concat(
[init_out_bboxes.unsqueeze(0), dec_out_bboxes])
dec_out_masks = paddle.concat(
[init_out_masks.unsqueeze(0), dec_out_masks])
dn_out_logits = paddle.concat(
[init_out_logits_dn.unsqueeze(0), dn_out_logits])
dn_out_bboxes = paddle.concat(
[init_out_bboxes_dn.unsqueeze(0), dn_out_bboxes])
dn_out_masks = paddle.concat(
[init_out_masks_dn.unsqueeze(0), dn_out_masks])
else:
dn_out_bboxes, dn_out_logits = None, None
dn_out_masks = None
enc_out_logits, enc_out_bboxes, enc_out_masks = enc_out
out_logits = paddle.concat(
[enc_out_logits.unsqueeze(0), dec_out_logits])
out_bboxes = paddle.concat(
[enc_out_bboxes.unsqueeze(0), dec_out_bboxes])
out_masks = paddle.concat(
[enc_out_masks.unsqueeze(0), dec_out_masks])
return self.loss(
out_bboxes,
out_logits,
inputs['gt_bbox'],
inputs['gt_class'],
masks=out_masks,
gt_mask=inputs['gt_segm'],
dn_out_logits=dn_out_logits,
dn_out_bboxes=dn_out_bboxes,
dn_out_masks=dn_out_masks,
dn_meta=dn_meta)
else:
return (dec_out_bboxes[-1], dec_out_logits[-1], dec_out_masks[-1])
......@@ -1135,6 +1135,7 @@ def _convert_attention_mask(attn_mask, dtype):
"""
return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
@register
class MultiHeadAttention(nn.Layer):
"""
......@@ -1296,7 +1297,6 @@ class MultiHeadAttention(nn.Layer):
self.dropout,
training=self.training,
mode="upscale_in_train")
out = paddle.matmul(weights, v)
# combine heads
......
......@@ -54,8 +54,8 @@ class DETRLoss(nn.Layer):
use_focal_loss (bool): Use focal loss or not.
"""
super(DETRLoss, self).__init__()
self.num_classes = num_classes
self.num_classes = num_classes
self.matcher = matcher
self.loss_coeff = loss_coeff
self.aux_loss = aux_loss
......@@ -76,8 +76,7 @@ class DETRLoss(nn.Layer):
postfix=""):
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = "loss_class" + postfix
if logits is None:
return {name_class: paddle.zeros([1])}
target_label = paddle.full(logits.shape[:2], bg_index, dtype='int64')
bs, num_query_objects = target_label.shape
if sum(len(a) for a in gt_class) > 0:
......@@ -101,8 +100,7 @@ class DETRLoss(nn.Layer):
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = "loss_bbox" + postfix
name_giou = "loss_giou" + postfix
if boxes is None:
return {name_bbox: paddle.zeros([1]), name_giou: paddle.zeros([1])}
loss = dict()
if sum(len(a) for a in gt_bbox) == 0:
loss[name_bbox] = paddle.to_tensor([0.])
......@@ -124,8 +122,7 @@ class DETRLoss(nn.Layer):
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
name_mask = "loss_mask" + postfix
name_dice = "loss_dice" + postfix
if masks is None:
return {name_mask: paddle.zeros([1]), name_dice: paddle.zeros([1])}
loss = dict()
if sum(len(a) for a in gt_mask) == 0:
loss[name_mask] = paddle.to_tensor([0.])
......@@ -164,20 +161,22 @@ class DETRLoss(nn.Layer):
bg_index,
num_gts,
dn_match_indices=None,
postfix=""):
if boxes is None or logits is None:
return {
"loss_class_aux" + postfix: paddle.paddle.zeros([1]),
"loss_bbox_aux" + postfix: paddle.paddle.zeros([1]),
"loss_giou_aux" + postfix: paddle.paddle.zeros([1])
}
postfix="",
masks=None,
gt_mask=None):
loss_class = []
loss_bbox = []
loss_giou = []
for aux_boxes, aux_logits in zip(boxes, logits):
loss_bbox, loss_giou = [], []
loss_mask, loss_dice = [], []
for i, (aux_boxes, aux_logits) in enumerate(zip(boxes, logits)):
aux_masks = masks[i] if masks is not None else None
if dn_match_indices is None:
match_indices = self.matcher(aux_boxes, aux_logits, gt_bbox,
gt_class)
match_indices = self.matcher(
aux_boxes,
aux_logits,
gt_bbox,
gt_class,
masks=aux_masks,
gt_mask=gt_mask)
else:
match_indices = dn_match_indices
loss_class.append(
......@@ -188,11 +187,19 @@ class DETRLoss(nn.Layer):
num_gts, postfix)
loss_bbox.append(loss_['loss_bbox' + postfix])
loss_giou.append(loss_['loss_giou' + postfix])
if masks is not None and gt_mask is not None:
loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices,
num_gts, postfix)
loss_mask.append(loss_['loss_mask' + postfix])
loss_dice.append(loss_['loss_dice' + postfix])
loss = {
"loss_class_aux" + postfix: paddle.add_n(loss_class),
"loss_bbox_aux" + postfix: paddle.add_n(loss_bbox),
"loss_giou_aux" + postfix: paddle.add_n(loss_giou)
}
if masks is not None and gt_mask is not None:
loss["loss_mask_aux" + postfix] = paddle.add_n(loss_mask)
loss["loss_dice_aux" + postfix] = paddle.add_n(loss_dice)
return loss
def _get_index_updates(self, num_query_objects, target, match_indices):
......@@ -220,6 +227,44 @@ class DETRLoss(nn.Layer):
])
return src_assign, target_assign
def _get_num_gts(self, targets, dtype="float32"):
num_gts = sum(len(a) for a in targets)
num_gts = paddle.to_tensor([num_gts], dtype=dtype)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(num_gts)
num_gts /= paddle.distributed.get_world_size()
num_gts = paddle.clip(num_gts, min=1.)
return num_gts
def _get_prediction_loss(self,
boxes,
logits,
gt_bbox,
gt_class,
masks=None,
gt_mask=None,
postfix="",
dn_match_indices=None,
num_gts=1):
if dn_match_indices is None:
match_indices = self.matcher(
boxes, logits, gt_bbox, gt_class, masks=masks, gt_mask=gt_mask)
else:
match_indices = dn_match_indices
loss = dict()
loss.update(
self._get_loss_class(logits, gt_class, match_indices,
self.num_classes, num_gts, postfix))
loss.update(
self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts,
postfix))
if masks is not None and gt_mask is not None:
loss.update(
self._get_loss_mask(masks, gt_mask, match_indices, num_gts,
postfix))
return loss
def forward(self,
boxes,
logits,
......@@ -231,48 +276,44 @@ class DETRLoss(nn.Layer):
**kwargs):
r"""
Args:
boxes (Tensor|None): [l, b, query, 4]
logits (Tensor|None): [l, b, query, num_classes]
boxes (Tensor): [l, b, query, 4]
logits (Tensor): [l, b, query, num_classes]
gt_bbox (List(Tensor)): list[[n, 4]]
gt_class (List(Tensor)): list[[n, 1]]
masks (Tensor, optional): [b, query, h, w]
masks (Tensor, optional): [l, b, query, h, w]
gt_mask (List(Tensor), optional): list[[n, H, W]]
postfix (str): postfix of loss name
"""
dn_match_indices = kwargs.get("dn_match_indices", None)
if dn_match_indices is None and (boxes is not None and
logits is not None):
match_indices = self.matcher(boxes[-1].detach(),
logits[-1].detach(), gt_bbox, gt_class)
else:
match_indices = dn_match_indices
num_gts = sum(len(a) for a in gt_bbox)
num_gts = paddle.to_tensor([num_gts], dtype="float32")
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(num_gts)
num_gts /= paddle.distributed.get_world_size()
num_gts = paddle.clip(num_gts, min=1.) * kwargs.get("dn_num_group", 1.)
dn_match_indices = kwargs.get("dn_match_indices", None)
num_gts = kwargs.get("num_gts", None)
if num_gts is None:
num_gts = self._get_num_gts(gt_class)
total_loss = dict()
total_loss.update(
self._get_loss_class(logits[
-1] if logits is not None else None, gt_class, match_indices,
self.num_classes, num_gts, postfix))
total_loss.update(
self._get_loss_bbox(boxes[-1] if boxes is not None else None,
gt_bbox, match_indices, num_gts, postfix))
if masks is not None and gt_mask is not None:
total_loss.update(
self._get_loss_mask(masks if masks is not None else None,
gt_mask, match_indices, num_gts, postfix))
total_loss = self._get_prediction_loss(
boxes[-1],
logits[-1],
gt_bbox,
gt_class,
masks=masks[-1] if masks is not None else None,
gt_mask=gt_mask,
postfix=postfix,
dn_match_indices=dn_match_indices,
num_gts=num_gts)
if self.aux_loss:
total_loss.update(
self._get_loss_aux(
boxes[:-1] if boxes is not None else None, logits[:-1]
if logits is not None else None, gt_bbox, gt_class,
self.num_classes, num_gts, dn_match_indices, postfix))
boxes[:-1],
logits[:-1],
gt_bbox,
gt_class,
self.num_classes,
num_gts,
dn_match_indices,
postfix,
masks=masks[:-1] if masks is not None else None,
gt_mask=gt_mask))
return total_loss
......@@ -291,8 +332,9 @@ class DINOLoss(DETRLoss):
dn_out_logits=None,
dn_meta=None,
**kwargs):
total_loss = super(DINOLoss, self).forward(boxes, logits, gt_bbox,
gt_class)
num_gts = self._get_num_gts(gt_class)
total_loss = super(DINOLoss, self).forward(
boxes, logits, gt_bbox, gt_class, num_gts=num_gts)
if dn_meta is not None:
dn_positive_idx, dn_num_group = \
......@@ -300,31 +342,186 @@ class DINOLoss(DETRLoss):
assert len(gt_class) == len(dn_positive_idx)
# denoising match indices
dn_match_indices = []
for i in range(len(gt_class)):
num_gt = len(gt_class[i])
if num_gt > 0:
gt_idx = paddle.arange(end=num_gt, dtype="int64")
gt_idx = gt_idx.unsqueeze(0).tile(
[dn_num_group, 1]).flatten()
assert len(gt_idx) == len(dn_positive_idx[i])
dn_match_indices.append((dn_positive_idx[i], gt_idx))
else:
dn_match_indices.append((paddle.zeros(
[0], dtype="int64"), paddle.zeros(
[0], dtype="int64")))
dn_match_indices = self.get_dn_match_indices(
gt_class, dn_positive_idx, dn_num_group)
# compute denoising training loss
num_gts *= dn_num_group
dn_loss = super(DINOLoss, self).forward(
dn_out_bboxes,
dn_out_logits,
gt_bbox,
gt_class,
postfix="_dn",
dn_match_indices=dn_match_indices,
num_gts=num_gts)
total_loss.update(dn_loss)
else:
dn_match_indices, dn_num_group = None, 1.
total_loss.update(
{k + '_dn': paddle.to_tensor([0.])
for k in total_loss.keys()})
return total_loss
@staticmethod
def get_dn_match_indices(labels, dn_positive_idx, dn_num_group):
dn_match_indices = []
for i in range(len(labels)):
num_gt = len(labels[i])
if num_gt > 0:
gt_idx = paddle.arange(end=num_gt, dtype="int64")
gt_idx = gt_idx.tile([dn_num_group])
assert len(dn_positive_idx[i]) == len(gt_idx)
dn_match_indices.append((dn_positive_idx[i], gt_idx))
else:
dn_match_indices.append((paddle.zeros(
[0], dtype="int64"), paddle.zeros(
[0], dtype="int64")))
return dn_match_indices
@register
class MaskDINOLoss(DETRLoss):
__shared__ = ['num_classes', 'use_focal_loss', 'num_sample_points']
__inject__ = ['matcher']
def __init__(self,
num_classes=80,
matcher='HungarianMatcher',
loss_coeff={
'class': 4,
'bbox': 5,
'giou': 2,
'mask': 5,
'dice': 5
},
aux_loss=True,
use_focal_loss=False,
num_sample_points=12544,
oversample_ratio=3.0,
important_sample_ratio=0.75):
super(MaskDINOLoss, self).__init__(num_classes, matcher, loss_coeff,
aux_loss, use_focal_loss)
assert oversample_ratio >= 1
assert important_sample_ratio <= 1 and important_sample_ratio >= 0
# compute denoising training loss
dn_loss = super(DINOLoss, self).forward(
dn_out_bboxes,
dn_out_logits,
self.num_sample_points = num_sample_points
self.oversample_ratio = oversample_ratio
self.important_sample_ratio = important_sample_ratio
self.num_oversample_points = int(num_sample_points * oversample_ratio)
self.num_important_points = int(num_sample_points *
important_sample_ratio)
self.num_random_points = num_sample_points - self.num_important_points
def forward(self,
boxes,
logits,
gt_bbox,
gt_class,
masks=None,
gt_mask=None,
postfix="",
dn_out_bboxes=None,
dn_out_logits=None,
dn_out_masks=None,
dn_meta=None,
**kwargs):
num_gts = self._get_num_gts(gt_class)
total_loss = super(MaskDINOLoss, self).forward(
boxes,
logits,
gt_bbox,
gt_class,
postfix="_dn",
dn_match_indices=dn_match_indices,
dn_num_group=dn_num_group)
total_loss.update(dn_loss)
masks=masks,
gt_mask=gt_mask,
num_gts=num_gts)
if dn_meta is not None:
dn_positive_idx, dn_num_group = \
dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
assert len(gt_class) == len(dn_positive_idx)
# denoising match indices
dn_match_indices = DINOLoss.get_dn_match_indices(
gt_class, dn_positive_idx, dn_num_group)
# compute denoising training loss
num_gts *= dn_num_group
dn_loss = super(MaskDINOLoss, self).forward(
dn_out_bboxes,
dn_out_logits,
gt_bbox,
gt_class,
masks=dn_out_masks,
gt_mask=gt_mask,
postfix="_dn",
dn_match_indices=dn_match_indices,
num_gts=num_gts)
total_loss.update(dn_loss)
else:
total_loss.update(
{k + '_dn': paddle.to_tensor([0.])
for k in total_loss.keys()})
return total_loss
def _get_loss_mask(self, masks, gt_mask, match_indices, num_gts,
postfix=""):
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
name_mask = "loss_mask" + postfix
name_dice = "loss_dice" + postfix
loss = dict()
if sum(len(a) for a in gt_mask) == 0:
loss[name_mask] = paddle.to_tensor([0.])
loss[name_dice] = paddle.to_tensor([0.])
return loss
src_masks, target_masks = self._get_src_target_assign(masks, gt_mask,
match_indices)
# sample points
sample_points = self._get_point_coords_by_uncertainty(src_masks)
sample_points = 2.0 * sample_points.unsqueeze(1) - 1.0
src_masks = F.grid_sample(
src_masks.unsqueeze(1), sample_points,
align_corners=False).squeeze([1, 2])
target_masks = F.grid_sample(
target_masks.unsqueeze(1), sample_points,
align_corners=False).squeeze([1, 2]).detach()
loss[name_mask] = self.loss_coeff[
'mask'] * F.binary_cross_entropy_with_logits(
src_masks, target_masks,
reduction='none').mean(1).sum() / num_gts
loss[name_dice] = self.loss_coeff['dice'] * self._dice_loss(
src_masks, target_masks, num_gts)
return loss
def _get_point_coords_by_uncertainty(self, masks):
# Sample points based on their uncertainty.
masks = masks.detach()
num_masks = masks.shape[0]
sample_points = paddle.rand(
[num_masks, 1, self.num_oversample_points, 2])
out_mask = F.grid_sample(
masks.unsqueeze(1), 2.0 * sample_points - 1.0,
align_corners=False).squeeze([1, 2])
out_mask = -paddle.abs(out_mask)
_, topk_ind = paddle.topk(out_mask, self.num_important_points, axis=1)
batch_ind = paddle.arange(end=num_masks, dtype=topk_ind.dtype)
batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_important_points])
topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1)
sample_points = paddle.gather_nd(sample_points.squeeze(1), topk_ind)
if self.num_random_points > 0:
sample_points = paddle.concat(
[
sample_points,
paddle.rand([num_masks, self.num_random_points, 2])
],
axis=1)
return sample_points
......@@ -26,7 +26,7 @@ except Exception:
__all__ = [
'BBoxPostProcess', 'MaskPostProcess', 'JDEBBoxPostProcess',
'CenterNetPostProcess', 'DETRBBoxPostProcess', 'SparsePostProcess'
'CenterNetPostProcess', 'DETRPostProcess', 'SparsePostProcess'
]
......@@ -443,8 +443,8 @@ class CenterNetPostProcess(object):
@register
class DETRBBoxPostProcess(object):
__shared__ = ['num_classes', 'use_focal_loss']
class DETRPostProcess(object):
__shared__ = ['num_classes', 'use_focal_loss', 'with_mask']
__inject__ = []
def __init__(self,
......@@ -452,22 +452,39 @@ class DETRBBoxPostProcess(object):
num_top_queries=100,
dual_queries=False,
dual_groups=0,
use_focal_loss=False):
super(DETRBBoxPostProcess, self).__init__()
use_focal_loss=False,
with_mask=False,
mask_threshold=0.5,
use_avg_mask_score=False):
super(DETRPostProcess, self).__init__()
self.num_classes = num_classes
self.num_top_queries = num_top_queries
self.dual_queries = dual_queries
self.dual_groups = dual_groups
self.use_focal_loss = use_focal_loss
self.with_mask = with_mask
self.mask_threshold = mask_threshold
self.use_avg_mask_score = use_avg_mask_score
def __call__(self, head_out, im_shape, scale_factor):
def _mask_postprocess(self, mask_pred, score_pred, index):
mask_score = F.sigmoid(paddle.gather_nd(mask_pred, index))
mask_pred = (mask_score > self.mask_threshold).astype(mask_score.dtype)
if self.use_avg_mask_score:
avg_mask_score = (mask_pred * mask_score).sum([-2, -1]) / (
mask_pred.sum([-2, -1]) + 1e-6)
score_pred *= avg_mask_score
return mask_pred[0].astype('int32'), score_pred
def __call__(self, head_out, im_shape, scale_factor, pad_shape):
"""
Decode the bbox.
Args:
head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
im_shape (Tensor): The shape of the input image.
im_shape (Tensor): The shape of the input image without padding.
scale_factor (Tensor): The scale factor of the input image.
pad_shape (Tensor): The shape of the input image with padding.
Returns:
bbox_pred (Tensor): The output prediction with shape [N, 6], including
labels, scores and bboxes. The size of bboxes are corresponding
......@@ -482,11 +499,13 @@ class DETRBBoxPostProcess(object):
bboxes[:, :int(num_queries // (self.dual_groups + 1)), :]
bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
# calculate the original shape of the image
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
img_h, img_w = paddle.split(origin_shape, 2, axis=-1)
origin_shape = paddle.concat(
[img_w, img_h, img_w, img_h], axis=-1).reshape([-1, 1, 4])
bbox_pred *= origin_shape
# calculate the shape of the image with padding
out_shape = pad_shape / im_shape * origin_shape
out_shape = out_shape.flip(1).tile([1, 2]).unsqueeze(1)
bbox_pred *= out_shape
scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
logits)[:, :, :-1]
......@@ -512,6 +531,25 @@ class DETRBBoxPostProcess(object):
index = paddle.stack([batch_ind, index], axis=-1)
bbox_pred = paddle.gather_nd(bbox_pred, index)
mask_pred = None
if self.with_mask:
assert masks is not None
masks = F.interpolate(
masks, scale_factor=4, mode="bilinear", align_corners=False)
# TODO: Support prediction with bs>1.
# remove padding for input image
h, w = im_shape.astype('int32')[0]
masks = masks[..., :h, :w]
# get pred_mask in the original resolution.
img_h = img_h[0].astype('int32')
img_w = img_w[0].astype('int32')
masks = F.interpolate(
masks,
size=(img_h, img_w),
mode="bilinear",
align_corners=False)
mask_pred, scores = self._mask_postprocess(masks, scores, index)
bbox_pred = paddle.concat(
[
labels.unsqueeze(-1).astype('float32'), scores.unsqueeze(-1),
......@@ -519,9 +557,9 @@ class DETRBBoxPostProcess(object):
],
axis=-1)
bbox_num = paddle.to_tensor(
bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
self.num_top_queries, dtype='int32').tile([bbox_pred.shape[0]])
bbox_pred = bbox_pred.reshape([-1, 6])
return bbox_pred, bbox_num
return bbox_pred, bbox_num, mask_pred
@register
......
......@@ -19,6 +19,7 @@ from . import position_encoding
from . import deformable_transformer
from . import dino_transformer
from . import group_detr_transformer
from . import mask_dino_transformer
from .detr_transformer import *
from .utils import *
......@@ -28,3 +29,4 @@ from .deformable_transformer import *
from .dino_transformer import *
from .petr_transformer import *
from .group_detr_transformer import *
from .mask_dino_transformer import *
......@@ -167,23 +167,24 @@ class DeformableTransformerEncoderLayer(nn.Layer):
activation="relu",
n_levels=4,
n_points=4,
lr_mult=0.1,
weight_attr=None,
bias_attr=None):
super(DeformableTransformerEncoderLayer, self).__init__()
# self attention
self.self_attn = MSDeformableAttention(d_model, n_head, n_levels,
n_points)
n_points, lr_mult)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm1 = nn.LayerNorm(
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
bias_attr)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = getattr(F, activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
bias_attr)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
self._reset_parameters()
def _reset_parameters(self):
......@@ -207,10 +208,10 @@ class DeformableTransformerEncoderLayer(nn.Layer):
spatial_shapes,
level_start_index,
src_mask=None,
pos_embed=None):
query_pos_embed=None):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, pos_embed), reference_points, src,
self.with_pos_embed(src, query_pos_embed), reference_points, src,
spatial_shapes, level_start_index, src_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
......@@ -243,23 +244,22 @@ class DeformableTransformerEncoder(nn.Layer):
return reference_points
def forward(self,
src,
feat,
spatial_shapes,
level_start_index,
src_mask=None,
pos_embed=None,
feat_mask=None,
query_pos_embed=None,
valid_ratios=None):
output = src
if valid_ratios is None:
valid_ratios = paddle.ones(
[src.shape[0], spatial_shapes.shape[0], 2])
[feat.shape[0], spatial_shapes.shape[0], 2])
reference_points = self.get_reference_points(spatial_shapes,
valid_ratios)
for layer in self.layers:
output = layer(output, reference_points, spatial_shapes,
level_start_index, src_mask, pos_embed)
feat = layer(feat, reference_points, spatial_shapes,
level_start_index, feat_mask, query_pos_embed)
return output
return feat
class DeformableTransformerDecoderLayer(nn.Layer):
......@@ -271,6 +271,7 @@ class DeformableTransformerDecoderLayer(nn.Layer):
activation="relu",
n_levels=4,
n_points=4,
lr_mult=0.1,
weight_attr=None,
bias_attr=None):
super(DeformableTransformerDecoderLayer, self).__init__()
......@@ -278,23 +279,24 @@ class DeformableTransformerDecoderLayer(nn.Layer):
# self attention
self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm1 = nn.LayerNorm(
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
# cross attention
self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels,
n_points)
n_points, lr_mult)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
bias_attr)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = getattr(F, activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
bias_attr)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
self._reset_parameters()
def _reset_parameters(self):
......@@ -378,7 +380,7 @@ class DeformableTransformer(nn.Layer):
num_queries=300,
position_embed_type='sine',
return_intermediate_dec=True,
backbone_num_channels=[512, 1024, 2048],
in_feats_channel=[512, 1024, 2048],
num_feature_levels=4,
num_encoder_points=4,
num_decoder_points=4,
......@@ -390,12 +392,12 @@ class DeformableTransformer(nn.Layer):
dropout=0.1,
activation="relu",
lr_mult=0.1,
weight_attr=None,
bias_attr=None):
pe_temperature=10000,
pe_offset=-0.5):
super(DeformableTransformer, self).__init__()
assert position_embed_type in ['sine', 'learned'], \
f'ValueError: position_embed_type not supported {position_embed_type}!'
assert len(backbone_num_channels) <= num_feature_levels
assert len(in_feats_channel) <= num_feature_levels
self.hidden_dim = hidden_dim
self.nhead = nhead
......@@ -403,13 +405,13 @@ class DeformableTransformer(nn.Layer):
encoder_layer = DeformableTransformerEncoderLayer(
hidden_dim, nhead, dim_feedforward, dropout, activation,
num_feature_levels, num_encoder_points, weight_attr, bias_attr)
num_feature_levels, num_encoder_points, lr_mult)
self.encoder = DeformableTransformerEncoder(encoder_layer,
num_encoder_layers)
decoder_layer = DeformableTransformerDecoderLayer(
hidden_dim, nhead, dim_feedforward, dropout, activation,
num_feature_levels, num_decoder_points, weight_attr, bias_attr)
num_feature_levels, num_decoder_points)
self.decoder = DeformableTransformerDecoder(
decoder_layer, num_decoder_layers, return_intermediate_dec)
......@@ -424,18 +426,14 @@ class DeformableTransformer(nn.Layer):
bias_attr=ParamAttr(learning_rate=lr_mult))
self.input_proj = nn.LayerList()
for in_channels in backbone_num_channels:
for in_channels in in_feats_channel:
self.input_proj.append(
nn.Sequential(
nn.Conv2D(
in_channels,
hidden_dim,
kernel_size=1,
weight_attr=weight_attr,
bias_attr=bias_attr),
in_channels, hidden_dim, kernel_size=1),
nn.GroupNorm(32, hidden_dim)))
in_channels = backbone_num_channels[-1]
for _ in range(num_feature_levels - len(backbone_num_channels)):
in_channels = in_feats_channel[-1]
for _ in range(num_feature_levels - len(in_feats_channel)):
self.input_proj.append(
nn.Sequential(
nn.Conv2D(
......@@ -443,17 +441,16 @@ class DeformableTransformer(nn.Layer):
hidden_dim,
kernel_size=3,
stride=2,
padding=1,
weight_attr=weight_attr,
bias_attr=bias_attr),
padding=1),
nn.GroupNorm(32, hidden_dim)))
in_channels = hidden_dim
self.position_embedding = PositionEmbedding(
hidden_dim // 2,
temperature=pe_temperature,
normalize=True if position_embed_type == 'sine' else False,
embed_type=position_embed_type,
offset=-0.5)
offset=pe_offset)
self._reset_parameters()
......@@ -469,7 +466,7 @@ class DeformableTransformer(nn.Layer):
@classmethod
def from_config(cls, cfg, input_shape):
return {'backbone_num_channels': [i.channels for i in input_shape], }
return {'in_feats_channel': [i.channels for i in input_shape], }
def forward(self, src_feats, src_mask=None, *args, **kwargs):
srcs = []
......
......@@ -243,6 +243,8 @@ class DETRTransformer(nn.Layer):
dim_feedforward=2048,
dropout=0.1,
activation="relu",
pe_temperature=10000,
pe_offset=0.,
attn_dropout=None,
act_dropout=None,
normalize_before=False):
......@@ -274,8 +276,10 @@ class DETRTransformer(nn.Layer):
self.query_pos_embed = nn.Embedding(num_queries, hidden_dim)
self.position_embedding = PositionEmbedding(
hidden_dim // 2,
temperature=pe_temperature,
normalize=True if position_embed_type == 'sine' else False,
embed_type=position_embed_type)
embed_type=position_embed_type,
offset=pe_offset)
self._reset_parameters()
......
......@@ -31,125 +31,18 @@ from ppdet.core.workspace import register
from ..layers import MultiHeadAttention
from .position_encoding import PositionEmbedding
from ..heads.detr_head import MLP
from .deformable_transformer import MSDeformableAttention
from .deformable_transformer import (MSDeformableAttention,
DeformableTransformerEncoderLayer,
DeformableTransformerEncoder)
from ..initializer import (linear_init_, constant_, xavier_uniform_, normal_,
bias_init_with_prob)
from .utils import (_get_clones, get_valid_ratio,
get_contrastive_denoising_training_group,
get_sine_pos_embed)
get_sine_pos_embed, inverse_sigmoid)
__all__ = ['DINOTransformer']
class DINOTransformerEncoderLayer(nn.Layer):
def __init__(self,
d_model=256,
n_head=8,
dim_feedforward=1024,
dropout=0.,
activation="relu",
n_levels=4,
n_points=4,
weight_attr=None,
bias_attr=None):
super(DINOTransformerEncoderLayer, self).__init__()
# self attention
self.self_attn = MSDeformableAttention(d_model, n_head, n_levels,
n_points, 1.0)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
bias_attr)
self.activation = getattr(F, activation)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
bias_attr)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self._reset_parameters()
def _reset_parameters(self):
linear_init_(self.linear1)
linear_init_(self.linear2)
xavier_uniform_(self.linear1.weight)
xavier_uniform_(self.linear2.weight)
def with_pos_embed(self, tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src):
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(self,
src,
reference_points,
spatial_shapes,
level_start_index,
src_mask=None,
query_pos_embed=None):
# self attention
src2 = self.self_attn(
self.with_pos_embed(src, query_pos_embed), reference_points, src,
spatial_shapes, level_start_index, src_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
return src
class DINOTransformerEncoder(nn.Layer):
def __init__(self, encoder_layer, num_layers):
super(DINOTransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, offset=0.5):
valid_ratios = valid_ratios.unsqueeze(1)
reference_points = []
for i, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = paddle.meshgrid(
paddle.arange(end=H) + offset, paddle.arange(end=W) + offset)
ref_y = ref_y.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 1] *
H)
ref_x = ref_x.flatten().unsqueeze(0) / (valid_ratios[:, :, i, 0] *
W)
reference_points.append(paddle.stack((ref_x, ref_y), axis=-1))
reference_points = paddle.concat(reference_points, 1).unsqueeze(2)
reference_points = reference_points * valid_ratios
return reference_points
def forward(self,
feat,
spatial_shapes,
level_start_index,
feat_mask=None,
query_pos_embed=None,
valid_ratios=None):
if valid_ratios is None:
valid_ratios = paddle.ones(
[feat.shape[0], spatial_shapes.shape[0], 2])
reference_points = self.get_reference_points(spatial_shapes,
valid_ratios)
for layer in self.layers:
feat = layer(feat, reference_points, spatial_shapes,
level_start_index, feat_mask, query_pos_embed)
return feat
class DINOTransformerDecoderLayer(nn.Layer):
def __init__(self,
d_model=256,
......@@ -159,6 +52,7 @@ class DINOTransformerDecoderLayer(nn.Layer):
activation="relu",
n_levels=4,
n_points=4,
lr_mult=1.0,
weight_attr=None,
bias_attr=None):
super(DINOTransformerDecoderLayer, self).__init__()
......@@ -167,31 +61,23 @@ class DINOTransformerDecoderLayer(nn.Layer):
self.self_attn = MultiHeadAttention(d_model, n_head, dropout=dropout)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
# cross attention
self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels,
n_points, 1.0)
n_points, lr_mult)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
# ffn
self.linear1 = nn.Linear(d_model, dim_feedforward, weight_attr,
bias_attr)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = getattr(F, activation)
self.dropout3 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, weight_attr,
bias_attr)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout4 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(
d_model,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
d_model, weight_attr=weight_attr, bias_attr=bias_attr)
self._reset_parameters()
def _reset_parameters(self):
......@@ -218,7 +104,10 @@ class DINOTransformerDecoderLayer(nn.Layer):
# self attention
q = k = self.with_pos_embed(tgt, query_pos_embed)
if attn_mask is not None:
attn_mask = attn_mask.astype('bool')
attn_mask = paddle.where(
attn_mask.astype('bool'),
paddle.zeros(attn_mask.shape, tgt.dtype),
paddle.full(attn_mask.shape, float("-inf"), tgt.dtype))
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
......@@ -243,16 +132,14 @@ class DINOTransformerDecoder(nn.Layer):
hidden_dim,
decoder_layer,
num_layers,
return_intermediate=True):
weight_attr=None,
bias_attr=None):
super(DINOTransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.return_intermediate = return_intermediate
self.norm = nn.LayerNorm(
hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
hidden_dim, weight_attr=weight_attr, bias_attr=bias_attr)
def forward(self,
tgt,
......@@ -271,9 +158,10 @@ class DINOTransformerDecoder(nn.Layer):
output = tgt
intermediate = []
inter_ref_bboxes_unact = []
inter_bboxes = []
ref_points = F.sigmoid(ref_points_unact)
for i, layer in enumerate(self.layers):
reference_points_input = F.sigmoid(ref_points_unact).unsqueeze(
reference_points_input = ref_points.detach().unsqueeze(
2) * valid_ratios.tile([1, 1, 2]).unsqueeze(1)
query_pos_embed = get_sine_pos_embed(
reference_points_input[..., 0, :], self.hidden_dim // 2)
......@@ -283,19 +171,13 @@ class DINOTransformerDecoder(nn.Layer):
memory_spatial_shapes, memory_level_start_index,
attn_mask, memory_mask, query_pos_embed)
inter_ref_bbox_unact = bbox_head[i](output) + ref_points_unact
if self.return_intermediate:
intermediate.append(self.norm(output))
inter_ref_bboxes_unact.append(inter_ref_bbox_unact)
ref_points = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
ref_points.detach()))
ref_points_unact = inter_ref_bbox_unact.detach()
intermediate.append(self.norm(output))
inter_bboxes.append(ref_points)
if self.return_intermediate:
return paddle.stack(intermediate), paddle.stack(
inter_ref_bboxes_unact)
return output, ref_points_unact
return paddle.stack(intermediate), paddle.stack(inter_bboxes)
@register
......@@ -307,8 +189,7 @@ class DINOTransformer(nn.Layer):
hidden_dim=256,
num_queries=900,
position_embed_type='sine',
return_intermediate_dec=True,
backbone_feat_channels=[512, 1024, 2048],
in_feats_channel=[512, 1024, 2048],
num_levels=4,
num_encoder_points=4,
num_decoder_points=4,
......@@ -318,6 +199,7 @@ class DINOTransformer(nn.Layer):
dim_feedforward=1024,
dropout=0.,
activation="relu",
lr_mult=1.0,
pe_temperature=10000,
pe_offset=-0.5,
num_denoising=100,
......@@ -328,7 +210,7 @@ class DINOTransformer(nn.Layer):
super(DINOTransformer, self).__init__()
assert position_embed_type in ['sine', 'learned'], \
f'ValueError: position_embed_type not supported {position_embed_type}!'
assert len(backbone_feat_channels) <= num_levels
assert len(in_feats_channel) <= num_levels
self.hidden_dim = hidden_dim
self.nhead = nhead
......@@ -338,20 +220,23 @@ class DINOTransformer(nn.Layer):
self.eps = eps
self.num_decoder_layers = num_decoder_layers
weight_attr = ParamAttr(regularizer=L2Decay(0.0))
bias_attr = ParamAttr(regularizer=L2Decay(0.0))
# backbone feature projection
self._build_input_proj_layer(backbone_feat_channels)
self._build_input_proj_layer(in_feats_channel, weight_attr, bias_attr)
# Transformer module
encoder_layer = DINOTransformerEncoderLayer(
encoder_layer = DeformableTransformerEncoderLayer(
hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels,
num_encoder_points)
self.encoder = DINOTransformerEncoder(encoder_layer, num_encoder_layers)
num_encoder_points, lr_mult, weight_attr, bias_attr)
self.encoder = DeformableTransformerEncoder(encoder_layer,
num_encoder_layers)
decoder_layer = DINOTransformerDecoderLayer(
hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels,
num_decoder_points)
num_decoder_points, lr_mult, weight_attr, bias_attr)
self.decoder = DINOTransformerDecoder(hidden_dim, decoder_layer,
num_decoder_layers,
return_intermediate_dec)
num_decoder_layers, weight_attr,
bias_attr)
# denoising part
self.denoising_class_embed = nn.Embedding(
......@@ -383,9 +268,7 @@ class DINOTransformer(nn.Layer):
self.enc_output = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(
hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))))
hidden_dim, weight_attr=weight_attr, bias_attr=bias_attr))
self.enc_score_head = nn.Linear(hidden_dim, num_classes)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
# decoder head
......@@ -426,22 +309,25 @@ class DINOTransformer(nn.Layer):
@classmethod
def from_config(cls, cfg, input_shape):
return {'backbone_feat_channels': [i.channels for i in input_shape], }
return {'in_feats_channel': [i.channels for i in input_shape], }
def _build_input_proj_layer(self, backbone_feat_channels):
def _build_input_proj_layer(self,
in_feats_channel,
weight_attr=None,
bias_attr=None):
self.input_proj = nn.LayerList()
for in_channels in backbone_feat_channels:
for in_channels in in_feats_channel:
self.input_proj.append(
nn.Sequential(
('conv', nn.Conv2D(
in_channels, self.hidden_dim, kernel_size=1)),
('norm', nn.GroupNorm(
32,
self.hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))))))
in_channels = backbone_feat_channels[-1]
for _ in range(self.num_levels - len(backbone_feat_channels)):
in_channels, self.hidden_dim, kernel_size=1)), (
'norm', nn.GroupNorm(
32,
self.hidden_dim,
weight_attr=weight_attr,
bias_attr=bias_attr))))
in_channels = in_feats_channel[-1]
for _ in range(self.num_levels - len(in_feats_channel)):
self.input_proj.append(
nn.Sequential(
('conv', nn.Conv2D(
......@@ -452,8 +338,8 @@ class DINOTransformer(nn.Layer):
padding=1)), ('norm', nn.GroupNorm(
32,
self.hidden_dim,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))))))
weight_attr=weight_attr,
bias_attr=bias_attr))))
in_channels = self.hidden_dim
def _get_encoder_input(self, feats, pad_mask=None):
......@@ -540,7 +426,7 @@ class DINOTransformer(nn.Layer):
denoising_bbox_unact)
# decoder
inter_feats, inter_ref_bboxes_unact = self.decoder(
inter_feats, inter_bboxes = self.decoder(
target, init_ref_points_unact, memory, spatial_shapes,
level_start_index, self.dec_bbox_head, self.query_pos_head,
valid_ratios, attn_mask, mask_flatten)
......@@ -555,8 +441,7 @@ class DINOTransformer(nn.Layer):
else:
out_bboxes.append(
F.sigmoid(self.dec_bbox_head[i](inter_feats[i]) +
inter_ref_bboxes_unact[i - 1]))
inverse_sigmoid(inter_bboxes[i - 1])))
out_bboxes = paddle.stack(out_bboxes)
out_logits = paddle.stack(out_logits)
......@@ -579,11 +464,8 @@ class DINOTransformer(nn.Layer):
valid_H, valid_W = h, w
grid_y, grid_x = paddle.meshgrid(
paddle.arange(
end=h, dtype=memory.dtype),
paddle.arange(
end=w, dtype=memory.dtype))
grid_xy = paddle.stack([grid_x, grid_y], -1)
paddle.arange(end=h), paddle.arange(end=w))
grid_xy = paddle.stack([grid_x, grid_y], -1).astype(memory.dtype)
valid_WH = paddle.stack([valid_W, valid_H], -1).reshape(
[-1, 1, 1, 2]).astype(grid_xy.dtype)
......@@ -623,7 +505,7 @@ class DINOTransformer(nn.Layer):
_, topk_ind = paddle.topk(
enc_outputs_class.max(-1), self.num_queries, axis=1)
# extract region proposal boxes
batch_ind = paddle.arange(end=bs, dtype=topk_ind.dtype)
batch_ind = paddle.arange(end=bs).astype(topk_ind.dtype)
batch_ind = batch_ind.unsqueeze(-1).tile([1, self.num_queries])
topk_ind = paddle.stack([batch_ind, topk_ind], axis=-1)
reference_points_unact = paddle.gather_nd(enc_outputs_coord_unact,
......
此差异已折叠。
......@@ -34,13 +34,19 @@ __all__ = ['HungarianMatcher']
@register
@serializable
class HungarianMatcher(nn.Layer):
__shared__ = ['use_focal_loss']
__shared__ = ['use_focal_loss', 'with_mask', 'num_sample_points']
def __init__(self,
matcher_coeff={'class': 1,
'bbox': 5,
'giou': 2},
matcher_coeff={
'class': 1,
'bbox': 5,
'giou': 2,
'mask': 1,
'dice': 1
},
use_focal_loss=False,
with_mask=False,
num_sample_points=12544,
alpha=0.25,
gamma=2.0):
r"""
......@@ -50,18 +56,28 @@ class HungarianMatcher(nn.Layer):
super(HungarianMatcher, self).__init__()
self.matcher_coeff = matcher_coeff
self.use_focal_loss = use_focal_loss
self.with_mask = with_mask
self.num_sample_points = num_sample_points
self.alpha = alpha
self.gamma = gamma
self.giou_loss = GIoULoss()
def forward(self, boxes, logits, gt_bbox, gt_class):
def forward(self,
boxes,
logits,
gt_bbox,
gt_class,
masks=None,
gt_mask=None):
r"""
Args:
boxes (Tensor): [b, query, 4]
logits (Tensor): [b, query, num_classes]
gt_bbox (List(Tensor)): list[[n, 4]]
gt_class (List(Tensor)): list[[n, 1]]
masks (Tensor|None): [b, query, h, w]
gt_mask (List(Tensor)): list[[n, H, W]]
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
......@@ -72,18 +88,19 @@ class HungarianMatcher(nn.Layer):
"""
bs, num_queries = boxes.shape[:2]
num_gts = sum(len(a) for a in gt_class)
if num_gts == 0:
num_gts = [len(a) for a in gt_class]
if sum(num_gts) == 0:
return [(paddle.to_tensor(
[], dtype=paddle.int64), paddle.to_tensor(
[], dtype=paddle.int64)) for _ in range(bs)]
# We flatten to compute the cost matrices in a batch
# [batch_size * num_queries, num_classes]
logits = logits.detach()
out_prob = F.sigmoid(logits.flatten(
0, 1)) if self.use_focal_loss else F.softmax(logits.flatten(0, 1))
# [batch_size * num_queries, 4]
out_bbox = boxes.flatten(0, 1)
out_bbox = boxes.detach().flatten(0, 1)
# Also concat the target labels and boxes
tgt_ids = paddle.concat(gt_class).flatten()
......@@ -111,11 +128,53 @@ class HungarianMatcher(nn.Layer):
bbox_cxcywh_to_xyxy(tgt_bbox.unsqueeze(0))).squeeze(-1)
# Final cost matrix
C = self.matcher_coeff['class'] * cost_class + self.matcher_coeff['bbox'] * cost_bbox + \
C = self.matcher_coeff['class'] * cost_class + \
self.matcher_coeff['bbox'] * cost_bbox + \
self.matcher_coeff['giou'] * cost_giou
# Compute the mask cost and dice cost
if self.with_mask:
assert (masks is not None and gt_mask is not None,
'Make sure the input has `mask` and `gt_mask`')
# all masks share the same set of points for efficient matching
sample_points = paddle.rand([bs, 1, self.num_sample_points, 2])
sample_points = 2.0 * sample_points - 1.0
out_mask = F.grid_sample(
masks.detach(), sample_points, align_corners=False).squeeze(-2)
out_mask = out_mask.flatten(0, 1)
tgt_mask = paddle.concat(gt_mask).unsqueeze(1)
sample_points = paddle.concat([
a.tile([b, 1, 1, 1]) for a, b in zip(sample_points, num_gts)
if b > 0
])
tgt_mask = F.grid_sample(
tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
with paddle.amp.auto_cast(enable=False):
# binary cross entropy cost
pos_cost_mask = F.binary_cross_entropy_with_logits(
out_mask, paddle.ones_like(out_mask), reduction='none')
neg_cost_mask = F.binary_cross_entropy_with_logits(
out_mask, paddle.zeros_like(out_mask), reduction='none')
cost_mask = paddle.matmul(
pos_cost_mask, tgt_mask, transpose_y=True) + paddle.matmul(
neg_cost_mask, 1 - tgt_mask, transpose_y=True)
cost_mask /= self.num_sample_points
# dice cost
out_mask = F.sigmoid(out_mask)
numerator = 2 * paddle.matmul(
out_mask, tgt_mask, transpose_y=True)
denominator = out_mask.sum(
-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
cost_dice = 1 - (numerator + 1) / (denominator + 1)
C = C + self.matcher_coeff['mask'] * cost_mask + \
self.matcher_coeff['dice'] * cost_dice
C = C.reshape([bs, num_queries, -1])
C = [a.squeeze(0) for a in C.chunk(bs)]
sizes = [a.shape[0] for a in gt_bbox]
indices = [
linear_sum_assignment(c.split(sizes, -1)[i].numpy())
......
......@@ -63,9 +63,9 @@ def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
return loss.mean(1).sum() / normalizer
def inverse_sigmoid(x, eps=1e-6):
def inverse_sigmoid(x, eps=1e-5):
x = x.clip(min=0., max=1.)
return paddle.log(x / (1 - x + eps) + eps)
return paddle.log(x.clip(min=eps) / (1 - x).clip(min=eps))
def deformable_attention_core_func(value, value_spatial_shapes,
......@@ -122,6 +122,99 @@ def get_valid_ratio(mask):
return paddle.stack([valid_ratio_w, valid_ratio_h], -1)
def get_denoising_training_group(targets,
num_classes,
num_queries,
class_embed,
num_denoising=100,
label_noise_ratio=0.5,
box_noise_scale=1.0):
if num_denoising <= 0:
return None, None, None, None
num_gts = [len(t) for t in targets["gt_class"]]
max_gt_num = max(num_gts)
if max_gt_num == 0:
return None, None, None, None
num_group = num_denoising // max_gt_num
num_group = 1 if num_group == 0 else num_group
# pad gt to max_num of a batch
bs = len(targets["gt_class"])
input_query_class = paddle.full(
[bs, max_gt_num], num_classes, dtype='int32')
input_query_bbox = paddle.zeros([bs, max_gt_num, 4])
pad_gt_mask = paddle.zeros([bs, max_gt_num])
for i in range(bs):
num_gt = num_gts[i]
if num_gt > 0:
input_query_class[i, :num_gt] = targets["gt_class"][i].squeeze(-1)
input_query_bbox[i, :num_gt] = targets["gt_bbox"][i]
pad_gt_mask[i, :num_gt] = 1
input_query_class = input_query_class.tile([1, num_group])
input_query_bbox = input_query_bbox.tile([1, num_group, 1])
pad_gt_mask = pad_gt_mask.tile([1, num_group])
dn_positive_idx = paddle.nonzero(pad_gt_mask)[:, 1]
dn_positive_idx = paddle.split(dn_positive_idx,
[n * num_group for n in num_gts])
# total denoising queries
num_denoising = int(max_gt_num * num_group)
if label_noise_ratio > 0:
input_query_class = input_query_class.flatten()
pad_gt_mask = pad_gt_mask.flatten()
# half of bbox prob
mask = paddle.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
chosen_idx = paddle.nonzero(mask * pad_gt_mask).squeeze(-1)
# randomly put a new one here
new_label = paddle.randint_like(
chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
input_query_class.scatter_(chosen_idx, new_label)
input_query_class.reshape_([bs, num_denoising])
pad_gt_mask.reshape_([bs, num_denoising])
if box_noise_scale > 0:
diff = paddle.concat(
[input_query_bbox[..., 2:] * 0.5, input_query_bbox[..., 2:]],
axis=-1) * box_noise_scale
diff *= (paddle.rand(input_query_bbox.shape) * 2.0 - 1.0)
input_query_bbox += diff
input_query_bbox = inverse_sigmoid(input_query_bbox)
class_embed = paddle.concat(
[class_embed, paddle.zeros([1, class_embed.shape[-1]])])
input_query_class = paddle.gather(
class_embed, input_query_class.flatten(),
axis=0).reshape([bs, num_denoising, -1])
tgt_size = num_denoising + num_queries
attn_mask = paddle.ones([tgt_size, tgt_size]) < 0
# match query cannot see the reconstruction
attn_mask[num_denoising:, :num_denoising] = True
# reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_gt_num * i:max_gt_num * (i + 1), max_gt_num * (i + 1):
num_denoising] = True
if i == num_group - 1:
attn_mask[max_gt_num * i:max_gt_num * (i + 1), :max_gt_num *
i] = True
else:
attn_mask[max_gt_num * i:max_gt_num * (i + 1), max_gt_num * (i + 1):
num_denoising] = True
attn_mask[max_gt_num * i:max_gt_num * (i + 1), :max_gt_num *
i] = True
attn_mask = ~attn_mask
dn_meta = {
"dn_positive_idx": dn_positive_idx,
"dn_num_group": num_group,
"dn_num_split": [num_denoising, num_queries]
}
return input_query_class, input_query_bbox, attn_mask, dn_meta
def get_contrastive_denoising_training_group(targets,
num_classes,
num_queries,
......@@ -204,7 +297,7 @@ def get_contrastive_denoising_training_group(targets,
tgt_size = num_denoising + num_queries
attn_mask = paddle.ones([tgt_size, tgt_size]) < 0
# match query cannot see the reconstruct
# match query cannot see the reconstruction
attn_mask[num_denoising:, :num_denoising] = True
# reconstruct cannot see each other
for i in range(num_group):
......@@ -263,3 +356,42 @@ def get_sine_pos_embed(pos_tensor,
pos_res[0], pos_res[1] = pos_res[1], pos_res[0]
pos_res = paddle.concat(pos_res, axis=2)
return pos_res
def mask_to_box_coordinate(mask,
normalize=False,
format="xyxy",
dtype="float32"):
"""
Compute the bounding boxes around the provided mask.
Args:
mask (Tensor:bool): [b, c, h, w]
Returns:
bbox (Tensor): [b, c, 4]
"""
assert mask.ndim == 4
assert format in ["xyxy", "xywh"]
if mask.sum() == 0:
return paddle.zeros([mask.shape[0], mask.shape[1], 4], dtype=dtype)
h, w = mask.shape[-2:]
y, x = paddle.meshgrid(
paddle.arange(
end=h, dtype=dtype), paddle.arange(
end=w, dtype=dtype))
x_mask = x * mask
x_max = x_mask.flatten(-2).max(-1) + 1
x_min = paddle.where(mask, x_mask,
paddle.to_tensor(1e8)).flatten(-2).min(-1)
y_mask = y * mask
y_max = y_mask.flatten(-2).max(-1) + 1
y_min = paddle.where(mask, y_mask,
paddle.to_tensor(1e8)).flatten(-2).min(-1)
out_bbox = paddle.stack([x_min, y_min, x_max, y_max], axis=-1)
if normalize:
out_bbox /= paddle.to_tensor([w, h, w, h]).astype(dtype)
return out_bbox if format == "xyxy" else bbox_xyxy_to_cxcywh(out_bbox)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册