From 6ac9743c97e03d86fad4eb9a1777bf6344eed93b Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 20 Oct 2020 09:51:56 +0800 Subject: [PATCH] Add detection api for Paddle 2.0 (#1575) * migrate detection ops from Paddle to PaddleDetection * add unittest for detection ops * update api & unittest * add copyright & api name to migrate --- ppdet/data/transform/operators.py | 2 +- ppdet/modeling/bbox.py | 3 +- ppdet/modeling/layers.py | 496 ++++++++++++++++++++++++ ppdet/modeling/ops.py | 619 +++++++----------------------- ppdet/modeling/tests/__init__.py | 13 + ppdet/modeling/tests/test_base.py | 73 ++++ ppdet/modeling/tests/test_ops.py | 158 ++++++++ tools/train.py | 1 - 8 files changed, 883 insertions(+), 482 deletions(-) create mode 100644 ppdet/modeling/layers.py create mode 100644 ppdet/modeling/tests/__init__.py create mode 100644 ppdet/modeling/tests/test_base.py create mode 100644 ppdet/modeling/tests/test_ops.py diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index eb9f287fa..c36b9af09 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -38,7 +38,7 @@ import cv2 from PIL import Image, ImageEnhance, ImageDraw from ppdet.core.workspace import serializable -from ppdet.modeling.ops import AnchorGrid +from ppdet.modeling.layers import AnchorGrid from .op_helper import (satisfy_sample_constraint, filter_and_process, generate_sample_bbox, clip_bbox, data_anchor_sampling, diff --git a/ppdet/modeling/bbox.py b/ppdet/modeling/bbox.py index f60cc65e1..31e03a92e 100644 --- a/ppdet/modeling/bbox.py +++ b/ppdet/modeling/bbox.py @@ -4,6 +4,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register +from . import ops @register @@ -218,7 +219,7 @@ class Proposal(object): start_level = 2 end_level = start_level + len(rpn_head_out) - rois_collect, rois_num_collect = fluid.layers.collect_fpn_proposals( + rois_collect, rois_num_collect = ops.collect_fpn_proposals( rpn_rois_list, rpn_prob_list, start_level, diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py new file mode 100644 index 000000000..0488d8233 --- /dev/null +++ b/ppdet/modeling/layers.py @@ -0,0 +1,496 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from numbers import Integral +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable +from ppdet.core.workspace import register, serializable +from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target +from ppdet.py_op.post_process import bbox_post_process + + +@register +@serializable +class AnchorGeneratorRPN(object): + def __init__(self, + anchor_sizes=[32, 64, 128, 256, 512], + aspect_ratios=[0.5, 1.0, 2.0], + stride=[16.0, 16.0], + variance=[1.0, 1.0, 1.0, 1.0], + anchor_start_size=None): + super(AnchorGeneratorRPN, self).__init__() + self.anchor_sizes = anchor_sizes + self.aspect_ratios = aspect_ratios + self.stride = stride + self.variance = variance + self.anchor_start_size = anchor_start_size + + def __call__(self, input, level=None): + anchor_sizes = self.anchor_sizes if ( + level is None or self.anchor_start_size is None) else ( + self.anchor_start_size * 2**level) + stride = self.stride if ( + level is None or self.anchor_start_size is None) else ( + self.stride[0] * (2.**level), self.stride[1] * (2.**level)) + anchor, var = fluid.layers.anchor_generator( + input=input, + anchor_sizes=anchor_sizes, + aspect_ratios=self.aspect_ratios, + stride=stride, + variance=self.variance) + return anchor, var + + +@register +@serializable +class AnchorTargetGeneratorRPN(object): + def __init__(self, + batch_size_per_im=256, + straddle_thresh=0., + fg_fraction=0.5, + positive_overlap=0.7, + negative_overlap=0.3, + use_random=True): + super(AnchorTargetGeneratorRPN, self).__init__() + self.batch_size_per_im = batch_size_per_im + self.straddle_thresh = straddle_thresh + self.fg_fraction = fg_fraction + self.positive_overlap = positive_overlap + self.negative_overlap = negative_overlap + self.use_random = use_random + + def __call__(self, cls_logits, bbox_pred, anchor_box, gt_boxes, is_crowd, + im_info): + anchor_box = anchor_box.numpy() + gt_boxes = gt_boxes.numpy() + is_crowd = is_crowd.numpy() + im_info = im_info.numpy() + loc_indexes, score_indexes, tgt_labels, tgt_bboxes, bbox_inside_weights = generate_rpn_anchor_target( + anchor_box, gt_boxes, is_crowd, im_info, self.straddle_thresh, + self.batch_size_per_im, self.positive_overlap, + self.negative_overlap, self.fg_fraction, self.use_random) + + loc_indexes = to_variable(loc_indexes) + score_indexes = to_variable(score_indexes) + tgt_labels = to_variable(tgt_labels) + tgt_bboxes = to_variable(tgt_bboxes) + bbox_inside_weights = to_variable(bbox_inside_weights) + + loc_indexes.stop_gradient = True + score_indexes.stop_gradient = True + tgt_labels.stop_gradient = True + + cls_logits = fluid.layers.reshape(x=cls_logits, shape=(-1, )) + bbox_pred = fluid.layers.reshape(x=bbox_pred, shape=(-1, 4)) + pred_cls_logits = fluid.layers.gather(cls_logits, score_indexes) + pred_bbox_pred = fluid.layers.gather(bbox_pred, loc_indexes) + + return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights + + +@register +@serializable +class AnchorGeneratorYOLO(object): + def __init__(self, + anchors=[ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, + 156, 198, 373, 326 + ], + anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]): + super(AnchorGeneratorYOLO, self).__init__() + self.anchors = anchors + self.anchor_masks = anchor_masks + + def __call__(self): + anchor_num = len(self.anchors) + mask_anchors = [] + for i in range(len(self.anchor_masks)): + mask_anchor = [] + for m in self.anchor_masks[i]: + assert m < anchor_num, "anchor mask index overflow" + mask_anchor.extend(self.anchors[2 * m:2 * m + 2]) + mask_anchors.append(mask_anchor) + return self.anchors, self.anchor_masks, mask_anchors + + +@register +@serializable +class AnchorTargetGeneratorYOLO(object): + def __init__(self, + ignore_thresh=0.7, + downsample_ratio=32, + label_smooth=True): + super(AnchorTargetGeneratorYOLO, self).__init__() + self.ignore_thresh = ignore_thresh + self.downsample_ratio = downsample_ratio + self.label_smooth = label_smooth + + def __call__(self, ): + # TODO: split yolov3_loss into here + outs = { + 'ignore_thresh': self.ignore_thresh, + 'downsample_ratio': self.downsample_ratio, + 'label_smooth': self.label_smooth + } + return outs + + +@register +@serializable +class ProposalGenerator(object): + __append_doc__ = True + + def __init__(self, + train_pre_nms_top_n=12000, + train_post_nms_top_n=2000, + infer_pre_nms_top_n=6000, + infer_post_nms_top_n=1000, + nms_thresh=.5, + min_size=.1, + eta=1.): + super(ProposalGenerator, self).__init__() + self.train_pre_nms_top_n = train_pre_nms_top_n + self.train_post_nms_top_n = train_post_nms_top_n + self.infer_pre_nms_top_n = infer_pre_nms_top_n + self.infer_post_nms_top_n = infer_post_nms_top_n + self.nms_thresh = nms_thresh + self.min_size = min_size + self.eta = eta + + def __call__(self, + scores, + bbox_deltas, + anchors, + variances, + im_info, + mode='train'): + pre_nms_top_n = self.train_pre_nms_top_n if mode == 'train' else self.infer_pre_nms_top_n + post_nms_top_n = self.train_post_nms_top_n if mode == 'train' else self.infer_post_nms_top_n + rpn_rois, rpn_rois_prob, rpn_rois_num = fluid.layers.generate_proposals( + scores, + bbox_deltas, + im_info, + anchors, + variances, + pre_nms_top_n=pre_nms_top_n, + post_nms_top_n=post_nms_top_n, + nms_thresh=self.nms_thresh, + min_size=self.min_size, + eta=self.eta, + return_rois_num=True) + return rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n + + +@register +@serializable +class ProposalTargetGenerator(object): + __shared__ = ['num_classes'] + + def __init__(self, + batch_size_per_im=512, + fg_fraction=.25, + fg_thresh=[.5, ], + bg_thresh_hi=[.5, ], + bg_thresh_lo=[0., ], + bbox_reg_weights=[[0.1, 0.1, 0.2, 0.2]], + num_classes=81, + use_random=True, + is_cls_agnostic=False, + is_cascade_rcnn=False): + super(ProposalTargetGenerator, self).__init__() + self.batch_size_per_im = batch_size_per_im + self.fg_fraction = fg_fraction + self.fg_thresh = fg_thresh + self.bg_thresh_hi = bg_thresh_hi + self.bg_thresh_lo = bg_thresh_lo + self.bbox_reg_weights = bbox_reg_weights + self.num_classes = num_classes + self.use_random = use_random + self.is_cls_agnostic = is_cls_agnostic + self.is_cascade_rcnn = is_cascade_rcnn + + def __call__(self, + rpn_rois, + rpn_rois_num, + gt_classes, + is_crowd, + gt_boxes, + im_info, + stage=0): + rpn_rois = rpn_rois.numpy() + rpn_rois_num = rpn_rois_num.numpy() + gt_classes = gt_classes.numpy() + gt_boxes = gt_boxes.numpy() + is_crowd = is_crowd.numpy() + im_info = im_info.numpy() + outs = generate_proposal_target( + rpn_rois, rpn_rois_num, gt_classes, is_crowd, gt_boxes, im_info, + self.batch_size_per_im, self.fg_fraction, self.fg_thresh[stage], + self.bg_thresh_hi[stage], self.bg_thresh_lo[stage], + self.bbox_reg_weights[stage], self.num_classes, self.use_random, + self.is_cls_agnostic, self.is_cascade_rcnn) + outs = [to_variable(v) for v in outs] + for v in outs: + v.stop_gradient = True + return outs + + +@register +@serializable +class MaskTargetGenerator(object): + __shared__ = ['num_classes', 'mask_resolution'] + + def __init__(self, num_classes=81, mask_resolution=14): + super(MaskTargetGenerator, self).__init__() + self.num_classes = num_classes + self.mask_resolution = mask_resolution + + def __call__(self, im_info, gt_classes, is_crowd, gt_segms, rois, rois_num, + labels_int32): + im_info = im_info.numpy() + gt_classes = gt_classes.numpy() + is_crowd = is_crowd.numpy() + gt_segms = gt_segms.numpy() + rois = rois.numpy() + rois_num = rois_num.numpy() + labels_int32 = labels_int32.numpy() + outs = generate_mask_target(im_info, gt_classes, is_crowd, gt_segms, + rois, rois_num, labels_int32, + self.num_classes, self.mask_resolution) + + outs = [to_variable(v) for v in outs] + for v in outs: + v.stop_gradient = True + return outs + + +@register +class RoIExtractor(object): + def __init__(self, + resolution=14, + sampling_ratio=0, + canconical_level=4, + canonical_size=224, + start_level=0, + end_level=3): + super(RoIExtractor, self).__init__() + self.resolution = resolution + self.sampling_ratio = sampling_ratio + self.canconical_level = canconical_level + self.canonical_size = canonical_size + self.start_level = start_level + self.end_level = end_level + + def __call__(self, feats, rois, spatial_scale): + roi, rois_num = rois + cur_l = 0 + if self.start_level == self.end_level: + rois_feat = fluid.layers.roi_align( + feats[self.start_level], + roi, + self.resolution, + self.resolution, + spatial_scale, + rois_num=rois_num) + return rois_feat + offset = 2 + k_min = self.start_level + offset + k_max = self.end_level + offset + rois_dist, restore_index, rois_num_dist = fluid.layers.distribute_fpn_proposals( + roi, + k_min, + k_max, + self.canconical_level, + self.canonical_size, + rois_num=rois_num) + + rois_feat_list = [] + for lvl in range(self.start_level, self.end_level + 1): + roi_feat = fluid.layers.roi_align( + feats[lvl], + rois_dist[lvl], + self.resolution, + self.resolution, + spatial_scale[lvl], + sampling_ratio=self.sampling_ratio, + rois_num=rois_num_dist[lvl]) + rois_feat_list.append(roi_feat) + rois_feat_shuffle = fluid.layers.concat(rois_feat_list) + rois_feat = fluid.layers.gather(rois_feat_shuffle, restore_index) + + return rois_feat + + +@register +@serializable +class DecodeClipNms(object): + __shared__ = ['num_classes'] + + def __init__( + self, + num_classes=81, + keep_top_k=100, + score_threshold=0.05, + nms_threshold=0.5, ): + super(DecodeClipNms, self).__init__() + self.num_classes = num_classes + self.keep_top_k = keep_top_k + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + + def __call__(self, bboxes, bbox_prob, bbox_delta, im_info): + bboxes_np = (i.numpy() for i in bboxes) + # bbox, bbox_num + outs = bbox_post_process(bboxes_np, + bbox_prob.numpy(), + bbox_delta.numpy(), + im_info.numpy(), self.keep_top_k, + self.score_threshold, self.nms_threshold, + self.num_classes) + outs = [to_variable(v) for v in outs] + for v in outs: + v.stop_gradient = True + return outs + + +@register +@serializable +class MultiClassNMS(object): + __op__ = fluid.layers.multiclass_nms + __append_doc__ = True + + def __init__(self, + score_threshold=.05, + nms_top_k=-1, + keep_top_k=100, + nms_threshold=.5, + normalized=False, + nms_eta=1.0, + background_label=0): + super(MultiClassNMS, self).__init__() + self.score_threshold = score_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + self.nms_threshold = nms_threshold + self.normalized = normalized + self.nms_eta = nms_eta + self.background_label = background_label + + +@register +@serializable +class YOLOBox(object): + def __init__( + self, + conf_thresh=0.005, + downsample_ratio=32, + clip_bbox=True, ): + self.conf_thresh = conf_thresh + self.downsample_ratio = downsample_ratio + self.clip_bbox = clip_bbox + + def __call__(self, x, img_size, anchors, num_classes, stage=0): + outs = fluid.layers.yolo_box(x, img_size, anchors, num_classes, + self.conf_thresh, self.downsample_ratio // + 2**stage, self.clip_bbox) + return outs + + +@register +@serializable +class AnchorGrid(object): + """Generate anchor grid + + Args: + image_size (int or list): input image size, may be a single integer or + list of [h, w]. Default: 512 + min_level (int): min level of the feature pyramid. Default: 3 + max_level (int): max level of the feature pyramid. Default: 7 + anchor_base_scale: base anchor scale. Default: 4 + num_scales: number of anchor scales. Default: 3 + aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]] + """ + + def __init__(self, + image_size=512, + min_level=3, + max_level=7, + anchor_base_scale=4, + num_scales=3, + aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]): + super(AnchorGrid, self).__init__() + if isinstance(image_size, Integral): + self.image_size = [image_size, image_size] + else: + self.image_size = image_size + for dim in self.image_size: + assert dim % 2 ** max_level == 0, \ + "image size should be multiple of the max level stride" + self.min_level = min_level + self.max_level = max_level + self.anchor_base_scale = anchor_base_scale + self.num_scales = num_scales + self.aspect_ratios = aspect_ratios + + @property + def base_cell(self): + if not hasattr(self, '_base_cell'): + self._base_cell = self.make_cell() + return self._base_cell + + def make_cell(self): + scales = [2**(i / self.num_scales) for i in range(self.num_scales)] + scales = np.array(scales) + ratios = np.array(self.aspect_ratios) + ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1) + hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1) + anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs)) + return anchors + + def make_grid(self, stride): + cell = self.base_cell * stride * self.anchor_base_scale + x_steps = np.arange(stride // 2, self.image_size[1], stride) + y_steps = np.arange(stride // 2, self.image_size[0], stride) + offset_x, offset_y = np.meshgrid(x_steps, y_steps) + offset_x = offset_x.flatten() + offset_y = offset_y.flatten() + offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1) + offsets = offsets[:, np.newaxis, :] + return (cell + offsets).reshape(-1, 4) + + def generate(self): + return [ + self.make_grid(2**l) + for l in range(self.min_level, self.max_level + 1) + ] + + def __call__(self): + if not hasattr(self, '_anchor_vars'): + anchor_vars = [] + helper = LayerHelper('anchor_grid') + for idx, l in enumerate(range(self.min_level, self.max_level + 1)): + stride = 2**l + anchors = self.make_grid(stride) + var = helper.create_parameter( + attr=ParamAttr(name='anchors_{}'.format(idx)), + shape=anchors.shape, + dtype='float32', + stop_gradient=True, + default_initializer=NumpyArrayInitializer(anchors)) + anchor_vars.append(var) + var.persistable = True + self._anchor_vars = anchor_vars + + return self._anchor_vars diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 7f2edbac1..d385b891e 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -1,482 +1,143 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from paddle.fluid.framework import Variable, in_dygraph_mode +from paddle.fluid import core +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph import layers +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype +import math +import six import numpy as np -from numbers import Integral -import paddle.fluid as fluid -from paddle.fluid.dygraph.base import to_variable -from ppdet.core.workspace import register, serializable -from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target -from ppdet.py_op.post_process import bbox_post_process - - -@register -@serializable -class AnchorGeneratorRPN(object): - def __init__(self, - anchor_sizes=[32, 64, 128, 256, 512], - aspect_ratios=[0.5, 1.0, 2.0], - stride=[16.0, 16.0], - variance=[1.0, 1.0, 1.0, 1.0], - anchor_start_size=None): - super(AnchorGeneratorRPN, self).__init__() - self.anchor_sizes = anchor_sizes - self.aspect_ratios = aspect_ratios - self.stride = stride - self.variance = variance - self.anchor_start_size = anchor_start_size - - def __call__(self, input, level=None): - anchor_sizes = self.anchor_sizes if ( - level is None or self.anchor_start_size is None) else ( - self.anchor_start_size * 2**level) - stride = self.stride if ( - level is None or self.anchor_start_size is None) else ( - self.stride[0] * (2.**level), self.stride[1] * (2.**level)) - anchor, var = fluid.layers.anchor_generator( - input=input, - anchor_sizes=anchor_sizes, - aspect_ratios=self.aspect_ratios, - stride=stride, - variance=self.variance) - return anchor, var - - -@register -@serializable -class AnchorTargetGeneratorRPN(object): - def __init__(self, - batch_size_per_im=256, - straddle_thresh=0., - fg_fraction=0.5, - positive_overlap=0.7, - negative_overlap=0.3, - use_random=True): - super(AnchorTargetGeneratorRPN, self).__init__() - self.batch_size_per_im = batch_size_per_im - self.straddle_thresh = straddle_thresh - self.fg_fraction = fg_fraction - self.positive_overlap = positive_overlap - self.negative_overlap = negative_overlap - self.use_random = use_random - - def __call__(self, cls_logits, bbox_pred, anchor_box, gt_boxes, is_crowd, - im_info): - anchor_box = anchor_box.numpy() - gt_boxes = gt_boxes.numpy() - is_crowd = is_crowd.numpy() - im_info = im_info.numpy() - loc_indexes, score_indexes, tgt_labels, tgt_bboxes, bbox_inside_weights = generate_rpn_anchor_target( - anchor_box, gt_boxes, is_crowd, im_info, self.straddle_thresh, - self.batch_size_per_im, self.positive_overlap, - self.negative_overlap, self.fg_fraction, self.use_random) - - loc_indexes = to_variable(loc_indexes) - score_indexes = to_variable(score_indexes) - tgt_labels = to_variable(tgt_labels) - tgt_bboxes = to_variable(tgt_bboxes) - bbox_inside_weights = to_variable(bbox_inside_weights) - - loc_indexes.stop_gradient = True - score_indexes.stop_gradient = True - tgt_labels.stop_gradient = True - - cls_logits = fluid.layers.reshape(x=cls_logits, shape=(-1, )) - bbox_pred = fluid.layers.reshape(x=bbox_pred, shape=(-1, 4)) - pred_cls_logits = fluid.layers.gather(cls_logits, score_indexes) - pred_bbox_pred = fluid.layers.gather(bbox_pred, loc_indexes) - - return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights - - -@register -@serializable -class AnchorGeneratorYOLO(object): - def __init__(self, - anchors=[ - 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, - 156, 198, 373, 326 - ], - anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]): - super(AnchorGeneratorYOLO, self).__init__() - self.anchors = anchors - self.anchor_masks = anchor_masks - - def __call__(self): - anchor_num = len(self.anchors) - mask_anchors = [] - for i in range(len(self.anchor_masks)): - mask_anchor = [] - for m in self.anchor_masks[i]: - assert m < anchor_num, "anchor mask index overflow" - mask_anchor.extend(self.anchors[2 * m:2 * m + 2]) - mask_anchors.append(mask_anchor) - return self.anchors, self.anchor_masks, mask_anchors - - -@register -@serializable -class AnchorTargetGeneratorYOLO(object): - def __init__(self, - ignore_thresh=0.7, - downsample_ratio=32, - label_smooth=True): - super(AnchorTargetGeneratorYOLO, self).__init__() - self.ignore_thresh = ignore_thresh - self.downsample_ratio = downsample_ratio - self.label_smooth = label_smooth - - def __call__(self, ): - # TODO: split yolov3_loss into here - outs = { - 'ignore_thresh': self.ignore_thresh, - 'downsample_ratio': self.downsample_ratio, - 'label_smooth': self.label_smooth - } - return outs - - -@register -@serializable -class ProposalGenerator(object): - __append_doc__ = True - - def __init__(self, - train_pre_nms_top_n=12000, - train_post_nms_top_n=2000, - infer_pre_nms_top_n=6000, - infer_post_nms_top_n=1000, - nms_thresh=.5, - min_size=.1, - eta=1.): - super(ProposalGenerator, self).__init__() - self.train_pre_nms_top_n = train_pre_nms_top_n - self.train_post_nms_top_n = train_post_nms_top_n - self.infer_pre_nms_top_n = infer_pre_nms_top_n - self.infer_post_nms_top_n = infer_post_nms_top_n - self.nms_thresh = nms_thresh - self.min_size = min_size - self.eta = eta - - def __call__(self, - scores, - bbox_deltas, - anchors, - variances, - im_info, - mode='train'): - pre_nms_top_n = self.train_pre_nms_top_n if mode == 'train' else self.infer_pre_nms_top_n - post_nms_top_n = self.train_post_nms_top_n if mode == 'train' else self.infer_post_nms_top_n - rpn_rois, rpn_rois_prob, rpn_rois_num = fluid.layers.generate_proposals( - scores, - bbox_deltas, - im_info, - anchors, - variances, - pre_nms_top_n=pre_nms_top_n, - post_nms_top_n=post_nms_top_n, - nms_thresh=self.nms_thresh, - min_size=self.min_size, - eta=self.eta, - return_rois_num=True) - return rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n - - -@register -@serializable -class ProposalTargetGenerator(object): - __shared__ = ['num_classes'] - - def __init__(self, - batch_size_per_im=512, - fg_fraction=.25, - fg_thresh=[.5, ], - bg_thresh_hi=[.5, ], - bg_thresh_lo=[0., ], - bbox_reg_weights=[[0.1, 0.1, 0.2, 0.2]], - num_classes=81, - use_random=True, - is_cls_agnostic=False, - is_cascade_rcnn=False): - super(ProposalTargetGenerator, self).__init__() - self.batch_size_per_im = batch_size_per_im - self.fg_fraction = fg_fraction - self.fg_thresh = fg_thresh - self.bg_thresh_hi = bg_thresh_hi - self.bg_thresh_lo = bg_thresh_lo - self.bbox_reg_weights = bbox_reg_weights - self.num_classes = num_classes - self.use_random = use_random - self.is_cls_agnostic = is_cls_agnostic - self.is_cascade_rcnn = is_cascade_rcnn - - def __call__(self, - rpn_rois, - rpn_rois_num, - gt_classes, - is_crowd, - gt_boxes, - im_info, - stage=0): - rpn_rois = rpn_rois.numpy() - rpn_rois_num = rpn_rois_num.numpy() - gt_classes = gt_classes.numpy() - gt_boxes = gt_boxes.numpy() - is_crowd = is_crowd.numpy() - im_info = im_info.numpy() - outs = generate_proposal_target( - rpn_rois, rpn_rois_num, gt_classes, is_crowd, gt_boxes, im_info, - self.batch_size_per_im, self.fg_fraction, self.fg_thresh[stage], - self.bg_thresh_hi[stage], self.bg_thresh_lo[stage], - self.bbox_reg_weights[stage], self.num_classes, self.use_random, - self.is_cls_agnostic, self.is_cascade_rcnn) - outs = [to_variable(v) for v in outs] - for v in outs: - v.stop_gradient = True - return outs - - -@register -@serializable -class MaskTargetGenerator(object): - __shared__ = ['num_classes', 'mask_resolution'] - - def __init__(self, num_classes=81, mask_resolution=14): - super(MaskTargetGenerator, self).__init__() - self.num_classes = num_classes - self.mask_resolution = mask_resolution - - def __call__(self, im_info, gt_classes, is_crowd, gt_segms, rois, rois_num, - labels_int32): - im_info = im_info.numpy() - gt_classes = gt_classes.numpy() - is_crowd = is_crowd.numpy() - gt_segms = gt_segms.numpy() - rois = rois.numpy() - rois_num = rois_num.numpy() - labels_int32 = labels_int32.numpy() - outs = generate_mask_target(im_info, gt_classes, is_crowd, gt_segms, - rois, rois_num, labels_int32, - self.num_classes, self.mask_resolution) - - outs = [to_variable(v) for v in outs] - for v in outs: - v.stop_gradient = True - return outs - - -@register -class RoIExtractor(object): - def __init__(self, - resolution=14, - sampling_ratio=0, - canconical_level=4, - canonical_size=224, - start_level=0, - end_level=3): - super(RoIExtractor, self).__init__() - self.resolution = resolution - self.sampling_ratio = sampling_ratio - self.canconical_level = canconical_level - self.canonical_size = canonical_size - self.start_level = start_level - self.end_level = end_level - - def __call__(self, feats, rois, spatial_scale): - roi, rois_num = rois - cur_l = 0 - if self.start_level == self.end_level: - rois_feat = fluid.layers.roi_align( - feats[self.start_level], - roi, - self.resolution, - self.resolution, - spatial_scale, - rois_num=rois_num) - return rois_feat - offset = 2 - k_min = self.start_level + offset - k_max = self.end_level + offset - rois_dist, restore_index, rois_num_dist = fluid.layers.distribute_fpn_proposals( - roi, - k_min, - k_max, - self.canconical_level, - self.canonical_size, - rois_num=rois_num) - - rois_feat_list = [] - for lvl in range(self.start_level, self.end_level + 1): - roi_feat = fluid.layers.roi_align( - feats[lvl], - rois_dist[lvl], - self.resolution, - self.resolution, - spatial_scale[lvl], - sampling_ratio=self.sampling_ratio, - rois_num=rois_num_dist[lvl]) - rois_feat_list.append(roi_feat) - rois_feat_shuffle = fluid.layers.concat(rois_feat_list) - rois_feat = fluid.layers.gather(rois_feat_shuffle, restore_index) - - return rois_feat - - -@register -@serializable -class DecodeClipNms(object): - __shared__ = ['num_classes'] - - def __init__( - self, - num_classes=81, - keep_top_k=100, - score_threshold=0.05, - nms_threshold=0.5, ): - super(DecodeClipNms, self).__init__() - self.num_classes = num_classes - self.keep_top_k = keep_top_k - self.score_threshold = score_threshold - self.nms_threshold = nms_threshold - - def __call__(self, bboxes, bbox_prob, bbox_delta, im_info): - bboxes_np = (i.numpy() for i in bboxes) - # bbox, bbox_num - outs = bbox_post_process(bboxes_np, - bbox_prob.numpy(), - bbox_delta.numpy(), - im_info.numpy(), self.keep_top_k, - self.score_threshold, self.nms_threshold, - self.num_classes) - outs = [to_variable(v) for v in outs] - for v in outs: - v.stop_gradient = True - return outs - - -@register -@serializable -class MultiClassNMS(object): - __op__ = fluid.layers.multiclass_nms - __append_doc__ = True - - def __init__(self, - score_threshold=.05, - nms_top_k=-1, - keep_top_k=100, - nms_threshold=.5, - normalized=False, - nms_eta=1.0, - background_label=0): - super(MultiClassNMS, self).__init__() - self.score_threshold = score_threshold - self.nms_top_k = nms_top_k - self.keep_top_k = keep_top_k - self.nms_threshold = nms_threshold - self.normalized = normalized - self.nms_eta = nms_eta - self.background_label = background_label - - -@register -@serializable -class YOLOBox(object): - def __init__( - self, - conf_thresh=0.005, - downsample_ratio=32, - clip_bbox=True, ): - self.conf_thresh = conf_thresh - self.downsample_ratio = downsample_ratio - self.clip_bbox = clip_bbox - - def __call__(self, x, img_size, anchors, num_classes, stage=0): - outs = fluid.layers.yolo_box(x, img_size, anchors, num_classes, - self.conf_thresh, self.downsample_ratio // - 2**stage, self.clip_bbox) - return outs - - -@register -@serializable -class AnchorGrid(object): - """Generate anchor grid - +from functools import reduce + +__all__ = [ + #'roi_pool', + #'roi_align', + #'prior_box', + #'anchor_generator', + #'generate_proposals', + #'iou_similarity', + #'box_coder', + #'yolo_box', + #'multiclass_nms', + #'distribute_fpn_proposals', + 'collect_fpn_proposals', + #'matrix_nms', +] + + +def collect_fpn_proposals(multi_rois, + multi_scores, + min_level, + max_level, + post_nms_top_n, + rois_num_per_level=None, + name=None): + """ + + **This OP only supports LoDTensor as input**. Concat multi-level RoIs + (Region of Interest) and select N RoIs with respect to multi_scores. + This operation performs the following steps: + 1. Choose num_level RoIs and scores as input: num_level = max_level - min_level + 2. Concat multi-level RoIs and scores + 3. Sort scores and select post_nms_top_n scores + 4. Gather RoIs by selected indices from scores + 5. Re-sort RoIs by corresponding batch_id Args: - image_size (int or list): input image size, may be a single integer or - list of [h, w]. Default: 512 - min_level (int): min level of the feature pyramid. Default: 3 - max_level (int): max level of the feature pyramid. Default: 7 - anchor_base_scale: base anchor scale. Default: 4 - num_scales: number of anchor scales. Default: 3 - aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]] + multi_rois(list): List of RoIs to collect. Element in list is 2-D + LoDTensor with shape [N, 4] and data type is float32 or float64, + N is the number of RoIs. + multi_scores(list): List of scores of RoIs to collect. Element in list + is 2-D LoDTensor with shape [N, 1] and data type is float32 or + float64, N is the number of RoIs. + min_level(int): The lowest level of FPN layer to collect + max_level(int): The highest level of FPN layer to collect + post_nms_top_n(int): The number of selected RoIs + rois_num_per_level(list, optional): The List of RoIs' numbers. + Each element is 1-D Tensor which contains the RoIs' number of each + image on each level and the shape is [B] and data type is + int32, B is the number of images. If it is not None then return + a 1-D Tensor contains the output RoIs' number of each image and + the shape is [B]. Default: None + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + Returns: + Variable: + fpn_rois(Variable): 2-D LoDTensor with shape [N, 4] and data type is + float32 or float64. Selected RoIs. + rois_num(Tensor): 1-D Tensor contains the RoIs's number of each + image. The shape is [B] and data type is int32. B is the number of + images. + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle + paddle.enable_static() + multi_rois = [] + multi_scores = [] + for i in range(4): + multi_rois.append(fluid.data( + name='roi_'+str(i), shape=[None, 4], dtype='float32', lod_level=1)) + for i in range(4): + multi_scores.append(fluid.data( + name='score_'+str(i), shape=[None, 1], dtype='float32', lod_level=1)) + fpn_rois = fluid.layers.collect_fpn_proposals( + multi_rois=multi_rois, + multi_scores=multi_scores, + min_level=2, + max_level=5, + post_nms_top_n=2000) """ - - def __init__(self, - image_size=512, - min_level=3, - max_level=7, - anchor_base_scale=4, - num_scales=3, - aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]): - super(AnchorGrid, self).__init__() - if isinstance(image_size, Integral): - self.image_size = [image_size, image_size] - else: - self.image_size = image_size - for dim in self.image_size: - assert dim % 2 ** max_level == 0, \ - "image size should be multiple of the max level stride" - self.min_level = min_level - self.max_level = max_level - self.anchor_base_scale = anchor_base_scale - self.num_scales = num_scales - self.aspect_ratios = aspect_ratios - - @property - def base_cell(self): - if not hasattr(self, '_base_cell'): - self._base_cell = self.make_cell() - return self._base_cell - - def make_cell(self): - scales = [2**(i / self.num_scales) for i in range(self.num_scales)] - scales = np.array(scales) - ratios = np.array(self.aspect_ratios) - ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1) - hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1) - anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs)) - return anchors - - def make_grid(self, stride): - cell = self.base_cell * stride * self.anchor_base_scale - x_steps = np.arange(stride // 2, self.image_size[1], stride) - y_steps = np.arange(stride // 2, self.image_size[0], stride) - offset_x, offset_y = np.meshgrid(x_steps, y_steps) - offset_x = offset_x.flatten() - offset_y = offset_y.flatten() - offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1) - offsets = offsets[:, np.newaxis, :] - return (cell + offsets).reshape(-1, 4) - - def generate(self): - return [ - self.make_grid(2**l) - for l in range(self.min_level, self.max_level + 1) - ] - - def __call__(self): - if not hasattr(self, '_anchor_vars'): - anchor_vars = [] - helper = LayerHelper('anchor_grid') - for idx, l in enumerate(range(self.min_level, self.max_level + 1)): - stride = 2**l - anchors = self.make_grid(stride) - var = helper.create_parameter( - attr=ParamAttr(name='anchors_{}'.format(idx)), - shape=anchors.shape, - dtype='float32', - stop_gradient=True, - default_initializer=NumpyArrayInitializer(anchors)) - anchor_vars.append(var) - var.persistable = True - self._anchor_vars = anchor_vars - - return self._anchor_vars + check_type(multi_rois, 'multi_rois', list, 'collect_fpn_proposals') + check_type(multi_scores, 'multi_scores', list, 'collect_fpn_proposals') + num_lvl = max_level - min_level + 1 + input_rois = multi_rois[:num_lvl] + input_scores = multi_scores[:num_lvl] + + if in_dygraph_mode(): + assert rois_num_per_level is not None, "rois_num_per_level should not be None in dygraph mode." + attrs = ('post_nms_topN', post_nms_top_n) + output_rois, rois_num = core.ops.collect_fpn_proposals( + input_rois, input_scores, rois_num_per_level, *attrs) + + helper = LayerHelper('collect_fpn_proposals', **locals()) + dtype = helper.input_dtype('multi_rois') + check_dtype(dtype, 'multi_rois', ['float32', 'float64'], + 'collect_fpn_proposals') + output_rois = helper.create_variable_for_type_inference(dtype) + output_rois.stop_gradient = True + + inputs = { + 'MultiLevelRois': input_rois, + 'MultiLevelScores': input_scores, + } + outputs = {'FpnRois': output_rois} + if rois_num_per_level is not None: + inputs['MultiLevelRoIsNum'] = rois_num_per_level + rois_num = helper.create_variable_for_type_inference(dtype='int32') + rois_num.stop_gradient = True + outputs['RoisNum'] = rois_num + helper.append_op( + type='collect_fpn_proposals', + inputs=inputs, + outputs=outputs, + attrs={'post_nms_topN': post_nms_top_n}) + if rois_num_per_level is not None: + return output_rois, rois_num + return output_rois diff --git a/ppdet/modeling/tests/__init__.py b/ppdet/modeling/tests/__init__.py new file mode 100644 index 000000000..847ddc47a --- /dev/null +++ b/ppdet/modeling/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ppdet/modeling/tests/test_base.py b/ppdet/modeling/tests/test_base.py new file mode 100644 index 000000000..28fce2b41 --- /dev/null +++ b/ppdet/modeling/tests/test_base.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest + +import contextlib +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import Program +from paddle.fluid import core + + +class LayerTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.seed = 111 + + @classmethod + def tearDownClass(cls): + pass + + def _get_place(self, force_to_use_cpu=False): + # this option for ops that only have cpu kernel + if force_to_use_cpu: + return core.CPUPlace() + else: + if core.is_compiled_with_cuda(): + return core.CUDAPlace(0) + return core.CPUPlace() + + @contextlib.contextmanager + def static_graph(self): + scope = fluid.core.Scope() + program = Program() + with fluid.scope_guard(scope): + with fluid.program_guard(program): + paddle.manual_seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + yield + + def get_static_graph_result(self, + feed, + fetch_list, + with_lod=False, + force_to_use_cpu=False): + exe = fluid.Executor(self._get_place(force_to_use_cpu)) + exe.run(fluid.default_startup_program()) + return exe.run(fluid.default_main_program(), + feed=feed, + fetch_list=fetch_list, + return_numpy=(not with_lod)) + + @contextlib.contextmanager + def dynamic_graph(self, force_to_use_cpu=False): + with fluid.dygraph.guard( + self._get_place(force_to_use_cpu=force_to_use_cpu)): + paddle.manual_seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + yield diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py new file mode 100644 index 000000000..555ac3243 --- /dev/null +++ b/ppdet/modeling/tests/test_ops.py @@ -0,0 +1,158 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import os, sys +# add python path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4))) +if parent_path not in sys.path: + sys.path.append(parent_path) + +import unittest +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import Program, program_guard +from paddle.fluid.dygraph import base + +import ppdet.modeling.ops as ops +from ppdet.modeling.tests.test_base import LayerTest + + +class TestCollectFpnProposals(LayerTest): + def test_collect_fpn_proposals(self): + multi_bboxes_np = [] + multi_scores_np = [] + rois_num_per_level_np = [] + for i in range(4): + bboxes_np = np.random.rand(5, 4).astype('float32') + scores_np = np.random.rand(5, 1).astype('float32') + rois_num = np.array([2, 3]).astype('int32') + multi_bboxes_np.append(bboxes_np) + multi_scores_np.append(scores_np) + rois_num_per_level_np.append(rois_num) + + paddle.enable_static() + with self.static_graph(): + multi_bboxes = [] + multi_scores = [] + rois_num_per_level = [] + for i in range(4): + bboxes = paddle.static.data( + name='rois' + str(i), + shape=[5, 4], + dtype='float32', + lod_level=1) + scores = paddle.static.data( + name='scores' + str(i), + shape=[5, 1], + dtype='float32', + lod_level=1) + rois_num = paddle.static.data( + name='rois_num' + str(i), shape=[None], dtype='int32') + + multi_bboxes.append(bboxes) + multi_scores.append(scores) + rois_num_per_level.append(rois_num) + + fpn_rois, rois_num = ops.collect_fpn_proposals( + multi_bboxes, + multi_scores, + 2, + 5, + 10, + rois_num_per_level=rois_num_per_level) + feed = {} + for i in range(4): + feed['rois' + str(i)] = multi_bboxes_np[i] + feed['scores' + str(i)] = multi_scores_np[i] + feed['rois_num' + str(i)] = rois_num_per_level_np[i] + fpn_rois_stat, rois_num_stat = self.get_static_graph_result( + feed=feed, fetch_list=[fpn_rois, rois_num], with_lod=True) + fpn_rois_stat = np.array(fpn_rois_stat) + rois_num_stat = np.array(rois_num_stat) + + paddle.disable_static() + with self.dynamic_graph(): + multi_bboxes_dy = [] + multi_scores_dy = [] + rois_num_per_level_dy = [] + for i in range(4): + bboxes_dy = base.to_variable(multi_bboxes_np[i]) + scores_dy = base.to_variable(multi_scores_np[i]) + rois_num_dy = base.to_variable(rois_num_per_level_np[i]) + multi_bboxes_dy.append(bboxes_dy) + multi_scores_dy.append(scores_dy) + rois_num_per_level_dy.append(rois_num_dy) + fpn_rois_dy, rois_num_dy = ops.collect_fpn_proposals( + multi_bboxes_dy, + multi_scores_dy, + 2, + 5, + 10, + rois_num_per_level=rois_num_per_level_dy) + fpn_rois_dy = fpn_rois_dy.numpy() + rois_num_dy = rois_num_dy.numpy() + + self.assertTrue(np.array_equal(fpn_rois_stat, fpn_rois_dy)) + self.assertTrue(np.array_equal(rois_num_stat, rois_num_dy)) + + def test_collect_fpn_proposals_error(self): + def generate_input(bbox_type, score_type, name): + multi_bboxes = [] + multi_scores = [] + for i in range(4): + bboxes = paddle.static.data( + name='rois' + name + str(i), + shape=[10, 4], + dtype=bbox_type, + lod_level=1) + scores = paddle.static.data( + name='scores' + name + str(i), + shape=[10, 1], + dtype=score_type, + lod_level=1) + multi_bboxes.append(bboxes) + multi_scores.append(scores) + return multi_bboxes, multi_scores + + paddle.enable_static() + program = Program() + with program_guard(program): + bbox1 = paddle.static.data( + name='rois', shape=[5, 10, 4], dtype='float32', lod_level=1) + score1 = paddle.static.data( + name='scores', shape=[5, 10, 1], dtype='float32', lod_level=1) + bbox2, score2 = generate_input('int32', 'float32', '2') + self.assertRaises( + TypeError, + ops.collect_fpn_proposals, + multi_rois=bbox1, + multi_scores=score1, + min_level=2, + max_level=5, + post_nms_top_n=2000) + self.assertRaises( + TypeError, + ops.collect_fpn_proposals, + multi_rois=bbox2, + multi_scores=score2, + min_level=2, + max_level=5, + post_nms_top_n=2000) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/train.py b/tools/train.py index 178df7958..1887f8e69 100755 --- a/tools/train.py +++ b/tools/train.py @@ -121,7 +121,6 @@ def run(FLAGS, cfg): strategy = paddle.distributed.init_parallel_env() model = paddle.DataParallel(model, strategy) - logger.info("success!") # Data Reader start_iter = 0 if cfg.use_gpu: -- GitLab