未验证 提交 f2955696 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph]modify operators and batch_operators in data preprocess (#1666)

* modify operators and batch_operators in data preprocess

* import BboxError and ImageError in operators

* modify code according to review

* modify some bugs

* add Gt2Solov2TargetOp in batch_operators

* modify code according to review
上级 d4a6d324
......@@ -15,8 +15,11 @@
from . import operators
from . import batch_operators
# TODO: operators and batch_operators will be replaced by operator and batch_operator
from .operators import *
from .operator import *
from .batch_operators import *
from .batch_operator import *
__all__ = []
__all__ += registered_ops
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
import logging
import cv2
import numpy as np
from .operator import register_op, BaseOperator, ResizeOp
from .op_helper import jaccard_overlap, gaussian2D
from scipy import ndimage
logger = logging.getLogger(__name__)
__all__ = [
'PadBatchOp',
'Gt2YoloTargetOp',
'Gt2FCOSTargetOp',
'Gt2TTFTargetOp',
]
@register_op
class PadBatchOp(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`.
"""
def __init__(self, pad_to_stride=0, pad_gt=False):
super(PadBatchOp, self).__init__()
self.pad_to_stride = pad_to_stride
self.pad_gt = pad_gt
def __call__(self, samples, context=None):
"""
Args:
samples (list): a batch of sample, each is dict.
"""
coarsest_stride = self.pad_to_stride
max_shape = np.array([data['image'].shape for data in samples]).max(
axis=0)
if coarsest_stride > 0:
max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(
np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
padding_batch = []
for data in samples:
im = data['image']
im_c, im_h, im_w = im.shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
data['image'] = padding_im
if 'semantic' in data and data['semantic'] is not None:
semantic = data['semantic']
padding_sem = np.zeros(
(1, max_shape[1], max_shape[2]), dtype=np.float32)
padding_sem[:, :im_h, :im_w] = semantic
data['semantic'] = padding_sem
if 'gt_segm' in data and data['gt_segm'] is not None:
gt_segm = data['gt_segm']
padding_segm = np.zeros(
(gt_segm.shape[0], max_shape[1], max_shape[2]),
dtype=np.uint8)
padding_segm[:, :im_h, :im_w] = gt_segm
data['gt_segm'] = padding_segm
if self.pad_gt:
gt_num = []
if data['gt_poly'] is not None and len(data['gt_poly']) > 0:
pad_mask = True
else:
pad_mask = False
if pad_mask:
poly_num = []
poly_part_num = []
point_num = []
for data in samples:
gt_num.append(data['gt_bbox'].shape[0])
if pad_mask:
poly_num.append(len(data['gt_poly']))
for poly in data['gt_poly']:
poly_part_num.append(int(len(poly)))
for p_p in poly:
point_num.append(int(len(p_p) / 2))
gt_num_max = max(gt_num)
gt_box_data = np.zeros([gt_num_max, 4])
gt_class_data = np.zeros([gt_num_max])
is_crowd_data = np.ones([gt_num_max])
if pad_mask:
poly_num_max = max(poly_num)
poly_part_num_max = max(poly_part_num)
point_num_max = max(point_num)
gt_masks_data = -np.ones(
[poly_num_max, poly_part_num_max, point_num_max, 2])
for i, data in enumerate(samples):
gt_num = data['gt_bbox'].shape[0]
gt_box_data[0:gt_num, :] = data['gt_bbox']
gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd'])
if pad_mask:
for j, poly in enumerate(data['gt_poly']):
for k, p_p in enumerate(poly):
pp_np = np.array(p_p).reshape(-1, 2)
gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
data['gt_poly'] = gt_masks_data
data['gt_bbox'] = gt_box_data
data['gt_class'] = gt_class_data
data['is_crowd'] = is_crowd_data
return samples
@register_op
class BatchRandomResizeOp(BaseOperator):
"""
Resize image to target size randomly. random target_size and interpolation method
Args:
target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
keep_ratio (bool): whether keep_raio or not, default true
interp (int): the interpolation method
random_size (bool): whether random select target size of image
random_interp (bool): whether random select interpolation method
"""
def __init__(self,
target_size,
keep_ratio=True,
interp=cv2.INTER_LINEAR,
random_size=True,
random_interp=False):
super(BatchRandomResizeOp, self).__init__()
self.keep_ratio = keep_ratio
self.interps = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
self.interp = interp
assert isinstance(target_size, (
int, Sequence)), "target_size must be int, list or tuple"
if random_size and not isinstance(target_size, list):
raise TypeError(
"Type of target_size is invalid when random_size is True. Must be List, now is {}".
format(type(target_size)))
self.target_size = target_size
self.random_size = random_size
self.random_interp = random_interp
def __call__(self, samples, context=None):
if self.random_size:
target_size = np.random.choice(self.target_size)
else:
target_size = self.target_size
if self.random_interp:
interp = np.random.choice(self.interps)
else:
interp = self.interp
resizer = ResizeOp(
target_size, keep_ratio=self.keep_ratio, interp=interp)
return resizer(samples, context=context)
@register_op
class Gt2YoloTargetOp(BaseOperator):
"""
Generate YOLOv3 targets by groud truth data, this operator is only used in
fine grained YOLOv3 loss mode
"""
def __init__(self,
anchors,
anchor_masks,
downsample_ratios,
num_classes=80,
iou_thresh=1.):
super(Gt2YoloTargetOp, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.downsample_ratios = downsample_ratios
self.num_classes = num_classes
self.iou_thresh = iou_thresh
def __call__(self, samples, context=None):
assert len(self.anchor_masks) == len(self.downsample_ratios), \
"anchor_masks', and 'downsample_ratios' should have same length."
h, w = samples[0]['image'].shape[1:3]
an_hw = np.array(self.anchors) / np.array([[w, h]])
for sample in samples:
# im, gt_bbox, gt_class, gt_score = sample
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
gt_score = sample['gt_score']
for i, (
mask, downsample_ratio
) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
grid_h = int(h / downsample_ratio)
grid_w = int(w / downsample_ratio)
target = np.zeros(
(len(mask), 6 + self.num_classes, grid_h, grid_w),
dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx, gy, gw, gh = gt_bbox[b, :]
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(
gw * w / self.anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(
gh * h / self.anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
# For non-matched anchors, calculate the target if the iou
# between anchor and gt is larger than iou_thresh
if self.iou_thresh < 1:
for idx, mask_i in enumerate(mask):
if mask_i == best_idx: continue
iou = jaccard_overlap(
[0., 0., gw, gh],
[0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
if iou > self.iou_thresh:
# x, y, w, h, scale
target[idx, 0, gj, gi] = gx * grid_w - gi
target[idx, 1, gj, gi] = gy * grid_h - gj
target[idx, 2, gj, gi] = np.log(
gw * w / self.anchors[mask_i][0])
target[idx, 3, gj, gi] = np.log(
gh * h / self.anchors[mask_i][1])
target[idx, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
target[idx, 5, gj, gi] = score
# classification
target[idx, 6 + cls, gj, gi] = 1.
sample['target{}'.format(i)] = target
return samples
@register_op
class Gt2FCOSTargetOp(BaseOperator):
"""
Generate FCOS targets by groud truth data
"""
def __init__(self,
object_sizes_boundary,
center_sampling_radius,
downsample_ratios,
norm_reg_targets=False):
super(Gt2FCOSTargetOp, self).__init__()
self.center_sampling_radius = center_sampling_radius
self.downsample_ratios = downsample_ratios
self.INF = np.inf
self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF]
object_sizes_of_interest = []
for i in range(len(self.object_sizes_boundary) - 1):
object_sizes_of_interest.append([
self.object_sizes_boundary[i], self.object_sizes_boundary[i + 1]
])
self.object_sizes_of_interest = object_sizes_of_interest
self.norm_reg_targets = norm_reg_targets
def _compute_points(self, w, h):
"""
compute the corresponding points in each feature map
:param h: image height
:param w: image width
:return: points from all feature map
"""
locations = []
for stride in self.downsample_ratios:
shift_x = np.arange(0, w, stride).astype(np.float32)
shift_y = np.arange(0, h, stride).astype(np.float32)
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
shift_x = shift_x.flatten()
shift_y = shift_y.flatten()
location = np.stack([shift_x, shift_y], axis=1) + stride // 2
locations.append(location)
num_points_each_level = [len(location) for location in locations]
locations = np.concatenate(locations, axis=0)
return locations, num_points_each_level
def _convert_xywh2xyxy(self, gt_bbox, w, h):
"""
convert the bounding box from style xywh to xyxy
:param gt_bbox: bounding boxes normalized into [0, 1]
:param w: image width
:param h: image height
:return: bounding boxes in xyxy style
"""
bboxes = gt_bbox.copy()
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h
bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
return bboxes
def _check_inside_boxes_limited(self, gt_bbox, xs, ys,
num_points_each_level):
"""
check if points is within the clipped boxes
:param gt_bbox: bounding boxes
:param xs: horizontal coordinate of points
:param ys: vertical coordinate of points
:return: the mask of points is within gt_box or not
"""
bboxes = np.reshape(
gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]])
bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1])
ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2
ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2
beg = 0
clipped_box = bboxes.copy()
for lvl, stride in enumerate(self.downsample_ratios):
end = beg + num_points_each_level[lvl]
stride_exp = self.center_sampling_radius * stride
clipped_box[beg:end, :, 0] = np.maximum(
bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp)
clipped_box[beg:end, :, 1] = np.maximum(
bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp)
clipped_box[beg:end, :, 2] = np.minimum(
bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp)
clipped_box[beg:end, :, 3] = np.minimum(
bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp)
beg = end
l_res = xs - clipped_box[:, :, 0]
r_res = clipped_box[:, :, 2] - xs
t_res = ys - clipped_box[:, :, 1]
b_res = clipped_box[:, :, 3] - ys
clipped_box_reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0
return inside_gt_box
def __call__(self, samples, context=None):
assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \
"object_sizes_of_interest', and 'downsample_ratios' should have same length."
for sample in samples:
# im, gt_bbox, gt_class, gt_score = sample
im = sample['image']
bboxes = sample['gt_bbox']
gt_class = sample['gt_class']
# calculate the locations
h, w = im.shape[1:3]
points, num_points_each_level = self._compute_points(w, h)
object_scale_exp = []
for i, num_pts in enumerate(num_points_each_level):
object_scale_exp.append(
np.tile(
np.array([self.object_sizes_of_interest[i]]),
reps=[num_pts, 1]))
object_scale_exp = np.concatenate(object_scale_exp, axis=0)
gt_area = (bboxes[:, 2] - bboxes[:, 0]) * (
bboxes[:, 3] - bboxes[:, 1])
xs, ys = points[:, 0], points[:, 1]
xs = np.reshape(xs, newshape=[xs.shape[0], 1])
xs = np.tile(xs, reps=[1, bboxes.shape[0]])
ys = np.reshape(ys, newshape=[ys.shape[0], 1])
ys = np.tile(ys, reps=[1, bboxes.shape[0]])
l_res = xs - bboxes[:, 0]
r_res = bboxes[:, 2] - xs
t_res = ys - bboxes[:, 1]
b_res = bboxes[:, 3] - ys
reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
if self.center_sampling_radius > 0:
is_inside_box = self._check_inside_boxes_limited(
bboxes, xs, ys, num_points_each_level)
else:
is_inside_box = np.min(reg_targets, axis=2) > 0
# check if the targets is inside the corresponding level
max_reg_targets = np.max(reg_targets, axis=2)
lower_bound = np.tile(
np.expand_dims(
object_scale_exp[:, 0], axis=1),
reps=[1, max_reg_targets.shape[1]])
high_bound = np.tile(
np.expand_dims(
object_scale_exp[:, 1], axis=1),
reps=[1, max_reg_targets.shape[1]])
is_match_current_level = \
(max_reg_targets > lower_bound) & \
(max_reg_targets < high_bound)
points2gtarea = np.tile(
np.expand_dims(
gt_area, axis=0), reps=[xs.shape[0], 1])
points2gtarea[is_inside_box == 0] = self.INF
points2gtarea[is_match_current_level == 0] = self.INF
points2min_area = points2gtarea.min(axis=1)
points2min_area_ind = points2gtarea.argmin(axis=1)
labels = gt_class[points2min_area_ind] + 1
labels[points2min_area == self.INF] = 0
reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind]
ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \
reg_targets[:, [0, 2]].max(axis=1)) * \
(reg_targets[:, [1, 3]].min(axis=1) / \
reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32)
ctn_targets = np.reshape(
ctn_targets, newshape=[ctn_targets.shape[0], 1])
ctn_targets[labels <= 0] = 0
pos_ind = np.nonzero(labels != 0)
reg_targets_pos = reg_targets[pos_ind[0], :]
split_sections = []
beg = 0
for lvl in range(len(num_points_each_level)):
end = beg + num_points_each_level[lvl]
split_sections.append(end)
beg = end
labels_by_level = np.split(labels, split_sections, axis=0)
reg_targets_by_level = np.split(reg_targets, split_sections, axis=0)
ctn_targets_by_level = np.split(ctn_targets, split_sections, axis=0)
for lvl in range(len(self.downsample_ratios)):
grid_w = int(np.ceil(w / self.downsample_ratios[lvl]))
grid_h = int(np.ceil(h / self.downsample_ratios[lvl]))
if self.norm_reg_targets:
sample['reg_target{}'.format(lvl)] = \
np.reshape(
reg_targets_by_level[lvl] / \
self.downsample_ratios[lvl],
newshape=[grid_h, grid_w, 4])
else:
sample['reg_target{}'.format(lvl)] = np.reshape(
reg_targets_by_level[lvl],
newshape=[grid_h, grid_w, 4])
sample['labels{}'.format(lvl)] = np.reshape(
labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
sample['centerness{}'.format(lvl)] = np.reshape(
ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
return samples
@register_op
class Gt2TTFTargetOp(BaseOperator):
"""
Gt2TTFTarget
Generate TTFNet targets by ground truth data
Args:
num_classes(int): the number of classes.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
alpha(float): the alpha parameter to generate gaussian target.
0.54 by default.
"""
def __init__(self, num_classes, down_ratio=4, alpha=0.54):
super(Gt2TTFTargetOp, self).__init__()
self.down_ratio = down_ratio
self.num_classes = num_classes
self.alpha = alpha
def __call__(self, samples, context=None):
output_size = samples[0]['image'].shape[1]
feat_size = output_size // self.down_ratio
for sample in samples:
heatmap = np.zeros(
(self.num_classes, feat_size, feat_size), dtype='float32')
box_target = np.ones(
(4, feat_size, feat_size), dtype='float32') * -1
reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
area = bbox_w * bbox_h
boxes_areas_log = np.log(area)
boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
boxes_area_topk_log = boxes_areas_log[boxes_ind]
gt_bbox = gt_bbox[boxes_ind]
gt_class = gt_class[boxes_ind]
feat_gt_bbox = gt_bbox / self.down_ratio
feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
ct_inds = np.stack(
[(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
(gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
axis=1) / self.down_ratio
h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
for k in range(len(gt_bbox)):
cls_id = gt_class[k]
fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32')
self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
h_radiuses_alpha[k],
w_radiuses_alpha[k])
heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
box_target_inds = fake_heatmap > 0
box_target[:, box_target_inds] = gt_bbox[k][:, None]
local_heatmap = fake_heatmap[box_target_inds]
ct_div = np.sum(local_heatmap)
local_heatmap *= boxes_area_topk_log[k]
reg_weight[0, box_target_inds] = local_heatmap / ct_div
sample['ttf_heatmap'] = heatmap
sample['ttf_box_target'] = box_target
sample['ttf_reg_weight'] = reg_weight
return samples
def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
h, w = 2 * h_radius + 1, 2 * w_radius + 1
sigma_x = w / 6
sigma_y = h / 6
gaussian = gaussian2D((h, w), sigma_x, sigma_y)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, w_radius), min(width - x, w_radius + 1)
top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
left:w_radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
masked_heatmap, masked_gaussian)
return heatmap
@register_op
class Gt2Solov2TargetOp(BaseOperator):
"""Assign mask target and labels in SOLOv2 network.
Args:
num_grids (list): The list of feature map grids size.
scale_ranges (list): The list of mask boundary range.
coord_sigma (float): The coefficient of coordinate area length.
sampling_ratio (float): The ratio of down sampling.
"""
def __init__(self,
num_grids=[40, 36, 24, 16, 12],
scale_ranges=[[1, 96], [48, 192], [96, 384], [192, 768],
[384, 2048]],
coord_sigma=0.2,
sampling_ratio=4.0):
super(Gt2Solov2TargetOp, self).__init__()
self.num_grids = num_grids
self.scale_ranges = scale_ranges
self.coord_sigma = coord_sigma
self.sampling_ratio = sampling_ratio
def _scale_size(self, im, scale):
h, w = im.shape[:2]
new_size = (int(w * float(scale) + 0.5), int(h * float(scale) + 0.5))
resized_img = cv2.resize(
im, None, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
return resized_img
def __call__(self, samples, context=None):
sample_id = 0
for sample in samples:
gt_bboxes_raw = sample['gt_bbox']
gt_labels_raw = sample['gt_class']
im_c, im_h, im_w = sample['image'].shape[:]
gt_masks_raw = sample['gt_segm'].astype(np.uint8)
mask_feat_size = [
int(im_h / self.sampling_ratio), int(im_w / self.sampling_ratio)
]
gt_areas = np.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
(gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
ins_ind_label_list = []
idx = 0
for (lower_bound, upper_bound), num_grid \
in zip(self.scale_ranges, self.num_grids):
hit_indices = ((gt_areas >= lower_bound) &
(gt_areas <= upper_bound)).nonzero()[0]
num_ins = len(hit_indices)
ins_label = []
grid_order = []
cate_label = np.zeros([num_grid, num_grid], dtype=np.int64)
ins_ind_label = np.zeros([num_grid**2], dtype=np.bool)
if num_ins == 0:
ins_label = np.zeros(
[1, mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(
[sample_id * num_grid * num_grid + 0])
idx += 1
continue
gt_bboxes = gt_bboxes_raw[hit_indices]
gt_labels = gt_labels_raw[hit_indices]
gt_masks = gt_masks_raw[hit_indices, ...]
half_ws = 0.5 * (
gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.coord_sigma
half_hs = 0.5 * (
gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.coord_sigma
for seg_mask, gt_label, half_h, half_w in zip(
gt_masks, gt_labels, half_hs, half_ws):
if seg_mask.sum() == 0:
continue
# mass center
upsampled_size = (mask_feat_size[0] * 4,
mask_feat_size[1] * 4)
center_h, center_w = ndimage.measurements.center_of_mass(
seg_mask)
coord_w = int(
(center_w / upsampled_size[1]) // (1. / num_grid))
coord_h = int(
(center_h / upsampled_size[0]) // (1. / num_grid))
# left, top, right, down
top_box = max(0,
int(((center_h - half_h) / upsampled_size[0])
// (1. / num_grid)))
down_box = min(num_grid - 1,
int(((center_h + half_h) / upsampled_size[0])
// (1. / num_grid)))
left_box = max(0,
int(((center_w - half_w) / upsampled_size[1])
// (1. / num_grid)))
right_box = min(num_grid - 1,
int(((center_w + half_w) /
upsampled_size[1]) // (1. / num_grid)))
top = max(top_box, coord_h - 1)
down = min(down_box, coord_h + 1)
left = max(coord_w - 1, left_box)
right = min(right_box, coord_w + 1)
cate_label[top:(down + 1), left:(right + 1)] = gt_label
seg_mask = self._scale_size(
seg_mask, scale=1. / self.sampling_ratio)
for i in range(top, down + 1):
for j in range(left, right + 1):
label = int(i * num_grid + j)
cur_ins_label = np.zeros(
[mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
cur_ins_label[:seg_mask.shape[0], :seg_mask.shape[
1]] = seg_mask
ins_label.append(cur_ins_label)
ins_ind_label[label] = True
grid_order.append(
[sample_id * num_grid * num_grid + label])
if ins_label == []:
ins_label = np.zeros(
[1, mask_feat_size[0], mask_feat_size[1]],
dtype=np.uint8)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(
[sample_id * num_grid * num_grid + 0])
else:
ins_label = np.stack(ins_label, axis=0)
ins_ind_label_list.append(ins_ind_label)
sample['cate_label{}'.format(idx)] = cate_label.flatten()
sample['ins_label{}'.format(idx)] = ins_label
sample['grid_order{}'.format(idx)] = np.asarray(grid_order)
assert len(grid_order) > 0
idx += 1
ins_ind_labels = np.concatenate([
ins_ind_labels_level_img
for ins_ind_labels_level_img in ins_ind_label_list
])
fg_num = np.sum(ins_ind_labels)
sample['fg_num'] = fg_num
sample_id += 1
return samples
......@@ -24,7 +24,7 @@ except Exception:
import logging
import cv2
import numpy as np
from .operators import register_op, BaseOperator
from .operator import register_op, BaseOperator
from .op_helper import jaccard_overlap, gaussian2D
logger = logging.getLogger(__name__)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# function:
# operators to process sample,
# eg: decode/resize/crop image
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from numbers import Number
import uuid
import logging
import random
import math
import numpy as np
import os
import cv2
from PIL import Image, ImageEnhance, ImageDraw
from ppdet.core.workspace import serializable
from ppdet.modeling.layers import AnchorGrid
from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling,
is_poly, gaussian_radius, draw_gaussian)
logger = logging.getLogger(__name__)
registered_ops = []
def register_op(cls):
registered_ops.append(cls.__name__)
if not hasattr(BaseOperator, cls.__name__):
setattr(BaseOperator, cls.__name__, cls)
else:
raise KeyError("The {} class has been registered.".format(cls.__name__))
return serializable(cls)
class BboxError(ValueError):
pass
class ImageError(ValueError):
pass
class BaseOperator(object):
def __init__(self, name=None):
if name is None:
name = self.__class__.__name__
self._id = name + '_' + str(uuid.uuid4())[-6:]
def apply(self, sample, context=None):
""" Process a sample.
Args:
sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
context (dict): info about this sample processing
Returns:
result (dict): a processed sample
"""
return sample
def __call__(self, sample, context=None):
""" Process a sample.
Args:
sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
context (dict): info about this sample processing
Returns:
result (dict): a processed sample
"""
if isinstance(sample, Sequence):
for i in range(len(sample)):
sample[i] = self.apply(sample[i], context)
sample = self.apply(sample, context)
return sample
def __str__(self):
return str(self._id)
@register_op
class DecodeOp(BaseOperator):
def __init__(self):
""" Transform the image data to numpy format following the rgb format
"""
super(DecodeOp, self).__init__()
def apply(self, sample, context=None):
""" load image if 'im_file' field is not empty but 'image' is"""
if 'image' not in sample:
with open(sample['im_file'], 'rb') as f:
sample['image'] = f.read()
im = sample['image']
data = np.frombuffer(im, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
sample['image'] = im
if 'h' not in sample:
sample['h'] = im.shape[0]
elif sample['h'] != im.shape[0]:
logger.warn(
"The actual image height: {} is not equal to the "
"height: {} in annotation, and update sample['h'] by actual "
"image height.".format(im.shape[0], sample['h']))
sample['h'] = im.shape[0]
if 'w' not in sample:
sample['w'] = im.shape[1]
elif sample['w'] != im.shape[1]:
logger.warn(
"The actual image width: {} is not equal to the "
"width: {} in annotation, and update sample['w'] by actual "
"image width.".format(im.shape[1], sample['w']))
sample['w'] = im.shape[1]
sample['im_shape'] = im.shape[:2]
sample['scale_factor'] = [1., 1.]
return sample
@register_op
class PermuteOp(BaseOperator):
def __init__(self):
"""
Change the channel to be (C, H, W)
"""
super(PermuteOp, self).__init__()
def apply(self, sample, context=None):
im = sample['image']
im = im.transpose((2, 0, 1))
sample['image'] = im
return sample
@register_op
class LightingOp(BaseOperator):
"""
Lighting the imagen by eigenvalues and eigenvectors
Args:
eigval (list): eigenvalues
eigvec (list): eigenvectors
alphastd (float): random weight of lighting, 0.1 by default
"""
def __init__(self, eigval, eigvec, alphastd=0.1):
super(LightingOp, self).__init__()
self.alphastd = alphastd
self.eigval = np.array(eigval).astype('float32')
self.eigvec = np.array(eigvec).astype('float32')
def apply(self, sample, context=None):
alpha = np.random.normal(scale=self.alphastd, size=(3, ))
sample['image'] += np.dot(self.eigvec, self.eigval * alpha)
return sample
@register_op
class NormalizeImageOp(BaseOperator):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[1, 1, 1],
is_scale=True):
"""
Args:
mean (list): the pixel mean
std (list): the pixel variance
"""
super(NormalizeImageOp, self).__init__()
self.mean = mean
self.std = std
self.is_scale = is_scale
if not (isinstance(self.mean, list) and isinstance(self.std, list) and
isinstance(self.is_scale, bool)):
raise TypeError("{}: input type is invalid.".format(self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def apply(self, sample, context=None):
"""Normalize the image.
Operators:
1.(optional) Scale the image to [0,1]
2. Each pixel minus mean and is divided by std
"""
im = sample['image']
im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
if self.is_scale:
im = im / 255.0
im -= mean
im /= std
sample['image'] = im
return sample
@register_op
class GridMaskOp(BaseOperator):
def __init__(self,
use_h=True,
use_w=True,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.7,
upper_iter=360000):
"""
GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086
Args:
use_h (bool): whether to mask vertically
use_w (boo;): whether to mask horizontally
rotate (float): angle for the mask to rotate
offset (float): mask offset
ratio (float): mask ratio
mode (int): gridmask mode
prob (float): max probability to carry out gridmask
upper_iter (int): suggested to be equal to global max_iter
"""
super(GridMaskOp, self).__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.prob = prob
self.upper_iter = upper_iter
from .gridmask_utils import GridMask
self.gridmask_op = GridMask(
use_h,
use_w,
rotate=rotate,
offset=offset,
ratio=ratio,
mode=mode,
prob=prob,
upper_iter=upper_iter)
def apply(self, sample, context=None):
sample['image'] = self.gridmask_op(sample['image'], sample['curr_iter'])
return sample
@register_op
class RandomDistortOp(BaseOperator):
"""Random color distortion.
Args:
hue (list): hue settings. in [lower, upper, probability] format.
saturation (list): saturation settings. in [lower, upper, probability] format.
contrast (list): contrast settings. in [lower, upper, probability] format.
brightness (list): brightness settings. in [lower, upper, probability] format.
random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
order.
count (int): the number of doing distrot
random_channel (bool): whether to swap channels randomly
"""
def __init__(self,
hue=[-18, 18, 0.5],
saturation=[0.5, 1.5, 0.5],
contrast=[0.5, 1.5, 0.5],
brightness=[0.5, 1.5, 0.5],
random_apply=True,
count=4,
random_channel=False):
super(RandomDistortOp, self).__init__()
self.hue = hue
self.saturation = saturation
self.contrast = contrast
self.brightness = brightness
self.random_apply = random_apply
self.count = count
self.random_channel = random_channel
def apply_hue(self, img):
low, high, prob = self.hue
if np.random.uniform(0., 1.) < prob:
return img
img = img.astype(np.float32)
img[..., 0] += random.uniform(low, high)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
return img
def apply_saturation(self, img):
low, high, prob = self.saturation
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
img[..., 1] *= delta
return img
def apply_contrast(self, img):
low, high, prob = self.contrast
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
img *= delta
return img
def apply_brightness(self, img):
low, high, prob = self.brightness
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
img += delta
return img
def apply(self, sample, context=None):
img = sample['image']
if self.random_apply:
functions = [
self.apply_brightness,
self.apply_contrast,
lambda img: cv2.cvtColor(self.apply_saturation(cv2.cvtColor(img, cv2.COLOR_RGB2HSV)), cv2.COLOR_HSV2RGB),
lambda img: cv2.cvtColor(self.apply_hue(cv2.cvtColor(img, cv2.COLOR_RGB2HSV)), cv2.COLOR_HSV2RGB),
]
distortions = np.random.permutation(functions)[:count]
for func in distortions:
img = func(img)
sample['image'] = img
return sample
img = self.apply_brightness(img)
mode = np.random.randint(0, 2)
if mode:
img = self.apply_contrast(img)
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
img = self.apply_saturation(img)
img = self.apply_hue(img)
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
if not mode:
img = self.apply_contrast(img)
if self.random_channel:
if np.random.randint(0, 2):
img = img[..., np.random.permutation(3)]
sample['image'] = img
return sample
@register_op
class AutoAugmentOp(BaseOperator):
def __init__(self, autoaug_type="v1"):
"""
Args:
autoaug_type (str): autoaug type, support v0, v1, v2, v3, test
"""
super(AutoAugmentOp, self).__init__()
self.autoaug_type = autoaug_type
def apply(self, sample, context=None):
"""
Learning Data Augmentation Strategies for Object Detection, see https://arxiv.org/abs/1906.11172
"""
im = sample['image']
gt_bbox = sample['gt_bbox']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image is not a numpy array.".format(self))
if len(im.shape) != 3:
raise ImageError("{}: image is not 3-dimensional.".format(self))
if len(gt_bbox) == 0:
return sample
height, width, _ = im.shape
norm_gt_bbox = np.ones_like(gt_bbox, dtype=np.float32)
norm_gt_bbox[:, 0] = gt_bbox[:, 1] / float(height)
norm_gt_bbox[:, 1] = gt_bbox[:, 0] / float(width)
norm_gt_bbox[:, 2] = gt_bbox[:, 3] / float(height)
norm_gt_bbox[:, 3] = gt_bbox[:, 2] / float(width)
from .autoaugment_utils import distort_image_with_autoaugment
im, norm_gt_bbox = distort_image_with_autoaugment(im, norm_gt_bbox,
self.autoaug_type)
gt_bbox[:, 0] = norm_gt_bbox[:, 1] * float(width)
gt_bbox[:, 1] = norm_gt_bbox[:, 0] * float(height)
gt_bbox[:, 2] = norm_gt_bbox[:, 3] * float(width)
gt_bbox[:, 3] = norm_gt_bbox[:, 2] * float(height)
sample['image'] = im
sample['gt_bbox'] = gt_bbox
return sample
@register_op
class RandomFlipOp(BaseOperator):
def __init__(self, prob=0.5, is_mask_flip=False):
"""
Args:
prob (float): the probability of flipping image
is_mask_flip (bool): whether flip the segmentation
"""
super(RandomFlipOp, self).__init__()
self.prob = prob
self.is_mask_flip = is_mask_flip
if not (isinstance(self.prob, float) and
isinstance(self.is_mask_flip, bool)):
raise TypeError("{}: input type is invalid.".format(self))
def apply_segm(self, segms, height, width):
def _flip_poly(poly, width):
flipped_poly = np.array(poly)
flipped_poly[0::2] = width - np.array(poly[0::2]) - 1
return flipped_poly.tolist()
def _flip_rle(rle, height, width):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, height, width)
mask = mask_util.decode(rle)
mask = mask[:, ::-1]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
flipped_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
flipped_segms.append([_flip_poly(poly, width) for poly in segm])
else:
# RLE format
import pycocotools.mask as mask_util
flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms
def apply_keypoint(self, gt_keypoint, width):
for i in range(gt_keypoint.shape[1]):
if i % 2 == 0:
old_x = gt_keypoint[:, i].copy()
gt_keypoint[:, i] = width - old_x - 1
return gt_keypoint
def apply_image(self, image):
return image[:, ::-1, :]
def apply_bbox(self, bbox, width):
bbox[:, 0::2] = width - bbox[:, 0::2] - 1
return bbox
def apply(self, sample, context=None):
"""Filp the image and bounding box.
Operators:
1. Flip the image numpy.
2. Transform the bboxes' x coordinates.
(Must judge whether the coordinates are normalized!)
3. Transform the segmentations' x coordinates.
(Must judge whether the coordinates are normalized!)
Output:
sample: the image, bounding box and segmentation part
in sample are flipped.
"""
if np.random.uniform(0, 1) < self.prob:
im = sample['image']
height, width = im.shape[:2]
im = self.apply_image(im)
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], width)
if self.is_mask_flip and 'gt_poly' in sample and len(sample[
'gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], height,
width)
if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0:
sample['gt_keypoint'] = self.apply_keypoint(
sample['gt_keypoint'], width)
if 'semantic' in sample and sample['semantic']:
sample['semantic'] = sample['semantic'][:, ::-1]
if 'gt_segm' in sample and sample['gt_segm']:
sample['gt_segm'] = sample['gt_segm'][:, :, ::-1]
sample['flipped'] = True
sample['image'] = im
return sample
@register_op
class ResizeOp(BaseOperator):
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
"""
Resize image to target size. if keep_ratio is True,
resize the image's long side to the maximum of target_size
if keep_ratio is False, resize the image to target size(h, w)
Args:
target_size (int|list): image target size
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): the interpolation method
"""
super(ResizeOp, self).__init__()
self.keep_ratio = keep_ratio
self.interp = interp
if not isinstance(target_size, (int, list, tuple)):
raise TypeError(
"Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
format(type(target_size)))
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def apply_image(self, image, scale):
im_scale_x, im_scale_y = scale
return cv2.resize(
image,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
def apply_bbox(self, bbox, scale, size):
im_scale_x, im_scale_y = scale
resize_w, resize_h = size
bbox[:, 0::2] *= im_scale_x
bbox[:, 1::2] *= im_scale_y
bbox[:, 0::2] = np.clip(bbox[:, 0::2], 0, resize_w - 1)
bbox[:, 1::2] = np.clip(bbox[:, 1::2], 0, resize_h - 1)
return bbox
def apply_segm(self, segms, im_size, scale):
def _resize_poly(poly, im_scale_x, im_scale_y):
resized_poly = np.array(poly)
resized_poly[0::2] *= im_scale_x
resized_poly[1::2] *= im_scale_y
return resized_poly.tolist()
def _resize_rle(rle, im_h, im_w, im_scale_x, im_scale_y):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, im_h, im_w)
mask = mask_util.decode(rle)
mask = cv2.resize(
image,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
im_h, im_w = im_size
im_scale_x, im_scale_y = scale
resized_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
resized_segms.append([
_resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
])
else:
# RLE format
import pycocotools.mask as mask_util
resized_segms.append(
_resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
return resized_segms
def apply(self, sample, context=None):
""" Resize the image numpy.
"""
im = sample['image']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.".format(self))
if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.'.format(self))
# apply image
im_shape = im.shape
if self.keep_ratio:
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = min(target_size_min / im_size_min,
target_size_max / im_size_max)
resize_h = int(im_scale * im_shape[0])
resize_w = int(im_scale * im_shape[1])
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / im_shape[0]
im_scale_x = resize_w / im_shape[1]
im = self.apply_image(sample['image'], [im_scale_x, im_scale_y])
sample['image'] = im
sample['im_shape'] = [resize_h, resize_w]
scale_factor = sample['scale_factor']
sample['scale_factor'] = [
scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x
]
# apply bbox
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'],
[im_scale_x, im_scale_y],
[resize_w, resize_h])
# apply polygon
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape,
[im_scale_x, im_scale_y])
# apply semantic
if 'semantic' in sample and sample['semantic']:
semantic = sample['semantic']
semantic = cv2.resize(
semantic.astype('float32'),
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
semantic = np.asarray(semantic).astype('int32')
semantic = np.expand_dims(semantic, 0)
sample['semantic'] = semantic
# apply gt_segm
if 'gt_segm' in sample and len(sample['gt_segm']) > 0:
masks = [
cv2.resize(
gt_segm,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=cv2.INTER_NEAREST)
for gt_segm in sample['gt_segm']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
@register_op
class MultiscaleTestResizeOp(BaseOperator):
def __init__(self,
origin_target_size=[800, 1333],
target_size=[],
interp=cv2.INTER_LINEAR,
use_flip=True):
"""
Rescale image to the each size in target size, and capped at max_size.
Args:
origin_target_size (list): origin target size of image
target_size (list): A list of target sizes of image.
interp (int): the interpolation method.
use_flip (bool): whether use flip augmentation.
"""
super(MultiscaleTestResizeOp, self).__init__()
self.interp = interp
self.use_flip = use_flip
if not isinstance(target_size, list):
raise TypeError(
"Type of target_size is invalid. Must be List, now is {}".
format(type(target_size)))
self.target_size = target_size
if not isinstance(origin_target_size, list):
raise TypeError(
"Type of target_size is invalid. Must be List, now is {}".
format(type(target_size)))
self.origin_target_size = origin_target_size
def apply(self, sample, context=None):
""" Resize the image numpy for multi-scale test.
"""
samples = []
resizer = ResizeOp(
self.origin_target_size, keep_ratio=True, interp=self.interp)
samples.append(resizer(sample.copy(), context))
if self.use_flip:
flipper = RandomFlipOp(1.1)
samples.append(flipper(sample.copy(), context=context))
for size in self.target_size:
resizer = ResizeOp(size, keep_ratio=True, interp=self.interp)
samples.append(resizer(sample.copy(), context))
return samples
@register_op
class RandomResizeOp(BaseOperator):
def __init__(self,
target_size,
keep_ratio=True,
interp=cv2.INTER_LINEAR,
random_size=True,
random_interp=False):
"""
Resize image to target size randomly. random target_size and interpolation method
Args:
target_size (int, list, tuple): image target size, if random size is True, must be list or tuple
keep_ratio (bool): whether keep_raio or not, default true
interp (int): the interpolation method
random_size (bool): whether random select target size of image
random_interp (bool): whether random select interpolation method
"""
super(RandomResizeOp, self).__init__()
self.keep_ratio = keep_ratio
self.interp = interp
self.interps = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
assert isinstance(target_size, (
int, Sequence)), "target_size must be int, list or tuple"
if random_size and not isinstance(target_size, list):
raise TypeError(
"Type of target_size is invalid when random_size is True. Must be List, now is {}".
format(type(target_size)))
self.target_size = target_size
self.random_size = random_size
self.random_interp = random_interp
def apply(self, sample, context=None):
""" Resize the image numpy.
"""
if self.random_size:
target_size = random.choice(self.target_size)
else:
target_size = self.target_size
if self.random_interp:
interp = random.choice(self.interps)
else:
interp = self.interp
resizer = ResizeOp(target_size, self.keep_ratio, interp)
return resizer(sample, context=context)
@register_op
class RandomExpandOp(BaseOperator):
"""Random expand the canvas.
Args:
ratio (float): maximum expansion ratio.
prob (float): probability to expand.
fill_value (list): color value used to fill the canvas. in RGB order.
"""
def __init__(self, ratio=4., prob=0.5, fill_value=(127.5, 127.5, 127.5)):
super(RandomExpandOp, self).__init__()
assert ratio > 1.01, "expand ratio must be larger than 1.01"
self.ratio = ratio
self.prob = prob
assert isinstance(fill_value, (Number, Sequence)), \
"fill value must be either float or sequence"
if isinstance(fill_value, Number):
fill_value = (fill_value, ) * 3
if not isinstance(fill_value, tuple):
fill_value = tuple(fill_value)
self.fill_value = fill_value
def apply(self, sample, context=None):
if np.random.uniform(0., 1.) < self.prob:
return sample
im = sample['image']
height, width = im.shape[:2]
ratio = np.random.uniform(1., self.ratio)
h = int(height * ratio)
w = int(width * ratio)
if not h > height or not w > width:
return sample
y = np.random.randint(0, h - height)
x = np.random.randint(0, w - width)
offsets, size = [x, y], [h, w]
pad = Pad(size, pad_mode=-1, offsets=offsets)
return pad(sample, context=context)
@register_op
class CropWithSampling(BaseOperator):
def __init__(self, batch_sampler, satisfy_all=False, avoid_no_bbox=True):
"""
Args:
batch_sampler (list): Multiple sets of different
parameters for cropping.
satisfy_all (bool): whether all boxes must satisfy.
e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap]
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
"""
super(CropWithSampling, self).__init__()
self.batch_sampler = batch_sampler
self.satisfy_all = satisfy_all
self.avoid_no_bbox = avoid_no_bbox
def apply(self, sample, context):
"""
Crop the image and modify bounding box.
Operators:
1. Scale the image width and height.
2. Crop the image according to a radom sample.
3. Rescale the bounding box.
4. Determine if the new bbox is satisfied in the new image.
Returns:
sample: the image, bounding box are replaced.
"""
assert 'image' in sample, "image data not found"
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
im_height, im_width = im.shape[:2]
gt_score = None
if 'gt_score' in sample:
gt_score = sample['gt_score']
sampled_bbox = []
gt_bbox = gt_bbox.tolist()
for sampler in self.batch_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = generate_sample_bbox(sampler)
if satisfy_sample_constraint(sampler, sample_bbox, gt_bbox,
self.satisfy_all):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = \
filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
xmin = int(sample_bbox[0] * im_width)
xmax = int(sample_bbox[2] * im_width)
ymin = int(sample_bbox[1] * im_height)
ymax = int(sample_bbox[3] * im_height)
im = im[ymin:ymax, xmin:xmax]
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
return sample
return sample
@register_op
class CropWithDataAchorSampling(BaseOperator):
def __init__(self,
batch_sampler,
anchor_sampler=None,
target_size=None,
das_anchor_scales=[16, 32, 64, 128],
sampling_prob=0.5,
min_size=8.,
avoid_no_bbox=True):
"""
Args:
anchor_sampler (list): anchor_sampling sets of different
parameters for cropping.
batch_sampler (list): Multiple sets of different
parameters for cropping.
e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]]
[[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]]
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap, min coverage, max coverage]
target_size (bool): target image size.
das_anchor_scales (list[float]): a list of anchor scales in data
anchor smapling.
min_size (float): minimum size of sampled bbox.
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
"""
super(CropWithDataAchorSampling, self).__init__()
self.anchor_sampler = anchor_sampler
self.batch_sampler = batch_sampler
self.target_size = target_size
self.sampling_prob = sampling_prob
self.min_size = min_size
self.avoid_no_bbox = avoid_no_bbox
self.das_anchor_scales = np.array(das_anchor_scales)
def apply(self, sample, context):
"""
Crop the image and modify bounding box.
Operators:
1. Scale the image width and height.
2. Crop the image according to a radom sample.
3. Rescale the bounding box.
4. Determine if the new bbox is satisfied in the new image.
Returns:
sample: the image, bounding box are replaced.
"""
assert 'image' in sample, "image data not found"
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
image_height, image_width = im.shape[:2]
gt_score = None
if 'gt_score' in sample:
gt_score = sample['gt_score']
sampled_bbox = []
gt_bbox = gt_bbox.tolist()
prob = np.random.uniform(0., 1.)
if prob > self.sampling_prob: # anchor sampling
assert self.anchor_sampler
for sampler in self.anchor_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = data_anchor_sampling(
gt_bbox, image_width, image_height,
self.das_anchor_scales, self.target_size)
if sample_bbox == 0:
break
if satisfy_sample_constraint_coverage(sampler, sample_bbox,
gt_bbox):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
if 'gt_keypoint' in sample.keys():
keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
crop_bbox, crop_class, crop_score, gt_keypoints = \
filter_and_process(sample_bbox, gt_bbox, gt_class,
scores=gt_score,
keypoints=keypoints)
else:
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, scores=gt_score)
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
im = crop_image_sampling(im, sample_bbox, image_width,
image_height, self.target_size)
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
return sample
return sample
else:
for sampler in self.batch_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = generate_sample_bbox_square(
sampler, image_width, image_height)
if satisfy_sample_constraint_coverage(sampler, sample_bbox,
gt_bbox):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox)
if 'gt_keypoint' in sample.keys():
keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
crop_bbox, crop_class, crop_score, gt_keypoints = \
filter_and_process(sample_bbox, gt_bbox, gt_class,
scores=gt_score,
keypoints=keypoints)
else:
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, scores=gt_score)
# sampling bbox according the bbox area
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
im = im[ymin:ymax, xmin:xmax]
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
return sample
return sample
@register_op
class RandomCropOp(BaseOperator):
"""Random crop image and bboxes.
Args:
aspect_ratio (list): aspect ratio of cropped region.
in [min, max] format.
thresholds (list): iou thresholds for decide a valid bbox crop.
scaling (list): ratio between a cropped region and the original image.
in [min, max] format.
num_attempts (int): number of tries before giving up.
allow_no_crop (bool): allow return without actually cropping them.
cover_all_box (bool): ensure all bboxes are covered in the final crop.
is_mask_crop(bool): whether crop the segmentation.
"""
def __init__(self,
aspect_ratio=[.5, 2.],
thresholds=[.0, .1, .3, .5, .7, .9],
scaling=[.3, 1.],
num_attempts=50,
allow_no_crop=True,
cover_all_box=False,
is_mask_crop=False):
super(RandomCropOp, self).__init__()
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
self.scaling = scaling
self.num_attempts = num_attempts
self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box
self.is_mask_crop = is_mask_crop
def crop_segms(self, segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop):
xmin, ymin, xmax, ymax = crop
crop_coord = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]
crop_p = np.array(crop_coord).reshape(4, 2)
crop_p = Polygon(crop_p)
crop_segm = list()
for poly in segm:
poly = np.array(poly).reshape(len(poly) // 2, 2)
polygon = Polygon(poly)
if not polygon.is_valid:
exterior = polygon.exterior
multi_lines = exterior.intersection(exterior)
polygons = shapely.ops.polygonize(multi_lines)
polygon = MultiPolygon(polygons)
multi_polygon = list()
if isinstance(polygon, MultiPolygon):
multi_polygon = copy.deepcopy(polygon)
else:
multi_polygon.append(copy.deepcopy(polygon))
for per_polygon in multi_polygon:
inter = per_polygon.intersection(crop_p)
if not inter:
continue
if isinstance(inter, (MultiPolygon, GeometryCollection)):
for part in inter:
if not isinstance(part, Polygon):
continue
part = np.squeeze(
np.array(part.exterior.coords[:-1]).reshape(1,
-1))
part[0::2] -= xmin
part[1::2] -= ymin
crop_segm.append(part.tolist())
elif isinstance(inter, Polygon):
crop_poly = np.squeeze(
np.array(inter.exterior.coords[:-1]).reshape(1, -1))
crop_poly[0::2] -= xmin
crop_poly[1::2] -= ymin
crop_segm.append(crop_poly.tolist())
else:
continue
return crop_segm
def _crop_rle(rle, crop, height, width):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, height, width)
mask = mask_util.decode(rle)
mask = mask[crop[1]:crop[3], crop[0]:crop[2]]
rle = mask_util.encode(np.array(mask, order='F', dtype=np.uint8))
return rle
crop_segms = []
for id in valid_ids:
segm = segms[id]
if is_poly(segm):
import copy
import shapely.ops
from shapely.geometry import Polygon, MultiPolygon, GeometryCollection
logging.getLogger("shapely").setLevel(logging.WARNING)
# Polygon format
crop_segms.append(_crop_poly(segm, crop))
else:
# RLE format
import pycocotools.mask as mask_util
crop_segms.append(_crop_rle(segm, crop, height, width))
return crop_segms
def apply(self, sample, context=None):
if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
return sample
h, w = sample['image'].shape[:2]
gt_bbox = sample['gt_bbox']
# NOTE Original method attempts to generate one candidate for each
# threshold then randomly sample one from the resulting list.
# Here a short circuit approach is taken, i.e., randomly choose a
# threshold and attempt to find a valid crop, and simply return the
# first one found.
# The probability is not exactly the same, kinda resembling the
# "Monty Hall" problem. Actually carrying out the attempts will affect
# observability (just like opening doors in the "Monty Hall" game).
thresholds = list(self.thresholds)
if self.allow_no_crop:
thresholds.append('no_crop')
np.random.shuffle(thresholds)
for thresh in thresholds:
if thresh == 'no_crop':
return sample
found = False
for i in range(self.num_attempts):
scale = np.random.uniform(*self.scaling)
if self.aspect_ratio is not None:
min_ar, max_ar = self.aspect_ratio
aspect_ratio = np.random.uniform(
max(min_ar, scale**2), min(max_ar, scale**-2))
h_scale = scale / np.sqrt(aspect_ratio)
w_scale = scale * np.sqrt(aspect_ratio)
else:
h_scale = np.random.uniform(*self.scaling)
w_scale = np.random.uniform(*self.scaling)
crop_h = h * h_scale
crop_w = w * w_scale
if self.aspect_ratio is None:
if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
continue
crop_h = int(crop_h)
crop_w = int(crop_w)
crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = self._iou_matrix(
gt_bbox, np.array(
[crop_box], dtype=np.float32))
if iou.max() < thresh:
continue
if self.cover_all_box and iou.min() < thresh:
continue
cropped_box, valid_ids = self._crop_box_with_center_constraint(
gt_bbox, np.array(
crop_box, dtype=np.float32))
if valid_ids.size > 0:
found = True
break
if found:
if self.is_mask_crop and 'gt_poly' in sample and len(sample[
'gt_poly']) > 0:
crop_polys = self.crop_segms(
sample['gt_poly'],
valid_ids,
np.array(
crop_box, dtype=np.int64),
h,
w)
if [] in crop_polys:
delete_id = list()
valid_polys = list()
for id, crop_poly in enumerate(crop_polys):
if crop_poly == []:
delete_id.append(id)
else:
valid_polys.append(crop_poly)
valid_ids = np.delete(valid_ids, delete_id)
if len(valid_polys) == 0:
return sample
sample['gt_poly'] = valid_polys
else:
sample['gt_poly'] = crop_polys
if 'gt_segm' in sample:
sample['gt_segm'] = self._crop_segm(sample['gt_segm'],
crop_box)
sample['gt_segm'] = np.take(
sample['gt_segm'], valid_ids, axis=0)
sample['image'] = self._crop_image(sample['image'], crop_box)
sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
sample['gt_class'] = np.take(
sample['gt_class'], valid_ids, axis=0)
if 'gt_score' in sample:
sample['gt_score'] = np.take(
sample['gt_score'], valid_ids, axis=0)
if 'is_crowd' in sample:
sample['is_crowd'] = np.take(
sample['is_crowd'], valid_ids, axis=0)
return sample
return sample
def _iou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def _crop_box_with_center_constraint(self, box, crop):
cropped_box = box.copy()
cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
cropped_box[:, :2] -= crop[:2]
cropped_box[:, 2:] -= crop[:2]
centers = (box[:, :2] + box[:, 2:]) / 2
valid = np.logical_and(crop[:2] <= centers,
centers < crop[2:]).all(axis=1)
valid = np.logical_and(
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
return cropped_box, np.where(valid)[0]
def _crop_image(self, img, crop):
x1, y1, x2, y2 = crop
return img[y1:y2, x1:x2, :]
def _crop_segm(self, segm, crop):
x1, y1, x2, y2 = crop
return segm[:, y1:y2, x1:x2]
@register_op
class RandomScaledCropOp(BaseOperator):
"""Resize image and bbox based on long side (with optional random scaling),
then crop or pad image to target size.
Args:
target_dim (int): target size.
scale_range (list): random scale range.
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
"""
def __init__(self,
target_dim=512,
scale_range=[.1, 2.],
interp=cv2.INTER_LINEAR):
super(RandomScaledCropOp, self).__init__()
self.target_dim = target_dim
self.scale_range = scale_range
self.interp = interp
def apply(self, sample, context=None):
img = sample['image']
h, w = img.shape[:2]
random_scale = np.random.uniform(*self.scale_range)
dim = self.target_dim
random_dim = int(dim * random_scale)
dim_max = max(h, w)
scale = random_dim / dim_max
resize_w = int(round(w * scale))
resize_h = int(round(h * scale))
offset_x = int(max(0, np.random.uniform(0., resize_w - dim)))
offset_y = int(max(0, np.random.uniform(0., resize_h - dim)))
img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interp)
img = np.array(img)
canvas = np.zeros((dim, dim, 3), dtype=img.dtype)
canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[
offset_y:offset_y + dim, offset_x:offset_x + dim, :]
sample['image'] = canvas
sample['im_shape'] = [resize_h, resize_w]
scale_factor = sample['sacle_factor']
sample['scale_factor'] = [
scale_factor[0] * scale, scale_factor[1] * scale
]
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
scale_array = np.array([scale, scale] * 2, dtype=np.float32)
shift_array = np.array([offset_x, offset_y] * 2, dtype=np.float32)
boxes = sample['gt_bbox'] * scale_array - shift_array
boxes = np.clip(boxes, 0, dim - 1)
# filter boxes with no area
area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1)
valid = (area > 1.).nonzero()[0]
sample['gt_bbox'] = boxes[valid]
sample['gt_class'] = sample['gt_class'][valid]
return sample
@register_op
class CutmixOp(BaseOperator):
def __init__(self, alpha=1.5, beta=1.5):
"""
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://arxiv.org/abs/1905.04899
Cutmix image and gt_bbbox/gt_score
Args:
alpha (float): alpha parameter of beta distribute
beta (float): beta parameter of beta distribute
"""
super(CutmixOp, self).__init__()
self.alpha = alpha
self.beta = beta
if self.alpha <= 0.0:
raise ValueError("alpha shold be positive in {}".format(self))
if self.beta <= 0.0:
raise ValueError("beta shold be positive in {}".format(self))
def apply_image(self, img1, img2, factor):
""" _rand_bbox """
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
cut_rat = np.sqrt(1. - factor)
cut_w = np.int(w * cut_rat)
cut_h = np.int(h * cut_rat)
# uniform
cx = np.random.randint(w)
cy = np.random.randint(h)
bbx1 = np.clip(cx - cut_w // 2, 0, w - 1)
bby1 = np.clip(cy - cut_h // 2, 0, h - 1)
bbx2 = np.clip(cx + cut_w // 2, 0, w - 1)
bby2 = np.clip(cy + cut_h // 2, 0, h - 1)
img_1 = np.zeros((h, w, img1.shape[2]), 'float32')
img_1[:img1.shape[0], :img1.shape[1], :] = \
img1.astype('float32')
img_2 = np.zeros((h, w, img2.shape[2]), 'float32')
img_2[:img2.shape[0], :img2.shape[1], :] = \
img2.astype('float32')
img_1[bby1:bby2, bbx1:bbx2, :] = img2[bby1:bby2, bbx1:bbx2, :]
return img_1
def __call__(self, sample, context=None):
if not isinstance(sample, Sequence):
return sample
assert len(sample) == 2, 'cutmix need two samples'
factor = np.random.beta(self.alpha, self.beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
return sample[0]
if factor <= 0.0:
return sample[1]
img1 = sample[0]['image']
img2 = sample[1]['image']
img = self.apply_image(img1, img2, factor)
gt_bbox1 = sample[0]['gt_bbox']
gt_bbox2 = sample[1]['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = sample[0]['gt_class']
gt_class2 = sample[1]['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = sample[0]['gt_score']
gt_score2 = sample[1]['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
sample = sample[0]
sample['image'] = img
sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score
sample['gt_class'] = gt_class
return sample
@register_op
class MixupOp(BaseOperator):
def __init__(self, alpha=1.5, beta=1.5):
""" Mixup image and gt_bbbox/gt_score
Args:
alpha (float): alpha parameter of beta distribute
beta (float): beta parameter of beta distribute
"""
super(MixupOp, self).__init__()
self.alpha = alpha
self.beta = beta
if self.alpha <= 0.0:
raise ValueError("alpha shold be positive in {}".format(self))
if self.beta <= 0.0:
raise ValueError("beta shold be positive in {}".format(self))
def apply_image(self, img1, img2, factor):
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
img = np.zeros((h, w, img1.shape[2]), 'float32')
img[:img1.shape[0], :img1.shape[1], :] = \
img1.astype('float32') * factor
img[:img2.shape[0], :img2.shape[1], :] += \
img2.astype('float32') * (1.0 - factor)
return img.astype('uint8')
def __call__(self, sample, context=None):
if not isinstance(sample, Sequence):
return sample
assert len(sample) == 2, 'mixup need two samples'
factor = np.random.beta(self.alpha, self.beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
return sample[0]
if factor <= 0.0:
return sample[1]
im = self.apply_image(sample[0]['image'], sample[1]['image'], factor)
# apply bbox and score
gt_bbox1 = sample[0]['gt_bbox']
gt_bbox2 = sample[1]['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = sample[0]['gt_class']
gt_class2 = sample[1]['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = sample[0]['gt_score']
gt_score2 = sample[1]['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
is_crowd1 = sample[0]['is_crowd']
is_crowd2 = sample[1]['is_crowd']
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
sample = sample[0]
sample['image'] = im
sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score
sample['gt_class'] = gt_class
sample['is_crowd'] = is_crowd
return sample
@register_op
class NormalizeBoxOp(BaseOperator):
"""Transform the bounding box's coornidates to [0,1]."""
def __init__(self):
super(NormalizeBoxOp, self).__init__()
def apply(self, sample, context):
im = sample['image']
gt_bbox = sample['gt_bbox']
height, width, _ = im.shape
for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / width
gt_bbox[i][1] = gt_bbox[i][1] / height
gt_bbox[i][2] = gt_bbox[i][2] / width
gt_bbox[i][3] = gt_bbox[i][3] / height
sample['gt_bbox'] = gt_bbox
if 'gt_keypoint' in sample.keys():
gt_keypoint = sample['gt_keypoint']
for i in range(gt_keypoint.shape[1]):
if i % 2:
gt_keypoint[:, i] = gt_keypoint[:, i] / height
else:
gt_keypoint[:, i] = gt_keypoint[:, i] / width
sample['gt_keypoint'] = gt_keypoint
return sample
@register_op
class BboxXYXY2XYWH(BaseOperator):
"""
Convert bbox XYXY format to XYWH format.
"""
def __init__(self):
super(BboxXYXY2XYWH, self).__init__()
def apply(self, sample, context=None):
assert 'gt_bbox' in sample
bbox = sample['gt_bbox']
bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2]
bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2.
sample['gt_bbox'] = bbox
return sample
@register_op
class DebugVisibleImageOp(BaseOperator):
"""
In debug mode, visualize images according to `gt_box`.
(Currently only supported when not cropping and flipping image.)
"""
def __init__(self, output_dir='output/debug', is_normalized=False):
super(DebugVisibleImageOp, self).__init__()
self.is_normalized = is_normalized
self.output_dir = output_dir
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
if not isinstance(self.is_normalized, bool):
raise TypeError("{}: input type is invalid.".format(self))
def apply(self, sample, context=None):
image = Image.open(sample['im_file']).convert('RGB')
out_file_name = sample['im_file'].split('/')[-1]
width = sample['w']
height = sample['h']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
draw = ImageDraw.Draw(image)
for i in range(gt_bbox.shape[0]):
if self.is_normalized:
gt_bbox[i][0] = gt_bbox[i][0] * width
gt_bbox[i][1] = gt_bbox[i][1] * height
gt_bbox[i][2] = gt_bbox[i][2] * width
gt_bbox[i][3] = gt_bbox[i][3] * height
xmin, ymin, xmax, ymax = gt_bbox[i]
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill='green')
# draw label
text = str(gt_class[i][0])
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green')
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
if 'gt_keypoint' in sample.keys():
gt_keypoint = sample['gt_keypoint']
if self.is_normalized:
for i in range(gt_keypoint.shape[1]):
if i % 2:
gt_keypoint[:, i] = gt_keypoint[:, i] * height
else:
gt_keypoint[:, i] = gt_keypoint[:, i] * width
for i in range(gt_keypoint.shape[0]):
keypoint = gt_keypoint[i]
for j in range(int(keypoint.shape[0] / 2)):
x1 = round(keypoint[2 * j]).astype(np.int32)
y1 = round(keypoint[2 * j + 1]).astype(np.int32)
draw.ellipse(
(x1, y1, x1 + 5, y1 + 5), fill='green', outline='green')
save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95)
return sample
@register_op
class Pad(BaseOperator):
def __init__(self,
size=None,
size_divisor=32,
pad_mode=0,
offsets=None,
fill_value=(127.5, 127.5, 127.5)):
"""
Pad image to a specified size or multiple of size_divisor. random target_size and interpolation method
Args:
size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
size_divisor (int): size divisor, default 32
pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
"""
super(Pad, self).__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError(
"Type of target_size is invalid when random_size is True. \
Must be List, now is {}".format(type(size)))
if isinstance(size, int):
size = [size, size]
assert pad_mode in [
-1, 0, 1, 2
], 'currently only supports four modes [-1, 0, 1, 2]'
assert pad_mode == -1 and offsets, 'if pad_mode is -1, offsets should not be None'
self.size = size
self.size_divisor = size_divisor
self.pad_mode = pad_mode
self.fill_value = fill_value
self.offsets = offsets
def apply_segm(self, segms, offsets, im_size, size):
def _expand_poly(poly, x, y):
expanded_poly = np.array(poly)
expanded_poly[0::2] += x
expanded_poly[1::2] += y
return expanded_poly.tolist()
def _expand_rle(rle, x, y, height, width, h, w):
if 'counts' in rle and type(rle['counts']) == list:
rle = mask_util.frPyObjects(rle, height, width)
mask = mask_util.decode(rle)
expanded_mask = np.full((h, w), 0).astype(mask.dtype)
expanded_mask[y:y + height, x:x + width] = mask
rle = mask_util.encode(
np.array(
expanded_mask, order='F', dtype=np.uint8))
return rle
x, y = offsets
height, width = im_size
h, w = size
expanded_segms = []
for segm in segms:
if is_poly(segm):
# Polygon format
expanded_segms.append(
[_expand_poly(poly, x, y) for poly in segm])
else:
# RLE format
import pycocotools.mask as mask_util
expanded_segms.append(
_expand_rle(segm, x, y, height, width, h, w))
return expanded_segms
def apply_bbox(self, bbox, offsets):
return bbox + np.array(offsets * 2, dtype=np.float32)
def apply_keypoint(self, keypoints, offsets):
n = len(keypoints[0]) // 2
return keypoints + np.array(offsets * n, dtype=np.float32)
def apply_image(self, image, offsets, im_size, size):
x, y = offsets
im_h, im_w = im_size
h, w = size
canvas = np.ones((h, w, 3), dtype=np.uint8)
canvas *= np.array(self.fill_value, dtype=np.uint8)
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.uint8)
return canvas
def apply(self, sample, context=None):
im = sample['image']
im_h, im_w = im.shape[:2]
if self.size:
h, w = self.size
assert (
im_h < h and im_w < w
), '(h, w) of target size should be greater than (im_h, im_w)'
else:
h = np.ceil(im_h // self.size_divisor) * self.size_divisor
w = np.ceil(im_w / self.size_divisor) * self.size_divisor
if h == im_h and w == im_w:
return sample
if self.pad_mode == -1:
offset_x, offset_y = self.offsets
elif self.pad_mode == 0:
offset_y, offset_x = 0, 0
elif self.pad_mode == 1:
offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
else:
offset_y, offset_x = h - im_h, w - im_w
offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
sample['image'] = self.apply_image(im, offsets, im_size, size)
if self.pad_mode == 0:
return sample
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], offsets,
im_size, size)
if 'gt_keypoint' in sample and len(sample['gt_keypoint']) > 0:
sample['gt_keypoint'] = self.apply_keypoint(sample['gt_keypoint'],
offsets)
return sample
@register_op
class Poly2Mask(BaseOperator):
"""
gt poly to mask annotations
"""
def __init__(self):
super(Poly2Mask, self).__init__()
import pycocotools.mask as maskUtils
self.maskutils = maskUtils
def _poly2mask(self, mask_ann, img_h, img_w):
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
rle = self.maskutils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = self.maskutils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = self.maskutils.decode(rle)
return mask
def apply(self, sample, context=None):
assert 'gt_poly' in sample
im_h = sample['h']
im_w = sample['w']
masks = [
self._poly2mask(gt_poly, im_h, im_w)
for gt_poly in sample['gt_poly']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
......@@ -39,6 +39,7 @@ from PIL import Image, ImageEnhance, ImageDraw
from ppdet.core.workspace import serializable
from ppdet.modeling.layers import AnchorGrid
from .operator import register_op, BaseOperator, BboxError, ImageError
from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
......@@ -48,45 +49,6 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process,
logger = logging.getLogger(__name__)
registered_ops = []
def register_op(cls):
registered_ops.append(cls.__name__)
if not hasattr(BaseOperator, cls.__name__):
setattr(BaseOperator, cls.__name__, cls)
else:
raise KeyError("The {} class has been registered.".format(cls.__name__))
return serializable(cls)
class BboxError(ValueError):
pass
class ImageError(ValueError):
pass
class BaseOperator(object):
def __init__(self, name=None):
if name is None:
name = self.__class__.__name__
self._id = name + '_' + str(uuid.uuid4())[-6:]
def __call__(self, sample, context=None):
""" Process a sample.
Args:
sample (dict): a dict of sample, eg: {'image':xx, 'label': xxx}
context (dict): info about this sample processing
Returns:
result (dict): a processed sample
"""
return sample
def __str__(self):
return str(self._id)
@register_op
class DecodeImage(BaseOperator):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册