From eba2fca7099a8ae7130bb016a0d91cc762aa3e6d Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 24 Mar 2021 20:27:30 +0800 Subject: [PATCH] support not pad gt in rcnn model (#2411) * support not pad gt in rcnn model * add comment in reader --- .../_base_/cascade_fpn_reader.yml | 7 ++- .../_base_/cascade_mask_fpn_reader.yml | 7 ++- .../faster_rcnn/_base_/faster_fpn_reader.yml | 7 ++- configs/faster_rcnn/_base_/faster_reader.yml | 7 ++- configs/mask_rcnn/_base_/mask_fpn_reader.yml | 7 ++- configs/mask_rcnn/_base_/mask_reader.yml | 7 ++- configs/ttfnet/_base_/ttfnet_reader.yml | 2 +- deploy/python/preprocess.py | 4 +- ppdet/data/reader.py | 39 +++++++++---- ppdet/data/transform/batch_operators.py | 56 ++----------------- ppdet/engine/export_utils.py | 3 +- ppdet/modeling/proposal_generator/target.py | 55 ++++++------------ .../proposal_generator/target_layer.py | 2 +- 13 files changed, 80 insertions(+), 123 deletions(-) diff --git a/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml b/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml index cf54ecced..7329be68c 100644 --- a/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml +++ b/configs/cascade_rcnn/_base_/cascade_fpn_reader.yml @@ -7,10 +7,11 @@ TrainReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: true} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: true drop_last: true + collate_batch: false EvalReader: @@ -20,7 +21,7 @@ EvalReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false @@ -34,7 +35,7 @@ TestReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false diff --git a/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml b/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml index cf54ecced..7329be68c 100644 --- a/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml +++ b/configs/cascade_rcnn/_base_/cascade_mask_fpn_reader.yml @@ -7,10 +7,11 @@ TrainReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: true} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: true drop_last: true + collate_batch: false EvalReader: @@ -20,7 +21,7 @@ EvalReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false @@ -34,7 +35,7 @@ TestReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false diff --git a/configs/faster_rcnn/_base_/faster_fpn_reader.yml b/configs/faster_rcnn/_base_/faster_fpn_reader.yml index cf54ecced..7329be68c 100644 --- a/configs/faster_rcnn/_base_/faster_fpn_reader.yml +++ b/configs/faster_rcnn/_base_/faster_fpn_reader.yml @@ -7,10 +7,11 @@ TrainReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: true} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: true drop_last: true + collate_batch: false EvalReader: @@ -20,7 +21,7 @@ EvalReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false @@ -34,7 +35,7 @@ TestReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false diff --git a/configs/faster_rcnn/_base_/faster_reader.yml b/configs/faster_rcnn/_base_/faster_reader.yml index c1be1de4c..1f6eaa86d 100644 --- a/configs/faster_rcnn/_base_/faster_reader.yml +++ b/configs/faster_rcnn/_base_/faster_reader.yml @@ -7,10 +7,11 @@ TrainReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: -1., pad_gt: true} + - PadBatch: {pad_to_stride: -1.} batch_size: 1 shuffle: true drop_last: true + collate_batch: false EvalReader: @@ -20,7 +21,7 @@ EvalReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: -1., pad_gt: false} + - PadBatch: {pad_to_stride: -1.} batch_size: 1 shuffle: false drop_last: false @@ -34,7 +35,7 @@ TestReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: -1., pad_gt: false} + - PadBatch: {pad_to_stride: -1.} batch_size: 1 shuffle: false drop_last: false diff --git a/configs/mask_rcnn/_base_/mask_fpn_reader.yml b/configs/mask_rcnn/_base_/mask_fpn_reader.yml index d2cb8ec96..d8b552cc0 100644 --- a/configs/mask_rcnn/_base_/mask_fpn_reader.yml +++ b/configs/mask_rcnn/_base_/mask_fpn_reader.yml @@ -7,10 +7,11 @@ TrainReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: true} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: true drop_last: true + collate_batch: false EvalReader: sample_transforms: @@ -19,7 +20,7 @@ EvalReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false @@ -33,7 +34,7 @@ TestReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, pad_gt: false} + - PadBatch: {pad_to_stride: 32} batch_size: 1 shuffle: false drop_last: false diff --git a/configs/mask_rcnn/_base_/mask_reader.yml b/configs/mask_rcnn/_base_/mask_reader.yml index 22ef9f44b..277fd617f 100644 --- a/configs/mask_rcnn/_base_/mask_reader.yml +++ b/configs/mask_rcnn/_base_/mask_reader.yml @@ -7,10 +7,11 @@ TrainReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: -1., pad_gt: true} + - PadBatch: {pad_to_stride: -1.} batch_size: 1 shuffle: true drop_last: true + collate_batch: false EvalReader: @@ -20,7 +21,7 @@ EvalReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: -1., pad_gt: false} + - PadBatch: {pad_to_stride: -1.} batch_size: 1 shuffle: false drop_last: false @@ -34,7 +35,7 @@ TestReader: - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - Permute: {} batch_transforms: - - PadBatch: {pad_to_stride: -1., pad_gt: false} + - PadBatch: {pad_to_stride: -1.} batch_size: 1 shuffle: false drop_last: false diff --git a/configs/ttfnet/_base_/ttfnet_reader.yml b/configs/ttfnet/_base_/ttfnet_reader.yml index 5a69c59dd..f9ed6cc57 100644 --- a/configs/ttfnet/_base_/ttfnet_reader.yml +++ b/configs/ttfnet/_base_/ttfnet_reader.yml @@ -8,7 +8,7 @@ TrainReader: - Permute: {} batch_transforms: - Gt2TTFTarget: {down_ratio: 4} - - PadBatch: {pad_to_stride: 32, pad_gt: true} + - PadBatch: {pad_to_stride: 32} batch_size: 12 shuffle: true drop_last: true diff --git a/deploy/python/preprocess.py b/deploy/python/preprocess.py index 371b1172d..3d0c1b9b1 100644 --- a/deploy/python/preprocess.py +++ b/deploy/python/preprocess.py @@ -172,9 +172,9 @@ class Permute(object): class PadStride(object): - """ padding image for model with FPN , instead PadBatch(pad_to_stride, pad_gt) in original config + """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config Args: - stride (bool): model with FPN need image shape % stride == 0 + stride (bool): model with FPN need image shape % stride == 0 """ def __init__(self, stride=0): diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 084ad7b93..126051734 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -64,10 +64,11 @@ class Compose(object): class BatchCompose(Compose): - def __init__(self, transforms, num_classes=80): + def __init__(self, transforms, num_classes=80, collate_batch=True): super(BatchCompose, self).__init__(transforms, num_classes) self.output_fields = mp.Manager().list([]) self.lock = mp.Lock() + self.collate_batch = collate_batch def __call__(self, data): for f in self.transforms_cls: @@ -103,11 +104,27 @@ class BatchCompose(Compose): self.output_fields.append(k) self.lock.release() - data = [[data[i][k] for k in self.output_fields] - for i in range(len(data))] - data = list(zip(*data)) + batch_data = [] + # If set collate_batch=True, all data will collate a batch + # and it will transfor to paddle.tensor. + # If set collate_batch=False, `image`, `im_shape` and + # `scale_factor` will collate a batch, but `gt` data(such as: + # gt_bbox, gt_class, gt_poly.etc.) will not collate a batch + # and it will transfor to list[Tensor] or list[list]. + if self.collate_batch: + data = [[data[i][k] for k in self.output_fields] + for i in range(len(data))] + data = list(zip(*data)) + batch_data = [np.stack(d, axis=0) for d in data] + else: + for k in self.output_fields: + tmp_data = [] + for i in range(len(data)): + tmp_data.append(data[i][k]) + if not 'gt_' in k and not 'is_crowd' in k: + tmp_data = np.stack(tmp_data, axis=0) + batch_data.append(tmp_data) - batch_data = [np.stack(d, axis=0) for d in data] return batch_data @@ -145,6 +162,7 @@ class BaseDataLoader(object): drop_last=False, drop_empty=True, num_classes=80, + collate_batch=True, use_shared_memory=False, **kwargs): # sample transform @@ -152,8 +170,8 @@ class BaseDataLoader(object): sample_transforms, num_classes=num_classes) # batch transfrom - self._batch_transforms = BatchCompose(batch_transforms, num_classes) - + self._batch_transforms = BatchCompose(batch_transforms, num_classes, + collate_batch) self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last @@ -238,10 +256,11 @@ class TrainReader(BaseDataLoader): drop_last=True, drop_empty=True, num_classes=80, + collate_batch=True, **kwargs): - super(TrainReader, self).__init__(sample_transforms, batch_transforms, - batch_size, shuffle, drop_last, - drop_empty, num_classes, **kwargs) + super(TrainReader, self).__init__( + sample_transforms, batch_transforms, batch_size, shuffle, drop_last, + drop_empty, num_classes, collate_batch, **kwargs) @register diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index bd99c6f93..b749111d3 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -50,10 +50,9 @@ class PadBatch(BaseOperator): height and width is divisible by `pad_to_stride`. """ - def __init__(self, pad_to_stride=0, pad_gt=False): + def __init__(self, pad_to_stride=0): super(PadBatch, self).__init__() self.pad_to_stride = pad_to_stride - self.pad_gt = pad_gt def __call__(self, samples, context=None): """ @@ -70,7 +69,6 @@ class PadBatch(BaseOperator): max_shape[2] = int( np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride) - padding_batch = [] for data in samples: im = data['image'] im_c, im_h, im_w = im.shape[:] @@ -92,55 +90,6 @@ class PadBatch(BaseOperator): padding_segm[:, :im_h, :im_w] = gt_segm data['gt_segm'] = padding_segm - if self.pad_gt: - gt_num = [] - if 'gt_poly' in data and data['gt_poly'] is not None and len(data[ - 'gt_poly']) > 0: - pad_mask = True - else: - pad_mask = False - - if pad_mask: - poly_num = [] - poly_part_num = [] - point_num = [] - for data in samples: - gt_num.append(data['gt_bbox'].shape[0]) - if pad_mask: - poly_num.append(len(data['gt_poly'])) - for poly in data['gt_poly']: - poly_part_num.append(int(len(poly))) - for p_p in poly: - point_num.append(int(len(p_p) / 2)) - gt_num_max = max(gt_num) - - for i, data in enumerate(samples): - gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32) - gt_class_data = -np.ones([gt_num_max], dtype=np.int32) - is_crowd_data = np.ones([gt_num_max], dtype=np.int32) - - if pad_mask: - poly_num_max = max(poly_num) - poly_part_num_max = max(poly_part_num) - point_num_max = max(point_num) - gt_masks_data = -np.ones( - [poly_num_max, poly_part_num_max, point_num_max, 2], - dtype=np.float32) - - gt_num = data['gt_bbox'].shape[0] - gt_box_data[0:gt_num, :] = data['gt_bbox'] - gt_class_data[0:gt_num] = np.squeeze(data['gt_class']) - is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd']) - if pad_mask: - for j, poly in enumerate(data['gt_poly']): - for k, p_p in enumerate(poly): - pp_np = np.array(p_p).reshape(-1, 2) - gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np - data['gt_poly'] = gt_masks_data - data['gt_bbox'] = gt_box_data - data['gt_class'] = gt_class_data - data['is_crowd'] = is_crowd_data - return samples @@ -585,6 +534,9 @@ class Gt2TTFTarget(BaseOperator): sample['ttf_heatmap'] = heatmap sample['ttf_box_target'] = box_target sample['ttf_reg_weight'] = reg_weight + sample.pop('is_crowd') + sample.pop('gt_class') + sample.pop('gt_bbox') return samples def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius): diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 7fa2403d5..c1b140615 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -56,10 +56,9 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): preprocess_list.append(p) batch_transforms = reader_cfg.get('batch_transforms', None) if batch_transforms: - methods = [list(bt.keys())[0] for bt in batch_transforms] for bt in batch_transforms: for key, value in bt.items(): - # for deploy/infer, use PadStride(stride) instead PadBatch(pad_to_stride, pad_gt) + # for deploy/infer, use PadStride(stride) instead PadBatch(pad_to_stride) if key == 'PadBatch': preprocess_list.append({ 'type': 'PadStride', diff --git a/ppdet/modeling/proposal_generator/target.py b/ppdet/modeling/proposal_generator/target.py index b66f0d9cd..b4d490a52 100644 --- a/ppdet/modeling/proposal_generator/target.py +++ b/ppdet/modeling/proposal_generator/target.py @@ -139,7 +139,8 @@ def generate_proposal_target(rpn_rois, bg_thresh = cascade_iou if is_cascade else bg_thresh for i, rpn_roi in enumerate(rpn_rois): gt_bbox = gt_boxes[i] - gt_class = gt_classes[i] + gt_class = paddle.squeeze(gt_classes[i], axis=-1) + if not is_cascade: bbox = paddle.concat([rpn_roi, gt_bbox]) else: @@ -197,25 +198,6 @@ def sample_bbox(matches, return sampled_inds, sampled_gt_classes -def _strip_pad(gt_polys): - new_gt_polys = [] - for i in range(gt_polys.shape[0]): - gt_segs = [] - for j in range(gt_polys[i].shape[0]): - new_poly = [] - polys = gt_polys[i][j] - for ii in range(polys.shape[0]): - x, y = polys[ii] - if (x == -1 and y == -1): - continue - elif (x >= 0 or y >= 0): - new_poly.extend([x, y]) # array, one poly - if len(new_poly) > 6: - gt_segs.append(np.array(new_poly).astype('float64')) - new_gt_polys.append(gt_segs) - return new_gt_polys - - def polygons_to_mask(polygons, height, width): """ Args: @@ -233,8 +215,7 @@ def polygons_to_mask(polygons, height, width): def rasterize_polygons_within_box(poly, box, resolution): w, h = box[2] - box[0], box[3] - box[1] - - polygons = copy.deepcopy(poly) + polygons = [np.asarray(p, dtype=np.float64) for p in poly] for p in polygons: p[0::2] = p[0::2] - box[0] p[1::2] = p[1::2] - box[1] @@ -265,36 +246,36 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds, mask_index = [] tgt_weights = [] for k in range(len(rois)): - has_fg = True - rois_per_im = rois[k] - gt_segms_per_im = gt_segms[k] labels_per_im = labels_int32[k] fg_inds = paddle.nonzero( paddle.logical_and(labels_per_im != -1, labels_per_im != num_classes)) + has_fg = True if fg_inds.numel() == 0: has_fg = False fg_inds = paddle.ones([1], dtype='int32') - inds_per_im = sampled_gt_inds[k] inds_per_im = paddle.gather(inds_per_im, fg_inds) - gt_segms_per_im = paddle.gather(gt_segms_per_im, inds_per_im) - + rois_per_im = rois[k] fg_rois = paddle.gather(rois_per_im, fg_inds) + boxes = fg_rois.numpy() + gt_segms_per_im = gt_segms[k] + new_segm = [] + inds_per_im = inds_per_im.numpy() + for i in inds_per_im: + new_segm.append(gt_segms_per_im[i]) + fg_inds_new = fg_inds.reshape([-1]).numpy() + results = [] + for j in fg_inds_new: + results.append( + rasterize_polygons_within_box(new_segm[j], boxes[j], + resolution)) + fg_classes = paddle.gather(labels_per_im, fg_inds) - fg_segms = paddle.gather(gt_segms_per_im, fg_inds) weight = paddle.ones([fg_rois.shape[0]], dtype='float32') if not has_fg: weight = weight - 1 - # remove padding - gt_polys = fg_segms.numpy() - boxes = fg_rois.numpy() - new_gt_polys = _strip_pad(gt_polys) - results = [ - rasterize_polygons_within_box(poly, box, resolution) - for poly, box in zip(new_gt_polys, boxes) - ] tgt_mask = paddle.stack(results) tgt_mask.stop_gradient = True fg_rois.stop_gradient = True diff --git a/ppdet/modeling/proposal_generator/target_layer.py b/ppdet/modeling/proposal_generator/target_layer.py index 4586cadf3..1087638b9 100644 --- a/ppdet/modeling/proposal_generator/target_layer.py +++ b/ppdet/modeling/proposal_generator/target_layer.py @@ -41,7 +41,7 @@ class RPNTargetAssign(object): anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps. """ gt_boxes = inputs['gt_bbox'] - batch_size = gt_boxes.shape[0] + batch_size = len(gt_boxes) tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target( anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap, self.negative_overlap, self.fg_fraction, self.use_random, -- GitLab