未验证 提交 1cae8144 编写于 作者: W wangxinxin08 提交者: GitHub

add mosaic data augmentation (#3185)

上级 bc3a1145
......@@ -462,3 +462,62 @@ def gaussian2D(shape, sigma_x=1, sigma_y=1):
sigma_y)))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def transform_bbox(sample,
M,
w,
h,
area_thr=0.25,
wh_thr=2,
ar_thr=20,
perspective=False):
"""
transfrom bbox according to tranformation matrix M,
refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
"""
bbox = sample['gt_bbox']
label = sample['gt_class']
# rotate bbox
n = len(bbox)
xy = np.ones((n * 4, 3), dtype=np.float32)
xy[:, :2] = bbox[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)
# xy = xy @ M.T
xy = np.matmul(xy, M.T)
if perspective:
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8)
else:
xy = xy[:, :2].reshape(n, 8)
# get new bboxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
bbox = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
mask = filter_bbox(bbox, w, h, area_thr)
sample['gt_bbox'] = bbox[mask]
sample['gt_class'] = sample['gt_class'][mask]
if 'is_crowd' in sample:
sample['is_crowd'] = sample['is_crowd'][mask]
if 'difficult' in sample:
sample['difficult'] = sample['difficult'][mask]
return sample
def filter_bbox(bbox, w, h, area_thr=0.25, wh_thr=2, ar_thr=20):
"""
filter bbox, refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
"""
# clip boxes
area1 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
bbox[:, [0, 2]] = bbox[:, [0, 2]].clip(0, w)
bbox[:, [1, 3]] = bbox[:, [1, 3]].clip(0, h)
# compute
area2 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
area_ratio = area2 / (area1 + 1e-16)
wh = bbox[:, 2:4] - bbox[:, 0:2]
ar_ratio = np.maximum(wh[:, 1] / (wh[:, 0] + 1e-16),
wh[:, 0] / (wh[:, 1] + 1e-16))
mask = (area_ratio > area_thr) & (
(wh > wh_thr).all(1)) & (ar_ratio < ar_thr)
return mask
......@@ -45,7 +45,7 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling,
is_poly, gaussian_radius, draw_gaussian)
is_poly, gaussian_radius, draw_gaussian, transform_bbox)
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
......@@ -1767,8 +1767,8 @@ class DebugVisibleImage(BaseOperator):
raise TypeError("{}: input type is invalid.".format(self))
def apply(self, sample, context=None):
image = Image.open(sample['im_file']).convert('RGB')
out_file_name = sample['im_file'].split('/')[-1]
image = Image.fromarray(sample['image'].astype(np.uint8))
out_file_name = '{:012d}.jpg'.format(sample['im_id'][0])
width = sample['w']
height = sample['h']
gt_bbox = sample['gt_bbox']
......@@ -2348,5 +2348,183 @@ class RandomResizeCrop(BaseOperator):
for gt_segm in sample['gt_segm']
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
class RandomPerspective(BaseOperator):
"""
Rotate, tranlate, scale, shear and perspect image and bboxes randomly,
refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
Args:
degree (int): rotation degree, uniformly sampled in [-degree, degree]
translate (float): translate fraction, translate_x and translate_y are uniformly sampled
in [0.5 - translate, 0.5 + translate]
scale (float): scale factor, uniformly sampled in [1 - scale, 1 + scale]
shear (int): shear degree, shear_x and shear_y are uniformly sampled in [-shear, shear]
perspective (float): perspective_x and perspective_y are uniformly sampled in [-perspective, perspective]
area_thr (float): the area threshold of bbox to be kept after transformation, default 0.25
fill_value (tuple): value used in case of a constant border, default (114, 114, 114)
"""
def __init__(self,
degree=10,
translate=0.1,
scale=0.1,
shear=10,
perspective=0.0,
border=[0, 0],
area_thr=0.25,
fill_value=(114, 114, 114)):
super(RandomPerspective, self).__init__()
self.degree = degree
self.translate = translate
self.scale = scale
self.shear = shear
self.perspective = perspective
self.border = border
self.area_thr = area_thr
self.fill_value = fill_value
def apply(self, sample, context=None):
im = sample['image']
height = im.shape[0] + self.border[0] * 2
width = im.shape[1] + self.border[1] * 2
# center
C = np.eye(3)
C[0, 2] = -im.shape[1] / 2
C[1, 2] = -im.shape[0] / 2
# perspective
P = np.eye(3)
P[2, 0] = random.uniform(-self.perspective, self.perspective)
P[2, 1] = random.uniform(-self.perspective, self.perspective)
# Rotation and scale
R = np.eye(3)
a = random.uniform(-self.degree, self.degree)
s = random.uniform(1 - self.scale, 1 + self.scale)
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
S = np.eye(3)
# shear x (deg)
S[0, 1] = math.tan(
random.uniform(-self.shear, self.shear) * math.pi / 180)
# shear y (deg)
S[1, 0] = math.tan(
random.uniform(-self.shear, self.shear) * math.pi / 180)
# Translation
T = np.eye(3)
T[0, 2] = random.uniform(0.5 - self.translate,
0.5 + self.translate) * width
T[1, 2] = random.uniform(0.5 - self.translate,
0.5 + self.translate) * height
# matmul
# M = T @ S @ R @ P @ C
M = np.eye(3)
for cM in [T, S, R, P, C]:
M = np.matmul(M, cM)
if (self.border[0] != 0) or (self.border[1] != 0) or (
M != np.eye(3)).any():
if self.perspective:
im = cv2.warpPerspective(
im, M, dsize=(width, height), borderValue=self.fill_value)
else:
im = cv2.warpAffine(
im,
M[:2],
dsize=(width, height),
borderValue=self.fill_value)
sample['image'] = im
if sample['gt_bbox'].shape[0] > 0:
sample = transform_bbox(
sample,
M,
width,
height,
area_thr=self.area_thr,
perspective=self.perspective)
return sample
@register_op
class Mosaic(BaseOperator):
"""
Mosaic Data Augmentation, refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
"""
def __init__(self,
target_size,
mosaic_border=None,
fill_value=(114, 114, 114)):
super(Mosaic, self).__init__()
self.target_size = target_size
if mosaic_border is None:
mosaic_border = (-target_size // 2, -target_size // 2)
self.mosaic_border = mosaic_border
self.fill_value = fill_value
def __call__(self, sample, context=None):
if not isinstance(sample, Sequence):
return sample
s = self.target_size
yc, xc = [
int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border
]
boxes = [x['gt_bbox'] for x in sample]
labels = [x['gt_class'] for x in sample]
for i in range(len(sample)):
im = sample[i]['image']
h, w, c = im.shape
if i == 0: # top left
image = np.ones(
(s * 2, s * 2, c), dtype=np.uint8) * self.fill_value
# xmin, ymin, xmax, ymax (dst image)
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc
# xmin, ymin, xmax, ymax (src image)
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h
elif i == 1: # top right
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
elif i == 2: # bottom left
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, max(xc, w), min(
y2a - y1a, h)
elif i == 3: # bottom right
x1a, y1a, x2a, y2a = xc, yc, min(xc + w,
s * 2), min(s * 2, yc + h)
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
image[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b]
padw = x1a - x1b
padh = y1a - y1b
boxes[i] = boxes[i] + (padw, padh, padw, padh)
boxes = np.concatenate(boxes, axis=0)
boxes = np.clip(boxes, 0, s * 2)
labels = np.concatenate(labels, axis=0)
if 'is_crowd' in sample[0]:
is_crowd = np.concatenate([x['is_crowd'] for x in sample], axis=0)
if 'difficult' in sample[0]:
difficult = np.concatenate([x['difficult'] for x in sample], axis=0)
sample = sample[0]
sample['image'] = image.astype(np.uint8)
sample['gt_bbox'] = boxes
sample['gt_class'] = labels
if 'is_crowd' in sample:
sample['is_crowd'] = is_crowd
if 'difficult' in sample:
sample['difficult'] = difficult
return sample
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册