From dcf97ccda6df6a2efe4f61437c7bcb6bd82f9797 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 25 Nov 2020 22:34:01 +0800 Subject: [PATCH] adapt new im_shape im Mask R-CNN-FPN (#1760) * adapt new im_shape im Mask R-CNN-FPN * fix training --- configs/_base_/readers/mask_reader.yml | 12 ++-- ppdet/data/transform/operator.py | 20 +++---- ppdet/modeling/architecture/mask_rcnn.py | 6 +- ppdet/modeling/bbox.py | 7 ++- ppdet/modeling/head/mask_head.py | 16 +++--- ppdet/modeling/layers.py | 72 ++++++++++++++---------- ppdet/modeling/post_process.py | 15 ++--- ppdet/py_op/post_process.py | 18 ++++-- 8 files changed, 94 insertions(+), 72 deletions(-) diff --git a/configs/_base_/readers/mask_reader.yml b/configs/_base_/readers/mask_reader.yml index 07159041f..c7296653e 100644 --- a/configs/_base_/readers/mask_reader.yml +++ b/configs/_base_/readers/mask_reader.yml @@ -17,14 +17,14 @@ TrainReader: EvalReader: inputs_def: - fields: ['image', 'im_info', 'im_id'] + fields: ['image', 'im_shape', 'scale_factor', 'im_id'] sample_transforms: - - DecodeImage: {to_rgb: true} - - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - - ResizeImage: {interp: 1, max_size: 1333, target_size: 800, use_cv2: true} - - Permute: {channel_first: true, to_bgr: false} + - DecodeOp: {} + - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - ResizeOp: {interp: 1, target_size: [800, 1333]} + - PermuteOp: {} batch_transforms: - - PadBatch: {pad_to_stride: 32, use_padded_im_info: false, pad_gt: false} + - PadBatchOp: {pad_to_stride: 32, pad_gt: false} batch_size: 1 shuffle: false drop_last: false diff --git a/ppdet/data/transform/operator.py b/ppdet/data/transform/operator.py index 760383e91..9aa691b77 100644 --- a/ppdet/data/transform/operator.py +++ b/ppdet/data/transform/operator.py @@ -141,7 +141,7 @@ class DecodeOp(BaseOperator): "image width.".format(im.shape[1], sample['w'])) sample['w'] = im.shape[1] - sample['im_shape'] = np.array(im.shape[:2], dtype=np.int32) + sample['im_shape'] = np.array(im.shape[:2], dtype=np.float32) sample['scale_factor'] = np.array([1., 1.], dtype=np.float32) return sample @@ -666,8 +666,8 @@ class ResizeOp(BaseOperator): im_scale = min(target_size_min / im_size_min, target_size_max / im_size_max) - resize_h = int(im_scale * im_shape[0]) - resize_w = int(im_scale * im_shape[1]) + resize_h = im_scale * float(im_shape[0]) + resize_w = im_scale * float(im_shape[1]) im_scale_x = im_scale im_scale_y = im_scale @@ -678,14 +678,14 @@ class ResizeOp(BaseOperator): im = self.apply_image(sample['image'], [im_scale_x, im_scale_y]) sample['image'] = im - sample['im_shape'] = np.array([resize_h, resize_w], dtype=np.int32) + sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32) if 'scale_factor' in sample: scale_factor = sample['scale_factor'] - sample['scale_factor'] = np.array( + sample['scale_factor'] = np.asarray( [scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x], dtype=np.float32) else: - sample['scale_factor'] = np.array( + sample['scale_factor'] = np.asarray( [im_scale_y, im_scale_x], dtype=np.float32) # apply bbox @@ -1397,8 +1397,8 @@ class RandomScaledCropOp(BaseOperator): random_dim = int(dim * random_scale) dim_max = max(h, w) scale = random_dim / dim_max - resize_w = int(round(w * scale)) - resize_h = int(round(h * scale)) + resize_w = w * scale + resize_h = h * scale offset_x = int(max(0, np.random.uniform(0., resize_w - dim))) offset_y = int(max(0, np.random.uniform(0., resize_h - dim))) @@ -1408,9 +1408,9 @@ class RandomScaledCropOp(BaseOperator): canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[ offset_y:offset_y + dim, offset_x:offset_x + dim, :] sample['image'] = canvas - sample['im_shape'] = np.array([resize_h, resize_w], dtype=np.int32) + sample['im_shape'] = np.asarray([resize_h, resize_w], dtype=np.float32) scale_factor = sample['sacle_factor'] - sample['scale_factor'] = np.array( + sample['scale_factor'] = np.asarray( [scale_factor[0] * scale, scale_factor[1] * scale], dtype=np.float32) diff --git a/ppdet/modeling/architecture/mask_rcnn.py b/ppdet/modeling/architecture/mask_rcnn.py index eee788a55..9c542be3f 100644 --- a/ppdet/modeling/architecture/mask_rcnn.py +++ b/ppdet/modeling/architecture/mask_rcnn.py @@ -96,7 +96,8 @@ class MaskRCNN(BaseArch): self.bbox_head_out, rois) # Refine bbox by the output from bbox_head at test stage self.bboxes = self.bbox_post_process(bbox_pred, bboxes, - self.inputs['im_info']) + self.inputs['im_shape'], + self.inputs['scale_factor']) else: # Proposal RoI for Mask branch # bboxes update at training stage only @@ -134,7 +135,8 @@ class MaskRCNN(BaseArch): def get_pred(self, ): mask = self.mask_post_process(self.bboxes, self.mask_head_out, - self.inputs['im_info']) + self.inputs['im_shape'], + self.inputs['scale_factor']) bbox, bbox_num = self.bboxes output = { 'bbox': bbox.numpy(), diff --git a/ppdet/modeling/bbox.py b/ppdet/modeling/bbox.py index 78b7e787a..85dafe4ab 100644 --- a/ppdet/modeling/bbox.py +++ b/ppdet/modeling/bbox.py @@ -93,6 +93,11 @@ class Proposal(object): self.proposal_target_generator = proposal_target_generator def generate_proposal(self, inputs, rpn_head_out, anchor_out): + # TODO: delete im_info + try: + im_shape = inputs['im_info'] + except: + im_shape = inputs['im_shape'] rpn_rois_list = [] rpn_prob_list = [] rpn_rois_num_list = [] @@ -104,7 +109,7 @@ class Proposal(object): bbox_deltas=rpn_delta, anchors=anchor, variances=var, - im_info=inputs['im_info'], + im_shape=im_shape, mode=inputs['mode']) if len(rpn_head_out) == 1: return rpn_rois, rpn_rois_num diff --git a/ppdet/modeling/head/mask_head.py b/ppdet/modeling/head/mask_head.py index 93b9ba6f8..7db51eb18 100644 --- a/ppdet/modeling/head/mask_head.py +++ b/ppdet/modeling/head/mask_head.py @@ -138,7 +138,7 @@ class MaskHead(Layer): return mask_head_out def forward_test(self, - im_info, + scale_factor, body_feats, bboxes, bbox_feat, @@ -149,12 +149,14 @@ class MaskHead(Layer): if bbox.shape[0] == 0: mask_head_out = bbox else: - im_info_expand = [] + scale_factor_list = [] for idx, num in enumerate(bbox_num): for n in range(num): - im_info_expand.append(im_info[idx, -1]) - im_info_expand = paddle.concat(im_info_expand) - scaled_bbox = paddle.multiply(bbox[:, 2:], im_info_expand, axis=0) + scale_factor_list.append(scale_factor[idx, 0]) + scale_factor_list = paddle.cast( + paddle.concat(scale_factor_list), 'float32') + scaled_bbox = paddle.multiply( + bbox[:, 2:], scale_factor_list, axis=0) scaled_bboxes = (scaled_bbox, bbox_num) mask_feat = self.mask_feat(body_feats, scaled_bboxes, bbox_feat, mask_index, spatial_scale, stage) @@ -174,8 +176,8 @@ class MaskHead(Layer): mask_head_out = self.forward_train(body_feats, bboxes, bbox_feat, mask_index, spatial_scale, stage) else: - im_info = inputs['im_info'] - mask_head_out = self.forward_test(im_info, body_feats, bboxes, + scale_factor = inputs['scale_factor'] + mask_head_out = self.forward_test(scale_factor, body_feats, bboxes, bbox_feat, mask_index, spatial_scale, stage) return mask_head_out diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 15176341e..a5ae2e900 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -16,8 +16,7 @@ import numpy as np from numbers import Integral import paddle -import paddle.fluid as fluid -from paddle.fluid.dygraph.base import to_variable +from paddle import to_tensor from ppdet.core.workspace import register, serializable from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target from ppdet.py_op.post_process import bbox_post_process @@ -86,20 +85,20 @@ class AnchorTargetGeneratorRPN(object): self.batch_size_per_im, self.positive_overlap, self.negative_overlap, self.fg_fraction, self.use_random) - loc_indexes = to_variable(loc_indexes) - score_indexes = to_variable(score_indexes) - tgt_labels = to_variable(tgt_labels) - tgt_bboxes = to_variable(tgt_bboxes) - bbox_inside_weights = to_variable(bbox_inside_weights) + loc_indexes = to_tensor(loc_indexes) + score_indexes = to_tensor(score_indexes) + tgt_labels = to_tensor(tgt_labels) + tgt_bboxes = to_tensor(tgt_bboxes) + bbox_inside_weights = to_tensor(bbox_inside_weights) loc_indexes.stop_gradient = True score_indexes.stop_gradient = True tgt_labels.stop_gradient = True - cls_logits = fluid.layers.reshape(x=cls_logits, shape=(-1, )) - bbox_pred = fluid.layers.reshape(x=bbox_pred, shape=(-1, 4)) - pred_cls_logits = fluid.layers.gather(cls_logits, score_indexes) - pred_bbox_pred = fluid.layers.gather(bbox_pred, loc_indexes) + cls_logits = paddle.reshape(x=cls_logits, shape=(-1, )) + bbox_pred = paddle.reshape(x=bbox_pred, shape=(-1, 4)) + pred_cls_logits = paddle.gather(cls_logits, score_indexes) + pred_bbox_pred = paddle.gather(bbox_pred, loc_indexes) return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights @@ -131,22 +130,38 @@ class ProposalGenerator(object): bbox_deltas, anchors, variances, - im_info, + im_shape, mode='train'): pre_nms_top_n = self.train_pre_nms_top_n if mode == 'train' else self.infer_pre_nms_top_n post_nms_top_n = self.train_post_nms_top_n if mode == 'train' else self.infer_post_nms_top_n - rpn_rois, rpn_rois_prob, rpn_rois_num = fluid.layers.generate_proposals( - scores, - bbox_deltas, - im_info, - anchors, - variances, - pre_nms_top_n=pre_nms_top_n, - post_nms_top_n=post_nms_top_n, - nms_thresh=self.nms_thresh, - min_size=self.min_size, - eta=self.eta, - return_rois_num=True) + # TODO delete im_info + if im_shape.shape[1] > 2: + import paddle.fluid as fluid + rpn_rois, rpn_rois_prob, rpn_rois_num = fluid.layers.generate_proposals( + scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=pre_nms_top_n, + post_nms_top_n=post_nms_top_n, + nms_thresh=self.nms_thresh, + min_size=self.min_size, + eta=self.eta, + return_rois_num=True) + else: + rpn_rois, rpn_rois_prob, rpn_rois_num = ops.generate_proposals( + scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=pre_nms_top_n, + post_nms_top_n=post_nms_top_n, + nms_thresh=self.nms_thresh, + min_size=self.min_size, + eta=self.eta, + return_rois_num=True) return rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n @@ -198,7 +213,7 @@ class ProposalTargetGenerator(object): self.bg_thresh_hi[stage], self.bg_thresh_lo[stage], self.bbox_reg_weights[stage], self.num_classes, self.use_random, self.is_cls_agnostic, self.is_cascade_rcnn) - outs = [to_variable(v) for v in outs] + outs = [to_tensor(v) for v in outs] for v in outs: v.stop_gradient = True return outs @@ -227,7 +242,7 @@ class MaskTargetGenerator(object): rois, rois_num, labels_int32, self.num_classes, self.mask_resolution) - outs = [to_variable(v) for v in outs] + outs = [to_tensor(v) for v in outs] for v in outs: v.stop_gradient = True return outs @@ -260,7 +275,7 @@ class RCNNBox(object): scale_list = [] origin_shape_list = [] for idx in range(self.batch_size): - scale = scale_factor[idx, :] + scale = scale_factor[idx, :][0] rois_num_per_im = rois_num[idx] expand_scale = paddle.expand(scale, [rois_num_per_im, 1]) scale_list.append(expand_scale) @@ -327,7 +342,7 @@ class DecodeClipNms(object): im_info.numpy(), self.keep_top_k, self.score_threshold, self.nms_threshold, self.num_classes) - outs = [to_variable(v) for v in outs] + outs = [to_tensor(v) for v in outs] for v in outs: v.stop_gradient = True return outs @@ -407,7 +422,6 @@ class YOLOBox(object): def __call__(self, yolo_head_out, anchors, im_shape, scale_factor=None): boxes_list = [] scores_list = [] - im_shape = paddle.cast(im_shape, 'float32') if scale_factor is not None: origin_shape = im_shape / scale_factor else: diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index fab68eeae..24827f510 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -17,14 +17,7 @@ class BBoxPostProcess(object): self.nms = nms def __call__(self, head_out, rois, im_shape, scale_factor=None): - # TODO: compatible for im_info - # remove after unify the im_shape. scale_factor - if im_shape.shape[1] > 2: - origin_shape = im_shape[:, :2] - scale_factor = im_shape[:, 2:] - else: - origin_shape = im_shape - bboxes, score = self.decode(head_out, rois, origin_shape, scale_factor) + bboxes, score = self.decode(head_out, rois, im_shape, scale_factor) bbox_pred, bbox_num = self.nms(bboxes, score) return bbox_pred, bbox_num @@ -38,12 +31,12 @@ class MaskPostProcess(object): self.mask_resolution = mask_resolution self.binary_thresh = binary_thresh - def __call__(self, bboxes, mask_head_out, im_info): + def __call__(self, bboxes, mask_head_out, im_shape, scale_factor=None): # TODO: modify related ops for deploying bboxes_np = (i.numpy() for i in bboxes) mask = mask_post_process(bboxes_np, mask_head_out.numpy(), - im_info.numpy(), self.mask_resolution, - self.binary_thresh) + im_shape.numpy(), scale_factor[:, 0].numpy(), + self.mask_resolution, self.binary_thresh) mask = {'mask': mask} return mask diff --git a/ppdet/py_op/post_process.py b/ppdet/py_op/post_process.py index 7ae75bcd6..f303ff401 100755 --- a/ppdet/py_op/post_process.py +++ b/ppdet/py_op/post_process.py @@ -10,7 +10,8 @@ import cv2 def bbox_post_process(bboxes, bbox_prob, bbox_deltas, - im_info, + im_shape, + scale_factor, keep_top_k=100, score_thresh=0.05, nms_thresh=0.5, @@ -27,14 +28,14 @@ def bbox_post_process(bboxes, end_num += box_num boxes = bbox[st_num:end_num, :] # bbox - boxes = boxes / im_info[i][2] # scale + boxes = boxes / scale_factor[i] # scale bbox_delta = bbox_deltas[st_num:end_num, :, :] # bbox delta bbox_delta = np.reshape(bbox_delta, (box_num, -1)) # step1: decode boxes = delta2bbox(bbox_delta, boxes, bbox_reg_weights) # step2: clip - boxes = clip_bbox(boxes, im_info[i][:2] / im_info[i][2]) + boxes = clip_bbox(boxes, im_shape[i][:2] / scale_factor[i]) # step3: nms cls_boxes = [[] for _ in range(class_nums)] scores_n = bbox_prob[st_num:end_num, :] @@ -72,7 +73,12 @@ def bbox_post_process(bboxes, @jit -def mask_post_process(bboxes, masks, im_info, resolution=14, binary_thresh=0.5): +def mask_post_process(bboxes, + masks, + im_shape, + scale_factor, + resolution=14, + binary_thresh=0.5): if masks.shape[0] == 0: return masks bbox, bbox_nums = bboxes @@ -93,8 +99,8 @@ def mask_post_process(bboxes, masks, im_info, resolution=14, binary_thresh=0.5): labels_n = labels[st_num:end_num] masks_n = masks[st_num:end_num] - im_h = int(round(im_info[i][0] / im_info[i][2])) - im_w = int(round(im_info[i][1] / im_info[i][2])) + im_h = int(round(im_shape[i][0] / scale_factor[i])) + im_w = int(round(im_shape[i][1] / scale_factor[i])) boxes_n = expand_bbox(boxes_n, scale) boxes_n = boxes_n.astype(np.int32) padded_mask = np.zeros((M + 2, M + 2), dtype=np.float32) -- GitLab