未验证 提交 79270ca2 编写于 作者: L littletomatodonkey 提交者: GitHub

Add image augmentation methods (#1006)

Add gridmask and random erasing methods
上级 776ca93e
# GridMask Data Augmentation
## Introduction
- GridMask Data Augmentation
: [https://arxiv.org/abs/2001.04086](https://arxiv.org/abs/2001.04086)
```
@article{chen2020gridmask,
title={GridMask data augmentation},
author={Chen, Pengguang},
journal={arXiv preprint arXiv:2001.04086},
year={2020}
}
```
## Model Zoo
| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs |
| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: |
| ResNet50-vd-FPN | Faster | 2 | 4x | 21.847 | 39.1% | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_gridmask_4x.tar) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/gridmask/faster_rcnn_r50_vd_fpn_gridmask_4x.yml) |
architecture: FasterRCNN
max_iters: 360000
snapshot_iter: 40000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_gridmask_4x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 320000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !GridMaskOp
use_h: true
use_w: true
rotate: 1
offset: false
ratio: 0.5
mode: 1
prob: 0.7
upper_iter: 360000
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_size: 2
worker_num: 2
use_process: true
# Random Erasing Data Augmentation
## Introduction
- Random Erasing Data Augmentation
: [https://arxiv.org/abs/1708.04896](https://arxiv.org/abs/1708.04896)
```
@article{zhong1708random,
title={Random erasing data augmentation. arXiv 2017},
author={Zhong, Z and Zheng, L and Kang, G and Li, S and Yang, Y},
journal={arXiv preprint arXiv:1708.04896}
}
```
## Model Zoo
| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs |
| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: |
| ResNet50-vd-FPN | Faster | 2 | 4x | 21.847 | 39.0% | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_random_erasing_4x.tar) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/random_erasing/faster_rcnn_r50_vd_fpn_random_erasing_4x.yml) |
architecture: FasterRCNN
max_iters: 360000
snapshot_iter: 40000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_random_erasing_4x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [240000, 320000]
- !LinearWarmup
start_factor: 0.1
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !RandomErasingImage
prob: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 2
shuffle: true
worker_num: 2
use_process: false
......@@ -258,6 +258,8 @@ class Reader(object):
self._pos = -1
self._epoch = -1
self._curr_iter = 0
# multi-process
self._worker_num = worker_num
self._parallel = None
......@@ -317,6 +319,7 @@ class Reader(object):
if self.drained():
raise StopIteration
batch = self._load_batch()
self._curr_iter += 1
if self._drop_last and len(batch) < self._batch_size:
raise StopIteration
if self._worker_num > -1:
......@@ -332,6 +335,7 @@ class Reader(object):
break
pos = self.indexes[self._pos]
sample = copy.deepcopy(self._roidbs[pos])
sample["curr_iter"] = self._curr_iter
self._pos += 1
if self._drop_empty and self._fields and 'gt_mask' in self._fields:
......@@ -354,6 +358,7 @@ class Reader(object):
mix_idx = np.random.randint(1, num)
mix_idx = self.indexes[(mix_idx + self._pos - 1) % num]
sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx])
sample['mixup']["curr_iter"] = self._curr_iter
if self._load_img:
sample['mixup']['image'] = self._load_image(sample['mixup'][
'im_file'])
......@@ -361,6 +366,7 @@ class Reader(object):
num = len(self.indexes)
mix_idx = np.random.randint(1, num)
sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx])
sample['cutmix']["curr_iter"] = self._curr_iter
if self._load_img:
sample['cutmix']['image'] = self._load_image(sample[
'cutmix']['im_file'])
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import numpy as np
from PIL import Image
class GridMask(object):
def __init__(self,
use_h=True,
use_w=True,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.7,
upper_iter=360000):
super(GridMask, self).__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.prob = prob
self.st_prob = prob
self.upper_iter = upper_iter
def __call__(self, x, curr_iter):
self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter)
if np.random.rand() > self.prob:
return x
_, h, w = x.shape
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(2, h)
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.l, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.l, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2
+ w].astype(np.float32)
if self.mode == 1:
mask = 1 - mask
mask = np.expand_dims(mask, axis=0)
if self.offset:
offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32)
x = (x * mask + offset * (1 - mask)).astype(x.dtype)
else:
x = (x * mask).astype(x.dtype)
return x
......@@ -467,6 +467,121 @@ class RandomFlipImage(BaseOperator):
return sample
@register_op
class RandomErasingImage(BaseOperator):
def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3):
"""
Random Erasing Data Augmentation, see https://arxiv.org/abs/1708.04896
Args:
prob (float): probability to carry out random erasing
sl (float): lower limit of the erasing area ratio
sh (float): upper limit of the erasing area ratio
r1 (float): aspect ratio of the erasing region
"""
super(RandomErasingImage, self).__init__()
self.prob = prob
self.sl = sl
self.sh = sh
self.r1 = r1
def __call__(self, sample, context=None):
samples = sample
batch_input = True
if not isinstance(samples, Sequence):
batch_input = False
samples = [samples]
for sample in samples:
gt_bbox = sample['gt_bbox']
im = sample['image']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image is not a numpy array.".format(self))
if len(im.shape) != 3:
raise ImageError("{}: image is not 3-dimensional.".format(self))
for idx in range(gt_bbox.shape[0]):
if self.prob <= np.random.rand():
continue
x1, y1, x2, y2 = gt_bbox[idx, :]
w_bbox = x2 - x1 + 1
h_bbox = y2 - y1 + 1
area = w_bbox * h_bbox
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < w_bbox and h < h_bbox:
off_y1 = random.randint(0, int(h_bbox - h))
off_x1 = random.randint(0, int(w_bbox - w))
im[int(y1 + off_y1):int(y1 + off_y1 + h), int(x1 + off_x1):
int(x1 + off_x1 + w), :] = 0
sample['image'] = im
sample = samples if batch_input else samples[0]
return sample
@register_op
class GridMaskOp(BaseOperator):
def __init__(self,
use_h=True,
use_w=True,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.7,
upper_iter=360000):
"""
GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086
Args:
use_h (bool): whether to mask vertically
use_w (boo;): whether to mask horizontally
rotate (float): angle for the mask to rotate
offset (float): mask offset
ratio (float): mask ratio
mode (int): gridmask mode
prob (float): max probability to carry out gridmask
upper_iter (int): suggested to be equal to global max_iter
"""
super(GridMaskOp, self).__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.prob = prob
self.upper_iter = upper_iter
from .gridmask_utils import GridMask
self.gridmask_op = GridMask(
use_h,
use_w,
rotate=rotate,
offset=offset,
ratio=ratio,
mode=mode,
prob=prob,
upper_iter=upper_iter)
def __call__(self, sample, context=None):
samples = sample
batch_input = True
if not isinstance(samples, Sequence):
batch_input = False
samples = [samples]
for sample in samples:
sample['image'] = self.gridmask_op(sample['image'],
sample['curr_iter'])
if not batch_input:
samples = samples[0]
return sample
@register_op
class AutoAugmentImage(BaseOperator):
def __init__(self, is_normalized=False, autoaug_type="v1"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册