diff --git a/configs/yolov5/yolov5_reader.yml b/configs/yolov5/yolov5_reader.yml index ab0d6158401d8a050af9094ece6e9a67e35e4516..d84d8e1f48887a931ccc06fc97dfeddd9cb7076a 100644 --- a/configs/yolov5/yolov5_reader.yml +++ b/configs/yolov5/yolov5_reader.yml @@ -12,22 +12,27 @@ TrainReader: sample_transforms: - !DecodeImage to_rgb: True - # with_mosaic: True - # - !MosaicImage - # offset: 0.3 - # mosaic_scale: [0.8, 1.0] - # sample_scale: [0.8, 1.0] - # sample_flip: 0.5 - # use_cv2: true - # interp: 2 - - !NormalizeBox {} + with_mosaic: True + - !Mosaic + target_size: 640 + - !RandomPerspective + degree: 0 + translate: 0.1 + scale: 0.5 + shear: 0.0 + perspective: 0.0 + border: [-320, -320] + - !RandomFlipImage + prob: 0.5 + is_normalized: false + - !RandomHSV + hgain: 0.015 + sgain: 0.7 + vgain: 0.4 - !PadBox num_max_boxes: 50 - !BboxXYXY2XYWH {} batch_transforms: - - !RandomShape - sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640] - random_inter: True - !NormalizeImage mean: [0.0, 0.0, 0.0] std: [1.0, 1.0, 1.0] @@ -37,10 +42,6 @@ TrainReader: to_bgr: false channel_first: True # focus: false - - !Gt2YoloTarget - anchor_masks: [[0, 1, 2], [3, 4, 5], [6, 7, 8]] - anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], - [59, 119], [116, 90], [156, 198], [373, 326]] downsample_ratios: [8, 16, 32] batch_size: 2 mosaic_prob: 0.3 @@ -49,6 +50,9 @@ TrainReader: drop_last: true worker_num: 8 bufsize: 16 + target_size: 640 + rect: false + pad: 0.5 use_process: true EvalReader: diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index a9c7b69bd5bad2e75dc91a738b8bb51eddb44e45..9118914a9f1bb82aeb04d929e8ca20b753ff9023 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -21,6 +21,7 @@ import copy import functools import collections import traceback +import random import numpy as np import logging @@ -209,7 +210,8 @@ class Reader(object): memsize='3G', inputs_def=None, devices_num=1, - num_trainers=1): + num_trainers=1, + mosaic=False): self._dataset = dataset self._roidbs = self._dataset.get_roidb() if rect: @@ -219,9 +221,9 @@ class Reader(object): s = [] for i, rec in enumerate(self._roidbs): s.append([rec['h'], rec['w']]) - + s = np.array(s) - ar = s[:, 0] / s[:, 1] # h / w + ar = s[:, 0] / s[:, 1] # h / w irect = ar.argsort() ar = ar[irect] @@ -233,16 +235,19 @@ class Reader(object): shapes[i] = [maxi, 1] elif mini > 1: shapes[i] = [1, 1 / mini] - - batch_shapes = np.ceil(np.array(shapes) * target_size / stride + pad) * stride + + batch_shapes = np.ceil( + np.array(shapes) * target_size / stride + pad) * stride new_roidbs = [self._roidbs[j] for j in irect] self._roidbs = new_roidbs for i, j in enumerate(bi): - self._roidbs[i].update({'new_shape': batch_shapes[j]}) - + self._roidbs[i].update({'new_shape': batch_shapes[j]}) + self._fields = copy.deepcopy(inputs_def[ 'fields']) if inputs_def else None + self.mosaic = mosaic + # transform self._sample_transforms = Compose(sample_transforms, {'fields': self._fields}) @@ -387,6 +392,17 @@ class Reader(object): if self._load_img: sample['image'] = self._load_image(sample['im_file']) + if self.mosaic: + sample['mosaic'] = [] + for idx in [ + random.randint(0, len(self.indexes) - 1) + for _ in range(3) + ]: + rec = copy.deepcopy(self._roidbs[idx]) + if self._load_img: + rec['image'] = self._load_image(rec['im_file']) + sample['mosaic'].append(rec) + if self._epoch < self._mixup_epoch: num = len(self.indexes) mix_idx = np.random.randint(1, num) diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py index 02d219546d36c9fe1d0bc3cf81d9b41e108807f9..97ec539b7a5ba8e09075ad0274294cd85c42aee8 100644 --- a/ppdet/data/transform/op_helper.py +++ b/ppdet/data/transform/op_helper.py @@ -462,3 +462,56 @@ def gaussian2D(shape, sigma_x=1, sigma_y=1): sigma_y))) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h + + +def transform_bbox(bbox, + label, + M, + w, + h, + area_thr=0.25, + wh_thr=2, + ar_thr=20, + perspective=False): + """ + Transfrom bbox according to tranformation matrix M + """ + # rotate bbox + n = len(bbox) + xy = np.ones((n * 4, 3), dtype=np.float32) + xy[:, :2] = bbox[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) + # xy = xy @ M.T + xy = np.matmul(xy, M.T) + if perspective: + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) + else: + xy = xy[:, :2].reshape(n, 8) + # get new bboxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + new_bbox = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # clip boxes + new_bbox, mask = clip_bbox(new_bbox, w, h, area_thr) + new_label = label[mask] + return new_bbox, new_label + + +def clip_bbox(bbox, w, h, area_thr=0.25, wh_thr=2, ar_thr=20): + """ + clip bbox according to w and h + """ + # clip boxes + area1 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1) + bbox[:, [0, 2]] = bbox[:, [0, 2]].clip(0, w) + bbox[:, [1, 3]] = bbox[:, [1, 3]].clip(0, h) + # compute + area2 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1) + area_ratio = area2 / (area1 + 1e-16) + wh = bbox[:, 2:4] - bbox[:, 0:2] + ar_ratio = np.maximum(wh[:, 1] / (wh[:, 0] + 1e-16), + wh[:, 0] / (wh[:, 1] + 1e-16)) + mask = (area_ratio > area_thr) & ( + (wh > wh_thr).all(1)) & (ar_ratio < ar_thr) + bbox = bbox[mask] + return bbox, mask \ No newline at end of file diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 5ac9ab12ff17cfbf4c82d923520a1707498502ea..ff9ff2537b02011c284a3a22239adaa8e0764dbc 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -40,11 +40,11 @@ from PIL import Image, ImageEnhance, ImageDraw from ppdet.core.workspace import serializable from ppdet.modeling.ops import AnchorGrid -from .op_helper import (satisfy_sample_constraint, filter_and_process, - generate_sample_bbox, clip_bbox, data_anchor_sampling, - satisfy_sample_constraint_coverage, crop_image_sampling, - generate_sample_bbox_square, bbox_area_sampling, - is_poly, gaussian_radius, draw_gaussian) +from .op_helper import ( + satisfy_sample_constraint, filter_and_process, generate_sample_bbox, + clip_bbox, data_anchor_sampling, satisfy_sample_constraint_coverage, + crop_image_sampling, generate_sample_bbox_square, bbox_area_sampling, + is_poly, gaussian_radius, draw_gaussian, transform_bbox, clip_bbox) logger = logging.getLogger(__name__) @@ -90,7 +90,11 @@ class BaseOperator(object): @register_op class DecodeImage(BaseOperator): - def __init__(self, to_rgb=True, with_mixup=False, with_cutmix=False): + def __init__(self, + to_rgb=True, + with_mixup=False, + with_cutmix=False, + with_mosaic=False): """ Transform the image data to numpy format. Args: to_rgb (bool): whether to convert BGR to RGB @@ -102,6 +106,7 @@ class DecodeImage(BaseOperator): self.to_rgb = to_rgb self.with_mixup = with_mixup self.with_cutmix = with_cutmix + self.with_mosaic = with_mosaic if not isinstance(self.to_rgb, bool): raise TypeError("{}: input type is invalid.".format(self)) if not isinstance(self.with_mixup, bool): @@ -150,7 +155,11 @@ class DecodeImage(BaseOperator): if self.with_cutmix and 'cutmix' in sample: self.__call__(sample['cutmix'], context) - # decode semantic label + if self.with_mosaic and 'mosaic' in sample: + for x in sample['mosaic']: + self.__call__(x, context) + + # decode semantic label if 'semantic' in sample.keys() and sample['semantic'] is not None: sem_file = sample['semantic'] sem = cv2.imread(sem_file, cv2.IMREAD_GRAYSCALE) @@ -292,11 +301,11 @@ class ResizeImage(BaseOperator): self.use_cv2 = use_cv2 if not (isinstance(target_size, int) or isinstance(target_size, list)): raise TypeError( - "Type of target_size is invalid. Must be Integer or List, now is {}". - format(type(target_size))) + "Type of target_size is invalid. Must be Integer or List, now is {}" + .format(type(target_size))) self.target_size = target_size - if not (isinstance(self.max_size, int) and isinstance(self.interp, - int)): + if not (isinstance(self.max_size, int) and + isinstance(self.interp, int)): raise TypeError("{}: input type is invalid.".format(self)) def __call__(self, sample, context=None): @@ -372,30 +381,49 @@ class ResizeImage(BaseOperator): sample['image'] = im return sample + @register_op class ResizeAndKeepRatio(BaseOperator): - def __init__(self, target_size, augment=False): + def __init__(self, target_size, augment=False, with_mosaic=False): super(ResizeAndKeepRatio, self).__init__() self.target_size = target_size self.augment = augment def __call__(self, sample, context=None): im = sample['image'] + bbox = sample['gt_bbox'] + h0, w0 = im.shape[:2] r = self.target_size / max(h0, w0) if r != 1: interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR - im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp) - + im = cv2.resize( + im, (int(w0 * r), int(h0 * r)), interpolation=interp) + + bbox = bbox * (r, r, r, r) + bbox = bbox.clip(h0, w0) + sample['image'] = im sample['im_size'] = [float(h0), float(w0)] sample['im_scale'] = [1. / r, 1. / r] + sample['gt_bbox'] = bbox + + if self.with_mosaic and mosaic in sample: + for x in sample['mosaic']: + self.__call__(x, context) + return sample @register_op class LetterBox(BaseOperator): - def __init__(self, target_size, rect=True, color=(114, 114, 114), auto=True, scaleFill=False, augment=True): + def __init__(self, + target_size, + rect=True, + color=(114, 114, 114), + auto=True, + scaleFill=False, + augment=True): super(LetterBox, self).__init__() if isinstance(target_size, int): target_size = (target_size, target_size) @@ -405,7 +433,7 @@ class LetterBox(BaseOperator): self.scaleFill = scaleFill self.augment = augment self.rect = rect - + def __call__(self, sample, context=None): im = sample['image'] shape = im.shape[:2] @@ -413,16 +441,18 @@ class LetterBox(BaseOperator): r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) if not self.augment: r = min(r, 1.0) - + ratio = r, r new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) - dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[ + 1] # wh padding if self.auto: # minimum rectangle dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding elif self.scaleFill: # stretch dw, dh = 0.0, 0.0 new_unpad = (new_shape[1], new_shape[0]) - ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + ratio = new_shape[1] / shape[1], new_shape[0] / shape[ + 0] # width, height ratios dw /= 2 # divide padding into 2 sides dh /= 2 @@ -431,12 +461,14 @@ class LetterBox(BaseOperator): im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.color) # add border + im = cv2.copyMakeBorder( + im, top, bottom, left, right, cv2.BORDER_CONSTANT, + value=self.color) # add border sample['image'] = im - sample['im_pad'] = [dh, dw] + sample['im_pad'] = [dh, dw] return sample - + @register_op class RandomFlipImage(BaseOperator): @@ -1331,8 +1363,8 @@ class MixupImage(BaseOperator): if factor <= 0.0: return sample['mixup'] im = self._mixup_img(sample['image'], sample['mixup']['image'], factor) - gt_bbox1 = sample['gt_bbox'].reshape((-1, 4)) - gt_bbox2 = sample['mixup']['gt_bbox'].reshape((-1, 4)) + gt_bbox1 = sample['gt_bbox'] + gt_bbox2 = sample['mixup']['gt_bbox'] gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0) gt_class1 = sample['gt_class'] gt_class2 = sample['mixup']['gt_class'] @@ -2616,7 +2648,596 @@ class DebugVisibleImage(BaseOperator): x1 = round(keypoint[2 * j]).astype(np.int32) y1 = round(keypoint[2 * j + 1]).astype(np.int32) draw.ellipse( - (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green') + (x1, y1, x1 + 5, y1i + 5), + fill='green', + outline='green') save_path = os.path.join(self.output_dir, out_file_name) image.save(save_path, quality=95) return sample + + +@register_op +class Rotate(BaseOperator): + """Rotate image and bboxes + Args: + degree (int, float): the angle of rotation in degrees + scale (float): scale factor + center (tuple): center of the rotation in the source image + area_thr (float): the area threshold of bbox to be kept after rotation, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, + degree, + scale=1.0, + center=None, + area_thr=0.25, + border_value=(114, 114, 114)): + super(Rotate, self).__init__() + self.degree = degree + self.scale = scale + self.center = center + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # rotate image + height, width = im.shape[:2] + if self.center is None: + self.center = (width // 2, height // 2) + M = cv2.getRotationMatrix2D(self.center, self.degree, self.scale) + im = cv2.warpAffine( + im, M, (width, height), borderValue=self.border_value) + + # rotate bbox + if bbox.shape[0] > 0: + new_bbox, new_label = transform_bbox(bbox, label, M, width, height, + self.area_thr) + else: + new_bbox, new_label = bbox, label + sample['image'] = im + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomRotate(BaseOperator): + """Rotate image and bboxes randomly + Args: + degree (int, float, list, tuple): if(int, float), the rotation degree will be uniformly sampled uniformly in [-abs(degree), abs(degree)] + if (list, tuple), the rotation degree will be uniformly sampled in [degree[0], degree[1]] + scale (float): the scale factor will be uniformly sampled in [1 - scale, 1 + scale] + center (tuple): center of the rotation in the source image + area_thr (float): the area threshold of bbox to be kept after rotation, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, + degree, + scale=0.0, + center=None, + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomRotate, self).__init__() + if isinstance(degree, (int, float)): + degree = abs(degree) + degree = (-degree, degree) + elif isinstance(degree, list) or isinstance(degree, tuple): + assert len(degree) == 2, 'len of degree is not equal to 2' + else: + raise ValueError('degree is not reasonable') + + self.degree = degree + self.scale = scale + self.center = center + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + degree = random.uniform(*self.degree) + scale = random.uniform(1 - self.scale, 1 + self.scale) + rotate = Rotate(degree, scale, self.center, self.area_thr, + self.border_value) + return rotate(sample, context) + + +@register_op +class Shear(BaseOperator): + """Shear image and bboxes + Args: + shear (int, float, list, tuple): if (int, float), shear_x and shear_y are both equal to shear, + if (list, tuple), it means [shear_x, shear_y], the shear is in the format of degrees + area_thr (float): the area threshold of bbox to be kept after sheared, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, shear, area_thr=0.25, border_value=(114, 114, 114)): + super(Shear, self).__init__() + if isinstance(shear, (int, float)): + shear = (shear, shear) + elif isinstance(shear, list) or isinstance(shear, tuple): + assert len(shear) == 2, 'len of shear is not equal to 2' + else: + raise ValueError('shear is not reasonable') + + self.shear = shear + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # shear image + height, width = im.shape[:2] + shear_x = math.tan(self.shear[0] * math.pi / 180) + shear_y = math.tan(self.shear[1] * math.pi / 180) + M = np.array([[1, shear_x, 0], [shear_y, 1, 0]]) + im = cv2.warpAffine( + im, M, (width, height), borderValue=self.border_value) + + # shear box + if bbox.shape[0] > 0: + new_bbox, new_label = transform_bbox(bbox, label, M, width, height, + self.area_thr) + else: + new_bbox, new_label = bbox, label + sample['image'] = im + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomShear(BaseOperator): + """Shear image and bboxes randomly + Args: + shear_x (int, float, list, tuple): if (int, float), shear_x will be uniformly sampled in [-abs(shear_x), abs(shear_x)], + if (list, tuple), shear_x will be uniformly sampled in [shear_x[0], shear_x[1]], the shear_x is in the format of degrees + shear_y (int, float, list, tuple): if (int, float), shear_y will be uniformly sampled in [-abs(shear_y), abs(shear_y)], + if (list, tuple), shear_y will be uniformly sampled in [shear_y[0], shear_y[1]], the shear_y is in the format of degrees + area_thr (float): the area threshold of bbox to be kept after sheared, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, + shear_x, + shear_y, + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomShear, self).__init__() + if isinstance(shear_x, (int, float)): + shear_x = abs(shear_x) + shear_x = (-shear_x, shear_x) + elif isinstance(shear_x, list) or isinstance(shear_x, tuple): + assert len(shear_x) == 2, 'len of shear_x is not equal to 2' + else: + raise ValueError('shear_x is not reasonable') + + if isinstance(shear_y, (int, float)): + shear_y = abs(shear_y) + shear_y = (-shear_y, shear_y) + elif isinstance(shear_y, list) or isinstance(shear_y, tuple): + assert len(shear_y) == 2, 'len of shear_y is not equal to 2' + else: + raise ValueError('shear_y is not reasonable') + + self.shear_x = shear_x + self.shear_y = shear_y + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + shear_x = random.uniform(*self.shear_x) + shear_y = random.uniform(*self.shear_y) + shear = Shear((shear_x, shear_y), self.area_thr, self.border_value) + return shear(sample, context) + + +@register_op +class Translate(BaseOperator): + """Translate image and bboxes + Args: + translate (int, float, list, tuple): if (int, float), translate_x and translate_y are both equal to translate, + if (list, tuple), it means [translate_x, translate_y], translate is the fraction relative to original shape + area_thr (float): the area threshold of bbox to be kept after translation, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, translate, area_thr=0.25, border_value=(114, 114, 114)): + super(Translate, self).__init__() + if isinstance(translate, (int, float)): + translate = (translate, translate) + elif isinstance(translate, list) or isinstance(translate, tuple): + assert len(translate) == 2, 'len of translate is not equal to 2' + else: + raise ValueError('translate is not reasonable') + + assert abs(translate[0]) < 1 and abs(translate[ + 1]) < 1, 'translate should be in (-1, 1)' + + self.translate = translate + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # translate image + height, width = im.shape[:2] + translate_x = int(self.translate[0] * width) + translate_y = int(self.translate[1] * height) + + dst_cords = [ + max(0, translate_y), max(0, translate_x), + min(height, translate_y + height), min(width, translate_x + width) + ] + src_cords = [ + max(-translate_y, 0), max(-translate_x, 0), + min(-translate_y + height, height), min(-translate_x + width, width) + ] + canvas = np.ones(im.shape, dtype=np.uint8) * self.border_value + canvas[dst_cords[0]:dst_cords[2], dst_cords[1]:dst_cords[3], :] = im[ + src_cords[0]:src_cords[2], src_cords[1]:src_cords[3], :] + + if bbox.shape[0] > 0: + new_bbox = bbox + [ + translate_x, translate_y, translate_x, translate_y + ] + # compute + new_bbox, mask = clip_bbox(new_bbox, width, height, self.area_thr) + new_label = label[mask] + else: + new_bbox, new_label = bbox, label + sample['image'] = canvas.astype(np.uint8) + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomTranslate(BaseOperator): + """Translate image and bboxes randomly + Args: + translate_x (int, float, list, tuple): if (int, float), translate_x will be unifromly sampled in [-abs(translate_x), abs(translate_x)], + if (list, tuple), translate_x will be unifromly sampled in [translate_x[0], translate_x[1]], + translate_x is the fraction relative to original shape + translate_y (int, float, list, tuple): if (int, float), translate_y will be unifromly sampled in [-abs(translate_y), abs(translate_y)], + if (list, tuple), translate_y will be unifromly sampled in [translate_y[0], translate_y[1]], + translate_y is the fraction relative to original shape + area_thr (float): the area threshold of bbox to be kept after translation, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, + translate_x, + translate_y, + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomTranslate, self).__init__() + if isinstance(translate_x, (int, float)): + translate_x = abs(translate_x) + translate_x = (-translate_x, translate_x) + elif isinstance(translate_x, list) or isinstance(translate_x, tuple): + assert len(translate_x) == 2, 'len of translate_x is not equal to 2' + else: + raise ValueError('translate_x is not reasonable') + + if isinstance(translate_y, (int, float)): + translate_y = abs(translate_y) + translate_y = (-translate_y, translate_y) + elif isinstance(translate_y, list) or isinstance(translate_y, tuple): + assert len(translate_y) == 2, 'len of translate_y is not equal to 2' + else: + raise ValueError('translate_y is not reasonable') + + self.translate_x = translate_x + self.translate_y = translate_y + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + translate_x = random.uniform(*self.translate_x) + translate_y = random.uniform(*self.translate_y) + translate = Translate((translate_x, translate_y), self.area_thr, + self.border_value) + return translate(sample, context) + + +@register_op +class Scale(BaseOperator): + """Scale image and bboxes + Args: + scale (int, float, list, tuple): if (int, float), scale_x and scale_y are both equal to scale, + if (list, tuple), it means [scale_x, scale_y] + area_thr (float): the area threshold of bbox to be kept after scaled, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, scale, area_thr=0.25, border_value=(114, 114, 114)): + super(Scale, self).__init__() + if isinstance(scale, (int, float)): + scale = (scale, scale) + elif isinstance(scale, list) or isinstance(scale, tuple): + assert len(scale) == 2, 'len of scale is not equal to 2' + else: + raise ValueError('scale is not reasonable') + + assert scale[0] > 0. and scale[1] > 0., 'scale should be great than 0' + + self.scale = scale + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + # scale image + height, width = im.shape[:2] + dsize = (int(self.scale[0] * width), int(self.scale[1] * height)) + dst_img = cv2.resize(im, dsize) + canvas = np.ones_like(im, dtype=np.uint8) * self.border_value + y_lim = min(height, dsize[1]) + x_lim = min(width, dsize[0]) + canvas[:y_lim, :x_lim, :] = dst_img[:y_lim, :x_lim, :] + # scale bbox + if bbox.shape[0] > 0: + new_bbox = bbox * [ + self.scale[0], self.scale[1], self.scale[0], self.scale[1] + ] + new_bbox, mask = clip_bbox(new_bbox, width, height, self.area_thr) + new_label = label[mask] + else: + new_bbox, new_label = bbox, label + + sample['image'] = canvas.astype(np.uint8) + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomScale(BaseOperator): + """Scale image and bboxes randomly + Args: + scale_x (int, float, list, tuple): if (int, float), scale_x will be uniformly sampled in [0, scale_x], + if (list, tuple), scale_x will be uniformly sampled in [scale_x[0], scale_x[1]] + scale_y (int, float, list, tuple): if (int, float), scale_y will be uniformly sampled in [0, scale_y], + if (list, tuple), scale_y will be uniformly sampled in [scale_y[0], scale_y[1]] + area_thr (float): the area threshold of bbox to be kept after scaled, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, + scale_x, + scale_y, + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomScale, self).__init__() + if isinstance(scale_x, (int, float)): + assert scale_x > 0., 'scale_x should be great than 0' + scale_x = (0., scale_x) + elif isinstance(scale_x, list) or isinstance(scale_x, tuple): + assert len(scale_x) == 2, 'len of scale_x is not equal to 2' + else: + raise ValueError('scale_x is not reasonable') + + if isinstance(scale_y, (int, float)): + assert scale_y > 0., 'scale_y should be great than 0' + scale_y = (0., scale_y) + elif isinstance(scale_y, list) or isinstance(scale_y, tuple): + assert len(scale_y) == 2, 'len of scale_y is not equal to 2' + else: + raise ValueError('scale_y is not reasonable') + + self.scale_x = scale_x + self.scale_y = scale_y + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + scale_x = random.uniform(*self.scale_x) + scale_y = random.uniform(*self.scale_y) + scale = Scale((scale_x, scale_y), self.area_thr, self.border_value) + return scale(sample, context) + + +@register_op +class RandomPerspective(BaseOperator): + """Rotate, tranlate, scale, shear and perspect image and bboxes randomly + Args: + degree (int): rotation degree, uniformly sampled in [-degree, degree] + translate (float): translate fraction, translate_x and translate_y are uniformly sampled + in [0.5 - translate, 0.5 + translate] + scale (float): scale factor, uniformly sampled in [1 - scale, 1 + scale] + shear (int): shear degree, shear_x and shear_y are uniformly sampled in [-shear, shear] + perspective (float): perspective_x and perspective_y are uniformly sampled in [-perspective, perspective] + area_thr (float): the area threshold of bbox to be kept after transformation, default 0.25 + border_value (tuple): value used in case of a constant border, default (114, 114, 114) + """ + + def __init__(self, + degree=10, + translate=0.1, + scale=0.1, + shear=10, + perspective=0.0, + border=(0, 0), + area_thr=0.25, + border_value=(114, 114, 114)): + super(RandomPerspective, self).__init__() + self.degree = degree + self.translate = translate + self.scale = scale + self.shear = shear + self.perspective = perspective + self.border = border + self.area_thr = area_thr + self.border_value = border_value + + def __call__(self, sample, context=None): + im = sample['image'] + bbox = sample['gt_bbox'] + label = sample['gt_class'] + + height = im.shape[0] + self.border[0] + width = im.shape[1] + self.border[1] + + # center + C = np.eye(3) + C[0, 2] = -im.shape[1] / 2 + C[1, 2] = -im.shape[0] / 2 + + # perspective + P = np.eye(3) + P[2, 0] = random.uniform(-self.perspective, self.perspective) + P[2, 1] = random.uniform(-self.perspective, self.perspective) + + # Rotation and scale + R = np.eye(3) + a = random.uniform(-self.degree, self.degree) + s = random.uniform(1 - self.scale, 1 + self.scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3) + # shear x (deg) + S[0, 1] = math.tan( + random.uniform(-self.shear, self.shear) * math.pi / 180) + # shear y (deg) + S[1, 0] = math.tan( + random.uniform(-self.shear, self.shear) * math.pi / 180) + + # Translation + T = np.eye(3) + T[0, 2] = random.uniform(0.5 - self.translate, + 0.5 + self.translate) * width + T[1, 2] = random.uniform(0.5 - self.translate, + 0.5 + self.translate) * height + + # matmul + # M = T @ S @ R @ P @ C + M = np.eye(3) + for cM in [T, S, R, P, C]: + M = np.matmul(M, cM) + + if (self.border[0] != 0) or (self.border[1] != 0) or ( + M != np.eye(3)).any(): + if self.perspective: + im = cv2.warpPerspective( + im, M, dsize=(width, height), borderValue=self.border_value) + else: + im = cv2.warpAffine( + im, + M[:2], + dsize=(width, height), + borderValue=self.border_value) + + if bbox.shape[0] > 0: + new_bbox, new_label = transform_bbox( + bbox, + label, + M, + width, + height, + area_thr=self.area_thr, + perspective=self.perspective) + else: + new_bbox, new_label = bbox, label + + sample['image'] = im + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample + + +@register_op +class RandomHSV(BaseOperator): + def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5): + super(RandomHSV, self).__init__() + self.gains = [hgain, sgain, vgain] + + def __call__(self, sample, context=None): + im = sample['image'] + r = np.random.uniform(-1, 1, 3) * self.gains + 1 + hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV)) + x = np.arange(0, 256, dtype=np.int16) + lut_hue = ((x * r[0]) % 180).astype(np.uint8) + lut_sat = np.clip(x * r[1], 0, 255).astype(np.uint8) + lut_val = np.clip(x * r[2], 0, 255).astype(np.uint8) + im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), + cv2.LUT(val, lut_val))).astype(np.uint8) + im = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR) + sample['image'] = im + return sample + + +@register_op +class Mosaic(BaseOperator): + def __init__(self, + target_size, + mosaic_border=None, + border_value=(114, 114, 114)): + super(Mosaic, self).__init__() + self.target_size = target_size + if mosaic_border is None: + mosaic_border = (-target_size // 2, target_size // 2) + self.mosaic_border = mosaic_border + self.border_value = border_value + + def __call__(self, sample, context=None): + s = self.target_size + ims, bboxes, labels = [sample['image']], [sample['gt_bbox'] + ], [sample['gt_class']] + for x in sample['mosaic']: + ims.append(x['image']) + bboxes.append(x['gt_bbox']) + labels.append(x['gt_class']) + yc, xc = [ + int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border + ] + new_im = np.ones( + (s * 2, s * 2, ims[0].shape[2]), dtype=np.uint8) * self.border_value + n = len(ims) + for i in range(n): + im = ims[i] + h, w, _ = im.shape + if i == 0: # top left + # xmin, ymin, xmax, ymax (dst image) + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc + # xmin, ymin, xmax, ymax (src image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min( + y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, + yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + new_im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] + padw = x1a - x1b + padh = y1a - y1b + + bboxes[i] = bboxes[i] + (padw, padh, padw, padh) + + new_bbox = np.vstack(bboxes) + new_label = np.vstack(labels) + sample['image'] = new_im.astype(np.uint8) + sample['gt_bbox'] = new_bbox.astype(np.float32) + sample['gt_class'] = new_label.astype(np.int32) + return sample