未验证 提交 6ac9743c 编写于 作者: W wangguanzhong 提交者: GitHub

Add detection api for Paddle 2.0 (#1575)

* migrate detection ops from Paddle to PaddleDetection

* add unittest for detection ops

* update api & unittest

* add copyright & api name to migrate
上级 0e192063
...@@ -38,7 +38,7 @@ import cv2 ...@@ -38,7 +38,7 @@ import cv2
from PIL import Image, ImageEnhance, ImageDraw from PIL import Image, ImageEnhance, ImageDraw
from ppdet.core.workspace import serializable from ppdet.core.workspace import serializable
from ppdet.modeling.ops import AnchorGrid from ppdet.modeling.layers import AnchorGrid
from .op_helper import (satisfy_sample_constraint, filter_and_process, from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling, generate_sample_bbox, clip_bbox, data_anchor_sampling,
......
...@@ -4,6 +4,7 @@ import paddle ...@@ -4,6 +4,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
from . import ops
@register @register
...@@ -218,7 +219,7 @@ class Proposal(object): ...@@ -218,7 +219,7 @@ class Proposal(object):
start_level = 2 start_level = 2
end_level = start_level + len(rpn_head_out) end_level = start_level + len(rpn_head_out)
rois_collect, rois_num_collect = fluid.layers.collect_fpn_proposals( rois_collect, rois_num_collect = ops.collect_fpn_proposals(
rpn_rois_list, rpn_rois_list,
rpn_prob_list, rpn_prob_list,
start_level, start_level,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from numbers import Integral
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from ppdet.core.workspace import register, serializable
from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target
from ppdet.py_op.post_process import bbox_post_process
@register
@serializable
class AnchorGeneratorRPN(object):
def __init__(self,
anchor_sizes=[32, 64, 128, 256, 512],
aspect_ratios=[0.5, 1.0, 2.0],
stride=[16.0, 16.0],
variance=[1.0, 1.0, 1.0, 1.0],
anchor_start_size=None):
super(AnchorGeneratorRPN, self).__init__()
self.anchor_sizes = anchor_sizes
self.aspect_ratios = aspect_ratios
self.stride = stride
self.variance = variance
self.anchor_start_size = anchor_start_size
def __call__(self, input, level=None):
anchor_sizes = self.anchor_sizes if (
level is None or self.anchor_start_size is None) else (
self.anchor_start_size * 2**level)
stride = self.stride if (
level is None or self.anchor_start_size is None) else (
self.stride[0] * (2.**level), self.stride[1] * (2.**level))
anchor, var = fluid.layers.anchor_generator(
input=input,
anchor_sizes=anchor_sizes,
aspect_ratios=self.aspect_ratios,
stride=stride,
variance=self.variance)
return anchor, var
@register
@serializable
class AnchorTargetGeneratorRPN(object):
def __init__(self,
batch_size_per_im=256,
straddle_thresh=0.,
fg_fraction=0.5,
positive_overlap=0.7,
negative_overlap=0.3,
use_random=True):
super(AnchorTargetGeneratorRPN, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.straddle_thresh = straddle_thresh
self.fg_fraction = fg_fraction
self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap
self.use_random = use_random
def __call__(self, cls_logits, bbox_pred, anchor_box, gt_boxes, is_crowd,
im_info):
anchor_box = anchor_box.numpy()
gt_boxes = gt_boxes.numpy()
is_crowd = is_crowd.numpy()
im_info = im_info.numpy()
loc_indexes, score_indexes, tgt_labels, tgt_bboxes, bbox_inside_weights = generate_rpn_anchor_target(
anchor_box, gt_boxes, is_crowd, im_info, self.straddle_thresh,
self.batch_size_per_im, self.positive_overlap,
self.negative_overlap, self.fg_fraction, self.use_random)
loc_indexes = to_variable(loc_indexes)
score_indexes = to_variable(score_indexes)
tgt_labels = to_variable(tgt_labels)
tgt_bboxes = to_variable(tgt_bboxes)
bbox_inside_weights = to_variable(bbox_inside_weights)
loc_indexes.stop_gradient = True
score_indexes.stop_gradient = True
tgt_labels.stop_gradient = True
cls_logits = fluid.layers.reshape(x=cls_logits, shape=(-1, ))
bbox_pred = fluid.layers.reshape(x=bbox_pred, shape=(-1, 4))
pred_cls_logits = fluid.layers.gather(cls_logits, score_indexes)
pred_bbox_pred = fluid.layers.gather(bbox_pred, loc_indexes)
return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights
@register
@serializable
class AnchorGeneratorYOLO(object):
def __init__(self,
anchors=[
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90,
156, 198, 373, 326
],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]):
super(AnchorGeneratorYOLO, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
def __call__(self):
anchor_num = len(self.anchors)
mask_anchors = []
for i in range(len(self.anchor_masks)):
mask_anchor = []
for m in self.anchor_masks[i]:
assert m < anchor_num, "anchor mask index overflow"
mask_anchor.extend(self.anchors[2 * m:2 * m + 2])
mask_anchors.append(mask_anchor)
return self.anchors, self.anchor_masks, mask_anchors
@register
@serializable
class AnchorTargetGeneratorYOLO(object):
def __init__(self,
ignore_thresh=0.7,
downsample_ratio=32,
label_smooth=True):
super(AnchorTargetGeneratorYOLO, self).__init__()
self.ignore_thresh = ignore_thresh
self.downsample_ratio = downsample_ratio
self.label_smooth = label_smooth
def __call__(self, ):
# TODO: split yolov3_loss into here
outs = {
'ignore_thresh': self.ignore_thresh,
'downsample_ratio': self.downsample_ratio,
'label_smooth': self.label_smooth
}
return outs
@register
@serializable
class ProposalGenerator(object):
__append_doc__ = True
def __init__(self,
train_pre_nms_top_n=12000,
train_post_nms_top_n=2000,
infer_pre_nms_top_n=6000,
infer_post_nms_top_n=1000,
nms_thresh=.5,
min_size=.1,
eta=1.):
super(ProposalGenerator, self).__init__()
self.train_pre_nms_top_n = train_pre_nms_top_n
self.train_post_nms_top_n = train_post_nms_top_n
self.infer_pre_nms_top_n = infer_pre_nms_top_n
self.infer_post_nms_top_n = infer_post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = min_size
self.eta = eta
def __call__(self,
scores,
bbox_deltas,
anchors,
variances,
im_info,
mode='train'):
pre_nms_top_n = self.train_pre_nms_top_n if mode == 'train' else self.infer_pre_nms_top_n
post_nms_top_n = self.train_post_nms_top_n if mode == 'train' else self.infer_post_nms_top_n
rpn_rois, rpn_rois_prob, rpn_rois_num = fluid.layers.generate_proposals(
scores,
bbox_deltas,
im_info,
anchors,
variances,
pre_nms_top_n=pre_nms_top_n,
post_nms_top_n=post_nms_top_n,
nms_thresh=self.nms_thresh,
min_size=self.min_size,
eta=self.eta,
return_rois_num=True)
return rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n
@register
@serializable
class ProposalTargetGenerator(object):
__shared__ = ['num_classes']
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
fg_thresh=[.5, ],
bg_thresh_hi=[.5, ],
bg_thresh_lo=[0., ],
bbox_reg_weights=[[0.1, 0.1, 0.2, 0.2]],
num_classes=81,
use_random=True,
is_cls_agnostic=False,
is_cascade_rcnn=False):
super(ProposalTargetGenerator, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh_hi = bg_thresh_hi
self.bg_thresh_lo = bg_thresh_lo
self.bbox_reg_weights = bbox_reg_weights
self.num_classes = num_classes
self.use_random = use_random
self.is_cls_agnostic = is_cls_agnostic
self.is_cascade_rcnn = is_cascade_rcnn
def __call__(self,
rpn_rois,
rpn_rois_num,
gt_classes,
is_crowd,
gt_boxes,
im_info,
stage=0):
rpn_rois = rpn_rois.numpy()
rpn_rois_num = rpn_rois_num.numpy()
gt_classes = gt_classes.numpy()
gt_boxes = gt_boxes.numpy()
is_crowd = is_crowd.numpy()
im_info = im_info.numpy()
outs = generate_proposal_target(
rpn_rois, rpn_rois_num, gt_classes, is_crowd, gt_boxes, im_info,
self.batch_size_per_im, self.fg_fraction, self.fg_thresh[stage],
self.bg_thresh_hi[stage], self.bg_thresh_lo[stage],
self.bbox_reg_weights[stage], self.num_classes, self.use_random,
self.is_cls_agnostic, self.is_cascade_rcnn)
outs = [to_variable(v) for v in outs]
for v in outs:
v.stop_gradient = True
return outs
@register
@serializable
class MaskTargetGenerator(object):
__shared__ = ['num_classes', 'mask_resolution']
def __init__(self, num_classes=81, mask_resolution=14):
super(MaskTargetGenerator, self).__init__()
self.num_classes = num_classes
self.mask_resolution = mask_resolution
def __call__(self, im_info, gt_classes, is_crowd, gt_segms, rois, rois_num,
labels_int32):
im_info = im_info.numpy()
gt_classes = gt_classes.numpy()
is_crowd = is_crowd.numpy()
gt_segms = gt_segms.numpy()
rois = rois.numpy()
rois_num = rois_num.numpy()
labels_int32 = labels_int32.numpy()
outs = generate_mask_target(im_info, gt_classes, is_crowd, gt_segms,
rois, rois_num, labels_int32,
self.num_classes, self.mask_resolution)
outs = [to_variable(v) for v in outs]
for v in outs:
v.stop_gradient = True
return outs
@register
class RoIExtractor(object):
def __init__(self,
resolution=14,
sampling_ratio=0,
canconical_level=4,
canonical_size=224,
start_level=0,
end_level=3):
super(RoIExtractor, self).__init__()
self.resolution = resolution
self.sampling_ratio = sampling_ratio
self.canconical_level = canconical_level
self.canonical_size = canonical_size
self.start_level = start_level
self.end_level = end_level
def __call__(self, feats, rois, spatial_scale):
roi, rois_num = rois
cur_l = 0
if self.start_level == self.end_level:
rois_feat = fluid.layers.roi_align(
feats[self.start_level],
roi,
self.resolution,
self.resolution,
spatial_scale,
rois_num=rois_num)
return rois_feat
offset = 2
k_min = self.start_level + offset
k_max = self.end_level + offset
rois_dist, restore_index, rois_num_dist = fluid.layers.distribute_fpn_proposals(
roi,
k_min,
k_max,
self.canconical_level,
self.canonical_size,
rois_num=rois_num)
rois_feat_list = []
for lvl in range(self.start_level, self.end_level + 1):
roi_feat = fluid.layers.roi_align(
feats[lvl],
rois_dist[lvl],
self.resolution,
self.resolution,
spatial_scale[lvl],
sampling_ratio=self.sampling_ratio,
rois_num=rois_num_dist[lvl])
rois_feat_list.append(roi_feat)
rois_feat_shuffle = fluid.layers.concat(rois_feat_list)
rois_feat = fluid.layers.gather(rois_feat_shuffle, restore_index)
return rois_feat
@register
@serializable
class DecodeClipNms(object):
__shared__ = ['num_classes']
def __init__(
self,
num_classes=81,
keep_top_k=100,
score_threshold=0.05,
nms_threshold=0.5, ):
super(DecodeClipNms, self).__init__()
self.num_classes = num_classes
self.keep_top_k = keep_top_k
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
def __call__(self, bboxes, bbox_prob, bbox_delta, im_info):
bboxes_np = (i.numpy() for i in bboxes)
# bbox, bbox_num
outs = bbox_post_process(bboxes_np,
bbox_prob.numpy(),
bbox_delta.numpy(),
im_info.numpy(), self.keep_top_k,
self.score_threshold, self.nms_threshold,
self.num_classes)
outs = [to_variable(v) for v in outs]
for v in outs:
v.stop_gradient = True
return outs
@register
@serializable
class MultiClassNMS(object):
__op__ = fluid.layers.multiclass_nms
__append_doc__ = True
def __init__(self,
score_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.normalized = normalized
self.nms_eta = nms_eta
self.background_label = background_label
@register
@serializable
class YOLOBox(object):
def __init__(
self,
conf_thresh=0.005,
downsample_ratio=32,
clip_bbox=True, ):
self.conf_thresh = conf_thresh
self.downsample_ratio = downsample_ratio
self.clip_bbox = clip_bbox
def __call__(self, x, img_size, anchors, num_classes, stage=0):
outs = fluid.layers.yolo_box(x, img_size, anchors, num_classes,
self.conf_thresh, self.downsample_ratio //
2**stage, self.clip_bbox)
return outs
@register
@serializable
class AnchorGrid(object):
"""Generate anchor grid
Args:
image_size (int or list): input image size, may be a single integer or
list of [h, w]. Default: 512
min_level (int): min level of the feature pyramid. Default: 3
max_level (int): max level of the feature pyramid. Default: 7
anchor_base_scale: base anchor scale. Default: 4
num_scales: number of anchor scales. Default: 3
aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
"""
def __init__(self,
image_size=512,
min_level=3,
max_level=7,
anchor_base_scale=4,
num_scales=3,
aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
super(AnchorGrid, self).__init__()
if isinstance(image_size, Integral):
self.image_size = [image_size, image_size]
else:
self.image_size = image_size
for dim in self.image_size:
assert dim % 2 ** max_level == 0, \
"image size should be multiple of the max level stride"
self.min_level = min_level
self.max_level = max_level
self.anchor_base_scale = anchor_base_scale
self.num_scales = num_scales
self.aspect_ratios = aspect_ratios
@property
def base_cell(self):
if not hasattr(self, '_base_cell'):
self._base_cell = self.make_cell()
return self._base_cell
def make_cell(self):
scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
scales = np.array(scales)
ratios = np.array(self.aspect_ratios)
ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
return anchors
def make_grid(self, stride):
cell = self.base_cell * stride * self.anchor_base_scale
x_steps = np.arange(stride // 2, self.image_size[1], stride)
y_steps = np.arange(stride // 2, self.image_size[0], stride)
offset_x, offset_y = np.meshgrid(x_steps, y_steps)
offset_x = offset_x.flatten()
offset_y = offset_y.flatten()
offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
offsets = offsets[:, np.newaxis, :]
return (cell + offsets).reshape(-1, 4)
def generate(self):
return [
self.make_grid(2**l)
for l in range(self.min_level, self.max_level + 1)
]
def __call__(self):
if not hasattr(self, '_anchor_vars'):
anchor_vars = []
helper = LayerHelper('anchor_grid')
for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
stride = 2**l
anchors = self.make_grid(stride)
var = helper.create_parameter(
attr=ParamAttr(name='anchors_{}'.format(idx)),
shape=anchors.shape,
dtype='float32',
stop_gradient=True,
default_initializer=NumpyArrayInitializer(anchors))
anchor_vars.append(var)
var.persistable = True
self._anchor_vars = anchor_vars
return self._anchor_vars
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph import layers
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
import math
import six
import numpy as np import numpy as np
from numbers import Integral from functools import reduce
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable __all__ = [
from ppdet.core.workspace import register, serializable #'roi_pool',
from ppdet.py_op.target import generate_rpn_anchor_target, generate_proposal_target, generate_mask_target #'roi_align',
from ppdet.py_op.post_process import bbox_post_process #'prior_box',
#'anchor_generator',
#'generate_proposals',
@register #'iou_similarity',
@serializable #'box_coder',
class AnchorGeneratorRPN(object): #'yolo_box',
def __init__(self, #'multiclass_nms',
anchor_sizes=[32, 64, 128, 256, 512], #'distribute_fpn_proposals',
aspect_ratios=[0.5, 1.0, 2.0], 'collect_fpn_proposals',
stride=[16.0, 16.0], #'matrix_nms',
variance=[1.0, 1.0, 1.0, 1.0], ]
anchor_start_size=None):
super(AnchorGeneratorRPN, self).__init__()
self.anchor_sizes = anchor_sizes def collect_fpn_proposals(multi_rois,
self.aspect_ratios = aspect_ratios multi_scores,
self.stride = stride min_level,
self.variance = variance max_level,
self.anchor_start_size = anchor_start_size post_nms_top_n,
rois_num_per_level=None,
def __call__(self, input, level=None): name=None):
anchor_sizes = self.anchor_sizes if ( """
level is None or self.anchor_start_size is None) else (
self.anchor_start_size * 2**level) **This OP only supports LoDTensor as input**. Concat multi-level RoIs
stride = self.stride if ( (Region of Interest) and select N RoIs with respect to multi_scores.
level is None or self.anchor_start_size is None) else ( This operation performs the following steps:
self.stride[0] * (2.**level), self.stride[1] * (2.**level)) 1. Choose num_level RoIs and scores as input: num_level = max_level - min_level
anchor, var = fluid.layers.anchor_generator( 2. Concat multi-level RoIs and scores
input=input, 3. Sort scores and select post_nms_top_n scores
anchor_sizes=anchor_sizes, 4. Gather RoIs by selected indices from scores
aspect_ratios=self.aspect_ratios, 5. Re-sort RoIs by corresponding batch_id
stride=stride,
variance=self.variance)
return anchor, var
@register
@serializable
class AnchorTargetGeneratorRPN(object):
def __init__(self,
batch_size_per_im=256,
straddle_thresh=0.,
fg_fraction=0.5,
positive_overlap=0.7,
negative_overlap=0.3,
use_random=True):
super(AnchorTargetGeneratorRPN, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.straddle_thresh = straddle_thresh
self.fg_fraction = fg_fraction
self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap
self.use_random = use_random
def __call__(self, cls_logits, bbox_pred, anchor_box, gt_boxes, is_crowd,
im_info):
anchor_box = anchor_box.numpy()
gt_boxes = gt_boxes.numpy()
is_crowd = is_crowd.numpy()
im_info = im_info.numpy()
loc_indexes, score_indexes, tgt_labels, tgt_bboxes, bbox_inside_weights = generate_rpn_anchor_target(
anchor_box, gt_boxes, is_crowd, im_info, self.straddle_thresh,
self.batch_size_per_im, self.positive_overlap,
self.negative_overlap, self.fg_fraction, self.use_random)
loc_indexes = to_variable(loc_indexes)
score_indexes = to_variable(score_indexes)
tgt_labels = to_variable(tgt_labels)
tgt_bboxes = to_variable(tgt_bboxes)
bbox_inside_weights = to_variable(bbox_inside_weights)
loc_indexes.stop_gradient = True
score_indexes.stop_gradient = True
tgt_labels.stop_gradient = True
cls_logits = fluid.layers.reshape(x=cls_logits, shape=(-1, ))
bbox_pred = fluid.layers.reshape(x=bbox_pred, shape=(-1, 4))
pred_cls_logits = fluid.layers.gather(cls_logits, score_indexes)
pred_bbox_pred = fluid.layers.gather(bbox_pred, loc_indexes)
return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights
@register
@serializable
class AnchorGeneratorYOLO(object):
def __init__(self,
anchors=[
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90,
156, 198, 373, 326
],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]):
super(AnchorGeneratorYOLO, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
def __call__(self):
anchor_num = len(self.anchors)
mask_anchors = []
for i in range(len(self.anchor_masks)):
mask_anchor = []
for m in self.anchor_masks[i]:
assert m < anchor_num, "anchor mask index overflow"
mask_anchor.extend(self.anchors[2 * m:2 * m + 2])
mask_anchors.append(mask_anchor)
return self.anchors, self.anchor_masks, mask_anchors
@register
@serializable
class AnchorTargetGeneratorYOLO(object):
def __init__(self,
ignore_thresh=0.7,
downsample_ratio=32,
label_smooth=True):
super(AnchorTargetGeneratorYOLO, self).__init__()
self.ignore_thresh = ignore_thresh
self.downsample_ratio = downsample_ratio
self.label_smooth = label_smooth
def __call__(self, ):
# TODO: split yolov3_loss into here
outs = {
'ignore_thresh': self.ignore_thresh,
'downsample_ratio': self.downsample_ratio,
'label_smooth': self.label_smooth
}
return outs
@register
@serializable
class ProposalGenerator(object):
__append_doc__ = True
def __init__(self,
train_pre_nms_top_n=12000,
train_post_nms_top_n=2000,
infer_pre_nms_top_n=6000,
infer_post_nms_top_n=1000,
nms_thresh=.5,
min_size=.1,
eta=1.):
super(ProposalGenerator, self).__init__()
self.train_pre_nms_top_n = train_pre_nms_top_n
self.train_post_nms_top_n = train_post_nms_top_n
self.infer_pre_nms_top_n = infer_pre_nms_top_n
self.infer_post_nms_top_n = infer_post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = min_size
self.eta = eta
def __call__(self,
scores,
bbox_deltas,
anchors,
variances,
im_info,
mode='train'):
pre_nms_top_n = self.train_pre_nms_top_n if mode == 'train' else self.infer_pre_nms_top_n
post_nms_top_n = self.train_post_nms_top_n if mode == 'train' else self.infer_post_nms_top_n
rpn_rois, rpn_rois_prob, rpn_rois_num = fluid.layers.generate_proposals(
scores,
bbox_deltas,
im_info,
anchors,
variances,
pre_nms_top_n=pre_nms_top_n,
post_nms_top_n=post_nms_top_n,
nms_thresh=self.nms_thresh,
min_size=self.min_size,
eta=self.eta,
return_rois_num=True)
return rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n
@register
@serializable
class ProposalTargetGenerator(object):
__shared__ = ['num_classes']
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
fg_thresh=[.5, ],
bg_thresh_hi=[.5, ],
bg_thresh_lo=[0., ],
bbox_reg_weights=[[0.1, 0.1, 0.2, 0.2]],
num_classes=81,
use_random=True,
is_cls_agnostic=False,
is_cascade_rcnn=False):
super(ProposalTargetGenerator, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh_hi = bg_thresh_hi
self.bg_thresh_lo = bg_thresh_lo
self.bbox_reg_weights = bbox_reg_weights
self.num_classes = num_classes
self.use_random = use_random
self.is_cls_agnostic = is_cls_agnostic
self.is_cascade_rcnn = is_cascade_rcnn
def __call__(self,
rpn_rois,
rpn_rois_num,
gt_classes,
is_crowd,
gt_boxes,
im_info,
stage=0):
rpn_rois = rpn_rois.numpy()
rpn_rois_num = rpn_rois_num.numpy()
gt_classes = gt_classes.numpy()
gt_boxes = gt_boxes.numpy()
is_crowd = is_crowd.numpy()
im_info = im_info.numpy()
outs = generate_proposal_target(
rpn_rois, rpn_rois_num, gt_classes, is_crowd, gt_boxes, im_info,
self.batch_size_per_im, self.fg_fraction, self.fg_thresh[stage],
self.bg_thresh_hi[stage], self.bg_thresh_lo[stage],
self.bbox_reg_weights[stage], self.num_classes, self.use_random,
self.is_cls_agnostic, self.is_cascade_rcnn)
outs = [to_variable(v) for v in outs]
for v in outs:
v.stop_gradient = True
return outs
@register
@serializable
class MaskTargetGenerator(object):
__shared__ = ['num_classes', 'mask_resolution']
def __init__(self, num_classes=81, mask_resolution=14):
super(MaskTargetGenerator, self).__init__()
self.num_classes = num_classes
self.mask_resolution = mask_resolution
def __call__(self, im_info, gt_classes, is_crowd, gt_segms, rois, rois_num,
labels_int32):
im_info = im_info.numpy()
gt_classes = gt_classes.numpy()
is_crowd = is_crowd.numpy()
gt_segms = gt_segms.numpy()
rois = rois.numpy()
rois_num = rois_num.numpy()
labels_int32 = labels_int32.numpy()
outs = generate_mask_target(im_info, gt_classes, is_crowd, gt_segms,
rois, rois_num, labels_int32,
self.num_classes, self.mask_resolution)
outs = [to_variable(v) for v in outs]
for v in outs:
v.stop_gradient = True
return outs
@register
class RoIExtractor(object):
def __init__(self,
resolution=14,
sampling_ratio=0,
canconical_level=4,
canonical_size=224,
start_level=0,
end_level=3):
super(RoIExtractor, self).__init__()
self.resolution = resolution
self.sampling_ratio = sampling_ratio
self.canconical_level = canconical_level
self.canonical_size = canonical_size
self.start_level = start_level
self.end_level = end_level
def __call__(self, feats, rois, spatial_scale):
roi, rois_num = rois
cur_l = 0
if self.start_level == self.end_level:
rois_feat = fluid.layers.roi_align(
feats[self.start_level],
roi,
self.resolution,
self.resolution,
spatial_scale,
rois_num=rois_num)
return rois_feat
offset = 2
k_min = self.start_level + offset
k_max = self.end_level + offset
rois_dist, restore_index, rois_num_dist = fluid.layers.distribute_fpn_proposals(
roi,
k_min,
k_max,
self.canconical_level,
self.canonical_size,
rois_num=rois_num)
rois_feat_list = []
for lvl in range(self.start_level, self.end_level + 1):
roi_feat = fluid.layers.roi_align(
feats[lvl],
rois_dist[lvl],
self.resolution,
self.resolution,
spatial_scale[lvl],
sampling_ratio=self.sampling_ratio,
rois_num=rois_num_dist[lvl])
rois_feat_list.append(roi_feat)
rois_feat_shuffle = fluid.layers.concat(rois_feat_list)
rois_feat = fluid.layers.gather(rois_feat_shuffle, restore_index)
return rois_feat
@register
@serializable
class DecodeClipNms(object):
__shared__ = ['num_classes']
def __init__(
self,
num_classes=81,
keep_top_k=100,
score_threshold=0.05,
nms_threshold=0.5, ):
super(DecodeClipNms, self).__init__()
self.num_classes = num_classes
self.keep_top_k = keep_top_k
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
def __call__(self, bboxes, bbox_prob, bbox_delta, im_info):
bboxes_np = (i.numpy() for i in bboxes)
# bbox, bbox_num
outs = bbox_post_process(bboxes_np,
bbox_prob.numpy(),
bbox_delta.numpy(),
im_info.numpy(), self.keep_top_k,
self.score_threshold, self.nms_threshold,
self.num_classes)
outs = [to_variable(v) for v in outs]
for v in outs:
v.stop_gradient = True
return outs
@register
@serializable
class MultiClassNMS(object):
__op__ = fluid.layers.multiclass_nms
__append_doc__ = True
def __init__(self,
score_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0):
super(MultiClassNMS, self).__init__()
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
self.normalized = normalized
self.nms_eta = nms_eta
self.background_label = background_label
@register
@serializable
class YOLOBox(object):
def __init__(
self,
conf_thresh=0.005,
downsample_ratio=32,
clip_bbox=True, ):
self.conf_thresh = conf_thresh
self.downsample_ratio = downsample_ratio
self.clip_bbox = clip_bbox
def __call__(self, x, img_size, anchors, num_classes, stage=0):
outs = fluid.layers.yolo_box(x, img_size, anchors, num_classes,
self.conf_thresh, self.downsample_ratio //
2**stage, self.clip_bbox)
return outs
@register
@serializable
class AnchorGrid(object):
"""Generate anchor grid
Args: Args:
image_size (int or list): input image size, may be a single integer or multi_rois(list): List of RoIs to collect. Element in list is 2-D
list of [h, w]. Default: 512 LoDTensor with shape [N, 4] and data type is float32 or float64,
min_level (int): min level of the feature pyramid. Default: 3 N is the number of RoIs.
max_level (int): max level of the feature pyramid. Default: 7 multi_scores(list): List of scores of RoIs to collect. Element in list
anchor_base_scale: base anchor scale. Default: 4 is 2-D LoDTensor with shape [N, 1] and data type is float32 or
num_scales: number of anchor scales. Default: 3 float64, N is the number of RoIs.
aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]] min_level(int): The lowest level of FPN layer to collect
max_level(int): The highest level of FPN layer to collect
post_nms_top_n(int): The number of selected RoIs
rois_num_per_level(list, optional): The List of RoIs' numbers.
Each element is 1-D Tensor which contains the RoIs' number of each
image on each level and the shape is [B] and data type is
int32, B is the number of images. If it is not None then return
a 1-D Tensor contains the output RoIs' number of each image and
the shape is [B]. Default: None
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Variable:
fpn_rois(Variable): 2-D LoDTensor with shape [N, 4] and data type is
float32 or float64. Selected RoIs.
rois_num(Tensor): 1-D Tensor contains the RoIs's number of each
image. The shape is [B] and data type is int32. B is the number of
images.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
multi_rois = []
multi_scores = []
for i in range(4):
multi_rois.append(fluid.data(
name='roi_'+str(i), shape=[None, 4], dtype='float32', lod_level=1))
for i in range(4):
multi_scores.append(fluid.data(
name='score_'+str(i), shape=[None, 1], dtype='float32', lod_level=1))
fpn_rois = fluid.layers.collect_fpn_proposals(
multi_rois=multi_rois,
multi_scores=multi_scores,
min_level=2,
max_level=5,
post_nms_top_n=2000)
""" """
check_type(multi_rois, 'multi_rois', list, 'collect_fpn_proposals')
def __init__(self, check_type(multi_scores, 'multi_scores', list, 'collect_fpn_proposals')
image_size=512, num_lvl = max_level - min_level + 1
min_level=3, input_rois = multi_rois[:num_lvl]
max_level=7, input_scores = multi_scores[:num_lvl]
anchor_base_scale=4,
num_scales=3, if in_dygraph_mode():
aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]): assert rois_num_per_level is not None, "rois_num_per_level should not be None in dygraph mode."
super(AnchorGrid, self).__init__() attrs = ('post_nms_topN', post_nms_top_n)
if isinstance(image_size, Integral): output_rois, rois_num = core.ops.collect_fpn_proposals(
self.image_size = [image_size, image_size] input_rois, input_scores, rois_num_per_level, *attrs)
else:
self.image_size = image_size helper = LayerHelper('collect_fpn_proposals', **locals())
for dim in self.image_size: dtype = helper.input_dtype('multi_rois')
assert dim % 2 ** max_level == 0, \ check_dtype(dtype, 'multi_rois', ['float32', 'float64'],
"image size should be multiple of the max level stride" 'collect_fpn_proposals')
self.min_level = min_level output_rois = helper.create_variable_for_type_inference(dtype)
self.max_level = max_level output_rois.stop_gradient = True
self.anchor_base_scale = anchor_base_scale
self.num_scales = num_scales inputs = {
self.aspect_ratios = aspect_ratios 'MultiLevelRois': input_rois,
'MultiLevelScores': input_scores,
@property }
def base_cell(self): outputs = {'FpnRois': output_rois}
if not hasattr(self, '_base_cell'): if rois_num_per_level is not None:
self._base_cell = self.make_cell() inputs['MultiLevelRoIsNum'] = rois_num_per_level
return self._base_cell rois_num = helper.create_variable_for_type_inference(dtype='int32')
rois_num.stop_gradient = True
def make_cell(self): outputs['RoisNum'] = rois_num
scales = [2**(i / self.num_scales) for i in range(self.num_scales)] helper.append_op(
scales = np.array(scales) type='collect_fpn_proposals',
ratios = np.array(self.aspect_ratios) inputs=inputs,
ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1) outputs=outputs,
hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1) attrs={'post_nms_topN': post_nms_top_n})
anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs)) if rois_num_per_level is not None:
return anchors return output_rois, rois_num
return output_rois
def make_grid(self, stride):
cell = self.base_cell * stride * self.anchor_base_scale
x_steps = np.arange(stride // 2, self.image_size[1], stride)
y_steps = np.arange(stride // 2, self.image_size[0], stride)
offset_x, offset_y = np.meshgrid(x_steps, y_steps)
offset_x = offset_x.flatten()
offset_y = offset_y.flatten()
offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
offsets = offsets[:, np.newaxis, :]
return (cell + offsets).reshape(-1, 4)
def generate(self):
return [
self.make_grid(2**l)
for l in range(self.min_level, self.max_level + 1)
]
def __call__(self):
if not hasattr(self, '_anchor_vars'):
anchor_vars = []
helper = LayerHelper('anchor_grid')
for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
stride = 2**l
anchors = self.make_grid(stride)
var = helper.create_parameter(
attr=ParamAttr(name='anchors_{}'.format(idx)),
shape=anchors.shape,
dtype='float32',
stop_gradient=True,
default_initializer=NumpyArrayInitializer(anchors))
anchor_vars.append(var)
var.persistable = True
self._anchor_vars = anchor_vars
return self._anchor_vars
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import contextlib
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import Program
from paddle.fluid import core
class LayerTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.seed = 111
@classmethod
def tearDownClass(cls):
pass
def _get_place(self, force_to_use_cpu=False):
# this option for ops that only have cpu kernel
if force_to_use_cpu:
return core.CPUPlace()
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
@contextlib.contextmanager
def static_graph(self):
scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with fluid.program_guard(program):
paddle.manual_seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
yield
def get_static_graph_result(self,
feed,
fetch_list,
with_lod=False,
force_to_use_cpu=False):
exe = fluid.Executor(self._get_place(force_to_use_cpu))
exe.run(fluid.default_startup_program())
return exe.run(fluid.default_main_program(),
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod))
@contextlib.contextmanager
def dynamic_graph(self, force_to_use_cpu=False):
with fluid.dygraph.guard(
self._get_place(force_to_use_cpu=force_to_use_cpu)):
paddle.manual_seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
yield
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
if parent_path not in sys.path:
sys.path.append(parent_path)
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.dygraph import base
import ppdet.modeling.ops as ops
from ppdet.modeling.tests.test_base import LayerTest
class TestCollectFpnProposals(LayerTest):
def test_collect_fpn_proposals(self):
multi_bboxes_np = []
multi_scores_np = []
rois_num_per_level_np = []
for i in range(4):
bboxes_np = np.random.rand(5, 4).astype('float32')
scores_np = np.random.rand(5, 1).astype('float32')
rois_num = np.array([2, 3]).astype('int32')
multi_bboxes_np.append(bboxes_np)
multi_scores_np.append(scores_np)
rois_num_per_level_np.append(rois_num)
paddle.enable_static()
with self.static_graph():
multi_bboxes = []
multi_scores = []
rois_num_per_level = []
for i in range(4):
bboxes = paddle.static.data(
name='rois' + str(i),
shape=[5, 4],
dtype='float32',
lod_level=1)
scores = paddle.static.data(
name='scores' + str(i),
shape=[5, 1],
dtype='float32',
lod_level=1)
rois_num = paddle.static.data(
name='rois_num' + str(i), shape=[None], dtype='int32')
multi_bboxes.append(bboxes)
multi_scores.append(scores)
rois_num_per_level.append(rois_num)
fpn_rois, rois_num = ops.collect_fpn_proposals(
multi_bboxes,
multi_scores,
2,
5,
10,
rois_num_per_level=rois_num_per_level)
feed = {}
for i in range(4):
feed['rois' + str(i)] = multi_bboxes_np[i]
feed['scores' + str(i)] = multi_scores_np[i]
feed['rois_num' + str(i)] = rois_num_per_level_np[i]
fpn_rois_stat, rois_num_stat = self.get_static_graph_result(
feed=feed, fetch_list=[fpn_rois, rois_num], with_lod=True)
fpn_rois_stat = np.array(fpn_rois_stat)
rois_num_stat = np.array(rois_num_stat)
paddle.disable_static()
with self.dynamic_graph():
multi_bboxes_dy = []
multi_scores_dy = []
rois_num_per_level_dy = []
for i in range(4):
bboxes_dy = base.to_variable(multi_bboxes_np[i])
scores_dy = base.to_variable(multi_scores_np[i])
rois_num_dy = base.to_variable(rois_num_per_level_np[i])
multi_bboxes_dy.append(bboxes_dy)
multi_scores_dy.append(scores_dy)
rois_num_per_level_dy.append(rois_num_dy)
fpn_rois_dy, rois_num_dy = ops.collect_fpn_proposals(
multi_bboxes_dy,
multi_scores_dy,
2,
5,
10,
rois_num_per_level=rois_num_per_level_dy)
fpn_rois_dy = fpn_rois_dy.numpy()
rois_num_dy = rois_num_dy.numpy()
self.assertTrue(np.array_equal(fpn_rois_stat, fpn_rois_dy))
self.assertTrue(np.array_equal(rois_num_stat, rois_num_dy))
def test_collect_fpn_proposals_error(self):
def generate_input(bbox_type, score_type, name):
multi_bboxes = []
multi_scores = []
for i in range(4):
bboxes = paddle.static.data(
name='rois' + name + str(i),
shape=[10, 4],
dtype=bbox_type,
lod_level=1)
scores = paddle.static.data(
name='scores' + name + str(i),
shape=[10, 1],
dtype=score_type,
lod_level=1)
multi_bboxes.append(bboxes)
multi_scores.append(scores)
return multi_bboxes, multi_scores
paddle.enable_static()
program = Program()
with program_guard(program):
bbox1 = paddle.static.data(
name='rois', shape=[5, 10, 4], dtype='float32', lod_level=1)
score1 = paddle.static.data(
name='scores', shape=[5, 10, 1], dtype='float32', lod_level=1)
bbox2, score2 = generate_input('int32', 'float32', '2')
self.assertRaises(
TypeError,
ops.collect_fpn_proposals,
multi_rois=bbox1,
multi_scores=score1,
min_level=2,
max_level=5,
post_nms_top_n=2000)
self.assertRaises(
TypeError,
ops.collect_fpn_proposals,
multi_rois=bbox2,
multi_scores=score2,
min_level=2,
max_level=5,
post_nms_top_n=2000)
if __name__ == '__main__':
unittest.main()
...@@ -121,7 +121,6 @@ def run(FLAGS, cfg): ...@@ -121,7 +121,6 @@ def run(FLAGS, cfg):
strategy = paddle.distributed.init_parallel_env() strategy = paddle.distributed.init_parallel_env()
model = paddle.DataParallel(model, strategy) model = paddle.DataParallel(model, strategy)
logger.info("success!")
# Data Reader # Data Reader
start_iter = 0 start_iter = 0
if cfg.use_gpu: if cfg.use_gpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册