未验证 提交 e333d629 编写于 作者: W wangguanzhong 提交者: GitHub

Add faster rcnn fpn hrnet 1x & 2x (#2055)

* add hrnet, test=develop

* add comment, test=dygraph

* move config file to hrnet, test=dygraph

* add 2x model

* add hrnet README

* update latest config structure, test=dygraph
上级 6a46b3c4
# High-resolution networks (HRNets) for object detection
## Introduction
- Deep High-Resolution Representation Learning for Human Pose Estimation: [https://arxiv.org/abs/1902.09212](https://arxiv.org/abs/1902.09212)
```
@inproceedings{SunXLW19,
title={Deep High-Resolution Representation Learning for Human Pose Estimation},
author={Ke Sun and Bin Xiao and Dong Liu and Jingdong Wang},
booktitle={CVPR},
year={2019}
}
```
- High-Resolution Representations for Labeling Pixels and Regions: [https://arxiv.org/abs/1904.04514](https://arxiv.org/abs/1904.04514)
```
@article{SunZJCXLMWLW19,
title={High-Resolution Representations for Labeling Pixels and Regions},
author={Ke Sun and Yang Zhao and Borui Jiang and Tianheng Cheng and Bin Xiao
and Dong Liu and Yadong Mu and Xinggang Wang and Wenyu Liu and Jingdong Wang},
journal = {CoRR},
volume = {abs/1904.04514},
year={2019}
}
```
## Model Zoo
| Backbone | Type | deformable Conv | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs |
| :---------------------- | :------------- | :---: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: |
| HRNetV2p_W18 | Faster | False | 2 | 1x | - | 35.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml) |
| HRNetV2p_W18 | Faster | False | 2 | 2x | - | 37.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml) |
architecture: FasterRCNN
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar
weights: output/faster_rcnn_hrnetv2p_w18_1x_coco/model_final
load_static_weights: True
# Model Achitecture
FasterRCNN:
# model anchor info flow
anchor: Anchor
proposal: Proposal
# model feat info flow
backbone: HRNet
neck: HRFPN
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
HRNet:
width: 18
freeze_at: 0
return_idx: [0, 1, 2, 3]
HRFPN:
out_channel: 256
share_conv: false
RPNHead:
rpn_feat:
name: RPNFeat
feat_in: 256
feat_out: 256
anchor_per_position: 3
rpn_channel: 256
Anchor:
anchor_generator:
name: AnchorGeneratorRPN
aspect_ratios: [0.5, 1.0, 2.0]
anchor_start_size: 32
stride: [4., 4.]
anchor_target_generator:
name: AnchorTargetGeneratorRPN
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
straddle_thresh: 0.0
Proposal:
proposal_generator:
name: ProposalGenerator
min_size: 0.0
nms_thresh: 0.7
train_pre_nms_top_n: 2000
train_post_nms_top_n: 2000
infer_pre_nms_top_n: 1000
infer_post_nms_top_n: 1000
proposal_target_generator:
name: ProposalTargetGenerator
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: [0.5,]
bg_thresh_lo: [0.0,]
fg_thresh: [0.5,]
fg_fraction: 0.25
BBoxHead:
bbox_feat:
name: BBoxFeat
roi_extractor:
name: RoIAlign
resolution: 7
sampling_ratio: 2
head_feat:
name: TwoFCHead
in_dim: 256
mlp_dim: 1024
in_feat: 1024
BBoxPostProcess:
decode:
name: RCNNBox
num_classes: 81
batch_size: 1
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
_BASE_: [
'../datasets/coco_detection.yml',
'./_base_/faster_rcnn_hrnetv2p_w18.yml',
'../faster_rcnn/_base_/optimizer_1x.yml',
'../faster_rcnn/_base_/faster_fpn_reader.yml',
'../runtime.yml',
]
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.1
steps: 1000
TrainReader:
batch_size: 2
_BASE_: [
'../datasets/coco_detection.yml',
'./_base_/faster_rcnn_hrnetv2p_w18.yml',
'../faster_rcnn/_base_/optimizer_1x.yml',
'../faster_rcnn/_base_/faster_fpn_reader.yml',
'../runtime.yml',
]
weights: output/faster_rcnn_hrnetv2p_w18_2x_coco/model_final
epoch: 24
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.1
steps: 1000
TrainReader:
batch_size: 2
......@@ -156,6 +156,7 @@ class Trainer(object):
def train(self):
assert self.mode == 'train', "Model not in 'train' mode"
self.model.train()
# if no given weights loaded, load backbone pretrain weights as default
if not self._weights_loaded:
......@@ -184,7 +185,6 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status)
# model forward
self.model.train()
outputs = self.model(data)
loss = outputs['loss']
......
......@@ -3,9 +3,11 @@ from . import resnet
from . import darknet
from . import mobilenet_v1
from . import mobilenet_v3
from . import hrnet
from .vgg import *
from .resnet import *
from .darknet import *
from .mobilenet_v1 import *
from .mobilenet_v3 import *
from .hrnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......
此差异已折叠。
......@@ -14,6 +14,8 @@
from . import fpn
from . import yolo_fpn
from . import hrfpn
from .fpn import *
from .yolo_fpn import *
from .hrfpn import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
import paddle.nn as nn
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
__all__ = ['HRFPN']
@register
class HRFPN(nn.Layer):
"""
Args:
in_channel (int): number of input feature channels from backbone
out_channel (int): number of output feature channels
share_conv (bool): whether to share conv for different layers' reduction
spatial_scale (list): feature map scaling factor
"""
def __init__(
self,
in_channel=270,
out_channel=256,
share_conv=False,
spatial_scale=[1. / 4, 1. / 8, 1. / 16, 1. / 32, 1. / 64], ):
super(HRFPN, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.share_conv = share_conv
self.spatial_scale = spatial_scale
self.reduction = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=1,
weight_attr=ParamAttr(name='hrfpn_reduction_weights'),
bias_attr=False)
self.num_out = len(self.spatial_scale)
if share_conv:
self.fpn_conv = nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(name='fpn_conv_weights'),
bias_attr=False)
else:
self.fpn_conv = []
for i in range(self.num_out):
conv_name = "fpn_conv_" + str(i)
conv = self.add_sublayer(
conv_name,
nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(name=conv_name + "_weights"),
bias_attr=False))
self.fpn_conv.append(conv)
def forward(self, body_feats):
num_backbone_stages = len(body_feats)
outs = []
outs.append(body_feats[0])
# resize
for i in range(1, num_backbone_stages):
resized = F.interpolate(
body_feats[i], scale_factor=2**i, mode='bilinear')
outs.append(resized)
# concat
out = paddle.concat(outs, axis=1)
assert out.shape[
1] == self.in_channel, 'in_channel should be {}, be received {}'.format(
out.shape[1], self.in_channel)
# reduction
out = self.reduction(out)
# conv
outs = [out]
for i in range(1, self.num_out):
outs.append(F.avg_pool2d(out, kernel_size=2**i, stride=2**i))
outputs = []
for i in range(self.num_out):
conv_func = self.fpn_conv if self.share_conv else self.fpn_conv[i]
conv = conv_func(outs[i])
outputs.append(conv)
fpn_feat = [outputs[k] for k in range(self.num_out)]
return fpn_feat, self.spatial_scale
......@@ -95,7 +95,7 @@ def load_weight(model, weight, optimizer=None):
last_epoch = 0
if optimizer is not None and os.path.exists(path + '.pdopt'):
optim_state_dict = paddle.load(path + '.pdopt')
# to slove resume bug, will it be fixed in paddle 2.0
# to solve resume bug, will it be fixed in paddle 2.0
for key in optimizer.state_dict().keys():
if not key in optim_state_dict.keys():
optim_state_dict[key] = optimizer.state_dict()[key]
......@@ -132,6 +132,9 @@ def load_pretrain_weight(model,
weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name]
else:
if 'backbone' in key:
logger.info('Lack weight: {}, structure name: {}'.format(
weight_name, key))
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册