diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py index 02d219546d36c9fe1d0bc3cf81d9b41e108807f9..6919ece50cc9a89baf849704fe037265c7416fe4 100644 --- a/ppdet/data/transform/op_helper.py +++ b/ppdet/data/transform/op_helper.py @@ -61,7 +61,10 @@ def is_overlap(object_bbox, sample_bbox): return True -def filter_and_process(sample_bbox, bboxes, labels, scores=None, +def filter_and_process(sample_bbox, + bboxes, + labels, + scores=None, keypoints=None): new_bboxes = [] new_labels = [] @@ -92,8 +95,8 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None, for j in range(len(sample_keypoint)): kp_len = sample_height if j % 2 else sample_width sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0] - sample_keypoint[j] = ( - sample_keypoint[j] - sample_coord) / kp_len + sample_keypoint[j] = (sample_keypoint[j] - + sample_coord) / kp_len sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0) new_keypoints.append(sample_keypoint) new_kp_ignore.append(keypoints[1][i]) @@ -261,12 +264,12 @@ def jaccard_overlap(sample_bbox, object_bbox): intersect_ymin = max(sample_bbox[1], object_bbox[1]) intersect_xmax = min(sample_bbox[2], object_bbox[2]) intersect_ymax = min(sample_bbox[3], object_bbox[3]) - intersect_size = (intersect_xmax - intersect_xmin) * ( - intersect_ymax - intersect_ymin) + intersect_size = (intersect_xmax - intersect_xmin) * (intersect_ymax - + intersect_ymin) sample_bbox_size = bbox_area(sample_bbox) object_bbox_size = bbox_area(object_bbox) - overlap = intersect_size / ( - sample_bbox_size + object_bbox_size - intersect_size) + overlap = intersect_size / (sample_bbox_size + object_bbox_size - + intersect_size) return overlap @@ -276,8 +279,10 @@ def intersect_bbox(bbox1, bbox2): intersection_box = [0.0, 0.0, 0.0, 0.0] else: intersection_box = [ - max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]), - min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3]) + max(bbox1[0], bbox2[0]), + max(bbox1[1], bbox2[1]), + min(bbox1[2], bbox2[2]), + min(bbox1[3], bbox2[3]) ] return intersection_box @@ -401,8 +406,8 @@ def crop_image_sampling(img, sample_bbox, image_width, image_height, sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \ img[cross_y1: cross_y2, cross_x1: cross_x2] - sample_img = cv2.resize( - sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA) + sample_img = cv2.resize(sample_img, (target_size, target_size), + interpolation=cv2.INTER_AREA) return sample_img @@ -449,8 +454,8 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6): top, bottom = min(y, radius), min(height - y, radius + 1) masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] - masked_gaussian = gaussian[radius - top:radius + bottom, radius - left: - radius + right] + masked_gaussian = gaussian[radius - top:radius + bottom, + radius - left:radius + right] np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) @@ -458,7 +463,53 @@ def gaussian2D(shape, sigma_x=1, sigma_y=1): m, n = [(ss - 1.) / 2. for ss in shape] y, x = np.ogrid[-m:m + 1, -n:n + 1] - h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y * - sigma_y))) + h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / + (2 * sigma_y * sigma_y))) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h + + +def transform_bbox(bbox, + label, + M, + w, + h, + area_thr=0.25, + wh_thr=2, + ar_thr=20, + perspective=False): + # rotate bbox + n = len(bbox) + xy = np.ones((n * 4, 3), dtype=np.float32) + xy[:, :2] = bbox[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) + xy = xy @ M.T + if perspective: + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) + else: + xy = xy[:, :2].reshape(n, 8) + # get new bboxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + new_bbox = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # clip boxes + new_bbox, mask = clip_bbox(new_bbox, w, h, area_thr) + new_label = label[mask] + return new_bbox, new_label + + +def clip_bbox(bbox, w, h, area_thr=0.25, wh_thr=2, ar_thr=20): + # clip boxes + area1 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1) + bbox[:, [0, 2]] = bbox[:, [0, 2]].clip(0, w) + bbox[:, [1, 3]] = bbox[:, [1, 3]].clip(0, h) + # compute + area2 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1) + area_ratio = area2 / (area1 + 1e-16) + wh = bbox[:, 2:4] - bbox[:, 0:2] + ar_ratio = np.maximum(wh[:, 1] / (wh[:, 0] + 1e-16), + wh[:, 0] / (wh[:, 1] + 1e-16)) + mask = (area_ratio > area_thr) & ( + (wh > wh_thr).all(1)) & (ar_ratio < ar_thr) + bbox = bbox[mask] + return bbox, mask diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index f0eca28c374b3ae18a8d45bb8960ad3df33640b8..dcb0aac4180de86b44f5e4fd31070eff1a0954c1 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -44,7 +44,7 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process, generate_sample_bbox, clip_bbox, data_anchor_sampling, satisfy_sample_constraint_coverage, crop_image_sampling, generate_sample_bbox_square, bbox_area_sampling, - is_poly, gaussian_radius, draw_gaussian) + is_poly, gaussian_radius, draw_gaussian, transform_bbox, clip_bbox) logger = logging.getLogger(__name__) @@ -2555,3 +2555,389 @@ class DebugVisibleImage(BaseOperator): save_path = os.path.join(self.output_dir, out_file_name) image.save(save_path, quality=95) return sample + +@register_op +class Rotate(BaseOperator): + def __init__(self, + degree, + scale=1.0, + center=None, + area_thr=0.25, + border_value=(114, 114, 114)): + super(Rotate, self).__init__() + self.degree = degree + self.scale = scale + self.center = center + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # rotate image + height, width = im.shape[:2] + if self.center is None: + self.center = (width // 2, height // 2) + M = cv2.getRotationMatrix2D(self.center, self.degree, self.scale) + im = cv2.warpAffine(im, + M, (width, height), + borderValue=self.border_value) + + # rotate bbox + if bbox.shape[0] > 0: + new_bbox, new_label = transform_bbox(bbox, label, M, width, height, self.area_thr) + else: + new_bbox, new_label = bbox, label + sample['image'] = im + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomRotate(BaseOperator): + def __init__(self, + degree, + scale=0.0, + center=None, + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomRotate, self).__init__() + if isinstance(degree, (int, float)): + degree = abs(degree) + degree = (-degree, degree) + elif isinstance(degree, list) or isinstance(degree, tuple): + assert len(degree) == 2, 'len of degree is not equal to 2' + else: + raise ValueError('degree is not reasonable') + + self.degree = degree + self.scale = scale + self.center = center + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + degree = random.uniform(*self.degree) + scale = random.uniform(1 - self.scale, 1 + self.scale) + rotate = Rotate(degree, scale, self.center, self.area_thr, self.border_value) + return rotate(sample, context) + + +@register_op +class Shear(BaseOperator): + def __init__(self, shear, area_thr=0.25, border_value=(114, 114, 114)): + super(Shear, self).__init__() + if isinstance(shear, (int, float)): + shear = (shear, shear) + elif isinstance(shear, list) or isinstance(shear, tuple): + assert len(shear) == 2, 'len of shear is not equal to 2' + else: + raise ValueError('shear is not reasonable') + + self.shear = shear + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # shear image + height, width = im.shape[:2] + shear_x = math.tan(self.shear[0] * math.pi / 180) + shear_y = math.tan(self.shear[1] * math.pi / 180) + M = np.array([[1, shear_x, 0], [shear_y, 1, 0]]) + im = cv2.warpAffine(im, + M, (width, height), + borderValue=self.border_value) + + # shear box + if bbox.shape[0] > 0: + new_bbox, new_label = transform_bbox(bbox, label, M, width, height, self.area_thr) + else: + new_bbox, new_label = bbox, label + sample['image'] = im + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomShear(BaseOperator): + def __init__(self, + shear_x, + shear_y, + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomShear, self).__init__() + if isinstance(shear_x, (int, float)): + shear_x = abs(shear_x) + shear_x = (-shear_x, shear_x) + elif isinstance(shear_x, list) or isinstance(shear_x, tuple): + assert len(shear_x) == 2, 'len of shear_x is not equal to 2' + else: + raise ValueError('shear_x is not reasonable') + + if isinstance(shear_y, (int, float)): + shear_y = abs(shear_y) + shear_y = (-shear_y, shear_y) + elif isinstance(shear_y, list) or isinstance(shear_y, tuple): + assert len(shear_y) == 2, 'len of shear_y is not equal to 2' + else: + raise ValueError('shear_y is not reasonable') + + self.shear_x = shear_x + self.shear_y = shear_y + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + shear_x = random.uniform(*self.shear_x) + shear_y = random.uniform(*self.shear_y) + shear = Shear((shear_x, shear_y), self.area_thr, self.border_value) + return shear(sample, context) + + +@register_op +class Translate(BaseOperator): + def __init__(self, translate, area_thr=0.25, border_value=(114, 114, 114)): + super(Translate, self).__init__() + if isinstance(translate, (int, float)): + translate = (translate, translate) + elif isinstance(translate, list) or isinstance(translate, tuple): + assert len(translate) == 2, 'len of translate is not equal to 2' + else: + raise ValueError('translate is not reasonable') + + assert abs(translate[0]) < 1 and abs(translate[1]) < 1, 'translate should be in (-1, 1)' + + self.translate = translate + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # translate image + height, width = im.shape[:2] + translate_x = int(self.translate[0] * width) + translate_y = int(self.translate[1] * height) + + dst_cords = [ + max(0, translate_y), + max(0, translate_x), + min(height, translate_y + height), + min(width, translate_x + width) + ] + src_cords = [ + max(-translate_y, 0), + max(-translate_x, 0), + min(-translate_y + height, height), + min(-translate_x + width, width) + ] + canvas = np.ones(im.shape, dtype=np.uint8) * self.border_value + canvas[dst_cords[0]:dst_cords[2], + dst_cords[1]:dst_cords[3], :] = im[src_cords[0]:src_cords[2], + src_cords[1]:src_cords[3], :] + + if bbox.shape[0] > 0: + new_bbox = bbox + [translate_x, translate_y, translate_x, translate_y] + # compute + new_bbox, mask = clip_bbox(new_bbox, width, height, self.area_thr) + new_label = label[mask] + else: + new_bbox, new_label = bbox, label + sample['image'] = canvas.astype(np.uint8) + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomTranslate(BaseOperator): + def __init__(self, + translate_x, + translate_y, + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomTranslate, self).__init__() + if isinstance(translate_x, (int, float)): + translate_x = abs(translate_x) + translate_x = (-translate_x, translate_x) + elif isinstance(translate_x, list) or isinstance(translate_x, tuple): + assert len(translate_x) == 2, 'len of translate_x is not equal to 2' + else: + raise ValueError('translate_x is not reasonable') + + if isinstance(translate_y, (int, float)): + translate_y = abs(translate_y) + translate_y = (-translate_y, translate_y) + elif isinstance(translate_y, list) or isinstance(translate_y, tuple): + assert len(translate_y) == 2, 'len of translate_y is not equal to 2' + else: + raise ValueError('translate_y is not reasonable') + + self.translate_x = translate_x + self.translate_y = translate_y + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + translate_x = random.uniform(*self.translate_x) + translate_y = random.uniform(*self.translate_y) + translate = Translate((translate_x, translate_y), self.area_thr, self.border_value) + return translate(sample, context) + +@register_op +class Scale(BaseOperator): + def __init__(self, scale, area_thr=0.25, border_value=(114, 114, 114)): + super(Scale, self).__init__() + if isinstance(scale, (int, float)): + scale = (scale, scale) + elif isinstance(scale, list) or isinstance(scale, tuple): + assert len(scale) == 2, 'len of scale is not equal to 2' + else: + raise ValueError('scale is not reasonable') + + assert scale[0] > 0. and scale[1] > 0., 'scale should be great than 0' + + self.scale = scale + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # scale image + height, width = im.shape[:2] + dsize = (int(self.scale[0] * width), int(self.scale[1] * height)) + dst_img = cv2.resize(im, dsize) + canvas = np.ones_like(im, dtype=np.uint8) * self.border_value + y_lim = min(height, dsize[1]) + x_lim = min(width, dsize[0]) + canvas[:y_lim, :x_lim, :] = dst_img[:y_lim, :x_lim, :] + # scale bbox + if bbox.shape[0] > 0: + new_bbox = bbox * [self.scale[0], self.scale[1], self.scale[0], self.scale[1]] + new_bbox, mask = clip_bbox(new_bbox, width, height, self.area_thr) + new_label = label[mask] + else: + new_bbox, new_label = bbox, label + + sample['image'] = canvas.astype(np.uint8) + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomScale(BaseOperator): + def __init__(self, scale_x, scale_y, area_thr=0.25, border_value=(114, 114, 114)): + super(RandomScale, self).__init__() + if isinstance(scale_x, (int, float)): + assert scale_x > 0., 'scale_x should be great than 0' + scale_x = (0., scale_x) + elif isinstance(scale_x, list) or isinstance(scale_x, tuple): + assert len(scale_x) == 2, 'len of scale_x is not equal to 2' + else: + raise ValueError('scale_x is not reasonable') + + if isinstance(scale_y, (int, float)): + assert scale_y > 0., 'scale_y should be great than 0' + scale_y = (0., scale_y) + elif isinstance(scale_y, list) or isinstance(scale_y, tuple): + assert len(scale_y) == 2, 'len of scale_y is not equal to 2' + else: + raise ValueError('scale_y is not reasonable') + + self.scale_x = scale_x + self.scale_y = scale_y + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + scale_x = random.uniform(*self.scale_x) + scale_y = random.uniform(*self.scale_y) + scale = Scale((scale_x, scale_y), self.area_thr, self.border_value) + return scale(sample, context) + +@register_op +class RandomPerspective(BaseOperator): + def __init__(self, degree=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0), area_thr=0.25, border_value=(114,114,114)): + super(RandomPerspective, self).__init__() + self.degree = degree + self.translate = translate + self.scale = scale + self.shear = shear + self.perspective = perspective + self.border = border + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + height = im.shape[0] + self.border[0] + width = im.shape[1] + self.border[1] + + # center + C = np.eye(3) + C[0, 2] = -im.shape[1] / 2 + C[1, 2] = -im.shape[0] / 2 + + # perspective + P = np.eye(3) + P[2, 0] = random.uniform(-self.perspective, self.perspective) + P[2, 1] = random.uniform(-self.perspective, self.perspective) + + # Rotation and scale + R = np.eye(3) + a = random.uniform(-self.degree, self.degree) + s = random.uniform(1 - self.scale, 1 + self.scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3) + # shear x (deg) + S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) + # shear y (deg) + S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) + + # Translation + T = np.eye(3) + T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * width + T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * height + + # matmul + M = T @ S @ R @ P @ C + if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any(): + if self.perspective: + im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=self.border_value) + else: + im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=self.border_value) + + if bbox.shape[0] > 0: + new_bbox, new_label = transform_bbox(bbox, label, M, width, height, area_thr=self.area_thr, perspective=self.perspective) + else: + new_bbox, new_label = bbox, label + + sample['image'] = im + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + + + + +