未验证 提交 d5702896 编写于 作者: X xinyingxinying 提交者: GitHub

Add cutmix (#958)

* Add cutmix op(#88)
上级 83caf99f
......@@ -31,6 +31,8 @@
| FCOS | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 39.8 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_1x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_1x.yml) |
| FCOS+multiscale_train | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 42.0 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_multiscale_2x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_multiscale_2x.yml) |
| FCOS+DCN | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 44.4 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_dcn_r50_fpn_1x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_dcn_r50_fpn_1x.yml) |
| FCOS+DCN+cutmix | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 44.5 | - | [下载链接]
(https://paddlemodels.bj.bcebos.com/object_detection/fcos_dcn_r50_fpn_1x_cutmix.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_dcn_r50_fpn_1x_cutmix.yml) |
**注意:**
......
architecture: FCOS
max_iters: 90000
use_gpu: true
snapshot_iter: 5000
log_smooth_window: 20
log_iter: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/fcos_dcn_r50_fpn_1x_cutmix/model_final
num_classes: 80
FCOS:
backbone: ResNet
fpn: FPN
fcos_head: FCOSHead
ResNet:
norm_type: affine_channel
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
freeze_at: 2
dcn_v2_stages: [3, 4, 5]
FPN:
min_level: 3
max_level: 7
num_chan: 256
use_c5: false
spatial_scale: [0.03125, 0.0625, 0.125]
has_extra_convs: true
FCOSHead:
num_classes: 80
fpn_stride: [8, 16, 32, 64, 128]
num_convs: 4
norm_type: "gn"
fcos_loss: FCOSLoss
norm_reg_targets: True
centerness_on_reg: True
use_dcn_in_tower: True
nms: MultiClassNMS
MultiClassNMS:
score_threshold: 0.025
nms_top_k: 1000
keep_top_k: 100
nms_threshold: 0.6
background_label: -1
FCOSLoss:
loss_alpha: 0.25
loss_gamma: 2.0
iou_loss_type: "giou"
reg_weights: 1.0
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'fcos_target']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_cutmix: True
- !CutmixImage
alpha: 1.5
beta: 1.5
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: false
- !Gt2FCOSTarget
object_sizes_boundary: [64, 128, 256, 512]
center_sampling_radius: 1.5
downsample_ratios: [8, 16, 32, 64, 128]
norm_reg_targets: True
batch_size: 2
shuffle: true
worker_num: 4
use_process: false
cutmix_epoch: 10
EvalReader:
inputs_def:
fields: ['image', 'im_id', 'im_shape', 'im_info']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: true
batch_size: 1
shuffle: false
worker_num: 1
use_process: false
TestReader:
inputs_def:
# set image_shape if needed
fields: ['image', 'im_id', 'im_shape', 'im_info']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: true
batch_size: 1
shuffle: false
......@@ -167,6 +167,8 @@ class Reader(object):
Default True.
mixup_epoch (int): mixup epoc number. Default is -1, meaning
not use mixup.
cutmix_epoch (int): cutmix epoc number. Default is -1, meaning
not use cutmix.
class_aware_sampling (bool): whether use class-aware sampling or not.
Default False.
worker_num (int): number of working threads/processes.
......@@ -191,6 +193,7 @@ class Reader(object):
drop_last=False,
drop_empty=True,
mixup_epoch=-1,
cutmix_epoch=-1,
class_aware_sampling=False,
worker_num=-1,
use_process=False,
......@@ -241,6 +244,7 @@ class Reader(object):
# sampling
self._mixup_epoch = mixup_epoch
self._cutmix_epoch = cutmix_epoch
self._class_aware_sampling = class_aware_sampling
self._load_img = False
......@@ -289,6 +293,10 @@ class Reader(object):
logger.debug("Disable mixup for dataset samples "
"less than 2 samples")
self._mixup_epoch = -1
if self._cutmix_epoch > 0 and len(self.indexes) < 2:
logger.info("Disable cutmix for dataset samples "
"less than 2 samples")
self._cutmix_epoch = -1
if self._epoch < 0:
self._epoch = 0
......@@ -346,6 +354,13 @@ class Reader(object):
if self._load_img:
sample['mixup']['image'] = self._load_image(sample['mixup'][
'im_file'])
if self._epoch < self._cutmix_epoch:
num = len(self.indexes)
mix_idx = np.random.randint(1, num)
sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx])
if self._load_img:
sample['cutmix']['image'] = self._load_image(sample[
'cutmix']['im_file'])
batch.append(sample)
bs += 1
......
......@@ -89,21 +89,25 @@ class BaseOperator(object):
@register_op
class DecodeImage(BaseOperator):
def __init__(self, to_rgb=True, with_mixup=False):
def __init__(self, to_rgb=True, with_mixup=False, with_cutmix=False):
""" Transform the image data to numpy format.
Args:
to_rgb (bool): whether to convert BGR to RGB
with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score
with_cutmix (bool): whether or not to cutmix image and gt_bbbox/gt_score
"""
super(DecodeImage, self).__init__()
self.to_rgb = to_rgb
self.with_mixup = with_mixup
self.with_cutmix = with_cutmix
if not isinstance(self.to_rgb, bool):
raise TypeError("{}: input type is invalid.".format(self))
if not isinstance(self.with_mixup, bool):
raise TypeError("{}: input type is invalid.".format(self))
if not isinstance(self.with_cutmix, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
""" load image if 'im_file' field is not empty but 'image' is"""
......@@ -142,6 +146,10 @@ class DecodeImage(BaseOperator):
# decode mixup image
if self.with_mixup and 'mixup' in sample:
self.__call__(sample['mixup'], context)
# decode cutmix image
if self.with_cutmix and 'cutmix' in sample:
self.__call__(sample['cutmix'], context)
return sample
......@@ -1094,6 +1102,84 @@ class MixupImage(BaseOperator):
return sample
@register_op
class CutmixImage(BaseOperator):
def __init__(self, alpha=1.5, beta=1.5):
"""
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://https://arxiv.org/abs/1905.04899
Cutmix image and gt_bbbox/gt_score
Args:
alpha (float): alpha parameter of beta distribute
beta (float): beta parameter of beta distribute
"""
super(CutmixImage, self).__init__()
self.alpha = alpha
self.beta = beta
if self.alpha <= 0.0:
raise ValueError("alpha shold be positive in {}".format(self))
if self.beta <= 0.0:
raise ValueError("beta shold be positive in {}".format(self))
def _rand_bbox(self, img1, img2, factor):
""" _rand_bbox """
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
cut_rat = np.sqrt(1. - factor)
cut_w = np.int(w * cut_rat)
cut_h = np.int(h * cut_rat)
# uniform
cx = np.random.randint(w)
cy = np.random.randint(h)
bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)
bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)
img_1 = np.zeros((h, w, img1.shape[2]), 'float32')
img_1[:img1.shape[0], :img1.shape[1], :] = \
img1.astype('float32')
img_2 = np.zeros((h, w, img2.shape[2]), 'float32')
img_2[:img2.shape[0], :img2.shape[1], :] = \
img2.astype('float32')
img_1[bby1:bby2, bbx1:bbx2, :] = img2[bby1:bby2, bbx1:bbx2, :]
return img_1
def __call__(self, sample, context=None):
if 'cutmix' not in sample:
return sample
factor = np.random.beta(self.alpha, self.beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
sample.pop('cutmix')
return sample
if factor <= 0.0:
return sample['cutmix']
img1 = sample['image']
img2 = sample['cutmix']['image']
img = self._rand_bbox(img1, img2, factor)
gt_bbox1 = sample['gt_bbox']
gt_bbox2 = sample['cutmix']['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = sample['gt_class']
gt_class2 = sample['cutmix']['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = sample['gt_score']
gt_score2 = sample['cutmix']['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
sample['image'] = img
sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score
sample['gt_class'] = gt_class
sample['h'] = img.shape[0]
sample['w'] = img.shape[1]
sample.pop('cutmix')
return sample
@register_op
class RandomInterpImage(BaseOperator):
def __init__(self, target_size=0, max_size=0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册