未验证 提交 576b06f8 编写于 作者: Y Yang Zhang 提交者: GitHub

Optimize data transforms for yolo training (#28)

* Optimize data transforms for yolo training

* Simplify and add docstring
上级 3ff10601
......@@ -27,7 +27,8 @@ from ppdet.data.reader import Reader
from ppdet.data.transform.operators import (
DecodeImage, MixupImage, NormalizeBox, NormalizeImage, RandomDistort,
RandomFlipImage, RandomInterpImage, ResizeImage, ExpandImage, CropImage,
Permute, MultiscaleTestResize)
Permute, MultiscaleTestResize, Resize, ColorDistort, NormalizePermute,
RandomExpand, RandomCrop)
from ppdet.data.transform.arrange_sample import (
ArrangeRCNN, ArrangeEvalRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD,
ArrangeTestSSD, ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO)
......@@ -195,7 +196,7 @@ class RandomShape(object):
class PadMSTest(object):
"""
Padding for multi-scale test
Args:
pad_to_stride (int): pad to multiple of strides, e.g., 32
"""
......@@ -896,25 +897,15 @@ class YoloTrainFeed(DataFeed):
sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=True),
MixupImage(alpha=1.5, beta=1.5),
ColorDistort(),
RandomExpand(fill_value=[123.675, 116.28, 103.53]),
RandomCrop(),
RandomFlipImage(is_normalized=False),
Resize(target_dim=608, interp='random'),
NormalizePermute(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
NormalizeBox(),
RandomDistort(),
ExpandImage(max_ratio=4., prob=.5,
mean=[123.675, 116.28, 103.53]),
CropImage([[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]),
RandomInterpImage(target_size=608),
RandomFlipImage(is_normalized=True),
NormalizeImage(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
Permute(to_bgr=False),
],
batch_transforms=[
RandomShape(sizes=[
......@@ -1010,6 +1001,8 @@ class YoloEvalFeed(DataFeed):
sample_transforms[i] = ResizeImage(
target_size=self.image_shape[-1],
interp=trans.interp)
if isinstance(trans, Resize):
sample_transforms[i].target_dim = self.image_shape[-1]
@register
......@@ -1066,4 +1059,6 @@ class YoloTestFeed(DataFeed):
sample_transforms[i] = ResizeImage(
target_size=self.image_shape[-1],
interp=trans.interp)
if isinstance(trans, Resize):
sample_transforms[i].target_dim = self.image_shape[-1]
# yapf: enable
......@@ -20,6 +20,13 @@ from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from numbers import Number
import uuid
import logging
import random
......@@ -234,11 +241,11 @@ class ResizeImage(BaseOperator):
target size.
Args:
target_size (int|list): the target size of image's short side,
target_size (int|list): the target size of image's short side,
multi-scale training is adopted when type is list.
max_size (int): the max size of image
interp (int): the interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL
use_cv2 (bool): use the cv2 interpolation method or use PIL
interpolation method
"""
super(ResizeImage, self).__init__()
......@@ -642,7 +649,7 @@ class CropImage(BaseOperator):
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap]
avoid_no_bbox (bool): whether to to avoid the
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
"""
super(CropImage, self).__init__()
......@@ -989,3 +996,367 @@ class RandomInterpImage(BaseOperator):
"""Resise the image numpy by random resizer."""
resizer = random.choice(self.resizers)
return resizer(sample, context)
@register_op
class Resize(BaseOperator):
"""Resize image and bbox.
Args:
target_dim (int or list): target size, can be a single number or a list
(for random shape).
interp (int or str): interpolation method, can be an integer or
'random' (for randomized interpolation).
default to `cv2.INTER_LINEAR`.
"""
def __init__(self,
target_dim=[],
interp=cv2.INTER_LINEAR):
super(Resize, self).__init__()
self.target_dim = target_dim
self.interp = interp # 'random' for yolov3
def __call__(self, sample, context=None):
w = sample['w']
h = sample['h']
interp = self.interp
if interp == 'random':
interp = np.random.choice(range(5))
if isinstance(self.target_dim, Sequence):
dim = np.random.choice(self.target_dim)
else:
dim = self.target_dim
resize_w = resize_h = dim
scale_x = dim / w
scale_y = dim / h
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
if self.scale_box or self.scale_box is None:
scale_array = np.array([scale_x, scale_y] * 2,
dtype=np.float32)
sample['gt_bbox'] = np.clip(
sample['gt_bbox'] * scale_array, 0, dim - 1)
sample['h'] = resize_h
sample['w'] = resize_w
sample['image'] = cv2.resize(
sample['image'], (resize_w, resize_h), interpolation=interp)
return sample
@register_op
class ColorDistort(BaseOperator):
"""Random color distortion.
Args:
hue (list): hue settings.
in [lower, upper, probability] format.
saturation (list): saturation settings.
in [lower, upper, probability] format.
contrast (list): contrast settings.
in [lower, upper, probability] format.
brightness (list): brightness settings.
in [lower, upper, probability] format.
random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
order.
"""
def __init__(self,
hue=[-18, 18, 0.5],
saturation=[0.5, 1.5, 0.5],
contrast=[0.5, 1.5, 0.5],
brightness=[0.5, 1.5, 0.5],
random_apply=True):
super(ColorDistort, self).__init__()
self.hue = hue
self.saturation = saturation
self.contrast = contrast
self.brightness = brightness
self.random_apply = random_apply
def apply_hue(self, img):
low, high, prob = self.hue
if np.random.uniform(0., 1.) < prob:
return img
img = img.astype(np.float32)
# XXX works, but result differ from HSV version
delta = np.random.uniform(low, high)
u = np.cos(delta * np.pi)
w = np.sin(delta * np.pi)
bt = np.array([[1.0, 0.0, 0.0],
[0.0, u, -w],
[0.0, w, u]])
tyiq = np.array([[0.299, 0.587, 0.114],
[0.596, -0.274, -0.321],
[0.211, -0.523, 0.311]])
ityiq = np.array([[1.0, 0.956, 0.621],
[1.0, -0.272, -0.647],
[1.0, -1.107, 1.705]])
t = np.dot(np.dot(ityiq, bt), tyiq).T
img = np.dot(img, t)
return img
def apply_saturation(self, img):
low, high, prob = self.saturation
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
gray = gray.sum(axis=2, keepdims=True)
gray *= (1.0 - delta)
img *= delta
img += gray
return img
def apply_contrast(self, img):
low, high, prob = self.contrast
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
img *= delta
return img
def apply_brightness(self, img):
low, high, prob = self.brightness
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
img += delta
return img
def __call__(self, sample, context=None):
img = sample['image']
if self.random_apply:
distortions = np.random.permutation([
self.apply_brightness,
self.apply_contrast,
self.apply_saturation,
self.apply_hue
])
for func in distortions:
img = func(img)
sample['image'] = img
return sample
img = self.apply_brightness(img)
if np.random.randint(0, 2):
img = self.apply_contrast(img)
img = self.apply_saturation(img)
img = self.apply_hue(img)
else:
img = self.apply_saturation(img)
img = self.apply_hue(img)
img = self.apply_contrast(img)
sample['image'] = img
return sample
@register_op
class NormalizePermute(BaseOperator):
"""Normalize and permute channel order.
Args:
mean (list): mean values in RGB order.
std (list): std values in RGB order.
"""
def __init__(self,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]):
super(NormalizePermute, self).__init__()
self.mean = mean
self.std = std
def __call__(self, sample, context=None):
img = sample['image']
img = img.astype(np.float32)
img = img.transpose((2, 0, 1))
mean = np.array(self.mean, dtype=np.float32)
std = np.array(self.std, dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
sample['image'] = img
return sample
@register_op
class RandomExpand(BaseOperator):
"""Random expand the canvas.
Args:
ratio (float): maximum expansion ratio.
prob (float): probability to expand.
fill_value (list): color value used to fill the canvas. in RGB order.
"""
def __init__(self, ratio=4., prob=0.5, fill_value=(127.5,) * 3):
super(RandomExpand, self).__init__()
assert ratio > 1.01, "expand ratio must be larger than 1.01"
self.ratio = ratio
self.prob = prob
assert isinstance(fill_value, (Number, Sequence)), \
"fill value must be either float or sequence"
if isinstance(fill_value, Number):
fill_value = (fill_value,) * 3
if not isinstance(fill_value, tuple):
fill_value = tuple(fill_value)
self.fill_value = fill_value
def __call__(self, sample, context=None):
if np.random.uniform(0., 1.) < self.prob:
return sample
img = sample['image']
height = int(sample['h'])
width = int(sample['w'])
expand_ratio = np.random.uniform(1., self.ratio)
h = int(height * expand_ratio)
w = int(width * expand_ratio)
if not h > height or not w > width:
return sample
y = np.random.randint(0, h - height)
x = np.random.randint(0, w - width)
canvas = np.ones((h, w, 3), dtype=np.uint8)
canvas *= np.array(self.fill_value, dtype=np.uint8)
canvas[y:y + height, x:x + width, :] = img.astype(np.uint8)
sample['h'] = h
sample['w'] = w
sample['image'] = canvas
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
return sample
@register_op
class RandomCrop(BaseOperator):
"""Random crop image and bboxes.
Args:
aspect_ratio (list): aspect ratio of cropped region.
in [min, max] format.
thresholds (list): iou thresholds for decide a valid bbox crop.
scaling (list): ratio between a cropped region and the original image.
in [min, max] format.
num_attempts (int): number of tries before giving up.
allow_no_crop (bool): allow return without actually cropping them.
cover_all_box (bool): ensure all bboxes are covered in the final crop.
"""
def __init__(self,
aspect_ratio=[.5, 2.],
thresholds=[.0, .1, .3, .5, .7, .9],
scaling=[.3, 1.],
num_attempts=50,
allow_no_crop=True,
cover_all_box=False):
super(RandomCrop, self).__init__()
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
self.scaling = scaling
self.num_attempts = num_attempts
self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box
def __call__(self, sample, context=None):
if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
return sample
h = sample['h']
w = sample['w']
gt_bbox = sample['gt_bbox']
# NOTE Original method attempts to generate one candidate for each
# threshold then randomly sample one from the resulting list.
# Here a short circuit approach is taken, i.e., randomly choose a
# threshold and attempt to find a valid crop, and simply return the
# first one found.
# The probability is not exactly the same, kinda resembling the
# "Monty Hall" problem. Actually carrying out the attempts will affect
# observability (just like opening doors in the "Monty Hall" game).
thresholds = list(self.thresholds)
if self.allow_no_crop:
thresholds.append('no_crop')
np.random.shuffle(thresholds)
for thresh in thresholds:
if thresh == 'no_crop':
return sample
found = False
for i in range(self.num_attempts):
scale = np.random.uniform(*self.scaling)
min_ar, max_ar = self.aspect_ratio
aspect_ratio = np.random.uniform(max(min_ar, scale**2),
min(max_ar, scale**-2))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = self._iou_matrix(gt_bbox,
np.array([crop_box], dtype=np.float32))
if iou.max() < thresh:
continue
if self.cover_all_box and iou.min() < thresh:
continue
cropped_box, valid_ids = self._crop_box_with_center_constraint(
gt_bbox, np.array(crop_box, dtype=np.float32))
if valid_ids.size > 0:
found = True
break
if found:
sample['image'] = self._crop_image(sample['image'], crop_box)
sample['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
sample['gt_class'] = np.take(
sample['gt_class'], valid_ids, axis=0)
sample['w'] = crop_box[2] - crop_box[0]
sample['h'] = crop_box[3] - crop_box[1]
if 'gt_score' in sample:
sample['gt_score'] = np.take(
sample['gt_score'], valid_ids, axis=0)
return sample
return sample
def _iou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def _crop_box_with_center_constraint(self, box, crop):
cropped_box = box.copy()
cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
cropped_box[:, :2] -= crop[:2]
cropped_box[:, 2:] -= crop[:2]
centers = (box[:, :2] + box[:, 2:]) / 2
valid = np.logical_and(
crop[:2] <= centers, centers < crop[2:]).all(axis=1)
valid = np.logical_and(
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
return cropped_box, np.where(valid)[0]
def _crop_image(self, img, crop):
x1, y1, x2, y2 = crop
return img[y1:y2, x1:x2, :]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册