未验证 提交 571dcb34 编写于 作者: W wangguanzhong 提交者: GitHub

support no label training (#2576)

* support no label training

* fix load_crowd

* update to fit rbox

* add empty_ratio

* update configs

* support mask rcnn

* clean drop_empty for mot & keypoint

* refine is_crowd
上级 6604bc45
......@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -26,7 +26,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -34,7 +34,6 @@ EvalReader:
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
......
......@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -119,7 +119,6 @@ EvalReader:
is_scale: true
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
sample_transforms:
......
......@@ -120,7 +120,6 @@ EvalReader:
is_scale: true
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
sample_transforms:
......
......@@ -119,7 +119,6 @@ EvalReader:
is_scale: true
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
sample_transforms:
......
......@@ -128,7 +128,6 @@ EvalReader:
is_scale: true
- Permute: {}
batch_size: 16
drop_empty: false
TestReader:
sample_transforms:
......
......@@ -129,7 +129,6 @@ EvalReader:
is_scale: true
- Permute: {}
batch_size: 16
drop_empty: false
TestReader:
sample_transforms:
......
......@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -26,7 +26,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......@@ -40,4 +39,3 @@ TestReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
......@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {}
batch_size: 1
drop_empty: false
TestReader:
......
......@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {}
batch_size: 1
drop_empty: false
TestReader:
......
......@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {}
batch_size: 1
drop_empty: false
TestReader:
......
......@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
......
......@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
......
......@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
......
......@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
......
......@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
......
......@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 8
drop_empty: false
TestReader:
inputs_def:
......
......@@ -27,7 +27,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -28,7 +28,6 @@ EvalReader:
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
......
......@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
......
......@@ -29,7 +29,6 @@ EvalReader:
- NormalizeImage: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false}
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
......
......@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
......
......@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
......
......@@ -27,7 +27,6 @@ EvalReader:
- Permute: {}
batch_size: 1
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
......@@ -37,4 +36,3 @@ TestReader:
- Permute: {}
batch_size: 1
drop_last: false
drop_empty: false
......@@ -27,7 +27,6 @@ EvalReader:
- Permute: {}
batch_size: 1
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
......@@ -37,4 +36,3 @@ TestReader:
- Permute: {}
batch_size: 1
drop_last: false
drop_empty: false
......@@ -22,7 +22,6 @@ EvalReader:
- Permute: {}
batch_size: 1
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
......@@ -32,4 +31,3 @@ TestReader:
- Permute: {}
batch_size: 1
drop_last: false
drop_empty: false
......@@ -32,7 +32,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
......
......@@ -116,8 +116,6 @@ class BaseDataLoader(object):
shuffle (bool): whether to shuffle samples
drop_last (bool): whether to drop the last incomplete,
default False
drop_empty (bool): whether to drop samples with no ground
truth labels, default True
num_classes (int): class number of dataset, default 80
collate_batch (bool): whether to collate batch in dataloader.
If set to True, the samples will collate into batch according
......@@ -140,7 +138,6 @@ class BaseDataLoader(object):
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=80,
collate_batch=True,
use_shared_memory=False,
......@@ -231,13 +228,12 @@ class TrainReader(BaseDataLoader):
batch_size=1,
shuffle=True,
drop_last=True,
drop_empty=True,
num_classes=80,
collate_batch=True,
**kwargs):
super(TrainReader, self).__init__(
sample_transforms, batch_transforms, batch_size, shuffle, drop_last,
drop_empty, num_classes, collate_batch, **kwargs)
super(TrainReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
num_classes, collate_batch, **kwargs)
@register
......@@ -250,12 +246,11 @@ class EvalReader(BaseDataLoader):
batch_size=1,
shuffle=False,
drop_last=True,
drop_empty=True,
num_classes=80,
**kwargs):
super(EvalReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs)
num_classes, **kwargs)
@register
......@@ -268,12 +263,11 @@ class TestReader(BaseDataLoader):
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=80,
**kwargs):
super(TestReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs)
num_classes, **kwargs)
@register
......@@ -286,12 +280,11 @@ class EvalMOTReader(BaseDataLoader):
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=1,
**kwargs):
super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs)
num_classes, **kwargs)
@register
......@@ -304,9 +297,8 @@ class TestMOTReader(BaseDataLoader):
batch_size=1,
shuffle=False,
drop_last=False,
drop_empty=True,
num_classes=1,
**kwargs):
super(TestMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs)
num_classes, **kwargs)
......@@ -33,6 +33,12 @@ class COCODataSet(DetDataset):
anno_path (str): coco annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
load_crowd (bool): whether to load crowded ground-truth.
False as default
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records. 1. as default
"""
def __init__(self,
......@@ -40,11 +46,26 @@ class COCODataSet(DetDataset):
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1):
sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.):
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
data_fields, sample_num)
self.load_image_only = False
self.load_semantic = False
self.load_crowd = load_crowd
self.allow_empty = allow_empty
self.empty_ratio = empty_ratio
def _sample_empty(self, records, num):
# if empty_ratio is out of [0. ,1.), do not sample the records
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = int(num * self.empty_ratio / (1 - self.empty_ratio))
records = random.sample(records, sample_num)
return records
def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
......@@ -58,6 +79,7 @@ class COCODataSet(DetDataset):
img_ids.sort()
cat_ids = coco.getCatIds()
records = []
empty_records = []
ct = 0
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
......@@ -79,6 +101,7 @@ class COCODataSet(DetDataset):
im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname
is_empty = False
if not os.path.exists(im_path):
logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path))
......@@ -98,12 +121,16 @@ class COCODataSet(DetDataset):
} if 'image' in self.data_fields else {}
if not self.load_image_only:
ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
ins_anno_ids = coco.getAnnIds(
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
is_rbox_anno = False
for inst in instances:
# check gt bbox
if inst.get('ignore', False):
continue
if 'bbox' not in inst.keys():
continue
else:
......@@ -137,8 +164,10 @@ class COCODataSet(DetDataset):
img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
if num_bbox <= 0:
if num_bbox <= 0 and not self.allow_empty:
continue
elif num_bbox <= 0:
is_empty = True
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
if is_rbox_anno:
......@@ -165,7 +194,8 @@ class COCODataSet(DetDataset):
gt_poly[i] = box['segmentation']
has_segmentation = True
if has_segmentation and not any(gt_poly):
if has_segmentation and not any(
gt_poly) and not self.allow_empty:
continue
if is_rbox_anno:
......@@ -196,10 +226,16 @@ class COCODataSet(DetDataset):
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_path, img_id, im_h, im_w))
if is_empty:
empty_records.append(coco_rec)
else:
records.append(coco_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert len(records) > 0, 'not found any coco record in %s' % (anno_path)
assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records
......@@ -124,6 +124,10 @@ def bbox_overlaps(boxes1, boxes2):
Return:
overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N]
"""
M = boxes1.shape[0]
N = boxes2.shape[0]
if M * N == 0:
return paddle.zeros([M, N], dtype='float32')
area1 = bbox_area(boxes1)
area2 = bbox_area(boxes2)
......
......@@ -265,14 +265,26 @@ class BBoxHead(nn.Layer):
targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
rois (List[Tensor]): RoIs generated in each batch
"""
cls_name = 'loss_bbox_cls'
reg_name = 'loss_bbox_reg'
loss_bbox = {}
# TODO: better pass args
tgt_labels, tgt_bboxes, tgt_gt_inds = targets
# bbox cls
tgt_labels = paddle.concat(tgt_labels) if len(
tgt_labels) > 1 else tgt_labels[0]
valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
if valid_inds.shape[0] == 0:
loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
else:
tgt_labels = tgt_labels.cast('int64')
tgt_labels.stop_gradient = True
loss_bbox_cls = F.cross_entropy(
input=scores, label=tgt_labels, reduction='mean')
loss_bbox[cls_name] = loss_bbox_cls
# bbox reg
cls_agnostic_bbox_reg = deltas.shape[1] == 4
......@@ -281,14 +293,9 @@ class BBoxHead(nn.Layer):
paddle.logical_and(tgt_labels >= 0, tgt_labels <
self.num_classes)).flatten()
cls_name = 'loss_bbox_cls'
reg_name = 'loss_bbox_reg'
loss_bbox = {}
loss_weight = 1.
if fg_inds.numel() == 0:
fg_inds = paddle.zeros([1], dtype='int32')
loss_weight = 0.
loss_bbox[reg_name] = paddle.zeros([1], dtype='float32')
return loss_bbox
if cls_agnostic_bbox_reg:
reg_delta = paddle.gather(deltas, fg_inds)
......@@ -323,8 +330,7 @@ class BBoxHead(nn.Layer):
loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
) / tgt_labels.shape[0]
loss_bbox[cls_name] = loss_bbox_cls * loss_weight
loss_bbox[reg_name] = loss_bbox_reg * loss_weight
loss_bbox[reg_name] = loss_bbox_reg
return loss_bbox
......
......@@ -238,6 +238,9 @@ class RPNHead(nn.Layer):
valid_ind = paddle.nonzero(valid_mask)
# cls loss
if valid_ind.shape[0] == 0:
loss_rpn_cls = paddle.zeros([1], dtype='float32')
else:
score_pred = paddle.gather(scores, valid_ind)
score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
score_label.stop_gradient = True
......@@ -245,6 +248,9 @@ class RPNHead(nn.Layer):
logit=score_pred, label=score_label, reduction="sum")
# reg loss
if pos_ind.shape[0] == 0:
loss_rpn_reg = paddle.zeros([1], dtype='float32')
else:
loc_pred = paddle.gather(deltas, pos_ind)
loc_tgt = paddle.concat(loc_tgt)
loc_tgt = paddle.gather(loc_tgt, pos_ind)
......
......@@ -28,31 +28,38 @@ def rpn_anchor_target(anchors,
rpn_fg_fraction,
use_random=True,
batch_size=1,
ignore_thresh=-1,
is_crowd=None,
weights=[1., 1., 1., 1.]):
tgt_labels = []
tgt_bboxes = []
tgt_deltas = []
for i in range(batch_size):
gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
# Step1: match anchor and gt_bbox
matches, match_labels = label_box(
anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True)
anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
ignore_thresh, is_crowd_i)
# Step2: sample anchor
fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im,
rpn_fg_fraction, 0, use_random)
# Fill with the ignore label (-1), then set positive and negative labels
labels = paddle.full(match_labels.shape, -1, dtype='int32')
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
if bg_inds.shape[0] > 0:
labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds))
if fg_inds.shape[0] > 0:
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
# Step3: make output
if gt_bbox.shape[0] == 0:
matched_gt_boxes = paddle.zeros([0, 4])
tgt_delta = paddle.zeros([0, 4])
else:
matched_gt_boxes = paddle.gather(gt_bbox, matches)
tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
labels.stop_gradient = True
matched_gt_boxes.stop_gradient = True
tgt_delta.stop_gradient = True
labels.stop_gradient = True
tgt_labels.append(labels)
tgt_bboxes.append(matched_gt_boxes)
tgt_deltas.append(tgt_delta)
......@@ -60,16 +67,46 @@ def rpn_anchor_target(anchors,
return tgt_labels, tgt_bboxes, tgt_deltas
def label_box(anchors, gt_boxes, positive_overlap, negative_overlap,
allow_low_quality):
def label_box(anchors,
gt_boxes,
positive_overlap,
negative_overlap,
allow_low_quality,
ignore_thresh,
is_crowd=None):
iou = bbox_overlaps(gt_boxes, anchors)
if iou.numel() == 0:
n_gt = gt_boxes.shape[0]
if n_gt == 0 or is_crowd is None:
n_gt_crowd = 0
else:
n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
if iou.shape[0] == 0 or n_gt_crowd == n_gt:
# No truth, assign everything to background
default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
default_match_labels = paddle.full((iou.shape[1], ), -1, dtype='int32')
default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
return default_matches, default_match_labels
# if ignore_thresh > 0, remove anchor if it is closed to
# one of the crowded ground-truth
if n_gt_crowd > 0:
N_a = anchors.shape[0]
ones = paddle.ones([N_a])
mask = is_crowd * ones
if ignore_thresh > 0:
crowd_iou = iou * mask
valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
axis=0) > 0).cast('float32')
iou = iou * (1 - valid) - valid
# ignore the iou between anchor and crowded ground-truth
iou = iou * (1 - mask) - mask
matched_vals, matches = paddle.topk(iou, k=1, axis=0)
match_labels = paddle.full(matches.shape, -1, dtype='int32')
match_labels = paddle.where(matched_vals < negative_overlap,
# set ignored anchor with iou = -1
neg_cond = paddle.logical_and(matched_vals > -1,
matched_vals < negative_overlap)
match_labels = paddle.where(neg_cond,
paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels)
......@@ -84,6 +121,7 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap,
matches = matches.flatten()
match_labels = match_labels.flatten()
return matches, match_labels
......@@ -96,24 +134,36 @@ def subsample_labels(labels,
paddle.logical_and(labels != -1, labels != bg_label))
negative = paddle.nonzero(labels == bg_label)
positive = positive.cast('int32').flatten()
negative = negative.cast('int32').flatten()
fg_num = int(num_samples * fg_fraction)
fg_num = min(positive.numel(), fg_num)
bg_num = num_samples - fg_num
bg_num = min(negative.numel(), bg_num)
if fg_num == 0 and bg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
bg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
# randomly select positive and negative examples
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
negative = negative.cast('int32').flatten()
bg_perm = paddle.randperm(negative.numel(), dtype='int32')
bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
if use_random:
fg_inds = paddle.gather(positive, fg_perm)
bg_inds = paddle.gather(negative, bg_perm)
else:
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
if fg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
positive = positive.cast('int32').flatten()
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
if use_random:
fg_inds = paddle.gather(positive, fg_perm)
else:
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
return fg_inds, bg_inds
......@@ -125,6 +175,8 @@ def generate_proposal_target(rpn_rois,
fg_thresh,
bg_thresh,
num_classes,
ignore_thresh=-1.,
is_crowd=None,
use_random=True,
is_cascade=False,
cascade_iou=0.5):
......@@ -141,17 +193,18 @@ def generate_proposal_target(rpn_rois,
bg_thresh = cascade_iou if is_cascade else bg_thresh
for i, rpn_roi in enumerate(rpn_rois):
gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
gt_class = paddle.squeeze(gt_classes[i], axis=-1)
# Concat RoIs and gt boxes except cascade rcnn
if not is_cascade:
# Concat RoIs and gt boxes except cascade rcnn or none gt
if not is_cascade and gt_bbox.shape[0] > 0:
bbox = paddle.concat([rpn_roi, gt_bbox])
else:
bbox = rpn_roi
# Step1: label bbox
matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
False)
False, ignore_thresh, is_crowd_i)
# Step2: sample bbox
sampled_inds, sampled_gt_classes = sample_bbox(
matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
......@@ -162,7 +215,10 @@ def generate_proposal_target(rpn_rois,
sampled_inds)
sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
sampled_inds)
if gt_bbox.shape[0] > 0:
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
else:
sampled_bbox = paddle.zeros([0, 4], dtype='float32')
rois_per_image.stop_gradient = True
sampled_gt_ind.stop_gradient = True
......@@ -184,6 +240,13 @@ def sample_bbox(matches,
num_classes,
use_random=True,
is_cascade=False):
n_gt = gt_classes.shape[0]
if n_gt == 0:
# No truth, assign everything to background
gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
#return matches, match_labels + num_classes
else:
gt_classes = paddle.gather(gt_classes, matches)
gt_classes = paddle.where(match_labels == 0,
paddle.ones_like(gt_classes) * num_classes,
......@@ -191,11 +254,17 @@ def sample_bbox(matches,
gt_classes = paddle.where(match_labels == -1,
paddle.ones_like(gt_classes) * -1, gt_classes)
if is_cascade:
return matches, gt_classes
index = paddle.arange(matches.shape[0])
return index, gt_classes
rois_per_image = int(batch_size_per_im)
fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction,
num_classes, use_random)
if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
# fake output labeled with -1 when all boxes are neither
# foreground nor background
sampled_inds = paddle.zeros([1], dtype='int32')
else:
sampled_inds = paddle.concat([fg_inds, bg_inds])
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
return sampled_inds, sampled_gt_classes
......@@ -268,20 +337,29 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
# to generate mask target with ground-truth
boxes = fg_rois.numpy()
gt_segms_per_im = gt_segms[k]
new_segm = []
inds_per_im = inds_per_im.numpy()
if len(gt_segms_per_im) > 0:
for i in inds_per_im:
new_segm.append(gt_segms_per_im[i])
fg_inds_new = fg_inds.reshape([-1]).numpy()
results = []
if len(gt_segms_per_im) > 0:
for j in fg_inds_new:
results.append(
rasterize_polygons_within_box(new_segm[j], boxes[j],
resolution))
else:
results.append(paddle.ones([resolution, resolution], dtype='int32'))
fg_classes = paddle.gather(labels_per_im, fg_inds)
weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
if not has_fg:
# now all sampled classes are background
# which will cause error in loss calculation,
# make fake classes with weight of 0.
fg_classes = paddle.zeros([1], dtype='int32')
weight = weight - 1
tgt_mask = paddle.stack(results)
tgt_mask.stop_gradient = True
......
......@@ -44,6 +44,8 @@ class RPNTargetAssign(object):
negative_overlap (float): Maximum overlap allowed between an anchor
and ground-truth box for the (anchor, gt box) pair to be
a background sample. default 0.3
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and
background boxes, default true.
"""
......@@ -53,12 +55,14 @@ class RPNTargetAssign(object):
fg_fraction=0.5,
positive_overlap=0.7,
negative_overlap=0.3,
ignore_thresh=-1.,
use_random=True):
super(RPNTargetAssign, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap
self.ignore_thresh = ignore_thresh
self.use_random = use_random
def __call__(self, inputs, anchors):
......@@ -67,11 +71,12 @@ class RPNTargetAssign(object):
anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps.
"""
gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
batch_size = len(gt_boxes)
tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap,
self.negative_overlap, self.fg_fraction, self.use_random,
batch_size)
batch_size, self.ignore_thresh, is_crowd)
norm = self.batch_size_per_im * batch_size
return tgt_labels, tgt_bboxes, tgt_deltas, norm
......@@ -101,6 +106,8 @@ class BBoxAssigner(object):
bg_thresh (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be
a background sample. default 0.5
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and
background boxes, default true
cascade_iou (list[iou]): The list of overlap to select foreground and
......@@ -113,6 +120,7 @@ class BBoxAssigner(object):
fg_fraction=.25,
fg_thresh=.5,
bg_thresh=.5,
ignore_thresh=-1.,
use_random=True,
cascade_iou=[0.5, 0.6, 0.7],
num_classes=80):
......@@ -121,6 +129,7 @@ class BBoxAssigner(object):
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh = bg_thresh
self.ignore_thresh = ignore_thresh
self.use_random = use_random
self.cascade_iou = cascade_iou
self.num_classes = num_classes
......@@ -133,12 +142,14 @@ class BBoxAssigner(object):
is_cascade=False):
gt_classes = inputs['gt_class']
gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds
# new_rois_num
outs = generate_proposal_target(
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
self.use_random, is_cascade, self.cascade_iou[stage])
self.ignore_thresh, is_crowd, self.use_random, is_cascade,
self.cascade_iou[stage])
rois = outs[0]
rois_num = outs[-1]
# tgt_labels, tgt_bboxes, tgt_gt_inds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册