diff --git a/configs/gcnet/README.md b/configs/gcnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b95e710336943ac0051235a017a9f61af9b136ae --- /dev/null +++ b/configs/gcnet/README.md @@ -0,0 +1,34 @@ +# GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond + +## Introduction + +- GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond +: [https://arxiv.org/abs/1904.11492](https://arxiv.org/abs/1904.11492) + +``` +@article{DBLP:journals/corr/abs-1904-11492, + author = {Yue Cao and + Jiarui Xu and + Stephen Lin and + Fangyun Wei and + Han Hu}, + title = {GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond}, + journal = {CoRR}, + volume = {abs/1904.11492}, + year = {2019}, + url = {http://arxiv.org/abs/1904.11492}, + archivePrefix = {arXiv}, + eprint = {1904.11492}, + timestamp = {Tue, 09 Jul 2019 16:48:55 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1904-11492}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` + + +## Model Zoo + +| Backbone | Type | Context| Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | +| :---------------------- | :-------------: | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | +| ResNet50-vd-FPN | Mask | GC(c3-c5, r16, add) | 2 | 2x | 15.31 | 41.4 | 36.8 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.tar) | +| ResNet50-vd-FPN | Mask | GC(c3-c5, r16, mul) | 2 | 2x | 15.35 | 40.7 | 36.1 | [model](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.tar) | diff --git a/configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.yml b/configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.yml new file mode 100644 index 0000000000000000000000000000000000000000..f053deeeff112d5b98e342cecf8a525d0691973b --- /dev/null +++ b/configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x.yml @@ -0,0 +1,119 @@ +architecture: MaskRCNN +use_gpu: true +max_iters: 180000 +snapshot_iter: 10000 +log_smooth_window: 20 +save_dir: output +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar +metric: COCO +weights: output/mask_rcnn_r50_vd_fpn_gcb_add_r16_2x/model_final/ +num_classes: 81 + +MaskRCNN: + backbone: ResNet + fpn: FPN + rpn_head: FPNRPNHead + roi_extractor: FPNRoIAlign + bbox_head: BBoxHead + bbox_assigner: BBoxAssigner + +ResNet: + depth: 50 + feature_maps: [2, 3, 4, 5] + freeze_at: 2 + norm_type: bn + variant: d + gcb_stages: [3, 4, 5] + gcb_params: + ratio: 0.0625 + pooling_type: att + fusion_types: [channel_add] + +FPN: + max_level: 6 + min_level: 2 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125, 0.25] + +FPNRPNHead: + anchor_generator: + aspect_ratios: [0.5, 1.0, 2.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_start_size: 32 + max_level: 6 + min_level: 2 + num_chan: 256 + rpn_target_assign: + rpn_batch_size_per_im: 256 + rpn_fg_fraction: 0.5 + rpn_negative_overlap: 0.3 + rpn_positive_overlap: 0.7 + rpn_straddle_thresh: 0.0 + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + +FPNRoIAlign: + canconical_level: 4 + canonical_size: 224 + max_level: 5 + min_level: 2 + box_resolution: 7 + sampling_ratio: 2 + mask_resolution: 14 + +MaskHead: + dilation: 1 + conv_dim: 256 + num_convs: 4 + resolution: 28 + +BBoxAssigner: + batch_size_per_im: 512 + bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] + bg_thresh_hi: 0.5 + bg_thresh_lo: 0.0 + fg_fraction: 0.25 + fg_thresh: 0.5 + +MaskAssigner: + resolution: 28 + +BBoxHead: + head: TwoFCHead + nms: + keep_top_k: 100 + nms_threshold: 0.5 + score_threshold: 0.05 + +TwoFCHead: + mlp_dim: 1024 + +LearningRate: + base_lr: 0.02 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [120000, 160000] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +_READER_: '../mask_fpn_reader.yml' +TrainReader: + batch_size: 2 diff --git a/configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.yml b/configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.yml new file mode 100644 index 0000000000000000000000000000000000000000..7eea2b4b0295580eef9f5f4bae7143198dd314be --- /dev/null +++ b/configs/gcnet/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x.yml @@ -0,0 +1,119 @@ +architecture: MaskRCNN +use_gpu: true +max_iters: 180000 +snapshot_iter: 10000 +log_smooth_window: 20 +save_dir: output +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar +metric: COCO +weights: output/mask_rcnn_r50_vd_fpn_gcb_mul_r16_2x/model_final/ +num_classes: 81 + +MaskRCNN: + backbone: ResNet + fpn: FPN + rpn_head: FPNRPNHead + roi_extractor: FPNRoIAlign + bbox_head: BBoxHead + bbox_assigner: BBoxAssigner + +ResNet: + depth: 50 + feature_maps: [2, 3, 4, 5] + freeze_at: 2 + norm_type: bn + variant: d + gcb_stages: [3, 4, 5] + gcb_params: + ratio: 0.0625 + pooling_type: att + fusion_types: [channel_mul] + +FPN: + max_level: 6 + min_level: 2 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125, 0.25] + +FPNRPNHead: + anchor_generator: + aspect_ratios: [0.5, 1.0, 2.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_start_size: 32 + max_level: 6 + min_level: 2 + num_chan: 256 + rpn_target_assign: + rpn_batch_size_per_im: 256 + rpn_fg_fraction: 0.5 + rpn_negative_overlap: 0.3 + rpn_positive_overlap: 0.7 + rpn_straddle_thresh: 0.0 + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + +FPNRoIAlign: + canconical_level: 4 + canonical_size: 224 + max_level: 5 + min_level: 2 + box_resolution: 7 + sampling_ratio: 2 + mask_resolution: 14 + +MaskHead: + dilation: 1 + conv_dim: 256 + num_convs: 4 + resolution: 28 + +BBoxAssigner: + batch_size_per_im: 512 + bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] + bg_thresh_hi: 0.5 + bg_thresh_lo: 0.0 + fg_fraction: 0.25 + fg_thresh: 0.5 + +MaskAssigner: + resolution: 28 + +BBoxHead: + head: TwoFCHead + nms: + keep_top_k: 100 + nms_threshold: 0.5 + score_threshold: 0.05 + +TwoFCHead: + mlp_dim: 1024 + +LearningRate: + base_lr: 0.02 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [120000, 160000] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +_READER_: '../mask_fpn_reader.yml' +TrainReader: + batch_size: 2 diff --git a/docs/MODEL_ZOO.md b/docs/MODEL_ZOO.md index 971801f1323a99efce0e56048465c7fdc0048e42..a50d38d936dc9c786810105b22f8d5e573c2c263 100644 --- a/docs/MODEL_ZOO.md +++ b/docs/MODEL_ZOO.md @@ -102,6 +102,9 @@ The backbone models pretrained on ImageNet are available. All backbone models ar ### IOU loss * GIOU loss and DIOU loss are included now. See more details in [IOU loss model zoo](../configs/iou_loss/README.md). +### GCNet +* See more details in [GCNet model zoo](../configs/gcnet/README.md). + ### Group Normalization | Backbone | Type | Image/gpu | Lr schd | Box AP | Mask AP | Download | diff --git a/docs/MODEL_ZOO_cn.md b/docs/MODEL_ZOO_cn.md index 6c98fa0e825efbbc6c6f4deaa92f279c4e495ee1..8aee6ddf56d1ed2d2124ff3cf628b57d366b3198 100644 --- a/docs/MODEL_ZOO_cn.md +++ b/docs/MODEL_ZOO_cn.md @@ -99,6 +99,9 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型 ### IOU loss * 目前模型库中包括GIOU loss和DIOU loss,详情加[IOU loss模型库](../configs/iou_loss/README.md). +### GCNet +* 详情见[GCNet模型库](../configs/gcnet/README.md). + ### Group Normalization | 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 | Box AP | Mask AP | 下载 | diff --git a/ppdet/modeling/backbones/gc_block.py b/ppdet/modeling/backbones/gc_block.py new file mode 100755 index 0000000000000000000000000000000000000000..fbd37422345b5676e6ac73b9a537c28a6c5f463e --- /dev/null +++ b/ppdet/modeling/backbones/gc_block.py @@ -0,0 +1,124 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import paddle +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +from paddle.fluid.initializer import ConstantInitializer + + +def spatial_pool(x, pooling_type, name): + _, channel, height, width = x.shape + if pooling_type == 'att': + input_x = x + # [N, 1, C, H * W] + input_x = fluid.layers.reshape(input_x, shape=(0, 1, channel, -1)) + context_mask = fluid.layers.conv2d( + input=x, + num_filters=1, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=ParamAttr(name=name + "_bias")) + # [N, 1, H * W] + context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1)) + # [N, 1, H * W] + context_mask = fluid.layers.softmax(context_mask, axis=2) + # [N, 1, H * W, 1] + context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1, 1)) + # [N, 1, C, 1] + context = fluid.layers.matmul(input_x, context_mask) + # [N, C, 1, 1] + context = fluid.layers.reshape(context, shape=(0, channel, 1, 1)) + else: + # [N, C, 1, 1] + context = fluid.layers.pool2d( + input=x, pool_type='avg', global_pooling=True) + return context + + +def channel_conv(input, inner_ch, out_ch, name): + conv = fluid.layers.conv2d( + input=input, + num_filters=inner_ch, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr(name=name + "_conv1_weights"), + bias_attr=ParamAttr(name=name + "_conv1_bias"), + name=name + "_conv1", ) + conv = fluid.layers.layer_norm( + conv, + begin_norm_axis=1, + param_attr=ParamAttr(name=name + "_ln_weights"), + bias_attr=ParamAttr(name=name + "_ln_bias"), + act="relu", + name=name + "_ln") + + conv = fluid.layers.conv2d( + input=conv, + num_filters=out_ch, + filter_size=1, + stride=1, + padding=0, + param_attr=ParamAttr( + name=name + "_conv2_weights", + initializer=ConstantInitializer(value=0.0), ), + bias_attr=ParamAttr( + name=name + "_conv2_bias", + initializer=ConstantInitializer(value=0.0), ), + name=name + "_conv2") + return conv + + +def add_gc_block(x, + ratio=1.0 / 16, + pooling_type='att', + fusion_types=['channel_add'], + name=None): + ''' + GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond, see https://arxiv.org/abs/1904.11492 + Args: + ratio (float): channel reduction ratio + pooling_type (str): pooling type, support att and avg + fusion_types (list): fusion types, support channel_add and channel_mul + name (str): prefix name of gc block + ''' + assert pooling_type in ['avg', 'att'] + assert isinstance(fusion_types, (list, tuple)) + valid_fusion_types = ['channel_add', 'channel_mul'] + assert all([f in valid_fusion_types for f in fusion_types]) + assert len(fusion_types) > 0, 'at least one fusion should be used' + + inner_ch = int(ratio * x.shape[1]) + out_ch = x.shape[1] + context = spatial_pool(x, pooling_type, name + "_spatial_pool") + out = x + if 'channel_mul' in fusion_types: + inner_out = channel_conv(context, inner_ch, out_ch, name + "_mul") + channel_mul_term = fluid.layers.sigmoid(inner_out) + out = out * channel_mul_term + + if 'channel_add' in fusion_types: + channel_add_term = channel_conv(context, inner_ch, out_ch, + name + "_add") + out = out + channel_add_term + + return out diff --git a/ppdet/modeling/backbones/resnet.py b/ppdet/modeling/backbones/resnet.py index c40286348a401366674a623d6fb844d32e55c7a0..9c88605838ec56cb2b68218cce224f04e5059aab 100644 --- a/ppdet/modeling/backbones/resnet.py +++ b/ppdet/modeling/backbones/resnet.py @@ -28,6 +28,7 @@ from ppdet.core.workspace import register, serializable from numbers import Integral from .nonlocal_helper import add_space_nonlocal +from .gc_block import add_gc_block from .name_adapter import NameAdapter __all__ = ['ResNet', 'ResNetC5'] @@ -48,6 +49,10 @@ class ResNet(object): feature_maps (list): index of stages whose feature maps are returned dcn_v2_stages (list): index of stages who select deformable conv v2 nonlocal_stages (list): index of stages who select nonlocal networks + gcb_stages (list): index of stages who select gc blocks + gcb_params (dict): gc blocks config, includes ratio(default as 1.0/16), + pooling_type(default as "att") and + fusion_types(default as ['channel_add']) """ __shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name'] @@ -61,7 +66,9 @@ class ResNet(object): feature_maps=[2, 3, 4, 5], dcn_v2_stages=[], weight_prefix_name='', - nonlocal_stages=[]): + nonlocal_stages=[], + gcb_stages=[], + gcb_params=dict()): super(ResNet, self).__init__() if isinstance(feature_maps, Integral): @@ -97,15 +104,18 @@ class ResNet(object): self._c1_out_chan_num = 64 self.na = NameAdapter(self) self.prefix_name = weight_prefix_name - + self.nonlocal_stages = nonlocal_stages self.nonlocal_mod_cfg = { - 50 : 2, - 101 : 5, - 152 : 8, - 200 : 12, + 50: 2, + 101: 5, + 152: 8, + 200: 12, } + self.gcb_stages = gcb_stages + self.gcb_params = gcb_params + def _conv_offset(self, input, filter_size, @@ -257,7 +267,9 @@ class ResNet(object): stride, is_first, name, - dcn_v2=False): + dcn_v2=False, + gcb=False, + gcb_name=None): if self.variant == 'a': stride1, stride2 = stride, 1 else: @@ -309,6 +321,8 @@ class ResNet(object): if callable(getattr(self, '_squeeze_excitation', None)): residual = self._squeeze_excitation( input=residual, num_channels=num_filters, name='fc' + name) + if gcb: + residual = add_gc_block(residual, name=gcb_name, **self.gcb_params) return fluid.layers.elementwise_add( x=short, y=residual, act='relu', name=name + ".add.output.5") @@ -318,8 +332,11 @@ class ResNet(object): stride, is_first, name, - dcn_v2=False): + dcn_v2=False, + gcb=False, + gcb_name=None): assert dcn_v2 is False, "Not implemented yet." + assert gcb is False, "Not implemented yet." conv0 = self._conv_norm( input=input, num_filters=num_filters, @@ -354,11 +371,12 @@ class ResNet(object): ch_out = self.stage_filters[stage_num - 2] is_first = False if stage_num != 2 else True dcn_v2 = True if stage_num in self.dcn_v2_stages else False - + nonlocal_mod = 1000 if stage_num in self.nonlocal_stages: - nonlocal_mod = self.nonlocal_mod_cfg[self.depth] if stage_num==4 else 2 - + nonlocal_mod = self.nonlocal_mod_cfg[ + self.depth] if stage_num == 4 else 2 + # Make the layer name and parameter name consistent # with ImageNet pre-trained model conv = input @@ -366,21 +384,26 @@ class ResNet(object): conv_name = self.na.fix_layer_warp_name(stage_num, count, i) if self.depth < 50: is_first = True if i == 0 and stage_num == 2 else False + + gcb = stage_num in self.gcb_stages + gcb_name = "gcb_res{}_b{}".format(stage_num, i) conv = block_func( input=conv, num_filters=ch_out, stride=2 if i == 0 and stage_num != 2 else 1, is_first=is_first, name=conv_name, - dcn_v2=dcn_v2) - + dcn_v2=dcn_v2, + gcb=gcb, + gcb_name=gcb_name) + # add non local model dim_in = conv.shape[1] - nonlocal_name = "nonlocal_conv{}".format( stage_num ) + nonlocal_name = "nonlocal_conv{}".format(stage_num) if i % nonlocal_mod == nonlocal_mod - 1: - conv = add_space_nonlocal( - conv, dim_in, dim_in, - nonlocal_name + '_{}'.format(i), int(dim_in / 2) ) + conv = add_space_nonlocal(conv, dim_in, dim_in, + nonlocal_name + '_{}'.format(i), + int(dim_in / 2)) return conv def c1_stage(self, input):