未验证 提交 571dcb34 编写于 作者: W wangguanzhong 提交者: GitHub

support no label training (#2576)

* support no label training

* fix load_crowd

* update to fit rbox

* add empty_ratio

* update configs

* support mask rcnn

* clean drop_empty for mot & keypoint

* refine is_crowd
上级 6604bc45
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -34,7 +34,6 @@ EvalReader: ...@@ -34,7 +34,6 @@ EvalReader:
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false} - NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -119,7 +119,6 @@ EvalReader: ...@@ -119,7 +119,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -120,7 +120,6 @@ EvalReader: ...@@ -120,7 +120,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -119,7 +119,6 @@ EvalReader: ...@@ -119,7 +119,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -128,7 +128,6 @@ EvalReader: ...@@ -128,7 +128,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 16 batch_size: 16
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -129,7 +129,6 @@ EvalReader: ...@@ -129,7 +129,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 16 batch_size: 16
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
...@@ -40,4 +39,3 @@ TestReader: ...@@ -40,4 +39,3 @@ TestReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
...@@ -45,7 +45,6 @@ EvalReader: ...@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {} - BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {} - Norm2PixelBbox: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -45,7 +45,6 @@ EvalReader: ...@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {} - BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {} - Norm2PixelBbox: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -45,7 +45,6 @@ EvalReader: ...@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {} - BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {} - Norm2PixelBbox: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -30,7 +30,6 @@ EvalReader: ...@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -30,7 +30,6 @@ EvalReader: ...@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -30,7 +30,6 @@ EvalReader: ...@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -47,7 +47,6 @@ EvalReader: ...@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -47,7 +47,6 @@ EvalReader: ...@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -47,7 +47,6 @@ EvalReader: ...@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -27,7 +27,6 @@ EvalReader: ...@@ -27,7 +27,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -28,7 +28,6 @@ EvalReader: ...@@ -28,7 +28,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false} - NormalizeImage: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -29,7 +29,6 @@ EvalReader: ...@@ -29,7 +29,6 @@ EvalReader:
- NormalizeImage: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false} - NormalizeImage: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -27,7 +27,6 @@ EvalReader: ...@@ -27,7 +27,6 @@ EvalReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
...@@ -37,4 +36,3 @@ TestReader: ...@@ -37,4 +36,3 @@ TestReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
...@@ -27,7 +27,6 @@ EvalReader: ...@@ -27,7 +27,6 @@ EvalReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
...@@ -37,4 +36,3 @@ TestReader: ...@@ -37,4 +36,3 @@ TestReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
...@@ -22,7 +22,6 @@ EvalReader: ...@@ -22,7 +22,6 @@ EvalReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
...@@ -32,4 +31,3 @@ TestReader: ...@@ -32,4 +31,3 @@ TestReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
...@@ -32,7 +32,6 @@ EvalReader: ...@@ -32,7 +32,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -116,8 +116,6 @@ class BaseDataLoader(object): ...@@ -116,8 +116,6 @@ class BaseDataLoader(object):
shuffle (bool): whether to shuffle samples shuffle (bool): whether to shuffle samples
drop_last (bool): whether to drop the last incomplete, drop_last (bool): whether to drop the last incomplete,
default False default False
drop_empty (bool): whether to drop samples with no ground
truth labels, default True
num_classes (int): class number of dataset, default 80 num_classes (int): class number of dataset, default 80
collate_batch (bool): whether to collate batch in dataloader. collate_batch (bool): whether to collate batch in dataloader.
If set to True, the samples will collate into batch according If set to True, the samples will collate into batch according
...@@ -140,7 +138,6 @@ class BaseDataLoader(object): ...@@ -140,7 +138,6 @@ class BaseDataLoader(object):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=80, num_classes=80,
collate_batch=True, collate_batch=True,
use_shared_memory=False, use_shared_memory=False,
...@@ -231,13 +228,12 @@ class TrainReader(BaseDataLoader): ...@@ -231,13 +228,12 @@ class TrainReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
drop_empty=True,
num_classes=80, num_classes=80,
collate_batch=True, collate_batch=True,
**kwargs): **kwargs):
super(TrainReader, self).__init__( super(TrainReader, self).__init__(sample_transforms, batch_transforms,
sample_transforms, batch_transforms, batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, collate_batch, **kwargs) num_classes, collate_batch, **kwargs)
@register @register
...@@ -250,12 +246,11 @@ class EvalReader(BaseDataLoader): ...@@ -250,12 +246,11 @@ class EvalReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
drop_empty=True,
num_classes=80, num_classes=80,
**kwargs): **kwargs):
super(EvalReader, self).__init__(sample_transforms, batch_transforms, super(EvalReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
@register @register
...@@ -268,12 +263,11 @@ class TestReader(BaseDataLoader): ...@@ -268,12 +263,11 @@ class TestReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=80, num_classes=80,
**kwargs): **kwargs):
super(TestReader, self).__init__(sample_transforms, batch_transforms, super(TestReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
@register @register
...@@ -286,12 +280,11 @@ class EvalMOTReader(BaseDataLoader): ...@@ -286,12 +280,11 @@ class EvalMOTReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=1, num_classes=1,
**kwargs): **kwargs):
super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms, super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
@register @register
...@@ -304,9 +297,8 @@ class TestMOTReader(BaseDataLoader): ...@@ -304,9 +297,8 @@ class TestMOTReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=1, num_classes=1,
**kwargs): **kwargs):
super(TestMOTReader, self).__init__(sample_transforms, batch_transforms, super(TestMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
...@@ -33,6 +33,12 @@ class COCODataSet(DetDataset): ...@@ -33,6 +33,12 @@ class COCODataSet(DetDataset):
anno_path (str): coco annotation file path. anno_path (str): coco annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'. data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all. sample_num (int): number of samples to load, -1 means all.
load_crowd (bool): whether to load crowded ground-truth.
False as default
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records. 1. as default
""" """
def __init__(self, def __init__(self,
...@@ -40,11 +46,26 @@ class COCODataSet(DetDataset): ...@@ -40,11 +46,26 @@ class COCODataSet(DetDataset):
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
data_fields=['image'], data_fields=['image'],
sample_num=-1): sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.):
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path, super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
data_fields, sample_num) data_fields, sample_num)
self.load_image_only = False self.load_image_only = False
self.load_semantic = False self.load_semantic = False
self.load_crowd = load_crowd
self.allow_empty = allow_empty
self.empty_ratio = empty_ratio
def _sample_empty(self, records, num):
# if empty_ratio is out of [0. ,1.), do not sample the records
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = int(num * self.empty_ratio / (1 - self.empty_ratio))
records = random.sample(records, sample_num)
return records
def parse_dataset(self): def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path) anno_path = os.path.join(self.dataset_dir, self.anno_path)
...@@ -58,6 +79,7 @@ class COCODataSet(DetDataset): ...@@ -58,6 +79,7 @@ class COCODataSet(DetDataset):
img_ids.sort() img_ids.sort()
cat_ids = coco.getCatIds() cat_ids = coco.getCatIds()
records = [] records = []
empty_records = []
ct = 0 ct = 0
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
...@@ -79,6 +101,7 @@ class COCODataSet(DetDataset): ...@@ -79,6 +101,7 @@ class COCODataSet(DetDataset):
im_path = os.path.join(image_dir, im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname im_fname) if image_dir else im_fname
is_empty = False
if not os.path.exists(im_path): if not os.path.exists(im_path):
logger.warning('Illegal image file: {}, and it will be ' logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path)) 'ignored'.format(im_path))
...@@ -98,12 +121,16 @@ class COCODataSet(DetDataset): ...@@ -98,12 +121,16 @@ class COCODataSet(DetDataset):
} if 'image' in self.data_fields else {} } if 'image' in self.data_fields else {}
if not self.load_image_only: if not self.load_image_only:
ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False) ins_anno_ids = coco.getAnnIds(
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
instances = coco.loadAnns(ins_anno_ids) instances = coco.loadAnns(ins_anno_ids)
bboxes = [] bboxes = []
is_rbox_anno = False
for inst in instances: for inst in instances:
# check gt bbox # check gt bbox
if inst.get('ignore', False):
continue
if 'bbox' not in inst.keys(): if 'bbox' not in inst.keys():
continue continue
else: else:
...@@ -137,8 +164,10 @@ class COCODataSet(DetDataset): ...@@ -137,8 +164,10 @@ class COCODataSet(DetDataset):
img_id, float(inst['area']), x1, y1, x2, y2)) img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes) num_bbox = len(bboxes)
if num_bbox <= 0: if num_bbox <= 0 and not self.allow_empty:
continue continue
elif num_bbox <= 0:
is_empty = True
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
if is_rbox_anno: if is_rbox_anno:
...@@ -165,7 +194,8 @@ class COCODataSet(DetDataset): ...@@ -165,7 +194,8 @@ class COCODataSet(DetDataset):
gt_poly[i] = box['segmentation'] gt_poly[i] = box['segmentation']
has_segmentation = True has_segmentation = True
if has_segmentation and not any(gt_poly): if has_segmentation and not any(
gt_poly) and not self.allow_empty:
continue continue
if is_rbox_anno: if is_rbox_anno:
...@@ -196,10 +226,16 @@ class COCODataSet(DetDataset): ...@@ -196,10 +226,16 @@ class COCODataSet(DetDataset):
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format( logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_path, img_id, im_h, im_w)) im_path, img_id, im_h, im_w))
records.append(coco_rec) if is_empty:
empty_records.append(coco_rec)
else:
records.append(coco_rec)
ct += 1 ct += 1
if self.sample_num > 0 and ct >= self.sample_num: if self.sample_num > 0 and ct >= self.sample_num:
break break
assert len(records) > 0, 'not found any coco record in %s' % (anno_path) assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records self.roidbs = records
...@@ -124,6 +124,10 @@ def bbox_overlaps(boxes1, boxes2): ...@@ -124,6 +124,10 @@ def bbox_overlaps(boxes1, boxes2):
Return: Return:
overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N] overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N]
""" """
M = boxes1.shape[0]
N = boxes2.shape[0]
if M * N == 0:
return paddle.zeros([M, N], dtype='float32')
area1 = bbox_area(boxes1) area1 = bbox_area(boxes1)
area2 = bbox_area(boxes2) area2 = bbox_area(boxes2)
......
...@@ -265,14 +265,26 @@ class BBoxHead(nn.Layer): ...@@ -265,14 +265,26 @@ class BBoxHead(nn.Layer):
targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
rois (List[Tensor]): RoIs generated in each batch rois (List[Tensor]): RoIs generated in each batch
""" """
cls_name = 'loss_bbox_cls'
reg_name = 'loss_bbox_reg'
loss_bbox = {}
# TODO: better pass args # TODO: better pass args
tgt_labels, tgt_bboxes, tgt_gt_inds = targets tgt_labels, tgt_bboxes, tgt_gt_inds = targets
# bbox cls
tgt_labels = paddle.concat(tgt_labels) if len( tgt_labels = paddle.concat(tgt_labels) if len(
tgt_labels) > 1 else tgt_labels[0] tgt_labels) > 1 else tgt_labels[0]
tgt_labels = tgt_labels.cast('int64') valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
tgt_labels.stop_gradient = True if valid_inds.shape[0] == 0:
loss_bbox_cls = F.cross_entropy( loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
input=scores, label=tgt_labels, reduction='mean') else:
tgt_labels = tgt_labels.cast('int64')
tgt_labels.stop_gradient = True
loss_bbox_cls = F.cross_entropy(
input=scores, label=tgt_labels, reduction='mean')
loss_bbox[cls_name] = loss_bbox_cls
# bbox reg # bbox reg
cls_agnostic_bbox_reg = deltas.shape[1] == 4 cls_agnostic_bbox_reg = deltas.shape[1] == 4
...@@ -281,14 +293,9 @@ class BBoxHead(nn.Layer): ...@@ -281,14 +293,9 @@ class BBoxHead(nn.Layer):
paddle.logical_and(tgt_labels >= 0, tgt_labels < paddle.logical_and(tgt_labels >= 0, tgt_labels <
self.num_classes)).flatten() self.num_classes)).flatten()
cls_name = 'loss_bbox_cls'
reg_name = 'loss_bbox_reg'
loss_bbox = {}
loss_weight = 1.
if fg_inds.numel() == 0: if fg_inds.numel() == 0:
fg_inds = paddle.zeros([1], dtype='int32') loss_bbox[reg_name] = paddle.zeros([1], dtype='float32')
loss_weight = 0. return loss_bbox
if cls_agnostic_bbox_reg: if cls_agnostic_bbox_reg:
reg_delta = paddle.gather(deltas, fg_inds) reg_delta = paddle.gather(deltas, fg_inds)
...@@ -323,8 +330,7 @@ class BBoxHead(nn.Layer): ...@@ -323,8 +330,7 @@ class BBoxHead(nn.Layer):
loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum( loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
) / tgt_labels.shape[0] ) / tgt_labels.shape[0]
loss_bbox[cls_name] = loss_bbox_cls * loss_weight loss_bbox[reg_name] = loss_bbox_reg
loss_bbox[reg_name] = loss_bbox_reg * loss_weight
return loss_bbox return loss_bbox
......
...@@ -238,18 +238,24 @@ class RPNHead(nn.Layer): ...@@ -238,18 +238,24 @@ class RPNHead(nn.Layer):
valid_ind = paddle.nonzero(valid_mask) valid_ind = paddle.nonzero(valid_mask)
# cls loss # cls loss
score_pred = paddle.gather(scores, valid_ind) if valid_ind.shape[0] == 0:
score_label = paddle.gather(score_tgt, valid_ind).cast('float32') loss_rpn_cls = paddle.zeros([1], dtype='float32')
score_label.stop_gradient = True else:
loss_rpn_cls = F.binary_cross_entropy_with_logits( score_pred = paddle.gather(scores, valid_ind)
logit=score_pred, label=score_label, reduction="sum") score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
score_label.stop_gradient = True
loss_rpn_cls = F.binary_cross_entropy_with_logits(
logit=score_pred, label=score_label, reduction="sum")
# reg loss # reg loss
loc_pred = paddle.gather(deltas, pos_ind) if pos_ind.shape[0] == 0:
loc_tgt = paddle.concat(loc_tgt) loss_rpn_reg = paddle.zeros([1], dtype='float32')
loc_tgt = paddle.gather(loc_tgt, pos_ind) else:
loc_tgt.stop_gradient = True loc_pred = paddle.gather(deltas, pos_ind)
loss_rpn_reg = paddle.abs(loc_pred - loc_tgt).sum() loc_tgt = paddle.concat(loc_tgt)
loc_tgt = paddle.gather(loc_tgt, pos_ind)
loc_tgt.stop_gradient = True
loss_rpn_reg = paddle.abs(loc_pred - loc_tgt).sum()
return { return {
'loss_rpn_cls': loss_rpn_cls / norm, 'loss_rpn_cls': loss_rpn_cls / norm,
'loss_rpn_reg': loss_rpn_reg / norm 'loss_rpn_reg': loss_rpn_reg / norm
......
...@@ -28,31 +28,38 @@ def rpn_anchor_target(anchors, ...@@ -28,31 +28,38 @@ def rpn_anchor_target(anchors,
rpn_fg_fraction, rpn_fg_fraction,
use_random=True, use_random=True,
batch_size=1, batch_size=1,
ignore_thresh=-1,
is_crowd=None,
weights=[1., 1., 1., 1.]): weights=[1., 1., 1., 1.]):
tgt_labels = [] tgt_labels = []
tgt_bboxes = [] tgt_bboxes = []
tgt_deltas = [] tgt_deltas = []
for i in range(batch_size): for i in range(batch_size):
gt_bbox = gt_boxes[i] gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
# Step1: match anchor and gt_bbox # Step1: match anchor and gt_bbox
matches, match_labels = label_box( matches, match_labels = label_box(
anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True) anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
ignore_thresh, is_crowd_i)
# Step2: sample anchor # Step2: sample anchor
fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im, fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im,
rpn_fg_fraction, 0, use_random) rpn_fg_fraction, 0, use_random)
# Fill with the ignore label (-1), then set positive and negative labels # Fill with the ignore label (-1), then set positive and negative labels
labels = paddle.full(match_labels.shape, -1, dtype='int32') labels = paddle.full(match_labels.shape, -1, dtype='int32')
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds)) if bg_inds.shape[0] > 0:
labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds)) labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds))
if fg_inds.shape[0] > 0:
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
# Step3: make output # Step3: make output
matched_gt_boxes = paddle.gather(gt_bbox, matches) if gt_bbox.shape[0] == 0:
matched_gt_boxes = paddle.zeros([0, 4])
tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights) tgt_delta = paddle.zeros([0, 4])
else:
matched_gt_boxes = paddle.gather(gt_bbox, matches)
tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
matched_gt_boxes.stop_gradient = True
tgt_delta.stop_gradient = True
labels.stop_gradient = True labels.stop_gradient = True
matched_gt_boxes.stop_gradient = True
tgt_delta.stop_gradient = True
tgt_labels.append(labels) tgt_labels.append(labels)
tgt_bboxes.append(matched_gt_boxes) tgt_bboxes.append(matched_gt_boxes)
tgt_deltas.append(tgt_delta) tgt_deltas.append(tgt_delta)
...@@ -60,16 +67,46 @@ def rpn_anchor_target(anchors, ...@@ -60,16 +67,46 @@ def rpn_anchor_target(anchors,
return tgt_labels, tgt_bboxes, tgt_deltas return tgt_labels, tgt_bboxes, tgt_deltas
def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, def label_box(anchors,
allow_low_quality): gt_boxes,
positive_overlap,
negative_overlap,
allow_low_quality,
ignore_thresh,
is_crowd=None):
iou = bbox_overlaps(gt_boxes, anchors) iou = bbox_overlaps(gt_boxes, anchors)
if iou.numel() == 0: n_gt = gt_boxes.shape[0]
if n_gt == 0 or is_crowd is None:
n_gt_crowd = 0
else:
n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
if iou.shape[0] == 0 or n_gt_crowd == n_gt:
# No truth, assign everything to background
default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64') default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
default_match_labels = paddle.full((iou.shape[1], ), -1, dtype='int32') default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
return default_matches, default_match_labels return default_matches, default_match_labels
# if ignore_thresh > 0, remove anchor if it is closed to
# one of the crowded ground-truth
if n_gt_crowd > 0:
N_a = anchors.shape[0]
ones = paddle.ones([N_a])
mask = is_crowd * ones
if ignore_thresh > 0:
crowd_iou = iou * mask
valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
axis=0) > 0).cast('float32')
iou = iou * (1 - valid) - valid
# ignore the iou between anchor and crowded ground-truth
iou = iou * (1 - mask) - mask
matched_vals, matches = paddle.topk(iou, k=1, axis=0) matched_vals, matches = paddle.topk(iou, k=1, axis=0)
match_labels = paddle.full(matches.shape, -1, dtype='int32') match_labels = paddle.full(matches.shape, -1, dtype='int32')
match_labels = paddle.where(matched_vals < negative_overlap, # set ignored anchor with iou = -1
neg_cond = paddle.logical_and(matched_vals > -1,
matched_vals < negative_overlap)
match_labels = paddle.where(neg_cond,
paddle.zeros_like(match_labels), match_labels) paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap, match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels) paddle.ones_like(match_labels), match_labels)
...@@ -84,6 +121,7 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, ...@@ -84,6 +121,7 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap,
matches = matches.flatten() matches = matches.flatten()
match_labels = match_labels.flatten() match_labels = match_labels.flatten()
return matches, match_labels return matches, match_labels
...@@ -96,24 +134,36 @@ def subsample_labels(labels, ...@@ -96,24 +134,36 @@ def subsample_labels(labels,
paddle.logical_and(labels != -1, labels != bg_label)) paddle.logical_and(labels != -1, labels != bg_label))
negative = paddle.nonzero(labels == bg_label) negative = paddle.nonzero(labels == bg_label)
positive = positive.cast('int32').flatten()
negative = negative.cast('int32').flatten()
fg_num = int(num_samples * fg_fraction) fg_num = int(num_samples * fg_fraction)
fg_num = min(positive.numel(), fg_num) fg_num = min(positive.numel(), fg_num)
bg_num = num_samples - fg_num bg_num = num_samples - fg_num
bg_num = min(negative.numel(), bg_num) bg_num = min(negative.numel(), bg_num)
if fg_num == 0 and bg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
bg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
# randomly select positive and negative examples # randomly select positive and negative examples
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num]) negative = negative.cast('int32').flatten()
bg_perm = paddle.randperm(negative.numel(), dtype='int32') bg_perm = paddle.randperm(negative.numel(), dtype='int32')
bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num]) bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
if use_random: if use_random:
fg_inds = paddle.gather(positive, fg_perm)
bg_inds = paddle.gather(negative, bg_perm) bg_inds = paddle.gather(negative, bg_perm)
else: else:
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num]) bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
if fg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
positive = positive.cast('int32').flatten()
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
if use_random:
fg_inds = paddle.gather(positive, fg_perm)
else:
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
return fg_inds, bg_inds return fg_inds, bg_inds
...@@ -125,6 +175,8 @@ def generate_proposal_target(rpn_rois, ...@@ -125,6 +175,8 @@ def generate_proposal_target(rpn_rois,
fg_thresh, fg_thresh,
bg_thresh, bg_thresh,
num_classes, num_classes,
ignore_thresh=-1.,
is_crowd=None,
use_random=True, use_random=True,
is_cascade=False, is_cascade=False,
cascade_iou=0.5): cascade_iou=0.5):
...@@ -141,17 +193,18 @@ def generate_proposal_target(rpn_rois, ...@@ -141,17 +193,18 @@ def generate_proposal_target(rpn_rois,
bg_thresh = cascade_iou if is_cascade else bg_thresh bg_thresh = cascade_iou if is_cascade else bg_thresh
for i, rpn_roi in enumerate(rpn_rois): for i, rpn_roi in enumerate(rpn_rois):
gt_bbox = gt_boxes[i] gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
gt_class = paddle.squeeze(gt_classes[i], axis=-1) gt_class = paddle.squeeze(gt_classes[i], axis=-1)
# Concat RoIs and gt boxes except cascade rcnn # Concat RoIs and gt boxes except cascade rcnn or none gt
if not is_cascade: if not is_cascade and gt_bbox.shape[0] > 0:
bbox = paddle.concat([rpn_roi, gt_bbox]) bbox = paddle.concat([rpn_roi, gt_bbox])
else: else:
bbox = rpn_roi bbox = rpn_roi
# Step1: label bbox # Step1: label bbox
matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh, matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
False) False, ignore_thresh, is_crowd_i)
# Step2: sample bbox # Step2: sample bbox
sampled_inds, sampled_gt_classes = sample_bbox( sampled_inds, sampled_gt_classes = sample_bbox(
matches, match_labels, gt_class, batch_size_per_im, fg_fraction, matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
...@@ -162,7 +215,10 @@ def generate_proposal_target(rpn_rois, ...@@ -162,7 +215,10 @@ def generate_proposal_target(rpn_rois,
sampled_inds) sampled_inds)
sampled_gt_ind = matches if is_cascade else paddle.gather(matches, sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
sampled_inds) sampled_inds)
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind) if gt_bbox.shape[0] > 0:
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
else:
sampled_bbox = paddle.zeros([0, 4], dtype='float32')
rois_per_image.stop_gradient = True rois_per_image.stop_gradient = True
sampled_gt_ind.stop_gradient = True sampled_gt_ind.stop_gradient = True
...@@ -184,19 +240,32 @@ def sample_bbox(matches, ...@@ -184,19 +240,32 @@ def sample_bbox(matches,
num_classes, num_classes,
use_random=True, use_random=True,
is_cascade=False): is_cascade=False):
gt_classes = paddle.gather(gt_classes, matches)
gt_classes = paddle.where(match_labels == 0, n_gt = gt_classes.shape[0]
paddle.ones_like(gt_classes) * num_classes, if n_gt == 0:
gt_classes) # No truth, assign everything to background
gt_classes = paddle.where(match_labels == -1, gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
paddle.ones_like(gt_classes) * -1, gt_classes) #return matches, match_labels + num_classes
else:
gt_classes = paddle.gather(gt_classes, matches)
gt_classes = paddle.where(match_labels == 0,
paddle.ones_like(gt_classes) * num_classes,
gt_classes)
gt_classes = paddle.where(match_labels == -1,
paddle.ones_like(gt_classes) * -1, gt_classes)
if is_cascade: if is_cascade:
return matches, gt_classes index = paddle.arange(matches.shape[0])
return index, gt_classes
rois_per_image = int(batch_size_per_im) rois_per_image = int(batch_size_per_im)
fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction, fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction,
num_classes, use_random) num_classes, use_random)
sampled_inds = paddle.concat([fg_inds, bg_inds]) if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
# fake output labeled with -1 when all boxes are neither
# foreground nor background
sampled_inds = paddle.zeros([1], dtype='int32')
else:
sampled_inds = paddle.concat([fg_inds, bg_inds])
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds) sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
return sampled_inds, sampled_gt_classes return sampled_inds, sampled_gt_classes
...@@ -268,20 +337,29 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds, ...@@ -268,20 +337,29 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
# to generate mask target with ground-truth # to generate mask target with ground-truth
boxes = fg_rois.numpy() boxes = fg_rois.numpy()
gt_segms_per_im = gt_segms[k] gt_segms_per_im = gt_segms[k]
new_segm = [] new_segm = []
inds_per_im = inds_per_im.numpy() inds_per_im = inds_per_im.numpy()
for i in inds_per_im: if len(gt_segms_per_im) > 0:
new_segm.append(gt_segms_per_im[i]) for i in inds_per_im:
new_segm.append(gt_segms_per_im[i])
fg_inds_new = fg_inds.reshape([-1]).numpy() fg_inds_new = fg_inds.reshape([-1]).numpy()
results = [] results = []
for j in fg_inds_new: if len(gt_segms_per_im) > 0:
results.append( for j in fg_inds_new:
rasterize_polygons_within_box(new_segm[j], boxes[j], results.append(
resolution)) rasterize_polygons_within_box(new_segm[j], boxes[j],
resolution))
else:
results.append(paddle.ones([resolution, resolution], dtype='int32'))
fg_classes = paddle.gather(labels_per_im, fg_inds) fg_classes = paddle.gather(labels_per_im, fg_inds)
weight = paddle.ones([fg_rois.shape[0]], dtype='float32') weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
if not has_fg: if not has_fg:
# now all sampled classes are background
# which will cause error in loss calculation,
# make fake classes with weight of 0.
fg_classes = paddle.zeros([1], dtype='int32')
weight = weight - 1 weight = weight - 1
tgt_mask = paddle.stack(results) tgt_mask = paddle.stack(results)
tgt_mask.stop_gradient = True tgt_mask.stop_gradient = True
......
...@@ -44,6 +44,8 @@ class RPNTargetAssign(object): ...@@ -44,6 +44,8 @@ class RPNTargetAssign(object):
negative_overlap (float): Maximum overlap allowed between an anchor negative_overlap (float): Maximum overlap allowed between an anchor
and ground-truth box for the (anchor, gt box) pair to be and ground-truth box for the (anchor, gt box) pair to be
a background sample. default 0.3 a background sample. default 0.3
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and use_random (bool): Use random sampling to choose foreground and
background boxes, default true. background boxes, default true.
""" """
...@@ -53,12 +55,14 @@ class RPNTargetAssign(object): ...@@ -53,12 +55,14 @@ class RPNTargetAssign(object):
fg_fraction=0.5, fg_fraction=0.5,
positive_overlap=0.7, positive_overlap=0.7,
negative_overlap=0.3, negative_overlap=0.3,
ignore_thresh=-1.,
use_random=True): use_random=True):
super(RPNTargetAssign, self).__init__() super(RPNTargetAssign, self).__init__()
self.batch_size_per_im = batch_size_per_im self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction self.fg_fraction = fg_fraction
self.positive_overlap = positive_overlap self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap self.negative_overlap = negative_overlap
self.ignore_thresh = ignore_thresh
self.use_random = use_random self.use_random = use_random
def __call__(self, inputs, anchors): def __call__(self, inputs, anchors):
...@@ -67,11 +71,12 @@ class RPNTargetAssign(object): ...@@ -67,11 +71,12 @@ class RPNTargetAssign(object):
anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps. anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps.
""" """
gt_boxes = inputs['gt_bbox'] gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
batch_size = len(gt_boxes) batch_size = len(gt_boxes)
tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target( tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap, anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap,
self.negative_overlap, self.fg_fraction, self.use_random, self.negative_overlap, self.fg_fraction, self.use_random,
batch_size) batch_size, self.ignore_thresh, is_crowd)
norm = self.batch_size_per_im * batch_size norm = self.batch_size_per_im * batch_size
return tgt_labels, tgt_bboxes, tgt_deltas, norm return tgt_labels, tgt_bboxes, tgt_deltas, norm
...@@ -101,7 +106,9 @@ class BBoxAssigner(object): ...@@ -101,7 +106,9 @@ class BBoxAssigner(object):
bg_thresh (float): Maximum overlap allowed between a RoI bg_thresh (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be and ground-truth box for the (roi, gt box) pair to be
a background sample. default 0.5 a background sample. default 0.5
use_random (bool): Use random sampling to choose foreground and ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and
background boxes, default true background boxes, default true
cascade_iou (list[iou]): The list of overlap to select foreground and cascade_iou (list[iou]): The list of overlap to select foreground and
background of each stage, which is only used In Cascade RCNN. background of each stage, which is only used In Cascade RCNN.
...@@ -113,6 +120,7 @@ class BBoxAssigner(object): ...@@ -113,6 +120,7 @@ class BBoxAssigner(object):
fg_fraction=.25, fg_fraction=.25,
fg_thresh=.5, fg_thresh=.5,
bg_thresh=.5, bg_thresh=.5,
ignore_thresh=-1.,
use_random=True, use_random=True,
cascade_iou=[0.5, 0.6, 0.7], cascade_iou=[0.5, 0.6, 0.7],
num_classes=80): num_classes=80):
...@@ -121,6 +129,7 @@ class BBoxAssigner(object): ...@@ -121,6 +129,7 @@ class BBoxAssigner(object):
self.fg_fraction = fg_fraction self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh self.fg_thresh = fg_thresh
self.bg_thresh = bg_thresh self.bg_thresh = bg_thresh
self.ignore_thresh = ignore_thresh
self.use_random = use_random self.use_random = use_random
self.cascade_iou = cascade_iou self.cascade_iou = cascade_iou
self.num_classes = num_classes self.num_classes = num_classes
...@@ -133,12 +142,14 @@ class BBoxAssigner(object): ...@@ -133,12 +142,14 @@ class BBoxAssigner(object):
is_cascade=False): is_cascade=False):
gt_classes = inputs['gt_class'] gt_classes = inputs['gt_class']
gt_boxes = inputs['gt_bbox'] gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds # rois, tgt_labels, tgt_bboxes, tgt_gt_inds
# new_rois_num # new_rois_num
outs = generate_proposal_target( outs = generate_proposal_target(
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im, rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes, self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
self.use_random, is_cascade, self.cascade_iou[stage]) self.ignore_thresh, is_crowd, self.use_random, is_cascade,
self.cascade_iou[stage])
rois = outs[0] rois = outs[0]
rois_num = outs[-1] rois_num = outs[-1]
# tgt_labels, tgt_bboxes, tgt_gt_inds # tgt_labels, tgt_bboxes, tgt_gt_inds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册