未验证 提交 b2c75899 编写于 作者: F FlyingQianMM 提交者: GitHub

add is_mask_expand and is_mask_crop in RandomExpand and RamdomCrop (#465)

* add mask operations in RandomExpand and RandomCrop

* remove all [] in cropped polys

* is_poly -> _is_poly()

* change the judgment on whether there are polygons

* abstract is_poly to op_helper.py

* add is_mask_expand and is_mask_crop in RandomExpand and RamdomCrop
上级 063e9b20
...@@ -1064,10 +1064,16 @@ class MixupImage(BaseOperator): ...@@ -1064,10 +1064,16 @@ class MixupImage(BaseOperator):
gt_score2 = sample['mixup']['gt_score'] gt_score2 = sample['mixup']['gt_score']
gt_score = np.concatenate( gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
is_crowd1 = sample['is_crowd']
is_crowd2 = sample['mixup']['is_crowd']
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
sample['image'] = im sample['image'] = im
sample['gt_bbox'] = gt_bbox sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score sample['gt_score'] = gt_score
sample['gt_class'] = gt_class sample['gt_class'] = gt_class
sample['is_crowd'] = is_crowd
sample['h'] = im.shape[0] sample['h'] = im.shape[0]
sample['w'] = im.shape[1] sample['w'] = im.shape[1]
sample.pop('mixup') sample.pop('mixup')
...@@ -1298,9 +1304,14 @@ class RandomExpand(BaseOperator): ...@@ -1298,9 +1304,14 @@ class RandomExpand(BaseOperator):
ratio (float): maximum expansion ratio. ratio (float): maximum expansion ratio.
prob (float): probability to expand. prob (float): probability to expand.
fill_value (list): color value used to fill the canvas. in RGB order. fill_value (list): color value used to fill the canvas. in RGB order.
is_mask_expand(bool): whether expand the segmentation.
""" """
def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, ) * 3): def __init__(self,
ratio=4.,
prob=0.5,
fill_value=(127.5, ) * 3,
is_mask_expand=False):
super(RandomExpand, self).__init__() super(RandomExpand, self).__init__()
assert ratio > 1.01, "expand ratio must be larger than 1.01" assert ratio > 1.01, "expand ratio must be larger than 1.01"
self.ratio = ratio self.ratio = ratio
...@@ -1312,6 +1323,7 @@ class RandomExpand(BaseOperator): ...@@ -1312,6 +1323,7 @@ class RandomExpand(BaseOperator):
if not isinstance(fill_value, tuple): if not isinstance(fill_value, tuple):
fill_value = tuple(fill_value) fill_value = tuple(fill_value)
self.fill_value = fill_value self.fill_value = fill_value
self.is_mask_expand = is_mask_expand
def expand_segms(self, segms, x, y, height, width, ratio): def expand_segms(self, segms, x, y, height, width, ratio):
def _expand_poly(poly, x, y): def _expand_poly(poly, x, y):
...@@ -1369,7 +1381,8 @@ class RandomExpand(BaseOperator): ...@@ -1369,7 +1381,8 @@ class RandomExpand(BaseOperator):
sample['image'] = canvas sample['image'] = canvas
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32) sample['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0: if self.is_mask_expand and 'gt_poly' in sample and len(sample[
'gt_poly']) > 0:
sample['gt_poly'] = self.expand_segms(sample['gt_poly'], x, y, sample['gt_poly'] = self.expand_segms(sample['gt_poly'], x, y,
height, width, expand_ratio) height, width, expand_ratio)
return sample return sample
...@@ -1388,6 +1401,7 @@ class RandomCrop(BaseOperator): ...@@ -1388,6 +1401,7 @@ class RandomCrop(BaseOperator):
num_attempts (int): number of tries before giving up. num_attempts (int): number of tries before giving up.
allow_no_crop (bool): allow return without actually cropping them. allow_no_crop (bool): allow return without actually cropping them.
cover_all_box (bool): ensure all bboxes are covered in the final crop. cover_all_box (bool): ensure all bboxes are covered in the final crop.
is_mask_crop(bool): whether crop the segmentation.
""" """
def __init__(self, def __init__(self,
...@@ -1396,7 +1410,8 @@ class RandomCrop(BaseOperator): ...@@ -1396,7 +1410,8 @@ class RandomCrop(BaseOperator):
scaling=[.3, 1.], scaling=[.3, 1.],
num_attempts=50, num_attempts=50,
allow_no_crop=True, allow_no_crop=True,
cover_all_box=False): cover_all_box=False,
is_mask_crop=False):
super(RandomCrop, self).__init__() super(RandomCrop, self).__init__()
self.aspect_ratio = aspect_ratio self.aspect_ratio = aspect_ratio
self.thresholds = thresholds self.thresholds = thresholds
...@@ -1404,6 +1419,7 @@ class RandomCrop(BaseOperator): ...@@ -1404,6 +1419,7 @@ class RandomCrop(BaseOperator):
self.num_attempts = num_attempts self.num_attempts = num_attempts
self.allow_no_crop = allow_no_crop self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box self.cover_all_box = cover_all_box
self.is_mask_crop = is_mask_crop
def crop_segms(self, segms, valid_ids, crop, height, width): def crop_segms(self, segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop): def _crop_poly(segm, crop):
...@@ -1527,7 +1543,8 @@ class RandomCrop(BaseOperator): ...@@ -1527,7 +1543,8 @@ class RandomCrop(BaseOperator):
break break
if found: if found:
if 'gt_poly' in sample and len(sample['gt_poly']) > 0: if self.is_mask_crop and 'gt_poly' in sample and len(sample[
'gt_poly']) > 0:
crop_polys = self.crop_segms( crop_polys = self.crop_segms(
sample['gt_poly'], sample['gt_poly'],
valid_ids, valid_ids,
......
...@@ -5,3 +5,4 @@ tb-paddle ...@@ -5,3 +5,4 @@ tb-paddle
tensorboard >= 1.15 tensorboard >= 1.15
cython cython
pycocotools pycocotools
shapely
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册