未验证 提交 571dcb34 编写于 作者: W wangguanzhong 提交者: GitHub

support no label training (#2576)

* support no label training

* fix load_crowd

* update to fit rbox

* add empty_ratio

* update configs

* support mask rcnn

* clean drop_empty for mot & keypoint

* refine is_crowd
上级 6604bc45
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -34,7 +34,6 @@ EvalReader: ...@@ -34,7 +34,6 @@ EvalReader:
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false} - NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -119,7 +119,6 @@ EvalReader: ...@@ -119,7 +119,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -120,7 +120,6 @@ EvalReader: ...@@ -120,7 +120,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -119,7 +119,6 @@ EvalReader: ...@@ -119,7 +119,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -128,7 +128,6 @@ EvalReader: ...@@ -128,7 +128,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 16 batch_size: 16
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -129,7 +129,6 @@ EvalReader: ...@@ -129,7 +129,6 @@ EvalReader:
is_scale: true is_scale: true
- Permute: {} - Permute: {}
batch_size: 16 batch_size: 16
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
......
...@@ -25,7 +25,6 @@ EvalReader: ...@@ -25,7 +25,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
...@@ -40,4 +39,3 @@ TestReader: ...@@ -40,4 +39,3 @@ TestReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
...@@ -45,7 +45,6 @@ EvalReader: ...@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {} - BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {} - Norm2PixelBbox: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -45,7 +45,6 @@ EvalReader: ...@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {} - BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {} - Norm2PixelBbox: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -45,7 +45,6 @@ EvalReader: ...@@ -45,7 +45,6 @@ EvalReader:
- BboxCXCYWH2XYXY: {} - BboxCXCYWH2XYXY: {}
- Norm2PixelBbox: {} - Norm2PixelBbox: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -30,7 +30,6 @@ EvalReader: ...@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -30,7 +30,6 @@ EvalReader: ...@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -30,7 +30,6 @@ EvalReader: ...@@ -30,7 +30,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -47,7 +47,6 @@ EvalReader: ...@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -47,7 +47,6 @@ EvalReader: ...@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -47,7 +47,6 @@ EvalReader: ...@@ -47,7 +47,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 8 batch_size: 8
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -27,7 +27,6 @@ EvalReader: ...@@ -27,7 +27,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -28,7 +28,6 @@ EvalReader: ...@@ -28,7 +28,6 @@ EvalReader:
batch_size: 1 batch_size: 1
shuffle: false shuffle: false
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false} - NormalizeImage: {mean: [127.5, 127.5, 127.5], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -29,7 +29,6 @@ EvalReader: ...@@ -29,7 +29,6 @@ EvalReader:
- NormalizeImage: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false} - NormalizeImage: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -26,7 +26,6 @@ EvalReader: ...@@ -26,7 +26,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
......
...@@ -27,7 +27,6 @@ EvalReader: ...@@ -27,7 +27,6 @@ EvalReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
...@@ -37,4 +36,3 @@ TestReader: ...@@ -37,4 +36,3 @@ TestReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
...@@ -27,7 +27,6 @@ EvalReader: ...@@ -27,7 +27,6 @@ EvalReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
...@@ -37,4 +36,3 @@ TestReader: ...@@ -37,4 +36,3 @@ TestReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
...@@ -22,7 +22,6 @@ EvalReader: ...@@ -22,7 +22,6 @@ EvalReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
TestReader: TestReader:
sample_transforms: sample_transforms:
...@@ -32,4 +31,3 @@ TestReader: ...@@ -32,4 +31,3 @@ TestReader:
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_last: false drop_last: false
drop_empty: false
...@@ -32,7 +32,6 @@ EvalReader: ...@@ -32,7 +32,6 @@ EvalReader:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- Permute: {} - Permute: {}
batch_size: 1 batch_size: 1
drop_empty: false
TestReader: TestReader:
inputs_def: inputs_def:
......
...@@ -116,8 +116,6 @@ class BaseDataLoader(object): ...@@ -116,8 +116,6 @@ class BaseDataLoader(object):
shuffle (bool): whether to shuffle samples shuffle (bool): whether to shuffle samples
drop_last (bool): whether to drop the last incomplete, drop_last (bool): whether to drop the last incomplete,
default False default False
drop_empty (bool): whether to drop samples with no ground
truth labels, default True
num_classes (int): class number of dataset, default 80 num_classes (int): class number of dataset, default 80
collate_batch (bool): whether to collate batch in dataloader. collate_batch (bool): whether to collate batch in dataloader.
If set to True, the samples will collate into batch according If set to True, the samples will collate into batch according
...@@ -140,7 +138,6 @@ class BaseDataLoader(object): ...@@ -140,7 +138,6 @@ class BaseDataLoader(object):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=80, num_classes=80,
collate_batch=True, collate_batch=True,
use_shared_memory=False, use_shared_memory=False,
...@@ -231,13 +228,12 @@ class TrainReader(BaseDataLoader): ...@@ -231,13 +228,12 @@ class TrainReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
drop_empty=True,
num_classes=80, num_classes=80,
collate_batch=True, collate_batch=True,
**kwargs): **kwargs):
super(TrainReader, self).__init__( super(TrainReader, self).__init__(sample_transforms, batch_transforms,
sample_transforms, batch_transforms, batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, collate_batch, **kwargs) num_classes, collate_batch, **kwargs)
@register @register
...@@ -250,12 +246,11 @@ class EvalReader(BaseDataLoader): ...@@ -250,12 +246,11 @@ class EvalReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
drop_empty=True,
num_classes=80, num_classes=80,
**kwargs): **kwargs):
super(EvalReader, self).__init__(sample_transforms, batch_transforms, super(EvalReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
@register @register
...@@ -268,12 +263,11 @@ class TestReader(BaseDataLoader): ...@@ -268,12 +263,11 @@ class TestReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=80, num_classes=80,
**kwargs): **kwargs):
super(TestReader, self).__init__(sample_transforms, batch_transforms, super(TestReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
@register @register
...@@ -286,12 +280,11 @@ class EvalMOTReader(BaseDataLoader): ...@@ -286,12 +280,11 @@ class EvalMOTReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=1, num_classes=1,
**kwargs): **kwargs):
super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms, super(EvalMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
@register @register
...@@ -304,9 +297,8 @@ class TestMOTReader(BaseDataLoader): ...@@ -304,9 +297,8 @@ class TestMOTReader(BaseDataLoader):
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
drop_empty=True,
num_classes=1, num_classes=1,
**kwargs): **kwargs):
super(TestMOTReader, self).__init__(sample_transforms, batch_transforms, super(TestMOTReader, self).__init__(sample_transforms, batch_transforms,
batch_size, shuffle, drop_last, batch_size, shuffle, drop_last,
drop_empty, num_classes, **kwargs) num_classes, **kwargs)
...@@ -33,6 +33,12 @@ class COCODataSet(DetDataset): ...@@ -33,6 +33,12 @@ class COCODataSet(DetDataset):
anno_path (str): coco annotation file path. anno_path (str): coco annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'. data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all. sample_num (int): number of samples to load, -1 means all.
load_crowd (bool): whether to load crowded ground-truth.
False as default
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records. 1. as default
""" """
def __init__(self, def __init__(self,
...@@ -40,11 +46,26 @@ class COCODataSet(DetDataset): ...@@ -40,11 +46,26 @@ class COCODataSet(DetDataset):
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
data_fields=['image'], data_fields=['image'],
sample_num=-1): sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.):
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path, super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
data_fields, sample_num) data_fields, sample_num)
self.load_image_only = False self.load_image_only = False
self.load_semantic = False self.load_semantic = False
self.load_crowd = load_crowd
self.allow_empty = allow_empty
self.empty_ratio = empty_ratio
def _sample_empty(self, records, num):
# if empty_ratio is out of [0. ,1.), do not sample the records
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = int(num * self.empty_ratio / (1 - self.empty_ratio))
records = random.sample(records, sample_num)
return records
def parse_dataset(self): def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path) anno_path = os.path.join(self.dataset_dir, self.anno_path)
...@@ -58,6 +79,7 @@ class COCODataSet(DetDataset): ...@@ -58,6 +79,7 @@ class COCODataSet(DetDataset):
img_ids.sort() img_ids.sort()
cat_ids = coco.getCatIds() cat_ids = coco.getCatIds()
records = [] records = []
empty_records = []
ct = 0 ct = 0
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)}) self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
...@@ -79,6 +101,7 @@ class COCODataSet(DetDataset): ...@@ -79,6 +101,7 @@ class COCODataSet(DetDataset):
im_path = os.path.join(image_dir, im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname im_fname) if image_dir else im_fname
is_empty = False
if not os.path.exists(im_path): if not os.path.exists(im_path):
logger.warning('Illegal image file: {}, and it will be ' logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path)) 'ignored'.format(im_path))
...@@ -98,12 +121,16 @@ class COCODataSet(DetDataset): ...@@ -98,12 +121,16 @@ class COCODataSet(DetDataset):
} if 'image' in self.data_fields else {} } if 'image' in self.data_fields else {}
if not self.load_image_only: if not self.load_image_only:
ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False) ins_anno_ids = coco.getAnnIds(
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
instances = coco.loadAnns(ins_anno_ids) instances = coco.loadAnns(ins_anno_ids)
bboxes = [] bboxes = []
is_rbox_anno = False
for inst in instances: for inst in instances:
# check gt bbox # check gt bbox
if inst.get('ignore', False):
continue
if 'bbox' not in inst.keys(): if 'bbox' not in inst.keys():
continue continue
else: else:
...@@ -137,8 +164,10 @@ class COCODataSet(DetDataset): ...@@ -137,8 +164,10 @@ class COCODataSet(DetDataset):
img_id, float(inst['area']), x1, y1, x2, y2)) img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes) num_bbox = len(bboxes)
if num_bbox <= 0: if num_bbox <= 0 and not self.allow_empty:
continue continue
elif num_bbox <= 0:
is_empty = True
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
if is_rbox_anno: if is_rbox_anno:
...@@ -165,7 +194,8 @@ class COCODataSet(DetDataset): ...@@ -165,7 +194,8 @@ class COCODataSet(DetDataset):
gt_poly[i] = box['segmentation'] gt_poly[i] = box['segmentation']
has_segmentation = True has_segmentation = True
if has_segmentation and not any(gt_poly): if has_segmentation and not any(
gt_poly) and not self.allow_empty:
continue continue
if is_rbox_anno: if is_rbox_anno:
...@@ -196,10 +226,16 @@ class COCODataSet(DetDataset): ...@@ -196,10 +226,16 @@ class COCODataSet(DetDataset):
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format( logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_path, img_id, im_h, im_w)) im_path, img_id, im_h, im_w))
if is_empty:
empty_records.append(coco_rec)
else:
records.append(coco_rec) records.append(coco_rec)
ct += 1 ct += 1
if self.sample_num > 0 and ct >= self.sample_num: if self.sample_num > 0 and ct >= self.sample_num:
break break
assert len(records) > 0, 'not found any coco record in %s' % (anno_path) assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
if len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records self.roidbs = records
...@@ -124,6 +124,10 @@ def bbox_overlaps(boxes1, boxes2): ...@@ -124,6 +124,10 @@ def bbox_overlaps(boxes1, boxes2):
Return: Return:
overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N] overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N]
""" """
M = boxes1.shape[0]
N = boxes2.shape[0]
if M * N == 0:
return paddle.zeros([M, N], dtype='float32')
area1 = bbox_area(boxes1) area1 = bbox_area(boxes1)
area2 = bbox_area(boxes2) area2 = bbox_area(boxes2)
......
...@@ -265,14 +265,26 @@ class BBoxHead(nn.Layer): ...@@ -265,14 +265,26 @@ class BBoxHead(nn.Layer):
targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
rois (List[Tensor]): RoIs generated in each batch rois (List[Tensor]): RoIs generated in each batch
""" """
cls_name = 'loss_bbox_cls'
reg_name = 'loss_bbox_reg'
loss_bbox = {}
# TODO: better pass args # TODO: better pass args
tgt_labels, tgt_bboxes, tgt_gt_inds = targets tgt_labels, tgt_bboxes, tgt_gt_inds = targets
# bbox cls
tgt_labels = paddle.concat(tgt_labels) if len( tgt_labels = paddle.concat(tgt_labels) if len(
tgt_labels) > 1 else tgt_labels[0] tgt_labels) > 1 else tgt_labels[0]
valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
if valid_inds.shape[0] == 0:
loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
else:
tgt_labels = tgt_labels.cast('int64') tgt_labels = tgt_labels.cast('int64')
tgt_labels.stop_gradient = True tgt_labels.stop_gradient = True
loss_bbox_cls = F.cross_entropy( loss_bbox_cls = F.cross_entropy(
input=scores, label=tgt_labels, reduction='mean') input=scores, label=tgt_labels, reduction='mean')
loss_bbox[cls_name] = loss_bbox_cls
# bbox reg # bbox reg
cls_agnostic_bbox_reg = deltas.shape[1] == 4 cls_agnostic_bbox_reg = deltas.shape[1] == 4
...@@ -281,14 +293,9 @@ class BBoxHead(nn.Layer): ...@@ -281,14 +293,9 @@ class BBoxHead(nn.Layer):
paddle.logical_and(tgt_labels >= 0, tgt_labels < paddle.logical_and(tgt_labels >= 0, tgt_labels <
self.num_classes)).flatten() self.num_classes)).flatten()
cls_name = 'loss_bbox_cls'
reg_name = 'loss_bbox_reg'
loss_bbox = {}
loss_weight = 1.
if fg_inds.numel() == 0: if fg_inds.numel() == 0:
fg_inds = paddle.zeros([1], dtype='int32') loss_bbox[reg_name] = paddle.zeros([1], dtype='float32')
loss_weight = 0. return loss_bbox
if cls_agnostic_bbox_reg: if cls_agnostic_bbox_reg:
reg_delta = paddle.gather(deltas, fg_inds) reg_delta = paddle.gather(deltas, fg_inds)
...@@ -323,8 +330,7 @@ class BBoxHead(nn.Layer): ...@@ -323,8 +330,7 @@ class BBoxHead(nn.Layer):
loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum( loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
) / tgt_labels.shape[0] ) / tgt_labels.shape[0]
loss_bbox[cls_name] = loss_bbox_cls * loss_weight loss_bbox[reg_name] = loss_bbox_reg
loss_bbox[reg_name] = loss_bbox_reg * loss_weight
return loss_bbox return loss_bbox
......
...@@ -238,6 +238,9 @@ class RPNHead(nn.Layer): ...@@ -238,6 +238,9 @@ class RPNHead(nn.Layer):
valid_ind = paddle.nonzero(valid_mask) valid_ind = paddle.nonzero(valid_mask)
# cls loss # cls loss
if valid_ind.shape[0] == 0:
loss_rpn_cls = paddle.zeros([1], dtype='float32')
else:
score_pred = paddle.gather(scores, valid_ind) score_pred = paddle.gather(scores, valid_ind)
score_label = paddle.gather(score_tgt, valid_ind).cast('float32') score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
score_label.stop_gradient = True score_label.stop_gradient = True
...@@ -245,6 +248,9 @@ class RPNHead(nn.Layer): ...@@ -245,6 +248,9 @@ class RPNHead(nn.Layer):
logit=score_pred, label=score_label, reduction="sum") logit=score_pred, label=score_label, reduction="sum")
# reg loss # reg loss
if pos_ind.shape[0] == 0:
loss_rpn_reg = paddle.zeros([1], dtype='float32')
else:
loc_pred = paddle.gather(deltas, pos_ind) loc_pred = paddle.gather(deltas, pos_ind)
loc_tgt = paddle.concat(loc_tgt) loc_tgt = paddle.concat(loc_tgt)
loc_tgt = paddle.gather(loc_tgt, pos_ind) loc_tgt = paddle.gather(loc_tgt, pos_ind)
......
...@@ -28,31 +28,38 @@ def rpn_anchor_target(anchors, ...@@ -28,31 +28,38 @@ def rpn_anchor_target(anchors,
rpn_fg_fraction, rpn_fg_fraction,
use_random=True, use_random=True,
batch_size=1, batch_size=1,
ignore_thresh=-1,
is_crowd=None,
weights=[1., 1., 1., 1.]): weights=[1., 1., 1., 1.]):
tgt_labels = [] tgt_labels = []
tgt_bboxes = [] tgt_bboxes = []
tgt_deltas = [] tgt_deltas = []
for i in range(batch_size): for i in range(batch_size):
gt_bbox = gt_boxes[i] gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
# Step1: match anchor and gt_bbox # Step1: match anchor and gt_bbox
matches, match_labels = label_box( matches, match_labels = label_box(
anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True) anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
ignore_thresh, is_crowd_i)
# Step2: sample anchor # Step2: sample anchor
fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im, fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im,
rpn_fg_fraction, 0, use_random) rpn_fg_fraction, 0, use_random)
# Fill with the ignore label (-1), then set positive and negative labels # Fill with the ignore label (-1), then set positive and negative labels
labels = paddle.full(match_labels.shape, -1, dtype='int32') labels = paddle.full(match_labels.shape, -1, dtype='int32')
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds)) if bg_inds.shape[0] > 0:
labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds)) labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds))
if fg_inds.shape[0] > 0:
labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
# Step3: make output # Step3: make output
if gt_bbox.shape[0] == 0:
matched_gt_boxes = paddle.zeros([0, 4])
tgt_delta = paddle.zeros([0, 4])
else:
matched_gt_boxes = paddle.gather(gt_bbox, matches) matched_gt_boxes = paddle.gather(gt_bbox, matches)
tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights) tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
labels.stop_gradient = True
matched_gt_boxes.stop_gradient = True matched_gt_boxes.stop_gradient = True
tgt_delta.stop_gradient = True tgt_delta.stop_gradient = True
labels.stop_gradient = True
tgt_labels.append(labels) tgt_labels.append(labels)
tgt_bboxes.append(matched_gt_boxes) tgt_bboxes.append(matched_gt_boxes)
tgt_deltas.append(tgt_delta) tgt_deltas.append(tgt_delta)
...@@ -60,16 +67,46 @@ def rpn_anchor_target(anchors, ...@@ -60,16 +67,46 @@ def rpn_anchor_target(anchors,
return tgt_labels, tgt_bboxes, tgt_deltas return tgt_labels, tgt_bboxes, tgt_deltas
def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, def label_box(anchors,
allow_low_quality): gt_boxes,
positive_overlap,
negative_overlap,
allow_low_quality,
ignore_thresh,
is_crowd=None):
iou = bbox_overlaps(gt_boxes, anchors) iou = bbox_overlaps(gt_boxes, anchors)
if iou.numel() == 0: n_gt = gt_boxes.shape[0]
if n_gt == 0 or is_crowd is None:
n_gt_crowd = 0
else:
n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
if iou.shape[0] == 0 or n_gt_crowd == n_gt:
# No truth, assign everything to background
default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64') default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
default_match_labels = paddle.full((iou.shape[1], ), -1, dtype='int32') default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
return default_matches, default_match_labels return default_matches, default_match_labels
# if ignore_thresh > 0, remove anchor if it is closed to
# one of the crowded ground-truth
if n_gt_crowd > 0:
N_a = anchors.shape[0]
ones = paddle.ones([N_a])
mask = is_crowd * ones
if ignore_thresh > 0:
crowd_iou = iou * mask
valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
axis=0) > 0).cast('float32')
iou = iou * (1 - valid) - valid
# ignore the iou between anchor and crowded ground-truth
iou = iou * (1 - mask) - mask
matched_vals, matches = paddle.topk(iou, k=1, axis=0) matched_vals, matches = paddle.topk(iou, k=1, axis=0)
match_labels = paddle.full(matches.shape, -1, dtype='int32') match_labels = paddle.full(matches.shape, -1, dtype='int32')
match_labels = paddle.where(matched_vals < negative_overlap, # set ignored anchor with iou = -1
neg_cond = paddle.logical_and(matched_vals > -1,
matched_vals < negative_overlap)
match_labels = paddle.where(neg_cond,
paddle.zeros_like(match_labels), match_labels) paddle.zeros_like(match_labels), match_labels)
match_labels = paddle.where(matched_vals >= positive_overlap, match_labels = paddle.where(matched_vals >= positive_overlap,
paddle.ones_like(match_labels), match_labels) paddle.ones_like(match_labels), match_labels)
...@@ -84,6 +121,7 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap, ...@@ -84,6 +121,7 @@ def label_box(anchors, gt_boxes, positive_overlap, negative_overlap,
matches = matches.flatten() matches = matches.flatten()
match_labels = match_labels.flatten() match_labels = match_labels.flatten()
return matches, match_labels return matches, match_labels
...@@ -96,24 +134,36 @@ def subsample_labels(labels, ...@@ -96,24 +134,36 @@ def subsample_labels(labels,
paddle.logical_and(labels != -1, labels != bg_label)) paddle.logical_and(labels != -1, labels != bg_label))
negative = paddle.nonzero(labels == bg_label) negative = paddle.nonzero(labels == bg_label)
positive = positive.cast('int32').flatten()
negative = negative.cast('int32').flatten()
fg_num = int(num_samples * fg_fraction) fg_num = int(num_samples * fg_fraction)
fg_num = min(positive.numel(), fg_num) fg_num = min(positive.numel(), fg_num)
bg_num = num_samples - fg_num bg_num = num_samples - fg_num
bg_num = min(negative.numel(), bg_num) bg_num = min(negative.numel(), bg_num)
if fg_num == 0 and bg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
bg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
# randomly select positive and negative examples # randomly select positive and negative examples
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num]) negative = negative.cast('int32').flatten()
bg_perm = paddle.randperm(negative.numel(), dtype='int32') bg_perm = paddle.randperm(negative.numel(), dtype='int32')
bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num]) bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
if use_random: if use_random:
fg_inds = paddle.gather(positive, fg_perm)
bg_inds = paddle.gather(negative, bg_perm) bg_inds = paddle.gather(negative, bg_perm)
else: else:
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num]) bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
if fg_num == 0:
fg_inds = paddle.zeros([0], dtype='int32')
return fg_inds, bg_inds
positive = positive.cast('int32').flatten()
fg_perm = paddle.randperm(positive.numel(), dtype='int32')
fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
if use_random:
fg_inds = paddle.gather(positive, fg_perm)
else:
fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
return fg_inds, bg_inds return fg_inds, bg_inds
...@@ -125,6 +175,8 @@ def generate_proposal_target(rpn_rois, ...@@ -125,6 +175,8 @@ def generate_proposal_target(rpn_rois,
fg_thresh, fg_thresh,
bg_thresh, bg_thresh,
num_classes, num_classes,
ignore_thresh=-1.,
is_crowd=None,
use_random=True, use_random=True,
is_cascade=False, is_cascade=False,
cascade_iou=0.5): cascade_iou=0.5):
...@@ -141,17 +193,18 @@ def generate_proposal_target(rpn_rois, ...@@ -141,17 +193,18 @@ def generate_proposal_target(rpn_rois,
bg_thresh = cascade_iou if is_cascade else bg_thresh bg_thresh = cascade_iou if is_cascade else bg_thresh
for i, rpn_roi in enumerate(rpn_rois): for i, rpn_roi in enumerate(rpn_rois):
gt_bbox = gt_boxes[i] gt_bbox = gt_boxes[i]
is_crowd_i = is_crowd[i] if is_crowd else None
gt_class = paddle.squeeze(gt_classes[i], axis=-1) gt_class = paddle.squeeze(gt_classes[i], axis=-1)
# Concat RoIs and gt boxes except cascade rcnn # Concat RoIs and gt boxes except cascade rcnn or none gt
if not is_cascade: if not is_cascade and gt_bbox.shape[0] > 0:
bbox = paddle.concat([rpn_roi, gt_bbox]) bbox = paddle.concat([rpn_roi, gt_bbox])
else: else:
bbox = rpn_roi bbox = rpn_roi
# Step1: label bbox # Step1: label bbox
matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh, matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
False) False, ignore_thresh, is_crowd_i)
# Step2: sample bbox # Step2: sample bbox
sampled_inds, sampled_gt_classes = sample_bbox( sampled_inds, sampled_gt_classes = sample_bbox(
matches, match_labels, gt_class, batch_size_per_im, fg_fraction, matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
...@@ -162,7 +215,10 @@ def generate_proposal_target(rpn_rois, ...@@ -162,7 +215,10 @@ def generate_proposal_target(rpn_rois,
sampled_inds) sampled_inds)
sampled_gt_ind = matches if is_cascade else paddle.gather(matches, sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
sampled_inds) sampled_inds)
if gt_bbox.shape[0] > 0:
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind) sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
else:
sampled_bbox = paddle.zeros([0, 4], dtype='float32')
rois_per_image.stop_gradient = True rois_per_image.stop_gradient = True
sampled_gt_ind.stop_gradient = True sampled_gt_ind.stop_gradient = True
...@@ -184,6 +240,13 @@ def sample_bbox(matches, ...@@ -184,6 +240,13 @@ def sample_bbox(matches,
num_classes, num_classes,
use_random=True, use_random=True,
is_cascade=False): is_cascade=False):
n_gt = gt_classes.shape[0]
if n_gt == 0:
# No truth, assign everything to background
gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
#return matches, match_labels + num_classes
else:
gt_classes = paddle.gather(gt_classes, matches) gt_classes = paddle.gather(gt_classes, matches)
gt_classes = paddle.where(match_labels == 0, gt_classes = paddle.where(match_labels == 0,
paddle.ones_like(gt_classes) * num_classes, paddle.ones_like(gt_classes) * num_classes,
...@@ -191,11 +254,17 @@ def sample_bbox(matches, ...@@ -191,11 +254,17 @@ def sample_bbox(matches,
gt_classes = paddle.where(match_labels == -1, gt_classes = paddle.where(match_labels == -1,
paddle.ones_like(gt_classes) * -1, gt_classes) paddle.ones_like(gt_classes) * -1, gt_classes)
if is_cascade: if is_cascade:
return matches, gt_classes index = paddle.arange(matches.shape[0])
return index, gt_classes
rois_per_image = int(batch_size_per_im) rois_per_image = int(batch_size_per_im)
fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction, fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction,
num_classes, use_random) num_classes, use_random)
if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
# fake output labeled with -1 when all boxes are neither
# foreground nor background
sampled_inds = paddle.zeros([1], dtype='int32')
else:
sampled_inds = paddle.concat([fg_inds, bg_inds]) sampled_inds = paddle.concat([fg_inds, bg_inds])
sampled_gt_classes = paddle.gather(gt_classes, sampled_inds) sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
return sampled_inds, sampled_gt_classes return sampled_inds, sampled_gt_classes
...@@ -268,20 +337,29 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds, ...@@ -268,20 +337,29 @@ def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
# to generate mask target with ground-truth # to generate mask target with ground-truth
boxes = fg_rois.numpy() boxes = fg_rois.numpy()
gt_segms_per_im = gt_segms[k] gt_segms_per_im = gt_segms[k]
new_segm = [] new_segm = []
inds_per_im = inds_per_im.numpy() inds_per_im = inds_per_im.numpy()
if len(gt_segms_per_im) > 0:
for i in inds_per_im: for i in inds_per_im:
new_segm.append(gt_segms_per_im[i]) new_segm.append(gt_segms_per_im[i])
fg_inds_new = fg_inds.reshape([-1]).numpy() fg_inds_new = fg_inds.reshape([-1]).numpy()
results = [] results = []
if len(gt_segms_per_im) > 0:
for j in fg_inds_new: for j in fg_inds_new:
results.append( results.append(
rasterize_polygons_within_box(new_segm[j], boxes[j], rasterize_polygons_within_box(new_segm[j], boxes[j],
resolution)) resolution))
else:
results.append(paddle.ones([resolution, resolution], dtype='int32'))
fg_classes = paddle.gather(labels_per_im, fg_inds) fg_classes = paddle.gather(labels_per_im, fg_inds)
weight = paddle.ones([fg_rois.shape[0]], dtype='float32') weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
if not has_fg: if not has_fg:
# now all sampled classes are background
# which will cause error in loss calculation,
# make fake classes with weight of 0.
fg_classes = paddle.zeros([1], dtype='int32')
weight = weight - 1 weight = weight - 1
tgt_mask = paddle.stack(results) tgt_mask = paddle.stack(results)
tgt_mask.stop_gradient = True tgt_mask.stop_gradient = True
......
...@@ -44,6 +44,8 @@ class RPNTargetAssign(object): ...@@ -44,6 +44,8 @@ class RPNTargetAssign(object):
negative_overlap (float): Maximum overlap allowed between an anchor negative_overlap (float): Maximum overlap allowed between an anchor
and ground-truth box for the (anchor, gt box) pair to be and ground-truth box for the (anchor, gt box) pair to be
a background sample. default 0.3 a background sample. default 0.3
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and use_random (bool): Use random sampling to choose foreground and
background boxes, default true. background boxes, default true.
""" """
...@@ -53,12 +55,14 @@ class RPNTargetAssign(object): ...@@ -53,12 +55,14 @@ class RPNTargetAssign(object):
fg_fraction=0.5, fg_fraction=0.5,
positive_overlap=0.7, positive_overlap=0.7,
negative_overlap=0.3, negative_overlap=0.3,
ignore_thresh=-1.,
use_random=True): use_random=True):
super(RPNTargetAssign, self).__init__() super(RPNTargetAssign, self).__init__()
self.batch_size_per_im = batch_size_per_im self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction self.fg_fraction = fg_fraction
self.positive_overlap = positive_overlap self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap self.negative_overlap = negative_overlap
self.ignore_thresh = ignore_thresh
self.use_random = use_random self.use_random = use_random
def __call__(self, inputs, anchors): def __call__(self, inputs, anchors):
...@@ -67,11 +71,12 @@ class RPNTargetAssign(object): ...@@ -67,11 +71,12 @@ class RPNTargetAssign(object):
anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps. anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps.
""" """
gt_boxes = inputs['gt_bbox'] gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
batch_size = len(gt_boxes) batch_size = len(gt_boxes)
tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target( tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap, anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap,
self.negative_overlap, self.fg_fraction, self.use_random, self.negative_overlap, self.fg_fraction, self.use_random,
batch_size) batch_size, self.ignore_thresh, is_crowd)
norm = self.batch_size_per_im * batch_size norm = self.batch_size_per_im * batch_size
return tgt_labels, tgt_bboxes, tgt_deltas, norm return tgt_labels, tgt_bboxes, tgt_deltas, norm
...@@ -101,6 +106,8 @@ class BBoxAssigner(object): ...@@ -101,6 +106,8 @@ class BBoxAssigner(object):
bg_thresh (float): Maximum overlap allowed between a RoI bg_thresh (float): Maximum overlap allowed between a RoI
and ground-truth box for the (roi, gt box) pair to be and ground-truth box for the (roi, gt box) pair to be
a background sample. default 0.5 a background sample. default 0.5
ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
if the value is larger than zero.
use_random (bool): Use random sampling to choose foreground and use_random (bool): Use random sampling to choose foreground and
background boxes, default true background boxes, default true
cascade_iou (list[iou]): The list of overlap to select foreground and cascade_iou (list[iou]): The list of overlap to select foreground and
...@@ -113,6 +120,7 @@ class BBoxAssigner(object): ...@@ -113,6 +120,7 @@ class BBoxAssigner(object):
fg_fraction=.25, fg_fraction=.25,
fg_thresh=.5, fg_thresh=.5,
bg_thresh=.5, bg_thresh=.5,
ignore_thresh=-1.,
use_random=True, use_random=True,
cascade_iou=[0.5, 0.6, 0.7], cascade_iou=[0.5, 0.6, 0.7],
num_classes=80): num_classes=80):
...@@ -121,6 +129,7 @@ class BBoxAssigner(object): ...@@ -121,6 +129,7 @@ class BBoxAssigner(object):
self.fg_fraction = fg_fraction self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh self.fg_thresh = fg_thresh
self.bg_thresh = bg_thresh self.bg_thresh = bg_thresh
self.ignore_thresh = ignore_thresh
self.use_random = use_random self.use_random = use_random
self.cascade_iou = cascade_iou self.cascade_iou = cascade_iou
self.num_classes = num_classes self.num_classes = num_classes
...@@ -133,12 +142,14 @@ class BBoxAssigner(object): ...@@ -133,12 +142,14 @@ class BBoxAssigner(object):
is_cascade=False): is_cascade=False):
gt_classes = inputs['gt_class'] gt_classes = inputs['gt_class']
gt_boxes = inputs['gt_bbox'] gt_boxes = inputs['gt_bbox']
is_crowd = inputs.get('is_crowd', None)
# rois, tgt_labels, tgt_bboxes, tgt_gt_inds # rois, tgt_labels, tgt_bboxes, tgt_gt_inds
# new_rois_num # new_rois_num
outs = generate_proposal_target( outs = generate_proposal_target(
rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im, rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes, self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
self.use_random, is_cascade, self.cascade_iou[stage]) self.ignore_thresh, is_crowd, self.use_random, is_cascade,
self.cascade_iou[stage])
rois = outs[0] rois = outs[0]
rois_num = outs[-1] rois_num = outs[-1]
# tgt_labels, tgt_bboxes, tgt_gt_inds # tgt_labels, tgt_bboxes, tgt_gt_inds
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册