未验证 提交 bd7adb05 编写于 作者: G Guanghua Yu 提交者: GitHub

add pssdet model (#2590)

* add pssdet model
上级 480c12d6
## 服务器端实用目标检测方案
### 简介
* 近年来,学术界和工业界广泛关注图像中目标检测任务。基于[PaddleClas](https://github.com/PaddlePaddle/PaddleClas)中SSLD蒸馏方案训练得到的ResNet50_vd预训练模型(ImageNet1k验证集上Top1 Acc为82.39%),结合PaddleDetection中的丰富算子,飞桨提供了一种面向服务器端实用的目标检测方案PSS-DET(Practical Server Side Detection)。基于COCO2017目标检测数据集,V100单卡预测速度为为61FPS时,COCO mAP可达41.2%。
### 模型库
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | Mask AP | 下载 | 配置文件 |
| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :-------------: | :-----: |
| ResNet50-vd-FPN-Dcnv2 | Faster | 2 | 3x | 61.425 | 41.2 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_enhance_3x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/rcnn_enhance/faster_rcnn_enhance_3x_coco.yml) |
architecture: FasterRCNN
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams
FasterRCNN:
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
variant: d
freeze_at: 0
return_idx: [0,1,2,3]
num_stages: 4
dcn_v2_stages: [1,2,3]
lr_mult_list: [0.05, 0.05, 0.1, 0.15]
FPN:
in_channels: [256, 512, 1024, 2048]
out_channel: 64
RPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
anchor_sizes: [[32], [64], [128], [256], [512]]
strides: [4, 8, 16, 32, 64]
rpn_target_assign:
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
use_random: True
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
topk_after_collect: True
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 500
post_nms_top_n: 300
BBoxHead:
head: TwoFCHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxLibraAssigner
bbox_loss: DIouLoss
TwoFCHead:
out_channel: 1024
BBoxLibraAssigner:
batch_size_per_im: 512
bg_thresh: 0.5
fg_thresh: 0.5
fg_fraction: 0.25
use_random: True
DIouLoss:
loss_weight: 10.0
use_complete_iou_loss: true
BBoxPostProcess:
decode: RCNNBox
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomResize: {target_size: [[384,1000], [416,1000], [448,1000], [480,1000], [512,1000], [544,1000], [576,1000], [608,1000], [640,1000], [672,1000]], interp: 2, keep_ratio: True}
- RandomFlip: {prob: 0.5}
- AutoAugment: {autoaug_type: v1}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 2
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [640, 640], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
epoch: 36
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [24, 33]
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_3x.yml',
'_base_/faster_rcnn_enhance.yml',
'_base_/faster_rcnn_enhance_reader.yml',
]
weights: output/faster_rcnn_enhance_r50_3x_coco/model_final
......@@ -1453,19 +1453,19 @@ def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
# Check to see if prob is passed into function. This is used for operations
# where we alter bboxes independently.
# pytype:disable=wrong-arg-types
if 'prob' in inspect.getargspec(func)[0]:
if 'prob' in inspect.getfullargspec(func)[0]:
args = tuple([prob] + list(args))
# pytype:enable=wrong-arg-types
# Add in replace arg if it is required for the function that is being called.
if 'replace' in inspect.getargspec(func)[0]:
if 'replace' in inspect.getfullargspec(func)[0]:
# Make sure replace is the final argument
assert 'replace' == inspect.getargspec(func)[0][-1]
assert 'replace' == inspect.getfullargspec(func)[0][-1]
args = tuple(list(args) + [replace_value])
# Add bboxes as the second positional argument for the function if it does
# not already exist.
if 'bboxes' not in inspect.getargspec(func)[0]:
if 'bboxes' not in inspect.getfullargspec(func)[0]:
func = bbox_wrapper(func)
return (func, prob, args)
......@@ -1473,11 +1473,11 @@ def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
def _apply_func_with_prob(func, image, args, prob, bboxes):
"""Apply `func` to image w/ `args` as input with probability `prob`."""
assert isinstance(args, tuple)
assert 'bboxes' == inspect.getargspec(func)[0][1]
assert 'bboxes' == inspect.getfullargspec(func)[0][1]
# If prob is a function argument, then this randomness is being handled
# inside the function, so make sure it is always called.
if 'prob' in inspect.getargspec(func)[0]:
if 'prob' in inspect.getfullargspec(func)[0]:
prob = 1.0
# Apply the function with probability `prob`.
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -162,7 +164,7 @@ class XConvNormHead(nn.Layer):
@register
class BBoxHead(nn.Layer):
__shared__ = ['num_classes']
__inject__ = ['bbox_assigner']
__inject__ = ['bbox_assigner', 'bbox_loss']
"""
RCNN bbox head
......@@ -184,7 +186,8 @@ class BBoxHead(nn.Layer):
bbox_assigner='BboxAssigner',
with_pool=False,
num_classes=80,
bbox_weight=[10., 10., 5., 5.]):
bbox_weight=[10., 10., 5., 5.],
bbox_loss=None):
super(BBoxHead, self).__init__()
self.head = head
self.roi_extractor = roi_extractor
......@@ -195,6 +198,7 @@ class BBoxHead(nn.Layer):
self.with_pool = with_pool
self.num_classes = num_classes
self.bbox_weight = bbox_weight
self.bbox_loss = bbox_loss
self.bbox_score = nn.Linear(
in_channel,
......@@ -311,14 +315,51 @@ class BBoxHead(nn.Layer):
reg_target = paddle.gather(reg_target, fg_inds)
reg_target.stop_gradient = True
loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
) / tgt_labels.shape[0]
if self.bbox_loss is not None:
reg_delta = self.bbox_transform(reg_delta)
reg_target = self.bbox_transform(reg_target)
loss_bbox_reg = self.bbox_loss(
reg_delta, reg_target).sum() / tgt_labels.shape[0]
loss_bbox_reg *= self.num_classes
else:
loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
) / tgt_labels.shape[0]
loss_bbox[cls_name] = loss_bbox_cls * loss_weight
loss_bbox[reg_name] = loss_bbox_reg * loss_weight
return loss_bbox
def bbox_transform(self, deltas, weights=[0.1, 0.1, 0.2, 0.2]):
wx, wy, ww, wh = weights
deltas = paddle.reshape(deltas, shape=(0, -1, 4))
dx = paddle.slice(deltas, axes=[2], starts=[0], ends=[1]) * wx
dy = paddle.slice(deltas, axes=[2], starts=[1], ends=[2]) * wy
dw = paddle.slice(deltas, axes=[2], starts=[2], ends=[3]) * ww
dh = paddle.slice(deltas, axes=[2], starts=[3], ends=[4]) * wh
dw = paddle.clip(dw, -1.e10, np.log(1000. / 16))
dh = paddle.clip(dh, -1.e10, np.log(1000. / 16))
pred_ctr_x = dx
pred_ctr_y = dy
pred_w = paddle.exp(dw)
pred_h = paddle.exp(dh)
x1 = pred_ctr_x - 0.5 * pred_w
y1 = pred_ctr_y - 0.5 * pred_h
x2 = pred_ctr_x + 0.5 * pred_w
y2 = pred_ctr_y + 0.5 * pred_h
x1 = paddle.reshape(x1, shape=(-1, ))
y1 = paddle.reshape(y1, shape=(-1, ))
x2 = paddle.reshape(x2, shape=(-1, ))
y2 = paddle.reshape(y2, shape=(-1, ))
return paddle.concat([x1, y1, x2, y2])
def get_prediction(self, score, delta):
bbox_prob = F.softmax(score)
return delta, bbox_prob
......
......@@ -16,12 +16,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ..bbox_utils import xywh2xyxy, bbox_iou
__all__ = ['IouLoss', 'GIoULoss']
__all__ = ['IouLoss', 'GIoULoss', 'DIouLoss']
@register
......@@ -129,3 +131,74 @@ class GIoULoss(object):
else:
loss = paddle.mean(giou * iou_weight)
return loss * self.loss_weight
@register
@serializable
class DIouLoss(GIoULoss):
"""
Distance-IoU Loss, see https://arxiv.org/abs/1911.08287
Args:
loss_weight (float): giou loss weight, default as 1
eps (float): epsilon to avoid divide by zero, default as 1e-10
use_complete_iou_loss (bool): whether to use complete iou loss
"""
def __init__(self, loss_weight=1., eps=1e-10, use_complete_iou_loss=True):
super(DIouLoss, self).__init__(loss_weight=loss_weight, eps=eps)
self.use_complete_iou_loss = use_complete_iou_loss
def __call__(self, pbox, gbox, iou_weight=1.):
x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1)
x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g
hg = y2g - y1g
x2 = paddle.maximum(x1, x2)
y2 = paddle.maximum(y1, y2)
# A and B
xkis1 = paddle.maximum(x1, x1g)
ykis1 = paddle.maximum(y1, y1g)
xkis2 = paddle.minimum(x2, x2g)
ykis2 = paddle.minimum(y2, y2g)
# A or B
xc1 = paddle.minimum(x1, x1g)
yc1 = paddle.minimum(y1, y1g)
xc2 = paddle.maximum(x2, x2g)
yc2 = paddle.maximum(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * paddle.greater_than(
xkis2, xkis1) * paddle.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + self.eps
iouk = intsctk / unionk
# DIOU term
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
diou_term = (dist_intersection + self.eps) / (dist_union + self.eps)
# CIOU term
ciou_term = 0
if self.use_complete_iou_loss:
ar_gt = wg / hg
ar_pred = w / h
arctan = paddle.atan(ar_gt) - paddle.atan(ar_pred)
ar_loss = 4. / np.pi / np.pi * arctan * arctan
alpha = ar_loss / (1 - iouk + ar_loss + self.eps)
alpha.stop_gradient = True
ciou_term = alpha * ar_loss
diou = paddle.mean((1 - iouk + ciou_term + diou_term) * iou_weight)
return diou * self.loss_weight
......@@ -301,3 +301,287 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
tgt_weights = paddle.concat(tgt_weights, axis=0)
return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
if len(pos_inds) <= num_expected:
return pos_inds
else:
unique_gt_inds = np.unique(max_classes[pos_inds])
num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = []
for i in unique_gt_inds:
inds = np.nonzero(max_classes == i)[0]
before_len = len(inds)
inds = list(set(inds) & set(pos_inds))
after_len = len(inds)
if len(inds) > num_per_gt:
inds = np.random.choice(inds, size=num_per_gt, replace=False)
sampled_inds.extend(list(inds)) # combine as a new sampler
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
"sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
len(sampled_inds), len(extra_inds), len(pos_inds))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(
extra_inds, size=num_extra, replace=False)
sampled_inds.extend(extra_inds.tolist())
elif len(sampled_inds) > num_expected:
sampled_inds = np.random.choice(
sampled_inds, size=num_expected, replace=False)
return paddle.to_tensor(sampled_inds)
def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
num_bins, bg_thresh):
max_iou = max_overlaps.max()
iou_interval = (max_iou - floor_thr) / num_bins
per_num_expected = int(num_expected / num_bins)
sampled_inds = []
for i in range(num_bins):
start_iou = floor_thr + i * iou_interval
end_iou = floor_thr + (i + 1) * iou_interval
tmp_set = set(
np.where(
np.logical_and(max_overlaps >= start_iou, max_overlaps <
end_iou))[0])
tmp_inds = list(tmp_set & full_set)
if len(tmp_inds) > per_num_expected:
tmp_sampled_set = np.random.choice(
tmp_inds, size=per_num_expected, replace=False)
else:
tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
sampled_inds.append(tmp_sampled_set)
sampled_inds = np.concatenate(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(full_set - set(sampled_inds)))
assert len(sampled_inds) + len(extra_inds) == len(full_set), \
"sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
len(sampled_inds), len(extra_inds), len(full_set))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
sampled_inds = np.concatenate([sampled_inds, extra_inds])
return sampled_inds
def libra_sample_neg(max_overlaps,
max_classes,
neg_inds,
num_expected,
floor_thr=-1,
floor_fraction=0,
num_bins=3,
bg_thresh=0.5):
if len(neg_inds) <= num_expected:
return neg_inds
else:
# balance sampling for negative samples
neg_set = set(neg_inds.tolist())
if floor_thr > 0:
floor_set = set(
np.where(
np.logical_and(max_overlaps >= 0, max_overlaps < floor_thr))
[0])
iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
elif floor_thr == 0:
floor_set = set(np.where(max_overlaps == 0)[0])
iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
else:
floor_set = set()
iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
floor_thr = 0
floor_neg_inds = list(floor_set & neg_set)
iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
if num_bins >= 2:
iou_sampled_inds = libra_sample_via_interval(
max_overlaps,
set(iou_sampling_neg_inds), num_expected_iou_sampling,
floor_thr, num_bins, bg_thresh)
else:
iou_sampled_inds = np.random.choice(
iou_sampling_neg_inds,
size=num_expected_iou_sampling,
replace=False)
else:
iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int)
num_expected_floor = num_expected - len(iou_sampled_inds)
if len(floor_neg_inds) > num_expected_floor:
sampled_floor_inds = np.random.choice(
floor_neg_inds, size=num_expected_floor, replace=False)
else:
sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(
extra_inds, size=num_extra, replace=False)
sampled_inds = np.concatenate((sampled_inds, extra_inds))
return paddle.to_tensor(sampled_inds)
def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
negative_overlap, num_classes):
# TODO: use paddle API to speed up
gt_classes = gt_classes.numpy()
gt_overlaps = np.zeros((anchors.shape[0], num_classes))
matches = np.zeros((anchors.shape[0]), dtype=np.int32)
if len(gt_boxes) > 0:
proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
overlaps_max = proposal_to_gt_overlaps.max(axis=1)
# Boxes which with non-zero overlap with gt boxes
overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
overlapped_boxes_ind]]
for idx in range(len(overlapped_boxes_ind)):
gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
idx]] = overlaps_max[overlapped_boxes_ind[idx]]
matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
overlapped_boxes_ind[idx]]
gt_overlaps = paddle.to_tensor(gt_overlaps)
matches = paddle.to_tensor(matches)
matched_vals = paddle.max(gt_overlaps, axis=1)
match_labels = paddle.full(matches.shape, -1, dtype='int32')
match_labels = paddle.where(matched_vals < negative_overlap,
paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels)
return matches, match_labels, matched_vals
def libra_sample_bbox(matches,
match_labels,
matched_vals,
gt_classes,
batch_size_per_im,
num_classes,
fg_fraction,
fg_thresh,
bg_thresh,
num_bins,
use_random=True,
is_cascade_rcnn=False):
rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
bg_rois_per_im = rois_per_image - fg_rois_per_im
if is_cascade_rcnn:
fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
bg_inds = paddle.nonzero(matched_vals < bg_thresh)
else:
matched_vals_np = matched_vals.numpy()
match_labels_np = match_labels.numpy()
# sample fg
fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
if (fg_inds.shape[0] > fg_nums) and use_random:
fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
fg_inds.numpy(), fg_rois_per_im)
fg_inds = fg_inds[:fg_nums]
# sample bg
bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
if (bg_inds.shape[0] > bg_nums) and use_random:
bg_inds = libra_sample_neg(
matched_vals_np,
match_labels_np,
bg_inds.numpy(),
bg_rois_per_im,
num_bins=num_bins,
bg_thresh=bg_thresh)
bg_inds = bg_inds[:bg_nums]
sampled_inds = paddle.concat([fg_inds, bg_inds])
gt_classes = paddle.gather(gt_classes, matches)
gt_classes = paddle.where(match_labels == 0,
paddle.ones_like(gt_classes) * num_classes,
gt_classes)
gt_classes = paddle.where(match_labels == -1,
paddle.ones_like(gt_classes) * -1, gt_classes)
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
return sampled_inds, sampled_gt_classes
def libra_generate_proposal_target(rpn_rois,
gt_classes,
gt_boxes,
batch_size_per_im,
fg_fraction,
fg_thresh,
bg_thresh,
num_classes,
use_random=True,
is_cascade_rcnn=False,
max_overlaps=None,
num_bins=3):
rois_with_gt = []
tgt_labels = []
tgt_bboxes = []
sampled_max_overlaps = []
tgt_gt_inds = []
new_rois_num = []
for i, rpn_roi in enumerate(rpn_rois):
max_overlap = max_overlaps[i] if is_cascade_rcnn else None
gt_bbox = gt_boxes[i]
gt_class = paddle.squeeze(gt_classes[i], axis=-1)
if is_cascade_rcnn:
rpn_roi = filter_roi(rpn_roi, max_overlap)
bbox = paddle.concat([rpn_roi, gt_bbox])
# Step1: label bbox
matches, match_labels, matched_vals = libra_label_box(
bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
# Step2: sample bbox
sampled_inds, sampled_gt_classes = libra_sample_bbox(
matches, match_labels, matched_vals, gt_class, batch_size_per_im,
num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
use_random, is_cascade_rcnn)
# Step3: make output
rois_per_image = paddle.gather(bbox, sampled_inds)
sampled_gt_ind = paddle.gather(matches, sampled_inds)
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
sampled_overlap = paddle.gather(matched_vals, sampled_inds)
rois_per_image.stop_gradient = True
sampled_gt_ind.stop_gradient = True
sampled_bbox.stop_gradient = True
sampled_overlap.stop_gradient = True
tgt_labels.append(sampled_gt_classes)
tgt_bboxes.append(sampled_bbox)
rois_with_gt.append(rois_per_image)
sampled_max_overlaps.append(sampled_overlap)
tgt_gt_inds.append(sampled_gt_ind)
new_rois_num.append(paddle.shape(sampled_inds)[0])
new_rois_num = paddle.concat(new_rois_num)
# rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
......@@ -14,7 +14,8 @@
import sys
import paddle
from ppdet.core.workspace import register, serializable
from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target
from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target, libra_generate_proposal_target
from ppdet.modeling import bbox_utils
import numpy as np
......@@ -95,15 +96,15 @@ class BBoxAssigner(object):
default 512
fg_fraction (float): Fraction of RoIs that is labeled
foreground, default 0.25
positive_overlap (float): Minimum overlap required between a RoI
and ground-truth box for the (roi, gt box) pair to be
fg_thresh (float): Minimum overlap required between a RoI
and ground-truth box for the (roi, gt box) pair to be
a foreground sample. default 0.5
negative_overlap (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be
bg_thresh (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be
a background sample. default 0.5
use_random (bool): Use random sampling to choose foreground and
use_random (bool): Use random sampling to choose foreground and
background boxes, default true
cascade_iou (list[iou]): The list of overlap to select foreground and
cascade_iou (list[iou]): The list of overlap to select foreground and
background of each stage, which is only used In Cascade RCNN.
num_classes (int): The number of class.
"""
......@@ -146,6 +147,77 @@ class BBoxAssigner(object):
return rois, rois_num, targets
@register
class BBoxLibraAssigner(object):
__shared__ = ['num_classes']
"""
Libra-RCNN targets assignment module
The assignment consists of three steps:
1. Match RoIs and ground-truth box, label the RoIs with foreground
or background sample
2. Sample anchors to keep the properly ratio between foreground and
background
3. Generate the targets for classification and regression branch
Args:
batch_size_per_im (int): Total number of RoIs per image.
default 512
fg_fraction (float): Fraction of RoIs that is labeled
foreground, default 0.25
fg_thresh (float): Minimum overlap required between a RoI
and ground-truth box for the (roi, gt box) pair to be
a foreground sample. default 0.5
bg_thresh (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be
a background sample. default 0.5
use_random (bool): Use random sampling to choose foreground and
background boxes, default true
cascade_iou (list[iou]): The list of overlap to select foreground and
background of each stage, which is only used In Cascade RCNN.
num_classes (int): The number of class.
num_bins (int): The number of libra_sample.
"""
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
fg_thresh=.5,
bg_thresh=.5,
use_random=True,
cascade_iou=[0.5, 0.6, 0.7],
num_classes=80,
num_bins=3):
super(BBoxLibraAssigner, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh = bg_thresh
self.use_random = use_random
self.cascade_iou = cascade_iou
self.num_classes = num_classes
self.num_bins = num_bins
def __call__(self,
rpn_rois,
rpn_rois_num,
inputs,
stage=0,
is_cascade=False):
gt_classes = inputs['gt_class']
gt_boxes = inputs['gt_bbox']
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds
outs = libra_generate_proposal_target(
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
self.use_random, is_cascade, self.cascade_iou[stage], self.num_bins)
rois = outs[0]
rois_num = outs[-1]
# tgt_labels, tgt_bboxes, tgt_gt_inds
targets = outs[1:4]
return rois, rois_num, targets
@register
@serializable
class MaskAssigner(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册