未验证 提交 dfc40ee0 编写于 作者: S shangliang Xu 提交者: GitHub

add DETR (#3690)

上级 fdf98755
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
hidden_dim: 256
DETR:
backbone: ResNet
transformer: DETRTransformer
detr_head: DETRHead
post_process: DETRBBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [3]
lr_mult_list: [0.0, 0.1, 0.1, 0.1]
num_stages: 4
DETRTransformer:
num_queries: 100
position_embed_type: sine
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
dropout: 0.1
activation: relu
DETRHead:
num_mlp_layers: 3
DETRLoss:
loss_coeff: {class: 1, bbox: 5, giou: 2, no_object: 0.1, mask: 1, dice: 1}
aux_loss: True
HungarianMatcher:
matcher_coeff: {class: 1, bbox: 5, giou: 2}
worker_num: 0
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600 },
RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 2
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
epoch: 500
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [400]
use_warmup: false
OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/detr_r50.yml',
'_base_/detr_reader.yml',
]
weights: output/detr_r50_1x_coco/model_final
......@@ -33,7 +33,7 @@ logger = setup_logger(__name__)
__all__ = [
'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget'
'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget', 'PadMaskBatch'
]
......@@ -764,10 +764,79 @@ class Gt2SparseRCNNTarget(BaseOperator):
img_whwh = np.array([w, h, w, h], dtype=np.int32)
sample["img_whwh"] = img_whwh
if "scale_factor" in sample:
sample["scale_factor_wh"] = np.array([sample["scale_factor"][1], sample["scale_factor"][0]],
dtype=np.float32)
sample["scale_factor_wh"] = np.array(
[sample["scale_factor"][1], sample["scale_factor"][0]],
dtype=np.float32)
sample.pop("scale_factor")
else:
sample["scale_factor_wh"] = np.array([1.0, 1.0], dtype=np.float32)
sample["scale_factor_wh"] = np.array(
[1.0, 1.0], dtype=np.float32)
return samples
@register_op
class PadMaskBatch(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`.
return_pad_mask (bool): If `return_pad_mask = True`, return
`pad_mask` for transformer.
"""
def __init__(self, pad_to_stride=0, return_pad_mask=False):
super(PadMaskBatch, self).__init__()
self.pad_to_stride = pad_to_stride
self.return_pad_mask = return_pad_mask
def __call__(self, samples, context=None):
"""
Args:
samples (list): a batch of sample, each is dict.
"""
coarsest_stride = self.pad_to_stride
max_shape = np.array([data['image'].shape for data in samples]).max(
axis=0)
if coarsest_stride > 0:
max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(
np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
for data in samples:
im = data['image']
im_c, im_h, im_w = im.shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
data['image'] = padding_im
if 'semantic' in data and data['semantic'] is not None:
semantic = data['semantic']
padding_sem = np.zeros(
(1, max_shape[1], max_shape[2]), dtype=np.float32)
padding_sem[:, :im_h, :im_w] = semantic
data['semantic'] = padding_sem
if 'gt_segm' in data and data['gt_segm'] is not None:
gt_segm = data['gt_segm']
padding_segm = np.zeros(
(gt_segm.shape[0], max_shape[1], max_shape[2]),
dtype=np.uint8)
padding_segm[:, :im_h, :im_w] = gt_segm
data['gt_segm'] = padding_segm
if self.return_pad_mask:
padding_mask = np.zeros(
(max_shape[1], max_shape[2]), dtype=np.float32)
padding_mask[:im_h, :im_w] = 1.
data['pad_mask'] = padding_mask
if 'gt_rbox2poly' in data and data['gt_rbox2poly'] is not None:
# ploy to rbox
polys = data['gt_rbox2poly']
rbox = bbox_utils.poly2rbox(polys)
data['gt_rbox'] = rbox
return samples
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册