diff --git a/configs/anchor_free/fcos_rt_dla34_fpn_4x.yml b/configs/anchor_free/fcos_rt_dla34_fpn_4x.yml new file mode 100644 index 0000000000000000000000000000000000000000..41a4e5a01a651a1db25c9ba200ef6fe600a95d9a --- /dev/null +++ b/configs/anchor_free/fcos_rt_dla34_fpn_4x.yml @@ -0,0 +1,180 @@ +architecture: FCOS +max_iters: 360000 +use_gpu: true +snapshot_iter: 20000 +log_smooth_window: 20 +log_iter: 20 +save_dir: output +pretrain_weights: "" +metric: COCO +weights: output/fcos_rt_dla34_fpn_4x/model_final +num_classes: 80 + +FCOS: + backbone: DLA + fpn: FPN + fcos_head: FCOSHead + +DLA: + norm_type: sync_bn + levels: [1, 1, 1, 2, 2, 1] + channels: [16, 32, 64, 128, 256, 512] + feature_maps: [3, 4, 5] + +FPN: + min_level: 3 + max_level: 5 + num_chan: 256 + use_c5: false + spatial_scale: [0.03125, 0.0625, 0.125] + has_extra_convs: false + +FCOSHead: + num_classes: 80 + fpn_stride: [8, 16, 32] + num_convs: 4 + norm_type: "gn" + fcos_loss: FCOSLoss + norm_reg_targets: True + centerness_on_reg: True + use_dcn_in_tower: False + nms: MultiClassNMS + +MultiClassNMS: + score_threshold: 0.025 + nms_top_k: 1000 + keep_top_k: 100 + nms_threshold: 0.6 + background_label: -1 + +FCOSLoss: + loss_alpha: 0.25 + loss_gamma: 2.0 + iou_loss_type: "giou" + reg_weights: 1.0 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [300000, 340000] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +TrainReader: + inputs_def: + fields: ['image', 'im_info', 'fcos_target'] + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !RandomFlipImage + prob: 0.5 + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [103.53, 116.28, 123.675] + std: [1.0, 1.0, 1.0] + - !ResizeImage + target_size: [256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + max_size: 900 + interp: 1 + use_cv2: true + - !Permute + to_bgr: false + channel_first: true + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + - !Gt2FCOSTarget + object_sizes_boundary: [64, 128] + center_sampling_radius: 1.5 + downsample_ratios: [8, 16, 32] + norm_reg_targets: True + batch_size: 16 + shuffle: true + worker_num: 4 + use_process: false + +EvalReader: + inputs_def: + fields: ['image', 'im_id', 'im_shape', 'im_info'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [103.53, 116.28, 123.675] + std: [1.0, 1.0, 1.0] + - !ResizeImage + target_size: 512 + max_size: 736 + interp: 1 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: true + batch_size: 1 + shuffle: false + worker_num: 2 + use_process: false + +TestReader: + inputs_def: + # set image_shape if needed + fields: ['image', 'im_id', 'im_shape', 'im_info'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [103.53, 116.28, 123.675] + std: [1.0, 1.0, 1.0] + - !ResizeImage + interp: 1 + max_size: 736 + target_size: 512 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: true + batch_size: 1 + shuffle: false diff --git a/configs/anchor_free/fcos_rt_r50_fpn_4x.yml b/configs/anchor_free/fcos_rt_r50_fpn_4x.yml new file mode 100644 index 0000000000000000000000000000000000000000..52c1a84814791b7b69947c3a65c2a816cd6485a9 --- /dev/null +++ b/configs/anchor_free/fcos_rt_r50_fpn_4x.yml @@ -0,0 +1,182 @@ +architecture: FCOS +max_iters: 360000 +use_gpu: true +snapshot_iter: 20000 +log_smooth_window: 20 +log_iter: 20 +save_dir: output +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar +metric: COCO +weights: output/fcos_rt_r50_fpn_4x/model_final +num_classes: 80 + +FCOS: + backbone: ResNet + fpn: FPN + fcos_head: FCOSHead + +ResNet: + norm_type: sync_bn + norm_decay: 0. + depth: 50 + feature_maps: [3, 4, 5] + variant: b + freeze_at: 2 + +FPN: + min_level: 3 + max_level: 5 + num_chan: 256 + use_c5: false + spatial_scale: [0.03125, 0.0625, 0.125] + has_extra_convs: false + +FCOSHead: + num_classes: 80 + fpn_stride: [8, 16, 32] + num_convs: 4 + norm_type: "gn" + fcos_loss: FCOSLoss + norm_reg_targets: True + centerness_on_reg: True + use_dcn_in_tower: False + nms: MultiClassNMS + +MultiClassNMS: + score_threshold: 0.025 + nms_top_k: 1000 + keep_top_k: 100 + nms_threshold: 0.6 + background_label: -1 + +FCOSLoss: + loss_alpha: 0.25 + loss_gamma: 2.0 + iou_loss_type: "giou" + reg_weights: 1.0 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [300000, 340000] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +TrainReader: + inputs_def: + fields: ['image', 'im_info', 'fcos_target'] + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + - !RandomFlipImage + prob: 0.5 + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [103.53, 116.28, 123.675] + std: [1.0, 1.0, 1.0] + - !ResizeImage + target_size: [256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608] + max_size: 900 + interp: 1 + use_cv2: true + - !Permute + to_bgr: false + channel_first: true + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + - !Gt2FCOSTarget + object_sizes_boundary: [64, 128] + center_sampling_radius: 1.5 + downsample_ratios: [8, 16, 32] + norm_reg_targets: True + batch_size: 16 + shuffle: true + worker_num: 4 + use_process: false + +EvalReader: + inputs_def: + fields: ['image', 'im_id', 'im_shape', 'im_info'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [103.53, 116.28, 123.675] + std: [1.0, 1.0, 1.0] + - !ResizeImage + target_size: 512 + max_size: 736 + interp: 1 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: true + batch_size: 1 + shuffle: false + worker_num: 2 + use_process: false + +TestReader: + inputs_def: + # set image_shape if needed + fields: ['image', 'im_id', 'im_shape', 'im_info'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: false + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [103.53, 116.28, 123.675] + std: [1.0, 1.0, 1.0] + - !ResizeImage + interp: 1 + max_size: 736 + target_size: 512 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: true + batch_size: 1 + shuffle: false diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index a6d2eb18fad8e8099be4ce26562f4b8e33c73c92..3b34a78fa54b2ccf43b62a5b61953225f6103a8a 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -35,6 +35,7 @@ from . import bifpn from . import cspdarknet from . import acfpn from . import ghostnet +from . import dla from .resnet import * from .resnext import * @@ -57,3 +58,4 @@ from .bifpn import * from .cspdarknet import * from .acfpn import * from .ghostnet import * +from .dla import * diff --git a/ppdet/modeling/backbones/dla.py b/ppdet/modeling/backbones/dla.py new file mode 100644 index 0000000000000000000000000000000000000000..83f8be418de1099eca83ac3a6b4492b7276e19c6 --- /dev/null +++ b/ppdet/modeling/backbones/dla.py @@ -0,0 +1,183 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import paddle.fluid.layers as L + +from ppdet.core.workspace import register +from ppdet.modeling.ops import Conv2dUnit + + +def get_norm(norm_type): + assert norm_type in ['bn', 'sync_bn', 'gn', 'affine_channel'] + bn = 0 + gn = 0 + af = 0 + if norm_type == 'bn': + bn = 1 + elif norm_type == 'sync_bn': + bn = 1 + elif norm_type == 'gn': + gn = 1 + elif norm_type == 'affine_channel': + af = 1 + return bn, gn, af + + +class BasicBlock(object): + def __init__(self, norm_type, inplanes, planes, stride=1, name=''): + super(BasicBlock, self).__init__() + bn, gn, af = get_norm(norm_type) + self.conv1 = Conv2dUnit(inplanes, planes, 3, stride=stride, bias_attr=False, bn=bn, gn=gn, af=af, act='relu', name=name+'.conv1') + self.conv2 = Conv2dUnit(planes, planes, 3, stride=1, bias_attr=False, bn=bn, gn=gn, af=af, act=None, name=name+'.conv2') + self.stride = stride + + def __call__(self, x, residual=None): + if residual is None: + residual = x + out = self.conv1(x) + out = self.conv2(out) + out = L.elementwise_add(x=out, y=residual, act=None) + out = L.relu(out) + return out + + + +class Root(object): + def __init__(self, norm_type, in_channels, out_channels, kernel_size, residual, name=''): + super(Root, self).__init__() + bn, gn, af = get_norm(norm_type) + self.conv = Conv2dUnit(in_channels, out_channels, kernel_size, stride=1, bias_attr=False, bn=bn, gn=gn, af=af, act=None, name=name+'.conv') + self.residual = residual + + def __call__(self, *x): + children = x + x = L.concat(list(x), axis=1) + x = self.conv(x) + if self.residual: + x += children[0] + x = L.relu(x) + return x + + +class Tree(object): + def __init__(self, norm_type, levels, block, in_channels, out_channels, stride=1, + level_root=False, root_dim=0, root_kernel_size=1, root_residual=False, name=''): + super(Tree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + if levels == 1: + self.tree1 = block(norm_type, in_channels, out_channels, stride, name=name+'.tree1') + self.tree2 = block(norm_type, out_channels, out_channels, 1, name=name+'.tree2') + else: + self.tree1 = Tree(norm_type, levels - 1, block, in_channels, out_channels, + stride, root_dim=0, + root_kernel_size=root_kernel_size, root_residual=root_residual, name=name+'.tree1') + self.tree2 = Tree(norm_type, levels - 1, block, out_channels, out_channels, + root_dim=root_dim + out_channels, + root_kernel_size=root_kernel_size, root_residual=root_residual, name=name+'.tree2') + if levels == 1: + self.root = Root(norm_type, root_dim, out_channels, root_kernel_size, root_residual, name=name+'.root') + self.level_root = level_root + self.root_dim = root_dim + self.downsample = False + self.stride = stride + self.project = None + self.levels = levels + if stride > 1: + self.downsample = True + if in_channels != out_channels: + bn, gn, af = get_norm(norm_type) + self.project = Conv2dUnit(in_channels, out_channels, 1, stride=1, bias_attr=False, bn=bn, gn=gn, af=af, act=None, name=name+'.project') + + def __call__(self, x, residual=None, children=None): + children = [] if children is None else children + bottom = L.pool2d(input=x, pool_size=self.stride, pool_stride=self.stride, pool_type='max') if self.downsample else x + residual = self.project(bottom) if self.project else bottom + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, residual) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +__all__ = ['DLA'] + + +@register +class DLA(object): + __shared__ = ['norm_type', 'levels', 'channels'] + + def __init__(self, + norm_type="sync_bn", + levels=[1, 1, 1, 2, 2, 1], + channels=[16, 32, 64, 128, 256, 512], + block=BasicBlock, + residual_root=False, + feature_maps=[3, 4, 5]): + self.norm_type = norm_type + self.channels = channels + self.feature_maps = feature_maps + + self._out_features = ["level{}".format(i) for i in range(6)] # 每个特征图的名字 + self._out_feature_channels = {k: channels[i] for i, k in enumerate(self._out_features)} # 每个特征图的输出通道数 + self._out_feature_strides = {k: 2 ** i for i, k in enumerate(self._out_features)} # 每个特征图的下采样倍率 + + bn, gn, af = get_norm(norm_type) + self.base_layer = Conv2dUnit(3, channels[0], 7, stride=1, bias_attr=False, bn=bn, gn=gn, af=af, act='relu', name='dla.base_layer') + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0], name='dla.level0') + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2, name='dla.level1') + self.level2 = Tree(norm_type, levels[2], block, channels[1], channels[2], 2, + level_root=False, root_residual=residual_root, name='dla.level2') + self.level3 = Tree(norm_type, levels[3], block, channels[2], channels[3], 2, + level_root=True, root_residual=residual_root, name='dla.level3') + self.level4 = Tree(norm_type, levels[4], block, channels[3], channels[4], 2, + level_root=True, root_residual=residual_root, name='dla.level4') + self.level5 = Tree(norm_type, levels[5], block, channels[4], channels[5], 2, + level_root=True, root_residual=residual_root, name='dla.level5') + + def _make_conv_level(self, inplanes, planes, convs, stride=1, name=''): + modules = [] + for i in range(convs): + bn, gn, af = get_norm(self.norm_type) + modules.append(Conv2dUnit(inplanes, planes, 3, stride=stride if i == 0 else 1, bias_attr=False, bn=bn, gn=gn, af=af, act='relu', name=name+'.conv%d'%i)) + inplanes = planes + return modules + + def __call__(self, x): + outs = [] + x = self.base_layer(x) + for i in range(6): + name = 'level{}'.format(i) + level = getattr(self, name) + if isinstance(level, list): + for ly in level: + x = ly(x) + else: + x = level(x) + if i in self.feature_maps: + outs.append(x) + return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat) + for idx, feat in enumerate(outs)]) \ No newline at end of file diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index a288e5de97321c1fb8f0455b2918a7be660c0be3..990a5c4988943d57bec95d7618f3f60978f78a1c 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -1565,3 +1565,128 @@ class RetinaOutputDecoder(object): self.nms_top_k = pre_nms_top_n self.keep_top_k = detections_per_im self.nms_eta = nms_eta + + +class Conv2dUnit(object): + def __init__(self, + input_dim, + filters, + filter_size, + stride=1, + bias_attr=False, + bn=0, + gn=0, + af=0, + groups=32, + act=None, + freeze_norm=False, + is_test=False, + norm_decay=0., + use_dcn=False, + bias_init_value=None, + name=''): + super(Conv2dUnit, self).__init__() + self.input_dim = input_dim + self.filters = filters + self.filter_size = filter_size + self.stride = stride + self.bias_attr = bias_attr + self.bn = bn + self.gn = gn + self.af = af + self.groups = groups + self.act = act + self.freeze_norm = freeze_norm + self.is_test = is_test + self.norm_decay = norm_decay + self.use_dcn = use_dcn + self.bias_init_value = bias_init_value + self.name = name + + def __call__(self, x): + conv_name = self.name + ".conv" + if self.use_dcn: + pass + else: + battr = None + if self.bias_attr: + initializer = None + if self.bias_init_value: + initializer = fluid.initializer.Constant(value=self.bias_init_value) + battr = ParamAttr(name=conv_name + ".bias", initializer=initializer) + x = fluid.layers.conv2d( + input=x, + num_filters=self.filters, + filter_size=self.filter_size, + stride=self.stride, + padding=(self.filter_size - 1) // 2, + act=None, + param_attr=ParamAttr(name=conv_name + ".weight"), + bias_attr=battr, + name=conv_name + '.output.1') + if self.bn: + bn_name = self.name + ".bn" + norm_lr = 0. if self.freeze_norm else 1. + norm_decay = self.norm_decay + pattr = ParamAttr( + name=bn_name + '.scale', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=bn_name + '.offset', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + x = fluid.layers.batch_norm( + input=x, + name=bn_name + '.output.1', + is_test=self.is_test, + param_attr=pattr, + bias_attr=battr, + moving_mean_name=bn_name + '.mean', + moving_variance_name=bn_name + '.var') + if self.gn: + gn_name = self.name + ".gn" + norm_lr = 0. if self.freeze_norm else 1. + norm_decay = self.norm_decay + pattr = ParamAttr( + name=gn_name + '.scale', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=gn_name + '.offset', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + x = fluid.layers.group_norm( + input=x, + groups=self.groups, + name=gn_name + '.output.1', + param_attr=pattr, + bias_attr=battr) + if self.af: + af_name = self.name + ".af" + norm_lr = 0. if self.freeze_norm else 1. + norm_decay = self.norm_decay + pattr = ParamAttr( + name=af_name + '.scale', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + battr = ParamAttr( + name=af_name + '.offset', + learning_rate=norm_lr, + regularizer=L2Decay(norm_decay)) + scale = fluid.layers.create_parameter( + shape=[x.shape[1]], + dtype=x.dtype, + attr=pattr, + default_initializer=fluid.initializer.Constant(1.)) + bias = fluid.layers.create_parameter( + shape=[x.shape[1]], + dtype=x.dtype, + attr=battr, + default_initializer=fluid.initializer.Constant(0.)) + x = fluid.layers.affine_channel(x, scale=scale, bias=bias) + if self.act == 'leaky': + x = fluid.layers.leaky_relu(x, alpha=0.1) + elif self.act == 'relu': + x = fluid.layers.relu(x) + return x