未验证 提交 0e228b11 编写于 作者: Y Yang Zhang 提交者: GitHub

Initial implementation of `EfficientDet` (#492)

上级 521a4a6a
architecture: EfficientDet
max_iters: 281250
use_gpu: true
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/EfficientNetB0_pretrained.tar
weights: output/efficientdet_d0/model_final
log_smooth_window: 20
snapshot_iter: 10000
metric: COCO
save_dir: output
num_classes: 81
use_ema: true
ema_decay: 0.9998
EfficientDet:
backbone: EfficientNet
fpn: BiFPN
efficient_head: EfficientHead
anchor_grid: AnchorGrid
box_loss_weight: 50.
EfficientNet:
# norm_type: sync_bn
# TODO
norm_type: bn
scale: b0
use_se: true
BiFPN:
num_chan: 64
repeat: 3
levels: 5
EfficientHead:
repeat: 3
num_chan: 64
prior_prob: 0.01
num_anchors: 9
gamma: 1.5
alpha: 0.25
delta: 0.1
output_decoder:
score_thresh: 0.05 # originally 0.
nms_thresh: 0.5
pre_nms_top_n: 1000 # originally 5000
detections_per_im: 100
nms_eta: 1.0
AnchorGrid:
anchor_base_scale: 4
num_scales: 3
aspect_ratios: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
LearningRate:
base_lr: 0.16
schedulers:
- !CosineDecayWithSkip
total_steps: 281250
skip_steps: 938
- !LinearWarmup
start_factor: 0.05
steps: 938
OptimizerBuilder:
clip_grad_by_norm: 10.
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.00004
type: L2
TrainReader:
inputs_def:
fields: ['image', 'im_id', 'fg_num', 'gt_label', 'gt_target']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !RandomScaledCrop
target_dim: 512
scale_range: [.1, 2.]
interp: 1
- !Permute
to_bgr: false
channel_first: true
- !TargetAssign
image_size: 512
batch_size: 16
shuffle: true
worker_num: 32
bufsize: 16
use_process: true
drop_empty: false
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeAndPad
target_dim: 512
interp: 1
- !Permute
channel_first: true
to_bgr: false
drop_empty: false
batch_size: 16
shuffle: false
worker_num: 2
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id']
image_shape: [3, 512, 512]
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeAndPad
target_dim: 512
interp: 1
- !Permute
channel_first: true
to_bgr: false
batch_size: 16
shuffle: false
......@@ -181,6 +181,14 @@ results of image size 608/416/320 above. Deformable conv is added on stage 5 of
**Notes:** In RetinaNet, the base LR is changed to 0.01 for minibatch size 16.
### EfficientDet
| Scale | Image/gpu | Lr schd | Box AP | Download |
| :---------------: | :-----: | :-----: | :----: | :-------: |
| EfficientDet-D0 | 16 | 300 epochs | 33.8 | [model](https://paddlemodels.bj.bcebos.com/object_detection/efficientdet_d0.pdparams) |
**Notes:** base LR is 0.16 for minibatch size 128 (8x16).
### SSDLite
| Backbone | Size | Image/gpu | Lr schd | Inf time (fps) | Box AP | Download | Configs |
......
......@@ -37,6 +37,7 @@ import cv2
from PIL import Image, ImageEnhance
from ppdet.core.workspace import serializable
from ppdet.modeling.ops import AnchorGrid
from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
......@@ -1971,3 +1972,202 @@ class CornerRatio(BaseOperator):
sample['ratios'] = np.array([height_ratio, width_ratio])
return sample
@register_op
class RandomScaledCrop(BaseOperator):
"""Resize image and bbox based on long side (with optional random scaling),
then crop or pad image to target size.
Args:
target_dim (int): target size.
scale_range (list): random scale range.
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
"""
def __init__(self,
target_dim=512,
scale_range=[.1, 2.],
interp=cv2.INTER_LINEAR):
super(RandomScaledCrop, self).__init__()
self.target_dim = target_dim
self.scale_range = scale_range
self.interp = interp
def __call__(self, sample, context=None):
w = sample['w']
h = sample['h']
random_scale = np.random.uniform(*self.scale_range)
dim = self.target_dim
random_dim = int(dim * random_scale)
dim_max = max(h, w)
scale = random_dim / dim_max
resize_w = int(round(w * scale))
resize_h = int(round(h * scale))
offset_x = int(max(0, np.random.uniform(0., resize_w - dim)))
offset_y = int(max(0, np.random.uniform(0., resize_h - dim)))
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
scale_array = np.array([scale, scale] * 2, dtype=np.float32)
shift_array = np.array([offset_x, offset_y] * 2, dtype=np.float32)
boxes = sample['gt_bbox'] * scale_array - shift_array
boxes = np.clip(boxes, 0, dim - 1)
# filter boxes with no area
area = np.prod(boxes[..., 2:] - boxes[..., :2], axis=1)
valid = (area > 1.).nonzero()[0]
sample['gt_bbox'] = boxes[valid]
sample['gt_class'] = sample['gt_class'][valid]
img = sample['image']
img = cv2.resize(img, (resize_w, resize_h), interpolation=self.interp)
img = np.array(img)
canvas = np.zeros((dim, dim, 3), dtype=img.dtype)
canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[
offset_y:offset_y + dim, offset_x:offset_x + dim, :]
sample['h'] = dim
sample['w'] = dim
sample['image'] = canvas
sample['im_info'] = [resize_h, resize_w, scale]
return sample
@register_op
class ResizeAndPad(BaseOperator):
"""Resize image and bbox, then pad image to target size.
Args:
target_dim (int): target size
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
"""
def __init__(self, target_dim=512, interp=cv2.INTER_LINEAR):
super(ResizeAndPad, self).__init__()
self.target_dim = target_dim
self.interp = interp
def __call__(self, sample, context=None):
w = sample['w']
h = sample['h']
interp = self.interp
dim = self.target_dim
dim_max = max(h, w)
scale = self.target_dim / dim_max
resize_w = int(round(w * scale))
resize_h = int(round(h * scale))
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
scale_array = np.array([scale, scale] * 2, dtype=np.float32)
sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0,
dim - 1)
img = sample['image']
img = cv2.resize(img, (resize_w, resize_h), interpolation=interp)
img = np.array(img)
canvas = np.zeros((dim, dim, 3), dtype=img.dtype)
canvas[:resize_h, :resize_w, :] = img
sample['h'] = dim
sample['w'] = dim
sample['image'] = canvas
sample['im_info'] = [resize_h, resize_w, scale]
return sample
@register_op
class TargetAssign(BaseOperator):
"""Assign regression target and labels.
Args:
image_size (int or list): input image size, a single integer or list of
[h, w]. Default: 512
min_level (int): min level of the feature pyramid. Default: 3
max_level (int): max level of the feature pyramid. Default: 7
anchor_base_scale (int): base anchor scale. Default: 4
num_scales (int): number of anchor scales. Default: 3
aspect_ratios (list): aspect ratios.
Default: [(1, 1), (1.4, 0.7), (0.7, 1.4)]
match_threshold (float): threshold for foreground IoU. Default: 0.5
"""
def __init__(self,
image_size=512,
min_level=3,
max_level=7,
anchor_base_scale=4,
num_scales=3,
aspect_ratios=[(1, 1), (1.4, 0.7), (0.7, 1.4)],
match_threshold=0.5):
super(TargetAssign, self).__init__()
assert image_size % 2 ** max_level == 0, \
"image size should be multiple of the max level stride"
self.image_size = image_size
self.min_level = min_level
self.max_level = max_level
self.anchor_base_scale = anchor_base_scale
self.num_scales = num_scales
self.aspect_ratios = aspect_ratios
self.match_threshold = match_threshold
@property
def anchors(self):
if not hasattr(self, '_anchors'):
anchor_grid = AnchorGrid(self.image_size, self.min_level,
self.max_level, self.anchor_base_scale,
self.num_scales, self.aspect_ratios)
self._anchors = np.concatenate(anchor_grid.generate())
return self._anchors
def iou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
# return area_i / (area_o + 1e-10)
return np.where(area_i == 0., np.zeros_like(area_i), area_i / area_o)
def match(self, anchors, gt_boxes):
# XXX put smaller matrix first would be a little bit faster
mat = self.iou_matrix(gt_boxes, anchors)
max_anchor_for_each_gt = mat.argmax(axis=1)
max_for_each_anchor = mat.max(axis=0)
anchor_to_gt = mat.argmax(axis=0)
anchor_to_gt[max_for_each_anchor < self.match_threshold] = -1
# XXX ensure each gt has at least one anchor assigned,
# see `force_match_for_each_row` in TF implementation
one_hot = np.zeros_like(mat)
one_hot[np.arange(mat.shape[0]), max_anchor_for_each_gt] = 1.
max_anchor_indices = one_hot.sum(axis=0).nonzero()[0]
max_gt_indices = one_hot.argmax(axis=0)[max_anchor_indices]
anchor_to_gt[max_anchor_indices] = max_gt_indices
return anchor_to_gt
def encode(self, anchors, boxes):
wha = anchors[..., 2:] - anchors[..., :2] + 1
ca = anchors[..., :2] + wha * .5
whb = boxes[..., 2:] - boxes[..., :2] + 1
cb = boxes[..., :2] + whb * .5
offsets = np.empty_like(anchors)
offsets[..., :2] = (cb - ca) / wha
offsets[..., 2:] = np.log(whb / wha)
return offsets
def __call__(self, sample, context=None):
gt_boxes = sample['gt_bbox']
gt_labels = sample['gt_class']
labels = np.full((self.anchors.shape[0], 1), 0, dtype=np.int32)
targets = np.full((self.anchors.shape[0], 4), 0., dtype=np.float32)
sample['gt_label'] = labels
sample['gt_target'] = targets
if len(gt_boxes) < 1:
sample['fg_num'] = np.array(0, dtype=np.int32)
return sample
anchor_to_gt = self.match(self.anchors, gt_boxes)
matched_indices = (anchor_to_gt >= 0).nonzero()[0]
labels[matched_indices] = gt_labels[anchor_to_gt[matched_indices]]
matched_boxes = gt_boxes[anchor_to_gt[matched_indices]]
matched_anchors = self.anchors[matched_indices]
matched_targets = self.encode(matched_anchors, matched_boxes)
targets[matched_indices] = matched_targets
sample['fg_num'] = np.array(len(matched_targets), dtype=np.int32)
return sample
......@@ -19,9 +19,11 @@ from . import yolo_head
from . import retina_head
from . import fcos_head
from . import corner_head
from . import efficient_head
from .rpn_head import *
from .yolo_head import *
from .retina_head import *
from .fcos_head import *
from .corner_head import *
from .efficient_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import TruncatedNormal, Constant
from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import RetinaOutputDecoder
from ppdet.core.workspace import register
__all__ = ['EfficientHead']
@register
class EfficientHead(object):
"""
EfficientDet Head
Args:
output_decoder (object): `RetinaOutputDecoder` instance.
repeat (int): Number of convolution layers.
num_chan (int): Number of octave output channels.
prior_prob (float): Initial value of the class prediction layer bias.
num_anchors (int): Number of anchors per cell.
num_classes (int): Number of classes.
gamma (float): Gamma parameter for focal loss.
alpha (float): Alpha parameter for focal loss.
sigma (float): Sigma parameter for smooth l1 loss.
"""
__inject__ = ['output_decoder']
__shared__ = ['num_classes']
def __init__(self,
output_decoder=RetinaOutputDecoder().__dict__,
repeat=3,
num_chan=64,
prior_prob=0.01,
num_anchors=9,
num_classes=81,
gamma=1.5,
alpha=0.25,
delta=0.1):
super(EfficientHead, self).__init__()
self.output_decoder = output_decoder
self.repeat = repeat
self.num_chan = num_chan
self.prior_prob = prior_prob
self.num_anchors = num_anchors
self.num_classes = num_classes
self.gamma = gamma
self.alpha = alpha
self.delta = delta
if isinstance(output_decoder, dict):
self.output_decoder = RetinaOutputDecoder(**output_decoder)
def _get_output(self, body_feats):
def separable_conv(inputs, num_chan, bias_init=None, name=''):
dw_conv_name = name + '_dw'
pw_conv_name = name + '_pw'
in_chan = inputs.shape[1]
fan_in = np.sqrt(1. / (in_chan * 3 * 3))
feat = fluid.layers.conv2d(
input=inputs,
num_filters=in_chan,
groups=in_chan,
filter_size=3,
stride=1,
padding='SAME',
param_attr=ParamAttr(
name=dw_conv_name + '_w',
initializer=TruncatedNormal(scale=fan_in)),
bias_attr=False)
fan_in = np.sqrt(1. / in_chan)
feat = fluid.layers.conv2d(
input=feat,
num_filters=num_chan,
filter_size=1,
stride=1,
param_attr=ParamAttr(
name=pw_conv_name + '_w',
initializer=TruncatedNormal(scale=fan_in)),
bias_attr=ParamAttr(
name=pw_conv_name + '_b',
initializer=bias_init,
regularizer=L2Decay(0.)))
return feat
def subnet(inputs, prefix, level):
feat = inputs
for i in range(self.repeat):
# NOTE share weight across FPN levels
conv_name = '{}_pred_conv_{}'.format(prefix, i)
feat = separable_conv(feat, self.num_chan, name=conv_name)
# NOTE batch norm params are not shared
bn_name = '{}_pred_bn_{}_{}'.format(prefix, level, i)
feat = fluid.layers.batch_norm(
input=feat,
act='swish',
momentum=0.997,
epsilon=1e-4,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
param_attr=ParamAttr(
name=bn_name + '_w',
initializer=Constant(value=1.),
regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
name=bn_name + '_b', regularizer=L2Decay(0.)))
return feat
cls_preds = []
box_preds = []
for l, feat in enumerate(body_feats):
cls_out = subnet(feat, 'cls', l)
box_out = subnet(feat, 'box', l)
bias_init = float(-np.log((1 - self.prior_prob) / self.prior_prob))
bias_init = Constant(value=bias_init)
cls_pred = separable_conv(
cls_out,
self.num_anchors * (self.num_classes - 1),
bias_init=bias_init,
name='cls_pred')
cls_pred = fluid.layers.transpose(cls_pred, perm=[0, 2, 3, 1])
cls_pred = fluid.layers.reshape(
cls_pred, shape=(0, -1, self.num_classes - 1))
cls_preds.append(cls_pred)
box_pred = separable_conv(
box_out, self.num_anchors * 4, name='box_pred')
box_pred = fluid.layers.transpose(box_pred, perm=[0, 2, 3, 1])
box_pred = fluid.layers.reshape(box_pred, shape=(0, -1, 4))
box_preds.append(box_pred)
return cls_preds, box_preds
def get_prediction(self, body_feats, anchors, im_info):
cls_preds, box_preds = self._get_output(body_feats)
cls_preds = [fluid.layers.sigmoid(pred) for pred in cls_preds]
pred_result = self.output_decoder(
bboxes=box_preds,
scores=cls_preds,
anchors=anchors,
im_info=im_info)
return {'bbox': pred_result}
def get_loss(self, body_feats, gt_labels, gt_targets, fg_num):
cls_preds, box_preds = self._get_output(body_feats)
fg_num = fluid.layers.reduce_sum(fg_num, name='fg_num')
fg_num.stop_gradient = True
cls_pred = fluid.layers.concat(cls_preds, axis=1)
box_pred = fluid.layers.concat(box_preds, axis=1)
cls_pred_reshape = fluid.layers.reshape(
cls_pred, shape=(-1, self.num_classes - 1))
gt_labels_reshape = fluid.layers.reshape(gt_labels, shape=(-1, 1))
loss_cls = fluid.layers.sigmoid_focal_loss(
x=cls_pred_reshape,
label=gt_labels_reshape,
fg_num=fg_num,
gamma=self.gamma,
alpha=self.alpha)
loss_cls = fluid.layers.reduce_sum(loss_cls)
loss_bbox = fluid.layers.huber_loss(
input=box_pred, label=gt_targets, delta=self.delta)
mask = fluid.layers.expand(gt_labels, expand_times=[1, 1, 4]) > 0
loss_bbox *= fluid.layers.cast(mask, 'float32')
loss_bbox = fluid.layers.reduce_sum(loss_bbox) / (fg_num * 4)
return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
......@@ -22,6 +22,7 @@ from . import cascade_rcnn_cls_aware
from . import yolov3
from . import ssd
from . import retinanet
from . import efficientdet
from . import blazeface
from . import faceboxes
from . import fcos
......@@ -35,6 +36,7 @@ from .cascade_rcnn_cls_aware import *
from .yolov3 import *
from .ssd import *
from .retinanet import *
from .efficientdet import *
from .blazeface import *
from .faceboxes import *
from .fcos import *
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from collections import OrderedDict
import paddle.fluid as fluid
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
__all__ = ['EfficientDet']
@register
class EfficientDet(object):
"""
EfficientDet architecture, see https://arxiv.org/abs/1911.09070
Args:
backbone (object): backbone instance
fpn (object): feature pyramid network instance
retina_head (object): `RetinaHead` instance
"""
__category__ = 'architecture'
__inject__ = ['backbone', 'fpn', 'efficient_head', 'anchor_grid']
def __init__(self,
backbone,
fpn,
efficient_head,
anchor_grid,
box_loss_weight=50.):
super(EfficientDet, self).__init__()
self.backbone = backbone
self.fpn = fpn
self.efficient_head = efficient_head
self.anchor_grid = anchor_grid
self.box_loss_weight = box_loss_weight
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
if mode == 'train':
gt_labels = feed_vars['gt_label']
gt_targets = feed_vars['gt_target']
fg_num = feed_vars['fg_num']
else:
im_info = feed_vars['im_info']
mixed_precision_enabled = mixed_precision_global_state() is not None
if mixed_precision_enabled:
im = fluid.layers.cast(im, 'float16')
body_feats = self.backbone(im)
if mixed_precision_enabled:
body_feats = [fluid.layers.cast(f, 'float32') for f in body_feats]
body_feats = self.fpn(body_feats)
# XXX not used for training, but the parameters are needed when
# exporting inference model
anchors = self.anchor_grid()
if mode == 'train':
loss = self.efficient_head.get_loss(body_feats, gt_labels,
gt_targets, fg_num)
loss_cls = loss['loss_cls']
loss_bbox = loss['loss_bbox']
total_loss = loss_cls + self.box_loss_weight * loss_bbox
loss.update({'loss': total_loss})
return loss
else:
pred = self.efficient_head.get_prediction(body_feats, anchors,
im_info)
return pred
def _inputs_def(self, image_shape):
im_shape = [None] + image_shape
inputs_def = {
'image': {
'shape': im_shape,
'dtype': 'float32'
},
'im_info': {
'shape': [None, 3],
'dtype': 'float32'
},
'im_id': {
'shape': [None, 1],
'dtype': 'int64'
},
'im_shape': {
'shape': [None, 3],
'dtype': 'float32'
},
'fg_num': {
'shape': [None, 1],
'dtype': 'int32'
},
'gt_label': {
'shape': [None, None, 1],
'dtype': 'int32'
},
'gt_target': {
'shape': [None, None, 4],
'dtype': 'float32'
},
}
return inputs_def
def build_inputs(self,
image_shape=[3, None, None],
fields=[
'image', 'im_info', 'im_id', 'fg_num', 'gt_label',
'gt_target'
],
use_dataloader=True,
iterable=False):
inputs_def = self._inputs_def(image_shape)
feed_vars = OrderedDict([(key, fluid.data(
name=key,
shape=inputs_def[key]['shape'],
dtype=inputs_def[key]['dtype'])) for key in fields])
loader = fluid.io.DataLoader.from_generator(
feed_list=list(feed_vars.values()),
capacity=16,
use_double_buffer=True,
iterable=iterable) if use_dataloader else None
return feed_vars, loader
def train(self, feed_vars):
return self.build(feed_vars, 'train')
def eval(self, feed_vars):
return self.build(feed_vars, 'test')
def test(self, feed_vars):
return self.build(feed_vars, 'test')
......@@ -30,6 +30,8 @@ from . import hrnet
from . import hrfpn
from . import bfp
from . import hourglass
from . import efficientnet
from . import bifpn
from .resnet import *
from .resnext import *
......@@ -47,3 +49,5 @@ from .hrnet import *
from .hrfpn import *
from .bfp import *
from .hourglass import *
from .efficientnet import *
from .bifpn import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import Constant, Xavier
from ppdet.core.workspace import register
__all__ = ['BiFPN']
class FusionConv(object):
def __init__(self, num_chan):
super(FusionConv, self).__init__()
self.num_chan = num_chan
def __call__(self, inputs, name=''):
x = fluid.layers.swish(inputs)
# depthwise
x = fluid.layers.conv2d(
x,
self.num_chan,
filter_size=3,
padding='SAME',
groups=self.num_chan,
param_attr=ParamAttr(
initializer=Xavier(), name=name + '_dw_w'),
bias_attr=False)
# pointwise
x = fluid.layers.conv2d(
x,
self.num_chan,
filter_size=1,
param_attr=ParamAttr(
initializer=Xavier(), name=name + '_pw_w'),
bias_attr=ParamAttr(
regularizer=L2Decay(0.), name=name + '_pw_b'))
# bn + act
x = fluid.layers.batch_norm(
x,
momentum=0.997,
epsilon=1e-04,
param_attr=ParamAttr(
initializer=Constant(1.0),
regularizer=L2Decay(0.),
name=name + '_bn_w'),
bias_attr=ParamAttr(
regularizer=L2Decay(0.), name=name + '_bn_b'))
return x
class BiFPNCell(object):
def __init__(self, num_chan, levels=5):
super(BiFPNCell, self).__init__()
self.levels = levels
self.num_chan = num_chan
num_trigates = levels - 2
num_bigates = levels
self.trigates = fluid.layers.create_parameter(
shape=[num_trigates, 3],
dtype='float32',
default_initializer=fluid.initializer.Constant(1.))
self.bigates = fluid.layers.create_parameter(
shape=[num_bigates, 2],
dtype='float32',
default_initializer=fluid.initializer.Constant(1.))
self.eps = 1e-4
def __call__(self, inputs, cell_name=''):
assert len(inputs) == self.levels
def upsample(feat):
return fluid.layers.resize_nearest(feat, scale=2.)
def downsample(feat):
return fluid.layers.pool2d(
feat,
pool_type='max',
pool_size=3,
pool_stride=2,
pool_padding='SAME')
fuse_conv = FusionConv(self.num_chan)
# normalize weight
trigates = fluid.layers.relu(self.trigates)
bigates = fluid.layers.relu(self.bigates)
trigates /= fluid.layers.reduce_sum(
trigates, dim=1, keep_dim=True) + self.eps
bigates /= fluid.layers.reduce_sum(
bigates, dim=1, keep_dim=True) + self.eps
feature_maps = list(inputs) # make a copy
# top down path
for l in range(self.levels - 1):
p = self.levels - l - 2
w1 = fluid.layers.slice(
bigates, axes=[0, 1], starts=[l, 0], ends=[l + 1, 1])
w2 = fluid.layers.slice(
bigates, axes=[0, 1], starts=[l, 1], ends=[l + 1, 2])
above = upsample(feature_maps[p + 1])
feature_maps[p] = fuse_conv(
w1 * above + w2 * inputs[p],
name='{}_tb_{}'.format(cell_name, l))
# bottom up path
for l in range(1, self.levels):
p = l
name = '{}_bt_{}'.format(cell_name, l)
below = downsample(feature_maps[p - 1])
if p == self.levels - 1:
# handle P7
w1 = fluid.layers.slice(
bigates, axes=[0, 1], starts=[p, 0], ends=[p + 1, 1])
w2 = fluid.layers.slice(
bigates, axes=[0, 1], starts=[p, 1], ends=[p + 1, 2])
feature_maps[p] = fuse_conv(
w1 * below + w2 * inputs[p], name=name)
else:
w1 = fluid.layers.slice(
trigates, axes=[0, 1], starts=[p - 1, 0], ends=[p, 1])
w2 = fluid.layers.slice(
trigates, axes=[0, 1], starts=[p - 1, 1], ends=[p, 2])
w3 = fluid.layers.slice(
trigates, axes=[0, 1], starts=[p - 1, 2], ends=[p, 3])
feature_maps[p] = fuse_conv(
w1 * feature_maps[p] + w2 * below + w3 * inputs[p],
name=name)
return feature_maps
@register
class BiFPN(object):
"""
Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070
Args:
num_chan (int): number of feature channels
repeat (int): number of repeats of the BiFPN module
level (int): number of FPN levels, default: 5
"""
def __init__(self, num_chan, repeat=3, levels=5):
super(BiFPN, self).__init__()
self.num_chan = num_chan
self.repeat = repeat
self.levels = levels
def __call__(self, inputs):
feats = []
# NOTE add two extra levels
for idx in range(self.levels):
if idx <= len(inputs):
if idx == len(inputs):
feat = inputs[-1]
else:
feat = inputs[idx]
if feat.shape[1] != self.num_chan:
feat = fluid.layers.conv2d(
feat,
self.num_chan,
filter_size=1,
padding='SAME',
param_attr=ParamAttr(initializer=Xavier()),
bias_attr=ParamAttr(regularizer=L2Decay(0.)))
feat = fluid.layers.batch_norm(
feat,
momentum=0.997,
epsilon=1e-04,
param_attr=ParamAttr(
initializer=Constant(1.0), regularizer=L2Decay(0.)),
bias_attr=ParamAttr(regularizer=L2Decay(0.)))
if idx >= len(inputs):
feat = fluid.layers.pool2d(
feat,
pool_type='max',
pool_size=3,
pool_stride=2,
pool_padding='SAME')
feats.append(feat)
biFPN = BiFPNCell(self.num_chan, self.levels)
for r in range(self.repeat):
feats = biFPN(feats, 'bifpn_{}'.format(r))
return feats
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
import collections
import math
import re
from paddle import fluid
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
__all__ = ['EfficientNet']
GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient',
'depth_coefficient', 'depth_divisor'
])
BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'stride', 'se_ratio'
])
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
def _decode_block_string(block_string):
assert isinstance(block_string, str)
ops = block_string.split('_')
options = {}
for op in ops:
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
assert (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
return BlockArgs(
kernel_size=int(options['k']),
num_repeat=int(options['r']),
input_filters=int(options['i']),
output_filters=int(options['o']),
expand_ratio=int(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s'][0]))
def get_model_params(scale):
block_strings = [
'r1_k3_s11_e1_i32_o16_se0.25',
'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25',
'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25',
'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
block_args = []
for block_string in block_strings:
block_args.append(_decode_block_string(block_string))
params_dict = {
# width, depth
'b0': (1.0, 1.0),
'b1': (1.0, 1.1),
'b2': (1.1, 1.2),
'b3': (1.2, 1.4),
'b4': (1.4, 1.8),
'b5': (1.6, 2.2),
'b6': (1.8, 2.6),
'b7': (2.0, 3.1),
}
w, d = params_dict[scale]
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
width_coefficient=w,
depth_coefficient=d,
depth_divisor=8)
return block_args, global_params
def round_filters(filters, global_params):
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
filters *= multiplier
min_depth = divisor
new_filters = max(min_depth,
int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
def round_repeats(repeats, global_params):
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
def conv2d(inputs,
num_filters,
filter_size,
stride=1,
padding='SAME',
groups=1,
use_bias=False,
name='conv2d'):
param_attr = fluid.ParamAttr(name=name + '_weights')
bias_attr = False
if use_bias:
bias_attr = fluid.ParamAttr(
name=name + '_offset', regularizer=L2Decay(0.))
feats = fluid.layers.conv2d(
inputs,
num_filters,
filter_size,
groups=groups,
name=name,
stride=stride,
padding=padding,
param_attr=param_attr,
bias_attr=bias_attr)
return feats
def batch_norm(inputs, momentum, eps, name=None):
param_attr = fluid.ParamAttr(name=name + '_scale', regularizer=L2Decay(0.))
bias_attr = fluid.ParamAttr(name=name + '_offset', regularizer=L2Decay(0.))
return fluid.layers.batch_norm(
input=inputs,
momentum=momentum,
epsilon=eps,
name=name,
moving_mean_name=name + '_mean',
moving_variance_name=name + '_variance',
param_attr=param_attr,
bias_attr=bias_attr)
def mb_conv_block(inputs,
input_filters,
output_filters,
expand_ratio,
kernel_size,
stride,
momentum,
eps,
se_ratio=None,
name=None):
feats = inputs
num_filters = input_filters * expand_ratio
if expand_ratio != 1:
feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv')
feats = batch_norm(feats, momentum, eps, name=name + '_bn0')
feats = fluid.layers.swish(feats)
feats = conv2d(
feats,
num_filters,
kernel_size,
stride,
groups=num_filters,
name=name + '_depthwise_conv')
feats = batch_norm(feats, momentum, eps, name=name + '_bn1')
feats = fluid.layers.swish(feats)
if se_ratio is not None:
filter_squeezed = max(1, int(input_filters * se_ratio))
squeezed = fluid.layers.pool2d(
feats, pool_type='avg', global_pooling=True)
squeezed = conv2d(
squeezed,
filter_squeezed,
1,
use_bias=True,
name=name + '_se_reduce')
squeezed = fluid.layers.swish(squeezed)
squeezed = conv2d(
squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand')
feats = feats * fluid.layers.sigmoid(squeezed)
feats = conv2d(feats, output_filters, 1, name=name + '_project_conv')
feats = batch_norm(feats, momentum, eps, name=name + '_bn2')
if stride == 1 and input_filters == output_filters:
feats = fluid.layers.elementwise_add(feats, inputs)
return feats
@register
class EfficientNet(object):
"""
EfficientNet, see https://arxiv.org/abs/1905.11946
Args:
scale (str): compounding scale factor, 'b0' - 'b7'.
use_se (bool): use squeeze and excite module.
norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
"""
__shared__ = ['norm_type']
def __init__(self, scale='b0', use_se=True, norm_type='bn'):
assert scale in ['b' + str(i) for i in range(8)], \
"valid scales are b0 - b7"
assert norm_type in ['bn', 'sync_bn'], \
"only 'bn' and 'sync_bn' are supported"
super(EfficientNet, self).__init__()
self.norm_type = norm_type
self.scale = scale
self.use_se = use_se
def __call__(self, inputs):
blocks_args, global_params = get_model_params(self.scale)
momentum = global_params.batch_norm_momentum
eps = global_params.batch_norm_epsilon
num_filters = round_filters(32, global_params)
feats = conv2d(
inputs,
num_filters=num_filters,
filter_size=3,
stride=2,
name='_conv_stem')
feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0')
feats = fluid.layers.swish(feats)
layer_count = 0
feature_maps = []
for b, block_arg in enumerate(blocks_args):
for r in range(block_arg.num_repeat):
input_filters = round_filters(block_arg.input_filters,
global_params)
output_filters = round_filters(block_arg.output_filters,
global_params)
kernel_size = block_arg.kernel_size
stride = block_arg.stride
se_ratio = None
if self.use_se:
se_ratio = block_arg.se_ratio
if r > 0:
input_filters = output_filters
stride = 1
feats = mb_conv_block(
feats,
input_filters,
output_filters,
block_arg.expand_ratio,
kernel_size,
stride,
momentum,
eps,
se_ratio=se_ratio,
name='_blocks.{}.'.format(layer_count))
layer_count += 1
feature_maps.append(feats)
return list(feature_maps[i] for i in [2, 4, 6])
......@@ -18,17 +18,19 @@ import math
import six
from paddle import fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ppdet.utils.bbox_utils import bbox_overlaps, box_to_delta
__all__ = [
'AnchorGenerator', 'DropBlock', 'RPNTargetAssign', 'GenerateProposals',
'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool',
'MultiBoxHead', 'SSDLiteMultiBoxHead', 'SSDOutputDecoder',
'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm', 'DeformConvNorm',
'MultiClassSoftNMS', 'LibraBBoxAssigner'
'AnchorGenerator', 'AnchorGrid', 'DropBlock', 'RPNTargetAssign',
'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner',
'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder',
'ConvNorm', 'DeformConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner'
]
......@@ -324,6 +326,94 @@ class AnchorGenerator(object):
self.stride = stride
@register
@serializable
class AnchorGrid(object):
"""Generate anchor grid
Args:
image_size (int or list): input image size, may be a single integer or
list of [h, w]. Default: 512
min_level (int): min level of the feature pyramid. Default: 3
max_level (int): max level of the feature pyramid. Default: 7
anchor_base_scale: base anchor scale. Default: 4
num_scales: number of anchor scales. Default: 3
aspect_ratios: aspect ratios. default: [[1, 1], [1.4, 0.7], [0.7, 1.4]]
"""
def __init__(self,
image_size=512,
min_level=3,
max_level=7,
anchor_base_scale=4,
num_scales=3,
aspect_ratios=[[1, 1], [1.4, 0.7], [0.7, 1.4]]):
super(AnchorGrid, self).__init__()
if isinstance(image_size, Integral):
self.image_size = [image_size, image_size]
else:
self.image_size = image_size
for dim in self.image_size:
assert dim % 2 ** max_level == 0, \
"image size should be multiple of the max level stride"
self.min_level = min_level
self.max_level = max_level
self.anchor_base_scale = anchor_base_scale
self.num_scales = num_scales
self.aspect_ratios = aspect_ratios
@property
def base_cell(self):
if not hasattr(self, '_base_cell'):
self._base_cell = self.make_cell()
return self._base_cell
def make_cell(self):
scales = [2**(i / self.num_scales) for i in range(self.num_scales)]
scales = np.array(scales)
ratios = np.array(self.aspect_ratios)
ws = np.outer(scales, ratios[:, 0]).reshape(-1, 1)
hs = np.outer(scales, ratios[:, 1]).reshape(-1, 1)
anchors = np.hstack((-0.5 * ws, -0.5 * hs, 0.5 * ws, 0.5 * hs))
return anchors
def make_grid(self, stride):
cell = self.base_cell * stride * self.anchor_base_scale
x_steps = np.arange(stride // 2, self.image_size[1], stride)
y_steps = np.arange(stride // 2, self.image_size[0], stride)
offset_x, offset_y = np.meshgrid(x_steps, y_steps)
offset_x = offset_x.flatten()
offset_y = offset_y.flatten()
offsets = np.stack((offset_x, offset_y, offset_x, offset_y), axis=-1)
offsets = offsets[:, np.newaxis, :]
return (cell + offsets).reshape(-1, 4)
def generate(self):
return [
self.make_grid(2**l)
for l in range(self.min_level, self.max_level + 1)
]
def __call__(self):
if not hasattr(self, '_anchor_vars'):
anchor_vars = []
helper = LayerHelper('anchor_grid')
for idx, l in enumerate(range(self.min_level, self.max_level + 1)):
stride = 2**l
anchors = self.make_grid(stride)
var = helper.create_parameter(
attr=ParamAttr(name='anchors_{}'.format(idx)),
shape=anchors.shape,
dtype='float32',
stop_gradient=True,
default_initializer=NumpyArrayInitializer(anchors))
anchor_vars.append(var)
var.persistable = True
self._anchor_vars = anchor_vars
return self._anchor_vars
@register
@serializable
class RPNTargetAssign(object):
......
......@@ -16,12 +16,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import logging
from paddle import fluid
import paddle.fluid.optimizer as optimizer
import paddle.fluid.regularizer as regularizer
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.layers.ops import cos
from ppdet.core.workspace import register, serializable
......@@ -75,6 +78,50 @@ class CosineDecay(object):
def __call__(self, base_lr=None, learning_rate=None):
assert base_lr is not None, "either base LR or values should be provided"
lr = fluid.layers.cosine_decay(base_lr, 1, self.max_iters)
@serializable
class CosineDecayWithSkip(object):
"""
Cosine decay, with explicit support for warm up
Args:
total_steps (int): total steps over which to apply the decay
skip_steps (int): skip some steps at the beginning, e.g., warm up
"""
def __init__(self, total_steps, skip_steps=None):
super(CosineDecayWithSkip, self).__init__()
assert (not skip_steps or skip_steps > 0), \
"skip steps must be greater than zero"
assert total_steps > 0, "total step must be greater than zero"
assert (not skip_steps or skip_steps < total_steps), \
"skip steps must be smaller than total steps"
self.total_steps = total_steps
self.skip_steps = skip_steps
def __call__(self, base_lr=None, learning_rate=None):
steps = _decay_step_counter()
total = self.total_steps
if self.skip_steps is not None:
total -= self.skip_steps
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=base_lr,
dtype='float32',
persistable=True,
name="learning_rate")
def decay():
cos_lr = base_lr * .5 * (cos(steps * (math.pi / total)) + 1)
fluid.layers.tensor.assign(input=cos_lr, output=lr)
if self.skip_steps is None:
decay()
else:
skipped = steps >= self.skip_steps
fluid.layers.cond(skipped, decay)
return lr
......@@ -140,10 +187,12 @@ class OptimizerBuilder():
__category__ = 'optim'
def __init__(self,
clip_grad_by_norm=None,
regularizer={'type': 'L2',
'factor': .0001},
optimizer={'type': 'Momentum',
'momentum': .9}):
self.clip_grad_by_norm = clip_grad_by_norm
self.regularizer = regularizer
self.optimizer = optimizer
......
......@@ -38,6 +38,8 @@ set_paddle_flags(
)
from paddle import fluid
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
from paddle.fluid.optimizer import ExponentialMovingAverage
from ppdet.experimental import mixed_precision_context
from ppdet.core.workspace import load_config, merge_config, create
......@@ -124,10 +126,21 @@ def main():
loss *= ctx.get_loss_scale_var()
lr = lr_builder()
optimizer = optim_builder(lr)
optimizer.minimize(loss)
clip = None
if optim_builder.clip_grad_by_norm is not None:
clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=optim_builder.clip_grad_by_norm)
optimizer.minimize(loss, grad_clip=clip)
if FLAGS.fp16:
loss /= ctx.get_loss_scale_var()
if 'use_ema' in cfg and cfg['use_ema']:
global_steps = _decay_step_counter()
ema = ExponentialMovingAverage(
cfg['ema_decay'], thres_steps=global_steps)
ema.update()
# parse train fetches
train_keys, train_values, _ = parse_fetches(train_fetches)
train_values.append(lr)
......@@ -265,6 +278,8 @@ def main():
if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
and (not FLAGS.dist or trainer_id == 0):
save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
if 'use_ema' in cfg and cfg['use_ema']:
exe.run(ema.apply_program)
checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
if FLAGS.eval:
......@@ -299,6 +314,9 @@ def main():
logger.info("Best test box ap: {}, in iter: {}".format(
best_box_ap_list[0], best_box_ap_list[1]))
if 'use_ema' in cfg and cfg['use_ema']:
exe.run(ema.restore_program)
train_loader.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册