From f29556961c9d030bbfbf62d39f1ed12e951cb5b4 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Mon, 9 Nov 2020 20:15:13 +0800 Subject: [PATCH] [Dygraph]modify operators and batch_operators in data preprocess (#1666) * modify operators and batch_operators in data preprocess * import BboxError and ImageError in operators * modify code according to review * modify some bugs * add Gt2Solov2TargetOp in batch_operators * modify code according to review --- ppdet/data/transform/__init__.py | 3 + ppdet/data/transform/batch_operator.py | 742 ++++++++++ ppdet/data/transform/batch_operators.py | 2 +- ppdet/data/transform/operator.py | 1796 +++++++++++++++++++++++ ppdet/data/transform/operators.py | 40 +- 5 files changed, 2543 insertions(+), 40 deletions(-) create mode 100644 ppdet/data/transform/batch_operator.py create mode 100644 ppdet/data/transform/operator.py diff --git a/ppdet/data/transform/__init__.py b/ppdet/data/transform/__init__.py index c5deb535a..8c606b3aa 100644 --- a/ppdet/data/transform/__init__.py +++ b/ppdet/data/transform/__init__.py @@ -15,8 +15,11 @@ from . import operators from . import batch_operators +# TODO: operators and batch_operators will be replaced by operator and batch_operator from .operators import * +from .operator import * from .batch_operators import * +from .batch_operator import * __all__ = [] __all__ += registered_ops diff --git a/ppdet/data/transform/batch_operator.py b/ppdet/data/transform/batch_operator.py new file mode 100644 index 000000000..09d273bfc --- /dev/null +++ b/ppdet/data/transform/batch_operator.py @@ -0,0 +1,742 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +import logging +import cv2 +import numpy as np +from .operator import register_op, BaseOperator, ResizeOp +from .op_helper import jaccard_overlap, gaussian2D +from scipy import ndimage + +logger = logging.getLogger(__name__) + +__all__ = [ + 'PadBatchOp', + 'Gt2YoloTargetOp', + 'Gt2FCOSTargetOp', + 'Gt2TTFTargetOp', +] + + +@register_op +class PadBatchOp(BaseOperator): + """ + Pad a batch of samples so they can be divisible by a stride. + The layout of each image should be 'CHW'. + Args: + pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure + height and width is divisible by `pad_to_stride`. + """ + + def __init__(self, pad_to_stride=0, pad_gt=False): + super(PadBatchOp, self).__init__() + self.pad_to_stride = pad_to_stride + self.pad_gt = pad_gt + + def __call__(self, samples, context=None): + """ + Args: + samples (list): a batch of sample, each is dict. + """ + coarsest_stride = self.pad_to_stride + + max_shape = np.array([data['image'].shape for data in samples]).max( + axis=0) + if coarsest_stride > 0: + max_shape[1] = int( + np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride) + max_shape[2] = int( + np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride) + + padding_batch = [] + for data in samples: + im = data['image'] + im_c, im_h, im_w = im.shape[:] + padding_im = np.zeros( + (im_c, max_shape[1], max_shape[2]), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + data['image'] = padding_im + if 'semantic' in data and data['semantic'] is not None: + semantic = data['semantic'] + padding_sem = np.zeros( + (1, max_shape[1], max_shape[2]), dtype=np.float32) + padding_sem[:, :im_h, :im_w] = semantic + data['semantic'] = padding_sem + if 'gt_segm' in data and data['gt_segm'] is not None: + gt_segm = data['gt_segm'] + padding_segm = np.zeros( + (gt_segm.shape[0], max_shape[1], max_shape[2]), + dtype=np.uint8) + padding_segm[:, :im_h, :im_w] = gt_segm + data['gt_segm'] = padding_segm + + if self.pad_gt: + gt_num = [] + if data['gt_poly'] is not None and len(data['gt_poly']) > 0: + pad_mask = True + else: + pad_mask = False + + if pad_mask: + poly_num = [] + poly_part_num = [] + point_num = [] + for data in samples: + gt_num.append(data['gt_bbox'].shape[0]) + if pad_mask: + poly_num.append(len(data['gt_poly'])) + for poly in data['gt_poly']: + poly_part_num.append(int(len(poly))) + for p_p in poly: + point_num.append(int(len(p_p) / 2)) + gt_num_max = max(gt_num) + gt_box_data = np.zeros([gt_num_max, 4]) + gt_class_data = np.zeros([gt_num_max]) + is_crowd_data = np.ones([gt_num_max]) + + if pad_mask: + poly_num_max = max(poly_num) + poly_part_num_max = max(poly_part_num) + point_num_max = max(point_num) + gt_masks_data = -np.ones( + [poly_num_max, poly_part_num_max, point_num_max, 2]) + + for i, data in enumerate(samples): + gt_num = data['gt_bbox'].shape[0] + gt_box_data[0:gt_num, :] = data['gt_bbox'] + gt_class_data[0:gt_num] = np.squeeze(data['gt_class']) + is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd']) + if pad_mask: + for j, poly in enumerate(data['gt_poly']): + for k, p_p in enumerate(poly): + pp_np = np.array(p_p).reshape(-1, 2) + gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np + data['gt_poly'] = gt_masks_data + data['gt_bbox'] = gt_box_data + data['gt_class'] = gt_class_data + data['is_crowd'] = is_crowd_data + + return samples + + +@register_op +class BatchRandomResizeOp(BaseOperator): + """ + Resize image to target size randomly. random target_size and interpolation method + Args: + target_size (int, list, tuple): image target size, if random size is True, must be list or tuple + keep_ratio (bool): whether keep_raio or not, default true + interp (int): the interpolation method + random_size (bool): whether random select target size of image + random_interp (bool): whether random select interpolation method + """ + + def __init__(self, + target_size, + keep_ratio=True, + interp=cv2.INTER_LINEAR, + random_size=True, + random_interp=False): + super(BatchRandomResizeOp, self).__init__() + self.keep_ratio = keep_ratio + self.interps = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_AREA, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] + self.interp = interp + assert isinstance(target_size, ( + int, Sequence)), "target_size must be int, list or tuple" + if random_size and not isinstance(target_size, list): + raise TypeError( + "Type of target_size is invalid when random_size is True. Must be List, now is {}". + format(type(target_size))) + self.target_size = target_size + self.random_size = random_size + self.random_interp = random_interp + + def __call__(self, samples, context=None): + if self.random_size: + target_size = np.random.choice(self.target_size) + else: + target_size = self.target_size + + if self.random_interp: + interp = np.random.choice(self.interps) + else: + interp = self.interp + + resizer = ResizeOp( + target_size, keep_ratio=self.keep_ratio, interp=interp) + return resizer(samples, context=context) + + +@register_op +class Gt2YoloTargetOp(BaseOperator): + """ + Generate YOLOv3 targets by groud truth data, this operator is only used in + fine grained YOLOv3 loss mode + """ + + def __init__(self, + anchors, + anchor_masks, + downsample_ratios, + num_classes=80, + iou_thresh=1.): + super(Gt2YoloTargetOp, self).__init__() + self.anchors = anchors + self.anchor_masks = anchor_masks + self.downsample_ratios = downsample_ratios + self.num_classes = num_classes + self.iou_thresh = iou_thresh + + def __call__(self, samples, context=None): + assert len(self.anchor_masks) == len(self.downsample_ratios), \ + "anchor_masks', and 'downsample_ratios' should have same length." + + h, w = samples[0]['image'].shape[1:3] + an_hw = np.array(self.anchors) / np.array([[w, h]]) + for sample in samples: + # im, gt_bbox, gt_class, gt_score = sample + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + gt_score = sample['gt_score'] + for i, ( + mask, downsample_ratio + ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)): + grid_h = int(h / downsample_ratio) + grid_w = int(w / downsample_ratio) + target = np.zeros( + (len(mask), 6 + self.num_classes, grid_h, grid_w), + dtype=np.float32) + for b in range(gt_bbox.shape[0]): + gx, gy, gw, gh = gt_bbox[b, :] + cls = gt_class[b] + score = gt_score[b] + if gw <= 0. or gh <= 0. or score <= 0.: + continue + + # find best match anchor index + best_iou = 0. + best_idx = -1 + for an_idx in range(an_hw.shape[0]): + iou = jaccard_overlap( + [0., 0., gw, gh], + [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]]) + if iou > best_iou: + best_iou = iou + best_idx = an_idx + + gi = int(gx * grid_w) + gj = int(gy * grid_h) + + # gtbox should be regresed in this layes if best match + # anchor index in anchor mask of this layer + if best_idx in mask: + best_n = mask.index(best_idx) + + # x, y, w, h, scale + target[best_n, 0, gj, gi] = gx * grid_w - gi + target[best_n, 1, gj, gi] = gy * grid_h - gj + target[best_n, 2, gj, gi] = np.log( + gw * w / self.anchors[best_idx][0]) + target[best_n, 3, gj, gi] = np.log( + gh * h / self.anchors[best_idx][1]) + target[best_n, 4, gj, gi] = 2.0 - gw * gh + + # objectness record gt_score + target[best_n, 5, gj, gi] = score + + # classification + target[best_n, 6 + cls, gj, gi] = 1. + + # For non-matched anchors, calculate the target if the iou + # between anchor and gt is larger than iou_thresh + if self.iou_thresh < 1: + for idx, mask_i in enumerate(mask): + if mask_i == best_idx: continue + iou = jaccard_overlap( + [0., 0., gw, gh], + [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]]) + if iou > self.iou_thresh: + # x, y, w, h, scale + target[idx, 0, gj, gi] = gx * grid_w - gi + target[idx, 1, gj, gi] = gy * grid_h - gj + target[idx, 2, gj, gi] = np.log( + gw * w / self.anchors[mask_i][0]) + target[idx, 3, gj, gi] = np.log( + gh * h / self.anchors[mask_i][1]) + target[idx, 4, gj, gi] = 2.0 - gw * gh + + # objectness record gt_score + target[idx, 5, gj, gi] = score + + # classification + target[idx, 6 + cls, gj, gi] = 1. + sample['target{}'.format(i)] = target + return samples + + +@register_op +class Gt2FCOSTargetOp(BaseOperator): + """ + Generate FCOS targets by groud truth data + """ + + def __init__(self, + object_sizes_boundary, + center_sampling_radius, + downsample_ratios, + norm_reg_targets=False): + super(Gt2FCOSTargetOp, self).__init__() + self.center_sampling_radius = center_sampling_radius + self.downsample_ratios = downsample_ratios + self.INF = np.inf + self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF] + object_sizes_of_interest = [] + for i in range(len(self.object_sizes_boundary) - 1): + object_sizes_of_interest.append([ + self.object_sizes_boundary[i], self.object_sizes_boundary[i + 1] + ]) + self.object_sizes_of_interest = object_sizes_of_interest + self.norm_reg_targets = norm_reg_targets + + def _compute_points(self, w, h): + """ + compute the corresponding points in each feature map + :param h: image height + :param w: image width + :return: points from all feature map + """ + locations = [] + for stride in self.downsample_ratios: + shift_x = np.arange(0, w, stride).astype(np.float32) + shift_y = np.arange(0, h, stride).astype(np.float32) + shift_x, shift_y = np.meshgrid(shift_x, shift_y) + shift_x = shift_x.flatten() + shift_y = shift_y.flatten() + location = np.stack([shift_x, shift_y], axis=1) + stride // 2 + locations.append(location) + num_points_each_level = [len(location) for location in locations] + locations = np.concatenate(locations, axis=0) + return locations, num_points_each_level + + def _convert_xywh2xyxy(self, gt_bbox, w, h): + """ + convert the bounding box from style xywh to xyxy + :param gt_bbox: bounding boxes normalized into [0, 1] + :param w: image width + :param h: image height + :return: bounding boxes in xyxy style + """ + bboxes = gt_bbox.copy() + bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w + bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h + bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] + bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] + return bboxes + + def _check_inside_boxes_limited(self, gt_bbox, xs, ys, + num_points_each_level): + """ + check if points is within the clipped boxes + :param gt_bbox: bounding boxes + :param xs: horizontal coordinate of points + :param ys: vertical coordinate of points + :return: the mask of points is within gt_box or not + """ + bboxes = np.reshape( + gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]]) + bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1]) + ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2 + ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2 + beg = 0 + clipped_box = bboxes.copy() + for lvl, stride in enumerate(self.downsample_ratios): + end = beg + num_points_each_level[lvl] + stride_exp = self.center_sampling_radius * stride + clipped_box[beg:end, :, 0] = np.maximum( + bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp) + clipped_box[beg:end, :, 1] = np.maximum( + bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp) + clipped_box[beg:end, :, 2] = np.minimum( + bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp) + clipped_box[beg:end, :, 3] = np.minimum( + bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp) + beg = end + l_res = xs - clipped_box[:, :, 0] + r_res = clipped_box[:, :, 2] - xs + t_res = ys - clipped_box[:, :, 1] + b_res = clipped_box[:, :, 3] - ys + clipped_box_reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2) + inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0 + return inside_gt_box + + def __call__(self, samples, context=None): + assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \ + "object_sizes_of_interest', and 'downsample_ratios' should have same length." + + for sample in samples: + # im, gt_bbox, gt_class, gt_score = sample + im = sample['image'] + bboxes = sample['gt_bbox'] + gt_class = sample['gt_class'] + # calculate the locations + h, w = im.shape[1:3] + points, num_points_each_level = self._compute_points(w, h) + object_scale_exp = [] + for i, num_pts in enumerate(num_points_each_level): + object_scale_exp.append( + np.tile( + np.array([self.object_sizes_of_interest[i]]), + reps=[num_pts, 1])) + object_scale_exp = np.concatenate(object_scale_exp, axis=0) + + gt_area = (bboxes[:, 2] - bboxes[:, 0]) * ( + bboxes[:, 3] - bboxes[:, 1]) + xs, ys = points[:, 0], points[:, 1] + xs = np.reshape(xs, newshape=[xs.shape[0], 1]) + xs = np.tile(xs, reps=[1, bboxes.shape[0]]) + ys = np.reshape(ys, newshape=[ys.shape[0], 1]) + ys = np.tile(ys, reps=[1, bboxes.shape[0]]) + + l_res = xs - bboxes[:, 0] + r_res = bboxes[:, 2] - xs + t_res = ys - bboxes[:, 1] + b_res = bboxes[:, 3] - ys + reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2) + if self.center_sampling_radius > 0: + is_inside_box = self._check_inside_boxes_limited( + bboxes, xs, ys, num_points_each_level) + else: + is_inside_box = np.min(reg_targets, axis=2) > 0 + # check if the targets is inside the corresponding level + max_reg_targets = np.max(reg_targets, axis=2) + lower_bound = np.tile( + np.expand_dims( + object_scale_exp[:, 0], axis=1), + reps=[1, max_reg_targets.shape[1]]) + high_bound = np.tile( + np.expand_dims( + object_scale_exp[:, 1], axis=1), + reps=[1, max_reg_targets.shape[1]]) + is_match_current_level = \ + (max_reg_targets > lower_bound) & \ + (max_reg_targets < high_bound) + points2gtarea = np.tile( + np.expand_dims( + gt_area, axis=0), reps=[xs.shape[0], 1]) + points2gtarea[is_inside_box == 0] = self.INF + points2gtarea[is_match_current_level == 0] = self.INF + points2min_area = points2gtarea.min(axis=1) + points2min_area_ind = points2gtarea.argmin(axis=1) + labels = gt_class[points2min_area_ind] + 1 + labels[points2min_area == self.INF] = 0 + reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind] + ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \ + reg_targets[:, [0, 2]].max(axis=1)) * \ + (reg_targets[:, [1, 3]].min(axis=1) / \ + reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32) + ctn_targets = np.reshape( + ctn_targets, newshape=[ctn_targets.shape[0], 1]) + ctn_targets[labels <= 0] = 0 + pos_ind = np.nonzero(labels != 0) + reg_targets_pos = reg_targets[pos_ind[0], :] + split_sections = [] + beg = 0 + for lvl in range(len(num_points_each_level)): + end = beg + num_points_each_level[lvl] + split_sections.append(end) + beg = end + labels_by_level = np.split(labels, split_sections, axis=0) + reg_targets_by_level = np.split(reg_targets, split_sections, axis=0) + ctn_targets_by_level = np.split(ctn_targets, split_sections, axis=0) + for lvl in range(len(self.downsample_ratios)): + grid_w = int(np.ceil(w / self.downsample_ratios[lvl])) + grid_h = int(np.ceil(h / self.downsample_ratios[lvl])) + if self.norm_reg_targets: + sample['reg_target{}'.format(lvl)] = \ + np.reshape( + reg_targets_by_level[lvl] / \ + self.downsample_ratios[lvl], + newshape=[grid_h, grid_w, 4]) + else: + sample['reg_target{}'.format(lvl)] = np.reshape( + reg_targets_by_level[lvl], + newshape=[grid_h, grid_w, 4]) + sample['labels{}'.format(lvl)] = np.reshape( + labels_by_level[lvl], newshape=[grid_h, grid_w, 1]) + sample['centerness{}'.format(lvl)] = np.reshape( + ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1]) + return samples + + +@register_op +class Gt2TTFTargetOp(BaseOperator): + """ + Gt2TTFTarget + Generate TTFNet targets by ground truth data + + Args: + num_classes(int): the number of classes. + down_ratio(int): the down ratio from images to heatmap, 4 by default. + alpha(float): the alpha parameter to generate gaussian target. + 0.54 by default. + """ + + def __init__(self, num_classes, down_ratio=4, alpha=0.54): + super(Gt2TTFTargetOp, self).__init__() + self.down_ratio = down_ratio + self.num_classes = num_classes + self.alpha = alpha + + def __call__(self, samples, context=None): + output_size = samples[0]['image'].shape[1] + feat_size = output_size // self.down_ratio + for sample in samples: + heatmap = np.zeros( + (self.num_classes, feat_size, feat_size), dtype='float32') + box_target = np.ones( + (4, feat_size, feat_size), dtype='float32') * -1 + reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32') + + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + + bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1 + bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1 + area = bbox_w * bbox_h + boxes_areas_log = np.log(area) + boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1] + boxes_area_topk_log = boxes_areas_log[boxes_ind] + gt_bbox = gt_bbox[boxes_ind] + gt_class = gt_class[boxes_ind] + + feat_gt_bbox = gt_bbox / self.down_ratio + feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1) + feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1], + feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0]) + + ct_inds = np.stack( + [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2, + (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2], + axis=1) / self.down_ratio + + h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32') + w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32') + + for k in range(len(gt_bbox)): + cls_id = gt_class[k] + fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32') + self.draw_truncate_gaussian(fake_heatmap, ct_inds[k], + h_radiuses_alpha[k], + w_radiuses_alpha[k]) + + heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap) + box_target_inds = fake_heatmap > 0 + box_target[:, box_target_inds] = gt_bbox[k][:, None] + + local_heatmap = fake_heatmap[box_target_inds] + ct_div = np.sum(local_heatmap) + local_heatmap *= boxes_area_topk_log[k] + reg_weight[0, box_target_inds] = local_heatmap / ct_div + sample['ttf_heatmap'] = heatmap + sample['ttf_box_target'] = box_target + sample['ttf_reg_weight'] = reg_weight + return samples + + def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius): + h, w = 2 * h_radius + 1, 2 * w_radius + 1 + sigma_x = w / 6 + sigma_y = h / 6 + gaussian = gaussian2D((h, w), sigma_x, sigma_y) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, w_radius), min(width - x, w_radius + 1) + top, bottom = min(y, h_radius), min(height - y, h_radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius - + left:w_radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + heatmap[y - top:y + bottom, x - left:x + right] = np.maximum( + masked_heatmap, masked_gaussian) + return heatmap + + +@register_op +class Gt2Solov2TargetOp(BaseOperator): + """Assign mask target and labels in SOLOv2 network. + Args: + num_grids (list): The list of feature map grids size. + scale_ranges (list): The list of mask boundary range. + coord_sigma (float): The coefficient of coordinate area length. + sampling_ratio (float): The ratio of down sampling. + """ + + def __init__(self, + num_grids=[40, 36, 24, 16, 12], + scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768], + [384, 2048]], + coord_sigma=0.2, + sampling_ratio=4.0): + super(Gt2Solov2TargetOp, self).__init__() + self.num_grids = num_grids + self.scale_ranges = scale_ranges + self.coord_sigma = coord_sigma + self.sampling_ratio = sampling_ratio + + def _scale_size(self, im, scale): + h, w = im.shape[:2] + new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5)) + resized_img = cv2.resize( + im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) + return resized_img + + def __call__(self, samples, context=None): + sample_id = 0 + for sample in samples: + gt_bboxes_raw = sample['gt_bbox'] + gt_labels_raw = sample['gt_class'] + im_c, im_h, im_w = sample['image'].shape[:] + gt_masks_raw = sample['gt_segm'].astype(np.uint8) + mask_feat_size = [ + int(im_h / self.sampling_ratio), int(im_w / self.sampling_ratio) + ] + gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * + (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1])) + ins_ind_label_list = [] + idx = 0 + for (lower_bound, upper_bound), num_grid \ + in zip(self.scale_ranges, self.num_grids): + + hit_indices = ((gt_areas >= lower_bound) & + (gt_areas <= upper_bound)).nonzero()[0] + num_ins = len(hit_indices) + + ins_label = [] + grid_order = [] + cate_label = np.zeros([num_grid, num_grid], dtype=np.int64) + ins_ind_label = np.zeros([num_grid**2], dtype=np.bool) + + if num_ins == 0: + ins_label = np.zeros( + [1, mask_feat_size[0], mask_feat_size[1]], + dtype=np.uint8) + ins_ind_label_list.append(ins_ind_label) + sample['cate_label{}'.format(idx)] = cate_label.flatten() + sample['ins_label{}'.format(idx)] = ins_label + sample['grid_order{}'.format(idx)] = np.asarray( + [sample_id * num_grid * num_grid + 0]) + idx += 1 + continue + gt_bboxes = gt_bboxes_raw[hit_indices] + gt_labels = gt_labels_raw[hit_indices] + gt_masks = gt_masks_raw[hit_indices, ...] + + half_ws = 0.5 * ( + gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma + half_hs = 0.5 * ( + gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma + + for seg_mask, gt_label, half_h, half_w in zip( + gt_masks, gt_labels, half_hs, half_ws): + if seg_mask.sum() == 0: + continue + # mass center + upsampled_size = (mask_feat_size[0] * 4, + mask_feat_size[1] * 4) + center_h, center_w = ndimage.measurements.center_of_mass( + seg_mask) + coord_w = int( + (center_w / upsampled_size[1]) // (1. / num_grid)) + coord_h = int( + (center_h / upsampled_size[0]) // (1. / num_grid)) + + # left, top, right, down + top_box = max(0, + int(((center_h - half_h) / upsampled_size[0]) + // (1. / num_grid))) + down_box = min(num_grid - 1, + int(((center_h + half_h) / upsampled_size[0]) + // (1. / num_grid))) + left_box = max(0, + int(((center_w - half_w) / upsampled_size[1]) + // (1. / num_grid))) + right_box = min(num_grid - 1, + int(((center_w + half_w) / + upsampled_size[1]) // (1. / num_grid))) + + top = max(top_box, coord_h - 1) + down = min(down_box, coord_h + 1) + left = max(coord_w - 1, left_box) + right = min(right_box, coord_w + 1) + + cate_label[top:(down + 1), left:(right + 1)] = gt_label + seg_mask = self._scale_size( + seg_mask, scale=1. / self.sampling_ratio) + for i in range(top, down + 1): + for j in range(left, right + 1): + label = int(i * num_grid + j) + cur_ins_label = np.zeros( + [mask_feat_size[0], mask_feat_size[1]], + dtype=np.uint8) + cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[ + 1]] = seg_mask + ins_label.append(cur_ins_label) + ins_ind_label[label] = True + grid_order.append( + [sample_id * num_grid * num_grid + label]) + if ins_label == []: + ins_label = np.zeros( + [1, mask_feat_size[0], mask_feat_size[1]], + dtype=np.uint8) + ins_ind_label_list.append(ins_ind_label) + sample['cate_label{}'.format(idx)] = cate_label.flatten() + sample['ins_label{}'.format(idx)] = ins_label + sample['grid_order{}'.format(idx)] = np.asarray( + [sample_id * num_grid * num_grid + 0]) + else: + ins_label = np.stack(ins_label, axis=0) + ins_ind_label_list.append(ins_ind_label) + sample['cate_label{}'.format(idx)] = cate_label.flatten() + sample['ins_label{}'.format(idx)] = ins_label + sample['grid_order{}'.format(idx)] = np.asarray(grid_order) + assert len(grid_order) > 0 + idx += 1 + ins_ind_labels = np.concatenate([ + ins_ind_labels_level_img + for ins_ind_labels_level_img in ins_ind_label_list + ]) + fg_num = np.sum(ins_ind_labels) + sample['fg_num'] = fg_num + sample_id += 1 + + return samples diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 538344438..e60954bb9 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -24,7 +24,7 @@ except Exception: import logging import cv2 import numpy as np -from .operators import register_op, BaseOperator +from .operator import register_op, BaseOperator from .op_helper import jaccard_overlap, gaussian2D logger = logging.getLogger(__name__) diff --git a/ppdet/data/transform/operator.py b/ppdet/data/transform/operator.py new file mode 100644 index 000000000..83bc1064d --- /dev/null +++ b/ppdet/data/transform/operator.py @@ -0,0 +1,1796 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# function: +# operators to process sample, +# eg: decode/resize/crop image + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence + +from numbers import Number + +import uuid +import logging +import random +import math +import numpy as np +import os + +import cv2 +from PIL import Image, ImageEnhance, ImageDraw + +from ppdet.core.workspace import serializable +from ppdet.modeling.layers import AnchorGrid + +from .op_helper import (satisfy_sample_constraint, filter_and_process, + generate_sample_bbox, clip_bbox, data_anchor_sampling, + satisfy_sample_constraint_coverage, crop_image_sampling, + generate_sample_bbox_square, bbox_area_sampling, + is_poly, gaussian_radius, draw_gaussian) + +logger = logging.getLogger(__name__) + +registered_ops = [] + + +def register_op(cls): + registered_ops.append(cls.__name__) + if not hasattr(BaseOperator, cls.__name__): + setattr(BaseOperator, cls.__name__, cls) + else: + raise KeyError("The {} class has been registered.".format(cls.__name__)) + return serializable(cls) + + +class BboxError(ValueError): + pass + + +class ImageError(ValueError): + pass + + +class BaseOperator(object): + def __init__(self, name=None): + if name is None: + name = self.__class__.__name__ + self._id = name + '_' + str(uuid.uuid4())[-6:] + + def apply(self, sample, context=None): + """ Process a sample. + Args: + sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx} + context (dict): info about this sample processing + Returns: + result (dict): a processed sample + """ + return sample + + def __call__(self, sample, context=None): + """ Process a sample. + Args: + sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx} + context (dict): info about this sample processing + Returns: + result (dict): a processed sample + """ + if isinstance(sample, Sequence): + for i in range(len(sample)): + sample[i] = self.apply(sample[i], context) + + sample = self.apply(sample, context) + return sample + + def __str__(self): + return str(self._id) + + +@register_op +class DecodeOp(BaseOperator): + def __init__(self): + """ Transform the image data to numpy format following the rgb format + """ + super(DecodeOp, self).__init__() + + def apply(self, sample, context=None): + """ load image if 'im_file' field is not empty but 'image' is""" + if 'image' not in sample: + with open(sample['im_file'], 'rb') as f: + sample['image'] = f.read() + + im = sample['image'] + data = np.frombuffer(im, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + sample['image'] = im + if 'h' not in sample: + sample['h'] = im.shape[0] + elif sample['h'] != im.shape[0]: + logger.warn( + "The actual image height: {} is not equal to the " + "height: {} in annotation, and update sample['h'] by actual " + "image height.".format(im.shape[0], sample['h'])) + sample['h'] = im.shape[0] + if 'w' not in sample: + sample['w'] = im.shape[1] + elif sample['w'] != im.shape[1]: + logger.warn( + "The actual image width: {} is not equal to the " + "width: {} in annotation, and update sample['w'] by actual " + "image width.".format(im.shape[1], sample['w'])) + sample['w'] = im.shape[1] + + sample['im_shape'] = im.shape[:2] + sample['scale_factor'] = [1., 1.] + return sample + + +@register_op +class PermuteOp(BaseOperator): + def __init__(self): + """ + Change the channel to be (C, H, W) + """ + super(PermuteOp, self).__init__() + + def apply(self, sample, context=None): + im = sample['image'] + im = im.transpose((2, 0, 1)) + sample['image'] = im + return sample + + +@register_op +class LightingOp(BaseOperator): + """ + Lighting the imagen by eigenvalues and eigenvectors + Args: + eigval (list): eigenvalues + eigvec (list): eigenvectors + alphastd (float): random weight of lighting, 0.1 by default + """ + + def __init__(self, eigval, eigvec, alphastd=0.1): + super(LightingOp, self).__init__() + self.alphastd = alphastd + self.eigval = np.array(eigval).astype('float32') + self.eigvec = np.array(eigvec).astype('float32') + + def apply(self, sample, context=None): + alpha = np.random.normal(scale=self.alphastd, size=(3, )) + sample['image'] += np.dot(self.eigvec, self.eigval * alpha) + return sample + + +@register_op +class NormalizeImageOp(BaseOperator): + def __init__(self, mean=[0.485, 0.456, 0.406], std=[1, 1, 1], + is_scale=True): + """ + Args: + mean (list): the pixel mean + std (list): the pixel variance + """ + super(NormalizeImageOp, self).__init__() + self.mean = mean + self.std = std + self.is_scale = is_scale + if not (isinstance(self.mean, list) and isinstance(self.std, list) and + isinstance(self.is_scale, bool)): + raise TypeError("{}: input type is invalid.".format(self)) + from functools import reduce + if reduce(lambda x, y: x * y, self.std) == 0: + raise ValueError('{}: std is invalid!'.format(self)) + + def apply(self, sample, context=None): + """Normalize the image. + Operators: + 1.(optional) Scale the image to [0,1] + 2. Each pixel minus mean and is divided by std + """ + im = sample['image'] + im = im.astype(np.float32, copy=False) + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + + if self.is_scale: + im = im / 255.0 + + im -= mean + im /= std + + sample['image'] = im + return sample + + +@register_op +class GridMaskOp(BaseOperator): + def __init__(self, + use_h=True, + use_w=True, + rotate=1, + offset=False, + ratio=0.5, + mode=1, + prob=0.7, + upper_iter=360000): + """ + GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086 + Args: + use_h (bool): whether to mask vertically + use_w (boo;): whether to mask horizontally + rotate (float): angle for the mask to rotate + offset (float): mask offset + ratio (float): mask ratio + mode (int): gridmask mode + prob (float): max probability to carry out gridmask + upper_iter (int): suggested to be equal to global max_iter + """ + super(GridMaskOp, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.prob = prob + self.upper_iter = upper_iter + + from .gridmask_utils import GridMask + self.gridmask_op = GridMask( + use_h, + use_w, + rotate=rotate, + offset=offset, + ratio=ratio, + mode=mode, + prob=prob, + upper_iter=upper_iter) + + def apply(self, sample, context=None): + sample['image'] = self.gridmask_op(sample['image'], sample['curr_iter']) + return sample + + +@register_op +class RandomDistortOp(BaseOperator): + """Random color distortion. + Args: + hue (list): hue settings. in [lower, upper, probability] format. + saturation (list): saturation settings. in [lower, upper, probability] format. + contrast (list): contrast settings. in [lower, upper, probability] format. + brightness (list): brightness settings. in [lower, upper, probability] format. + random_apply (bool): whether to apply in random (yolo) or fixed (SSD) + order. + count (int): the number of doing distrot + random_channel (bool): whether to swap channels randomly + """ + + def __init__(self, + hue=[-18, 18, 0.5], + saturation=[0.5, 1.5, 0.5], + contrast=[0.5, 1.5, 0.5], + brightness=[0.5, 1.5, 0.5], + random_apply=True, + count=4, + random_channel=False): + super(RandomDistortOp, self).__init__() + self.hue = hue + self.saturation = saturation + self.contrast = contrast + self.brightness = brightness + self.random_apply = random_apply + self.count = count + self.random_channel = random_channel + + def apply_hue(self, img): + low, high, prob = self.hue + if np.random.uniform(0., 1.) < prob: + return img + + img = img.astype(np.float32) + img[..., 0] += random.uniform(low, high) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + return img + + def apply_saturation(self, img): + low, high, prob = self.saturation + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + img = img.astype(np.float32) + img[..., 1] *= delta + return img + + def apply_contrast(self, img): + low, high, prob = self.contrast + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + img = img.astype(np.float32) + img *= delta + return img + + def apply_brightness(self, img): + low, high, prob = self.brightness + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + img = img.astype(np.float32) + img += delta + return img + + def apply(self, sample, context=None): + img = sample['image'] + if self.random_apply: + functions = [ + self.apply_brightness, + self.apply_contrast, + lambda img: cv2.cvtColor(self.apply_saturation(cv2.cvtColor(img, cv2.COLOR_RGB2HSV)), cv2.COLOR_HSV2RGB), + lambda img: cv2.cvtColor(self.apply_hue(cv2.cvtColor(img, cv2.COLOR_RGB2HSV)), cv2.COLOR_HSV2RGB), + ] + distortions = np.random.permutation(functions)[:count] + for func in distortions: + img = func(img) + sample['image'] = img + return sample + + img = self.apply_brightness(img) + mode = np.random.randint(0, 2) + + if mode: + img = self.apply_contrast(img) + + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + img = self.apply_saturation(img) + img = self.apply_hue(img) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + + if not mode: + img = self.apply_contrast(img) + + if self.random_channel: + if np.random.randint(0, 2): + img = img[..., np.random.permutation(3)] + sample['image'] = img + return sample + + +@register_op +class AutoAugmentOp(BaseOperator): + def __init__(self, autoaug_type="v1"): + """ + Args: + autoaug_type (str): autoaug type, support v0, v1, v2, v3, test + """ + super(AutoAugmentOp, self).__init__() + self.autoaug_type = autoaug_type + + def apply(self, sample, context=None): + """ + Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172 + """ + im = sample['image'] + gt_bbox = sample['gt_bbox'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image is not a numpy array.".format(self)) + if len(im.shape) != 3: + raise ImageError("{}: image is not 3-dimensional.".format(self)) + if len(gt_bbox) == 0: + return sample + + height, width, _ = im.shape + norm_gt_bbox = np.ones_like(gt_bbox, dtype=np.float32) + norm_gt_bbox[:, 0] = gt_bbox[:, 1] / float(height) + norm_gt_bbox[:, 1] = gt_bbox[:, 0] / float(width) + norm_gt_bbox[:, 2] = gt_bbox[:, 3] / float(height) + norm_gt_bbox[:, 3] = gt_bbox[:, 2] / float(width) + + from .autoaugment_utils import distort_image_with_autoaugment + im, norm_gt_bbox = distort_image_with_autoaugment(im, norm_gt_bbox, + self.autoaug_type) + + gt_bbox[:, 0] = norm_gt_bbox[:, 1] * float(width) + gt_bbox[:, 1] = norm_gt_bbox[:, 0] * float(height) + gt_bbox[:, 2] = norm_gt_bbox[:, 3] * float(width) + gt_bbox[:, 3] = norm_gt_bbox[:, 2] * float(height) + + sample['image'] = im + sample['gt_bbox'] = gt_bbox + return sample + + +@register_op +class RandomFlipOp(BaseOperator): + def __init__(self, prob=0.5, is_mask_flip=False): + """ + Args: + prob (float): the probability of flipping image + is_mask_flip (bool): whether flip the segmentation + """ + super(RandomFlipOp, self).__init__() + self.prob = prob + self.is_mask_flip = is_mask_flip + if not (isinstance(self.prob, float) and + isinstance(self.is_mask_flip, bool)): + raise TypeError("{}: input type is invalid.".format(self)) + + def apply_segm(self, segms, height, width): + def _flip_poly(poly, width): + flipped_poly = np.array(poly) + flipped_poly[0::2] = width - np.array(poly[0::2]) - 1 + return flipped_poly.tolist() + + def _flip_rle(rle, height, width): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + mask = mask[:, ::-1] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + flipped_segms = [] + for segm in segms: + if is_poly(segm): + # Polygon format + flipped_segms.append([_flip_poly(poly, width) for poly in segm]) + else: + # RLE format + import pycocotools.mask as mask_util + flipped_segms.append(_flip_rle(segm, height, width)) + return flipped_segms + + def apply_keypoint(self, gt_keypoint, width): + for i in range(gt_keypoint.shape[1]): + if i % 2 == 0: + old_x = gt_keypoint[:, i].copy() + gt_keypoint[:, i] = width - old_x - 1 + return gt_keypoint + + def apply_image(self, image): + return image[:, ::-1, :] + + def apply_bbox(self, bbox, width): + bbox[:, 0::2] = width - bbox[:, 0::2] - 1 + return bbox + + def apply(self, sample, context=None): + """Filp the image and bounding box. + Operators: + 1. Flip the image numpy. + 2. Transform the bboxes' x coordinates. + (Must judge whether the coordinates are normalized!) + 3. Transform the segmentations' x coordinates. + (Must judge whether the coordinates are normalized!) + Output: + sample: the image, bounding box and segmentation part + in sample are flipped. + """ + if np.random.uniform(0, 1) < self.prob: + im = sample['image'] + height, width = im.shape[:2] + im = self.apply_image(im) + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], width) + if self.is_mask_flip and 'gt_poly' in sample and len(sample[ + 'gt_poly']) > 0: + sample['gt_poly'] = self.apply_segm(sample['gt_poly'], height, + width) + if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0: + sample['gt_keypoint'] = self.apply_keypoint( + sample['gt_keypoint'], width) + + if 'semantic' in sample and sample['semantic']: + sample['semantic'] = sample['semantic'][:, ::-1] + + if 'gt_segm' in sample and sample['gt_segm']: + sample['gt_segm'] = sample['gt_segm'][:, :, ::-1] + + sample['flipped'] = True + sample['image'] = im + return sample + + +@register_op +class ResizeOp(BaseOperator): + def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): + """ + Resize image to target size. if keep_ratio is True, + resize the image's long side to the maximum of target_size + if keep_ratio is False, resize the image to target size(h, w) + Args: + target_size (int|list): image target size + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): the interpolation method + """ + super(ResizeOp, self).__init__() + self.keep_ratio = keep_ratio + self.interp = interp + if not isinstance(target_size, (int, list, tuple)): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List or Tuple, now is {}". + format(type(target_size))) + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + + def apply_image(self, image, scale): + im_scale_x, im_scale_y = scale + return cv2.resize( + image, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + + def apply_bbox(self, bbox, scale, size): + im_scale_x, im_scale_y = scale + resize_w, resize_h = size + bbox[:, 0::2] *= im_scale_x + bbox[:, 1::2] *= im_scale_y + bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, resize_w - 1) + bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h - 1) + return bbox + + def apply_segm(self, segms, im_size, scale): + def _resize_poly(poly, im_scale_x, im_scale_y): + resized_poly = np.array(poly) + resized_poly[0::2] *= im_scale_x + resized_poly[1::2] *= im_scale_y + return resized_poly.tolist() + + def _resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, im_h, im_w) + + mask = mask_util.decode(rle) + mask = cv2.resize( + image, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + im_h, im_w = im_size + im_scale_x, im_scale_y = scale + resized_segms = [] + for segm in segms: + if is_poly(segm): + # Polygon format + resized_segms.append([ + _resize_poly(poly, im_scale_x, im_scale_y) for poly in segm + ]) + else: + # RLE format + import pycocotools.mask as mask_util + resized_segms.append( + _resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y)) + + return resized_segms + + def apply(self, sample, context=None): + """ Resize the image numpy. + """ + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image type is not numpy.".format(self)) + if len(im.shape) != 3: + raise ImageError('{}: image is not 3-dimensional.'.format(self)) + + # apply image + im_shape = im.shape + if self.keep_ratio: + + im_size_min = np.min(im_shape[0:2]) + im_size_max = np.max(im_shape[0:2]) + + target_size_min = np.min(self.target_size) + target_size_max = np.max(self.target_size) + + im_scale = min(target_size_min / im_size_min, + target_size_max / im_size_max) + + resize_h = int(im_scale * im_shape[0]) + resize_w = int(im_scale * im_shape[1]) + + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = self.target_size + im_scale_y = resize_h / im_shape[0] + im_scale_x = resize_w / im_shape[1] + + im = self.apply_image(sample['image'], [im_scale_x, im_scale_y]) + sample['image'] = im + sample['im_shape'] = [resize_h, resize_w] + scale_factor = sample['scale_factor'] + sample['scale_factor'] = [ + scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x + ] + + # apply bbox + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], + [im_scale_x, im_scale_y], + [resize_w, resize_h]) + + # apply polygon + if 'gt_poly' in sample and len(sample['gt_poly']) > 0: + sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape, + [im_scale_x, im_scale_y]) + + # apply semantic + if 'semantic' in sample and sample['semantic']: + semantic = sample['semantic'] + semantic = cv2.resize( + semantic.astype('float32'), + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + semantic = np.asarray(semantic).astype('int32') + semantic = np.expand_dims(semantic, 0) + sample['semantic'] = semantic + + # apply gt_segm + if 'gt_segm' in sample and len(sample['gt_segm']) > 0: + masks = [ + cv2.resize( + gt_segm, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=cv2.INTER_NEAREST) + for gt_segm in sample['gt_segm'] + ] + sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + + return sample + + +@register_op +class MultiscaleTestResizeOp(BaseOperator): + def __init__(self, + origin_target_size=[800, 1333], + target_size=[], + interp=cv2.INTER_LINEAR, + use_flip=True): + """ + Rescale image to the each size in target size, and capped at max_size. + Args: + origin_target_size (list): origin target size of image + target_size (list): A list of target sizes of image. + interp (int): the interpolation method. + use_flip (bool): whether use flip augmentation. + """ + super(MultiscaleTestResizeOp, self).__init__() + self.interp = interp + self.use_flip = use_flip + + if not isinstance(target_size, list): + raise TypeError( + "Type of target_size is invalid. Must be List, now is {}". + format(type(target_size))) + self.target_size = target_size + + if not isinstance(origin_target_size, list): + raise TypeError( + "Type of target_size is invalid. Must be List, now is {}". + format(type(target_size))) + + self.origin_target_size = origin_target_size + + def apply(self, sample, context=None): + """ Resize the image numpy for multi-scale test. + """ + samples = [] + resizer = ResizeOp( + self.origin_target_size, keep_ratio=True, interp=self.interp) + samples.append(resizer(sample.copy(), context)) + if self.use_flip: + flipper = RandomFlipOp(1.1) + samples.append(flipper(sample.copy(), context=context)) + + for size in self.target_size: + resizer = ResizeOp(size, keep_ratio=True, interp=self.interp) + samples.append(resizer(sample.copy(), context)) + + return samples + + +@register_op +class RandomResizeOp(BaseOperator): + def __init__(self, + target_size, + keep_ratio=True, + interp=cv2.INTER_LINEAR, + random_size=True, + random_interp=False): + """ + Resize image to target size randomly. random target_size and interpolation method + Args: + target_size (int, list, tuple): image target size, if random size is True, must be list or tuple + keep_ratio (bool): whether keep_raio or not, default true + interp (int): the interpolation method + random_size (bool): whether random select target size of image + random_interp (bool): whether random select interpolation method + """ + super(RandomResizeOp, self).__init__() + self.keep_ratio = keep_ratio + self.interp = interp + self.interps = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_AREA, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] + assert isinstance(target_size, ( + int, Sequence)), "target_size must be int, list or tuple" + if random_size and not isinstance(target_size, list): + raise TypeError( + "Type of target_size is invalid when random_size is True. Must be List, now is {}". + format(type(target_size))) + self.target_size = target_size + self.random_size = random_size + self.random_interp = random_interp + + def apply(self, sample, context=None): + """ Resize the image numpy. + """ + if self.random_size: + target_size = random.choice(self.target_size) + else: + target_size = self.target_size + + if self.random_interp: + interp = random.choice(self.interps) + else: + interp = self.interp + + resizer = ResizeOp(target_size, self.keep_ratio, interp) + return resizer(sample, context=context) + + +@register_op +class RandomExpandOp(BaseOperator): + """Random expand the canvas. + Args: + ratio (float): maximum expansion ratio. + prob (float): probability to expand. + fill_value (list): color value used to fill the canvas. in RGB order. + """ + + def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)): + super(RandomExpandOp, self).__init__() + assert ratio > 1.01, "expand ratio must be larger than 1.01" + self.ratio = ratio + self.prob = prob + assert isinstance(fill_value, (Number, Sequence)), \ + "fill value must be either float or sequence" + if isinstance(fill_value, Number): + fill_value = (fill_value, ) * 3 + if not isinstance(fill_value, tuple): + fill_value = tuple(fill_value) + self.fill_value = fill_value + + def apply(self, sample, context=None): + if np.random.uniform(0., 1.) < self.prob: + return sample + + im = sample['image'] + height, width = im.shape[:2] + ratio = np.random.uniform(1., self.ratio) + h = int(height * ratio) + w = int(width * ratio) + if not h > height or not w > width: + return sample + y = np.random.randint(0, h - height) + x = np.random.randint(0, w - width) + offsets, size = [x, y], [h, w] + + pad = Pad(size, pad_mode=-1, offsets=offsets) + + return pad(sample, context=context) + + +@register_op +class CropWithSampling(BaseOperator): + def __init__(self, batch_sampler, satisfy_all=False, avoid_no_bbox=True): + """ + Args: + batch_sampler (list): Multiple sets of different + parameters for cropping. + satisfy_all (bool): whether all boxes must satisfy. + e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0], + [1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]] + [max sample, max trial, min scale, max scale, + min aspect ratio, max aspect ratio, + min overlap, max overlap] + avoid_no_bbox (bool): whether to to avoid the + situation where the box does not appear. + """ + super(CropWithSampling, self).__init__() + self.batch_sampler = batch_sampler + self.satisfy_all = satisfy_all + self.avoid_no_bbox = avoid_no_bbox + + def apply(self, sample, context): + """ + Crop the image and modify bounding box. + Operators: + 1. Scale the image width and height. + 2. Crop the image according to a radom sample. + 3. Rescale the bounding box. + 4. Determine if the new bbox is satisfied in the new image. + Returns: + sample: the image, bounding box are replaced. + """ + assert 'image' in sample, "image data not found" + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + im_height, im_width = im.shape[:2] + gt_score = None + if 'gt_score' in sample: + gt_score = sample['gt_score'] + sampled_bbox = [] + gt_bbox = gt_bbox.tolist() + for sampler in self.batch_sampler: + found = 0 + for i in range(sampler[1]): + if found >= sampler[0]: + break + sample_bbox = generate_sample_bbox(sampler) + if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox, + self.satisfy_all): + sampled_bbox.append(sample_bbox) + found = found + 1 + im = np.array(im) + while sampled_bbox: + idx = int(np.random.uniform(0, len(sampled_bbox))) + sample_bbox = sampled_bbox.pop(idx) + sample_bbox = clip_bbox(sample_bbox) + crop_bbox, crop_class, crop_score = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score) + if self.avoid_no_bbox: + if len(crop_bbox) < 1: + continue + xmin = int(sample_bbox[0] * im_width) + xmax = int(sample_bbox[2] * im_width) + ymin = int(sample_bbox[1] * im_height) + ymax = int(sample_bbox[3] * im_height) + im = im[ymin:ymax, xmin:xmax] + sample['image'] = im + sample['gt_bbox'] = crop_bbox + sample['gt_class'] = crop_class + sample['gt_score'] = crop_score + return sample + return sample + + +@register_op +class CropWithDataAchorSampling(BaseOperator): + def __init__(self, + batch_sampler, + anchor_sampler=None, + target_size=None, + das_anchor_scales=[16, 32, 64, 128], + sampling_prob=0.5, + min_size=8., + avoid_no_bbox=True): + """ + Args: + anchor_sampler (list): anchor_sampling sets of different + parameters for cropping. + batch_sampler (list): Multiple sets of different + parameters for cropping. + e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]] + [[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0], + [1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]] + [max sample, max trial, min scale, max scale, + min aspect ratio, max aspect ratio, + min overlap, max overlap, min coverage, max coverage] + target_size (bool): target image size. + das_anchor_scales (list[float]): a list of anchor scales in data + anchor smapling. + min_size (float): minimum size of sampled bbox. + avoid_no_bbox (bool): whether to to avoid the + situation where the box does not appear. + """ + super(CropWithDataAchorSampling, self).__init__() + self.anchor_sampler = anchor_sampler + self.batch_sampler = batch_sampler + self.target_size = target_size + self.sampling_prob = sampling_prob + self.min_size = min_size + self.avoid_no_bbox = avoid_no_bbox + self.das_anchor_scales = np.array(das_anchor_scales) + + def apply(self, sample, context): + """ + Crop the image and modify bounding box. + Operators: + 1. Scale the image width and height. + 2. Crop the image according to a radom sample. + 3. Rescale the bounding box. + 4. Determine if the new bbox is satisfied in the new image. + Returns: + sample: the image, bounding box are replaced. + """ + assert 'image' in sample, "image data not found" + im = sample['image'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + image_height, image_width = im.shape[:2] + gt_score = None + if 'gt_score' in sample: + gt_score = sample['gt_score'] + sampled_bbox = [] + gt_bbox = gt_bbox.tolist() + + prob = np.random.uniform(0., 1.) + if prob > self.sampling_prob: # anchor sampling + assert self.anchor_sampler + for sampler in self.anchor_sampler: + found = 0 + for i in range(sampler[1]): + if found >= sampler[0]: + break + sample_bbox = data_anchor_sampling( + gt_bbox, image_width, image_height, + self.das_anchor_scales, self.target_size) + if sample_bbox == 0: + break + if satisfy_sample_constraint_coverage(sampler, sample_bbox, + gt_bbox): + sampled_bbox.append(sample_bbox) + found = found + 1 + im = np.array(im) + while sampled_bbox: + idx = int(np.random.uniform(0, len(sampled_bbox))) + sample_bbox = sampled_bbox.pop(idx) + + if 'gt_keypoint' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + crop_bbox, crop_class, crop_score, gt_keypoints = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, + scores=gt_score, + keypoints=keypoints) + else: + crop_bbox, crop_class, crop_score = filter_and_process( + sample_bbox, gt_bbox, gt_class, scores=gt_score) + crop_bbox, crop_class, crop_score = bbox_area_sampling( + crop_bbox, crop_class, crop_score, self.target_size, + self.min_size) + + if self.avoid_no_bbox: + if len(crop_bbox) < 1: + continue + im = crop_image_sampling(im, sample_bbox, image_width, + image_height, self.target_size) + sample['image'] = im + sample['gt_bbox'] = crop_bbox + sample['gt_class'] = crop_class + sample['gt_score'] = crop_score + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] + return sample + return sample + + else: + for sampler in self.batch_sampler: + found = 0 + for i in range(sampler[1]): + if found >= sampler[0]: + break + sample_bbox = generate_sample_bbox_square( + sampler, image_width, image_height) + if satisfy_sample_constraint_coverage(sampler, sample_bbox, + gt_bbox): + sampled_bbox.append(sample_bbox) + found = found + 1 + im = np.array(im) + while sampled_bbox: + idx = int(np.random.uniform(0, len(sampled_bbox))) + sample_bbox = sampled_bbox.pop(idx) + sample_bbox = clip_bbox(sample_bbox) + + if 'gt_keypoint' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + crop_bbox, crop_class, crop_score, gt_keypoints = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, + scores=gt_score, + keypoints=keypoints) + else: + crop_bbox, crop_class, crop_score = filter_and_process( + sample_bbox, gt_bbox, gt_class, scores=gt_score) + # sampling bbox according the bbox area + crop_bbox, crop_class, crop_score = bbox_area_sampling( + crop_bbox, crop_class, crop_score, self.target_size, + self.min_size) + + if self.avoid_no_bbox: + if len(crop_bbox) < 1: + continue + xmin = int(sample_bbox[0] * image_width) + xmax = int(sample_bbox[2] * image_width) + ymin = int(sample_bbox[1] * image_height) + ymax = int(sample_bbox[3] * image_height) + im = im[ymin:ymax, xmin:xmax] + sample['image'] = im + sample['gt_bbox'] = crop_bbox + sample['gt_class'] = crop_class + sample['gt_score'] = crop_score + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] + return sample + return sample + + +@register_op +class RandomCropOp(BaseOperator): + """Random crop image and bboxes. + Args: + aspect_ratio (list): aspect ratio of cropped region. + in [min, max] format. + thresholds (list): iou thresholds for decide a valid bbox crop. + scaling (list): ratio between a cropped region and the original image. + in [min, max] format. + num_attempts (int): number of tries before giving up. + allow_no_crop (bool): allow return without actually cropping them. + cover_all_box (bool): ensure all bboxes are covered in the final crop. + is_mask_crop(bool): whether crop the segmentation. + """ + + def __init__(self, + aspect_ratio=[.5, 2.], + thresholds=[.0, .1, .3, .5, .7, .9], + scaling=[.3, 1.], + num_attempts=50, + allow_no_crop=True, + cover_all_box=False, + is_mask_crop=False): + super(RandomCropOp, self).__init__() + self.aspect_ratio = aspect_ratio + self.thresholds = thresholds + self.scaling = scaling + self.num_attempts = num_attempts + self.allow_no_crop = allow_no_crop + self.cover_all_box = cover_all_box + self.is_mask_crop = is_mask_crop + + def crop_segms(self, segms, valid_ids, crop, height, width): + def _crop_poly(segm, crop): + xmin, ymin, xmax, ymax = crop + crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin] + crop_p = np.array(crop_coord).reshape(4, 2) + crop_p = Polygon(crop_p) + + crop_segm = list() + for poly in segm: + poly = np.array(poly).reshape(len(poly) // 2, 2) + polygon = Polygon(poly) + if not polygon.is_valid: + exterior = polygon.exterior + multi_lines = exterior.intersection(exterior) + polygons = shapely.ops.polygonize(multi_lines) + polygon = MultiPolygon(polygons) + multi_polygon = list() + if isinstance(polygon, MultiPolygon): + multi_polygon = copy.deepcopy(polygon) + else: + multi_polygon.append(copy.deepcopy(polygon)) + for per_polygon in multi_polygon: + inter = per_polygon.intersection(crop_p) + if not inter: + continue + if isinstance(inter, (MultiPolygon, GeometryCollection)): + for part in inter: + if not isinstance(part, Polygon): + continue + part = np.squeeze( + np.array(part.exterior.coords[:-1]).reshape(1, + -1)) + part[0::2] -= xmin + part[1::2] -= ymin + crop_segm.append(part.tolist()) + elif isinstance(inter, Polygon): + crop_poly = np.squeeze( + np.array(inter.exterior.coords[:-1]).reshape(1, -1)) + crop_poly[0::2] -= xmin + crop_poly[1::2] -= ymin + crop_segm.append(crop_poly.tolist()) + else: + continue + return crop_segm + + def _crop_rle(rle, crop, height, width): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + mask = mask[crop[1]:crop[3], crop[0]:crop[2]] + rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8)) + return rle + + crop_segms = [] + for id in valid_ids: + segm = segms[id] + if is_poly(segm): + import copy + import shapely.ops + from shapely.geometry import Polygon, MultiPolygon, GeometryCollection + logging.getLogger("shapely").setLevel(logging.WARNING) + # Polygon format + crop_segms.append(_crop_poly(segm, crop)) + else: + # RLE format + import pycocotools.mask as mask_util + crop_segms.append(_crop_rle(segm, crop, height, width)) + return crop_segms + + def apply(self, sample, context=None): + if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0: + return sample + + h, w = sample['image'].shape[:2] + gt_bbox = sample['gt_bbox'] + + # NOTE Original method attempts to generate one candidate for each + # threshold then randomly sample one from the resulting list. + # Here a short circuit approach is taken, i.e., randomly choose a + # threshold and attempt to find a valid crop, and simply return the + # first one found. + # The probability is not exactly the same, kinda resembling the + # "Monty Hall" problem. Actually carrying out the attempts will affect + # observability (just like opening doors in the "Monty Hall" game). + thresholds = list(self.thresholds) + if self.allow_no_crop: + thresholds.append('no_crop') + np.random.shuffle(thresholds) + + for thresh in thresholds: + if thresh == 'no_crop': + return sample + + found = False + for i in range(self.num_attempts): + scale = np.random.uniform(*self.scaling) + if self.aspect_ratio is not None: + min_ar, max_ar = self.aspect_ratio + aspect_ratio = np.random.uniform( + max(min_ar, scale**2), min(max_ar, scale**-2)) + h_scale = scale / np.sqrt(aspect_ratio) + w_scale = scale * np.sqrt(aspect_ratio) + else: + h_scale = np.random.uniform(*self.scaling) + w_scale = np.random.uniform(*self.scaling) + crop_h = h * h_scale + crop_w = w * w_scale + if self.aspect_ratio is None: + if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0: + continue + + crop_h = int(crop_h) + crop_w = int(crop_w) + crop_y = np.random.randint(0, h - crop_h) + crop_x = np.random.randint(0, w - crop_w) + crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] + iou = self._iou_matrix( + gt_bbox, np.array( + [crop_box], dtype=np.float32)) + if iou.max() < thresh: + continue + + if self.cover_all_box and iou.min() < thresh: + continue + + cropped_box, valid_ids = self._crop_box_with_center_constraint( + gt_bbox, np.array( + crop_box, dtype=np.float32)) + if valid_ids.size > 0: + found = True + break + + if found: + if self.is_mask_crop and 'gt_poly' in sample and len(sample[ + 'gt_poly']) > 0: + crop_polys = self.crop_segms( + sample['gt_poly'], + valid_ids, + np.array( + crop_box, dtype=np.int64), + h, + w) + if [] in crop_polys: + delete_id = list() + valid_polys = list() + for id, crop_poly in enumerate(crop_polys): + if crop_poly == []: + delete_id.append(id) + else: + valid_polys.append(crop_poly) + valid_ids = np.delete(valid_ids, delete_id) + if len(valid_polys) == 0: + return sample + sample['gt_poly'] = valid_polys + else: + sample['gt_poly'] = crop_polys + + if 'gt_segm' in sample: + sample['gt_segm'] = self._crop_segm(sample['gt_segm'], + crop_box) + sample['gt_segm'] = np.take( + sample['gt_segm'], valid_ids, axis=0) + + sample['image'] = self._crop_image(sample['image'], crop_box) + sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0) + sample['gt_class'] = np.take( + sample['gt_class'], valid_ids, axis=0) + if 'gt_score' in sample: + sample['gt_score'] = np.take( + sample['gt_score'], valid_ids, axis=0) + + if 'is_crowd' in sample: + sample['is_crowd'] = np.take( + sample['is_crowd'], valid_ids, axis=0) + return sample + + return sample + + def _iou_matrix(self, a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + return area_i / (area_o + 1e-10) + + def _crop_box_with_center_constraint(self, box, crop): + cropped_box = box.copy() + + cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2]) + cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:]) + cropped_box[:, :2] -= crop[:2] + cropped_box[:, 2:] -= crop[:2] + + centers = (box[:, :2] + box[:, 2:]) / 2 + valid = np.logical_and(crop[:2] <= centers, + centers < crop[2:]).all(axis=1) + valid = np.logical_and( + valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1)) + + return cropped_box, np.where(valid)[0] + + def _crop_image(self, img, crop): + x1, y1, x2, y2 = crop + return img[y1:y2, x1:x2, :] + + def _crop_segm(self, segm, crop): + x1, y1, x2, y2 = crop + return segm[:, y1:y2, x1:x2] + + +@register_op +class RandomScaledCropOp(BaseOperator): + """Resize image and bbox based on long side (with optional random scaling), + then crop or pad image to target size. + Args: + target_dim (int): target size. + scale_range (list): random scale range. + interp (int): interpolation method, default to `cv2.INTER_LINEAR`. + """ + + def __init__(self, + target_dim=512, + scale_range=[.1, 2.], + interp=cv2.INTER_LINEAR): + super(RandomScaledCropOp, self).__init__() + self.target_dim = target_dim + self.scale_range = scale_range + self.interp = interp + + def apply(self, sample, context=None): + img = sample['image'] + h, w = img.shape[:2] + random_scale = np.random.uniform(*self.scale_range) + dim = self.target_dim + random_dim = int(dim * random_scale) + dim_max = max(h, w) + scale = random_dim / dim_max + resize_w = int(round(w * scale)) + resize_h = int(round(h * scale)) + offset_x = int(max(0, np.random.uniform(0., resize_w - dim))) + offset_y = int(max(0, np.random.uniform(0., resize_h - dim))) + + img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interp) + img = np.array(img) + canvas = np.zeros((dim, dim, 3), dtype=img.dtype) + canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[ + offset_y:offset_y + dim, offset_x:offset_x + dim, :] + sample['image'] = canvas + sample['im_shape'] = [resize_h, resize_w] + scale_factor = sample['sacle_factor'] + sample['scale_factor'] = [ + scale_factor[0] * scale, scale_factor[1] * scale + ] + + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + scale_array = np.array([scale, scale] * 2, dtype=np.float32) + shift_array = np.array([offset_x, offset_y] * 2, dtype=np.float32) + boxes = sample['gt_bbox'] * scale_array - shift_array + boxes = np.clip(boxes, 0, dim - 1) + # filter boxes with no area + area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1) + valid = (area > 1.).nonzero()[0] + sample['gt_bbox'] = boxes[valid] + sample['gt_class'] = sample['gt_class'][valid] + + return sample + + +@register_op +class CutmixOp(BaseOperator): + def __init__(self, alpha=1.5, beta=1.5): + """ + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://arxiv.org/abs/1905.04899 + Cutmix image and gt_bbbox/gt_score + Args: + alpha (float): alpha parameter of beta distribute + beta (float): beta parameter of beta distribute + """ + super(CutmixOp, self).__init__() + self.alpha = alpha + self.beta = beta + if self.alpha <= 0.0: + raise ValueError("alpha shold be positive in {}".format(self)) + if self.beta <= 0.0: + raise ValueError("beta shold be positive in {}".format(self)) + + def apply_image(self, img1, img2, factor): + """ _rand_bbox """ + h = max(img1.shape[0], img2.shape[0]) + w = max(img1.shape[1], img2.shape[1]) + cut_rat = np.sqrt(1. - factor) + + cut_w = np.int(w * cut_rat) + cut_h = np.int(h * cut_rat) + + # uniform + cx = np.random.randint(w) + cy = np.random.randint(h) + + bbx1 = np.clip(cx - cut_w // 2, 0, w - 1) + bby1 = np.clip(cy - cut_h // 2, 0, h - 1) + bbx2 = np.clip(cx + cut_w // 2, 0, w - 1) + bby2 = np.clip(cy + cut_h // 2, 0, h - 1) + + img_1 = np.zeros((h, w, img1.shape[2]), 'float32') + img_1[:img1.shape[0], :img1.shape[1], :] = \ + img1.astype('float32') + img_2 = np.zeros((h, w, img2.shape[2]), 'float32') + img_2[:img2.shape[0], :img2.shape[1], :] = \ + img2.astype('float32') + img_1[bby1:bby2, bbx1:bbx2, :] = img2[bby1:bby2, bbx1:bbx2, :] + return img_1 + + def __call__(self, sample, context=None): + if not isinstance(sample, Sequence): + return sample + + assert len(sample) == 2, 'cutmix need two samples' + + factor = np.random.beta(self.alpha, self.beta) + factor = max(0.0, min(1.0, factor)) + if factor >= 1.0: + return sample[0] + if factor <= 0.0: + return sample[1] + img1 = sample[0]['image'] + img2 = sample[1]['image'] + img = self.apply_image(img1, img2, factor) + gt_bbox1 = sample[0]['gt_bbox'] + gt_bbox2 = sample[1]['gt_bbox'] + gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0) + gt_class1 = sample[0]['gt_class'] + gt_class2 = sample[1]['gt_class'] + gt_class = np.concatenate((gt_class1, gt_class2), axis=0) + gt_score1 = sample[0]['gt_score'] + gt_score2 = sample[1]['gt_score'] + gt_score = np.concatenate( + (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) + sample = sample[0] + sample['image'] = img + sample['gt_bbox'] = gt_bbox + sample['gt_score'] = gt_score + sample['gt_class'] = gt_class + return sample + + +@register_op +class MixupOp(BaseOperator): + def __init__(self, alpha=1.5, beta=1.5): + """ Mixup image and gt_bbbox/gt_score + Args: + alpha (float): alpha parameter of beta distribute + beta (float): beta parameter of beta distribute + """ + super(MixupOp, self).__init__() + self.alpha = alpha + self.beta = beta + if self.alpha <= 0.0: + raise ValueError("alpha shold be positive in {}".format(self)) + if self.beta <= 0.0: + raise ValueError("beta shold be positive in {}".format(self)) + + def apply_image(self, img1, img2, factor): + h = max(img1.shape[0], img2.shape[0]) + w = max(img1.shape[1], img2.shape[1]) + img = np.zeros((h, w, img1.shape[2]), 'float32') + img[:img1.shape[0], :img1.shape[1], :] = \ + img1.astype('float32') * factor + img[:img2.shape[0], :img2.shape[1], :] += \ + img2.astype('float32') * (1.0 - factor) + return img.astype('uint8') + + def __call__(self, sample, context=None): + if not isinstance(sample, Sequence): + return sample + + assert len(sample) == 2, 'mixup need two samples' + + factor = np.random.beta(self.alpha, self.beta) + factor = max(0.0, min(1.0, factor)) + if factor >= 1.0: + return sample[0] + if factor <= 0.0: + return sample[1] + im = self.apply_image(sample[0]['image'], sample[1]['image'], factor) + # apply bbox and score + gt_bbox1 = sample[0]['gt_bbox'] + gt_bbox2 = sample[1]['gt_bbox'] + gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0) + gt_class1 = sample[0]['gt_class'] + gt_class2 = sample[1]['gt_class'] + gt_class = np.concatenate((gt_class1, gt_class2), axis=0) + + gt_score1 = sample[0]['gt_score'] + gt_score2 = sample[1]['gt_score'] + gt_score = np.concatenate( + (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) + + is_crowd1 = sample[0]['is_crowd'] + is_crowd2 = sample[1]['is_crowd'] + is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0) + + sample = sample[0] + sample['image'] = im + sample['gt_bbox'] = gt_bbox + sample['gt_score'] = gt_score + sample['gt_class'] = gt_class + sample['is_crowd'] = is_crowd + return sample + + +@register_op +class NormalizeBoxOp(BaseOperator): + """Transform the bounding box's coornidates to [0,1].""" + + def __init__(self): + super(NormalizeBoxOp, self).__init__() + + def apply(self, sample, context): + im = sample['image'] + gt_bbox = sample['gt_bbox'] + height, width, _ = im.shape + for i in range(gt_bbox.shape[0]): + gt_bbox[i][0] = gt_bbox[i][0] / width + gt_bbox[i][1] = gt_bbox[i][1] / height + gt_bbox[i][2] = gt_bbox[i][2] / width + gt_bbox[i][3] = gt_bbox[i][3] / height + sample['gt_bbox'] = gt_bbox + + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] + + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] / height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] / width + sample['gt_keypoint'] = gt_keypoint + + return sample + + +@register_op +class BboxXYXY2XYWH(BaseOperator): + """ + Convert bbox XYXY format to XYWH format. + """ + + def __init__(self): + super(BboxXYXY2XYWH, self).__init__() + + def apply(self, sample, context=None): + assert 'gt_bbox' in sample + bbox = sample['gt_bbox'] + bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2] + bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2. + sample['gt_bbox'] = bbox + return sample + + +@register_op +class DebugVisibleImageOp(BaseOperator): + """ + In debug mode, visualize images according to `gt_box`. + (Currently only supported when not cropping and flipping image.) + """ + + def __init__(self, output_dir='output/debug', is_normalized=False): + super(DebugVisibleImageOp, self).__init__() + self.is_normalized = is_normalized + self.output_dir = output_dir + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + if not isinstance(self.is_normalized, bool): + raise TypeError("{}: input type is invalid.".format(self)) + + def apply(self, sample, context=None): + image = Image.open(sample['im_file']).convert('RGB') + out_file_name = sample['im_file'].split('/')[-1] + width = sample['w'] + height = sample['h'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + draw = ImageDraw.Draw(image) + for i in range(gt_bbox.shape[0]): + if self.is_normalized: + gt_bbox[i][0] = gt_bbox[i][0] * width + gt_bbox[i][1] = gt_bbox[i][1] * height + gt_bbox[i][2] = gt_bbox[i][2] * width + gt_bbox[i][3] = gt_bbox[i][3] * height + + xmin, ymin, xmax, ymax = gt_bbox[i] + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill='green') + # draw label + text = str(gt_class[i][0]) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green') + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] + if self.is_normalized: + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] * height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] * width + for i in range(gt_keypoint.shape[0]): + keypoint = gt_keypoint[i] + for j in range(int(keypoint.shape[0] / 2)): + x1 = round(keypoint[2 * j]).astype(np.int32) + y1 = round(keypoint[2 * j + 1]).astype(np.int32) + draw.ellipse( + (x1, y1, x1 + 5, y1 + 5), fill='green', outline='green') + save_path = os.path.join(self.output_dir, out_file_name) + image.save(save_path, quality=95) + return sample + + +@register_op +class Pad(BaseOperator): + def __init__(self, + size=None, + size_divisor=32, + pad_mode=0, + offsets=None, + fill_value=(127.5, 127.5, 127.5)): + """ + Pad image to a specified size or multiple of size_divisor. random target_size and interpolation method + Args: + size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None + size_divisor (int): size divisor, default 32 + pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets + if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top + fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5) + """ + super(Pad, self).__init__() + + if not isinstance(size, (int, Sequence)): + raise TypeError( + "Type of target_size is invalid when random_size is True. \ + Must be List, now is {}".format(type(size))) + + if isinstance(size, int): + size = [size, size] + + assert pad_mode in [ + -1, 0, 1, 2 + ], 'currently only supports four modes [-1, 0, 1, 2]' + assert pad_mode == -1 and offsets, 'if pad_mode is -1, offsets should not be None' + + self.size = size + self.size_divisor = size_divisor + self.pad_mode = pad_mode + self.fill_value = fill_value + self.offsets = offsets + + def apply_segm(self, segms, offsets, im_size, size): + def _expand_poly(poly, x, y): + expanded_poly = np.array(poly) + expanded_poly[0::2] += x + expanded_poly[1::2] += y + return expanded_poly.tolist() + + def _expand_rle(rle, x, y, height, width, h, w): + if 'counts' in rle and type(rle['counts']) == list: + rle = mask_util.frPyObjects(rle, height, width) + mask = mask_util.decode(rle) + expanded_mask = np.full((h, w), 0).astype(mask.dtype) + expanded_mask[y:y + height, x:x + width] = mask + rle = mask_util.encode( + np.array( + expanded_mask, order='F', dtype=np.uint8)) + return rle + + x, y = offsets + height, width = im_size + h, w = size + expanded_segms = [] + for segm in segms: + if is_poly(segm): + # Polygon format + expanded_segms.append( + [_expand_poly(poly, x, y) for poly in segm]) + else: + # RLE format + import pycocotools.mask as mask_util + expanded_segms.append( + _expand_rle(segm, x, y, height, width, h, w)) + return expanded_segms + + def apply_bbox(self, bbox, offsets): + return bbox + np.array(offsets * 2, dtype=np.float32) + + def apply_keypoint(self, keypoints, offsets): + n = len(keypoints[0]) // 2 + return keypoints + np.array(offsets * n, dtype=np.float32) + + def apply_image(self, image, offsets, im_size, size): + x, y = offsets + im_h, im_w = im_size + h, w = size + canvas = np.ones((h, w, 3), dtype=np.uint8) + canvas *= np.array(self.fill_value, dtype=np.uint8) + canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.uint8) + return canvas + + def apply(self, sample, context=None): + im = sample['image'] + im_h, im_w = im.shape[:2] + if self.size: + h, w = self.size + assert ( + im_h < h and im_w < w + ), '(h, w) of target size should be greater than (im_h, im_w)' + else: + h = np.ceil(im_h // self.size_divisor) * self.size_divisor + w = np.ceil(im_w / self.size_divisor) * self.size_divisor + + if h == im_h and w == im_w: + return sample + + if self.pad_mode == -1: + offset_x, offset_y = self.offsets + elif self.pad_mode == 0: + offset_y, offset_x = 0, 0 + elif self.pad_mode == 1: + offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2 + else: + offset_y, offset_x = h - im_h, w - im_w + + offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w] + + sample['image'] = self.apply_image(im, offsets, im_size, size) + + if self.pad_mode == 0: + return sample + if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0: + sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets) + + if 'gt_poly' in sample and len(sample['gt_poly']) > 0: + sample['gt_poly'] = self.apply_segm(sample['gt_poly'], offsets, + im_size, size) + + if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0: + sample['gt_keypoint'] = self.apply_keypoint(sample['gt_keypoint'], + offsets) + + return sample + + +@register_op +class Poly2Mask(BaseOperator): + """ + gt poly to mask annotations + """ + + def __init__(self): + super(Poly2Mask, self).__init__() + import pycocotools.mask as maskUtils + self.maskutils = maskUtils + + def _poly2mask(self, mask_ann, img_h, img_w): + if isinstance(mask_ann, list): + # polygon -- a single object might consist of multiple parts + # we merge all parts into one mask rle code + rles = self.maskutils.frPyObjects(mask_ann, img_h, img_w) + rle = self.maskutils.merge(rles) + elif isinstance(mask_ann['counts'], list): + # uncompressed RLE + rle = self.maskutils.frPyObjects(mask_ann, img_h, img_w) + else: + # rle + rle = mask_ann + mask = self.maskutils.decode(rle) + return mask + + def apply(self, sample, context=None): + assert 'gt_poly' in sample + im_h = sample['h'] + im_w = sample['w'] + masks = [ + self._poly2mask(gt_poly, im_h, im_w) + for gt_poly in sample['gt_poly'] + ] + sample['gt_segm'] = np.asarray(masks).astype(np.uint8) + return sample diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index c36b9af09..f1c476f27 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -39,6 +39,7 @@ from PIL import Image, ImageEnhance, ImageDraw from ppdet.core.workspace import serializable from ppdet.modeling.layers import AnchorGrid +from .operator import register_op, BaseOperator, BboxError, ImageError from .op_helper import (satisfy_sample_constraint, filter_and_process, generate_sample_bbox, clip_bbox, data_anchor_sampling, @@ -48,45 +49,6 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process, logger = logging.getLogger(__name__) -registered_ops = [] - - -def register_op(cls): - registered_ops.append(cls.__name__) - if not hasattr(BaseOperator, cls.__name__): - setattr(BaseOperator, cls.__name__, cls) - else: - raise KeyError("The {} class has been registered.".format(cls.__name__)) - return serializable(cls) - - -class BboxError(ValueError): - pass - - -class ImageError(ValueError): - pass - - -class BaseOperator(object): - def __init__(self, name=None): - if name is None: - name = self.__class__.__name__ - self._id = name + '_' + str(uuid.uuid4())[-6:] - - def __call__(self, sample, context=None): - """ Process a sample. - Args: - sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx} - context (dict): info about this sample processing - Returns: - result (dict): a processed sample - """ - return sample - - def __str__(self): - return str(self._id) - @register_op class DecodeImage(BaseOperator): -- GitLab