未验证 提交 e333d629 编写于 作者: W wangguanzhong 提交者: GitHub

Add faster rcnn fpn hrnet 1x & 2x (#2055)

* add hrnet, test=develop

* add comment, test=dygraph

* move config file to hrnet, test=dygraph

* add 2x model

* add hrnet README

* update latest config structure, test=dygraph
上级 6a46b3c4
# High-resolution networks (HRNets) for object detection
## Introduction
- Deep High-Resolution Representation Learning for Human Pose Estimation: [https://arxiv.org/abs/1902.09212](https://arxiv.org/abs/1902.09212)
```
@inproceedings{SunXLW19,
title={Deep High-Resolution Representation Learning for Human Pose Estimation},
author={Ke Sun and Bin Xiao and Dong Liu and Jingdong Wang},
booktitle={CVPR},
year={2019}
}
```
- High-Resolution Representations for Labeling Pixels and Regions: [https://arxiv.org/abs/1904.04514](https://arxiv.org/abs/1904.04514)
```
@article{SunZJCXLMWLW19,
title={High-Resolution Representations for Labeling Pixels and Regions},
author={Ke Sun and Yang Zhao and Borui Jiang and Tianheng Cheng and Bin Xiao
and Dong Liu and Yadong Mu and Xinggang Wang and Wenyu Liu and Jingdong Wang},
journal = {CoRR},
volume = {abs/1904.04514},
year={2019}
}
```
## Model Zoo
| Backbone | Type | deformable Conv | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs |
| :---------------------- | :------------- | :---: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: |
| HRNetV2p_W18 | Faster | False | 2 | 1x | - | 35.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml) |
| HRNetV2p_W18 | Faster | False | 2 | 2x | - | 37.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml) |
architecture: FasterRCNN
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar
weights: output/faster_rcnn_hrnetv2p_w18_1x_coco/model_final
load_static_weights: True
# Model Achitecture
FasterRCNN:
# model anchor info flow
anchor: Anchor
proposal: Proposal
# model feat info flow
backbone: HRNet
neck: HRFPN
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
HRNet:
width: 18
freeze_at: 0
return_idx: [0, 1, 2, 3]
HRFPN:
out_channel: 256
share_conv: false
RPNHead:
rpn_feat:
name: RPNFeat
feat_in: 256
feat_out: 256
anchor_per_position: 3
rpn_channel: 256
Anchor:
anchor_generator:
name: AnchorGeneratorRPN
aspect_ratios: [0.5, 1.0, 2.0]
anchor_start_size: 32
stride: [4., 4.]
anchor_target_generator:
name: AnchorTargetGeneratorRPN
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
straddle_thresh: 0.0
Proposal:
proposal_generator:
name: ProposalGenerator
min_size: 0.0
nms_thresh: 0.7
train_pre_nms_top_n: 2000
train_post_nms_top_n: 2000
infer_pre_nms_top_n: 1000
infer_post_nms_top_n: 1000
proposal_target_generator:
name: ProposalTargetGenerator
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: [0.5,]
bg_thresh_lo: [0.0,]
fg_thresh: [0.5,]
fg_fraction: 0.25
BBoxHead:
bbox_feat:
name: BBoxFeat
roi_extractor:
name: RoIAlign
resolution: 7
sampling_ratio: 2
head_feat:
name: TwoFCHead
in_dim: 256
mlp_dim: 1024
in_feat: 1024
BBoxPostProcess:
decode:
name: RCNNBox
num_classes: 81
batch_size: 1
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
_BASE_: [
'../datasets/coco_detection.yml',
'./_base_/faster_rcnn_hrnetv2p_w18.yml',
'../faster_rcnn/_base_/optimizer_1x.yml',
'../faster_rcnn/_base_/faster_fpn_reader.yml',
'../runtime.yml',
]
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.1
steps: 1000
TrainReader:
batch_size: 2
_BASE_: [
'../datasets/coco_detection.yml',
'./_base_/faster_rcnn_hrnetv2p_w18.yml',
'../faster_rcnn/_base_/optimizer_1x.yml',
'../faster_rcnn/_base_/faster_fpn_reader.yml',
'../runtime.yml',
]
weights: output/faster_rcnn_hrnetv2p_w18_2x_coco/model_final
epoch: 24
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [16, 22]
- !LinearWarmup
start_factor: 0.1
steps: 1000
TrainReader:
batch_size: 2
...@@ -156,6 +156,7 @@ class Trainer(object): ...@@ -156,6 +156,7 @@ class Trainer(object):
def train(self): def train(self):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
self.model.train()
# if no given weights loaded, load backbone pretrain weights as default # if no given weights loaded, load backbone pretrain weights as default
if not self._weights_loaded: if not self._weights_loaded:
...@@ -184,7 +185,6 @@ class Trainer(object): ...@@ -184,7 +185,6 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
# model forward # model forward
self.model.train()
outputs = self.model(data) outputs = self.model(data)
loss = outputs['loss'] loss = outputs['loss']
......
...@@ -3,9 +3,11 @@ from . import resnet ...@@ -3,9 +3,11 @@ from . import resnet
from . import darknet from . import darknet
from . import mobilenet_v1 from . import mobilenet_v1
from . import mobilenet_v3 from . import mobilenet_v3
from . import hrnet
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
from .darknet import * from .darknet import *
from .mobilenet_v1 import * from .mobilenet_v1 import *
from .mobilenet_v3 import * from .mobilenet_v3 import *
from .hrnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.regularizer import L2Decay
from paddle import ParamAttr
from paddle.nn.initializer import Normal
from numbers import Integral
import math
from ppdet.core.workspace import register, serializable
__all__ = ['HRNet']
class ConvNormLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size,
stride=1,
norm_type='bn',
norm_groups=32,
use_dcn=False,
norm_decay=0.,
freeze_norm=False,
act=None,
name=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
self.act = act
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=1,
weight_attr=ParamAttr(
name=name + "_weights", initializer=Normal(
mean=0., std=0.01)),
bias_attr=False)
norm_lr = 0. if freeze_norm else 1.
norm_name = name + '_bn'
param_attr = ParamAttr(
name=norm_name + "_scale",
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
bias_attr = ParamAttr(
name=norm_name + "_offset",
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay))
global_stats = True if freeze_norm else False
if norm_type in ['bn', 'sync_bn']:
self.norm = nn.BatchNorm(
ch_out,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats,
moving_mean_name=norm_name + '_mean',
moving_variance_name=norm_name + '_variance')
elif norm_type == 'gn':
self.norm = nn.GroupNorm(
num_groups=norm_groups,
num_channels=ch_out,
weight_attr=param_attr,
bias_attr=bias_attr)
norm_params = self.norm.parameters()
if freeze_norm:
for param in norm_params:
param.stop_gradient = True
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
if self.act == 'relu':
out = F.relu(out)
return out
class Layer1(nn.Layer):
def __init__(self, num_channels, has_se=False, freeze_norm=True, name=None):
super(Layer1, self).__init__()
self.bottleneck_block_list = []
for i in range(4):
bottleneck_block = self.add_sublayer(
"block_{}_{}".format(name, i + 1),
BottleneckBlock(
num_channels=num_channels if i == 0 else 256,
num_filters=64,
has_se=has_se,
stride=1,
downsample=True if i == 0 else False,
freeze_norm=freeze_norm,
name=name + '_' + str(i + 1)))
self.bottleneck_block_list.append(bottleneck_block)
def forward(self, input):
conv = input
for block_func in self.bottleneck_block_list:
conv = block_func(conv)
return conv
class TransitionLayer(nn.Layer):
def __init__(self, in_channels, out_channels, freeze_norm=True, name=None):
super(TransitionLayer, self).__init__()
num_in = len(in_channels)
num_out = len(out_channels)
out = []
self.conv_bn_func_list = []
for i in range(num_out):
residual = None
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = self.add_sublayer(
"transition_{}_layer_{}".format(name, i + 1),
ConvNormLayer(
ch_in=in_channels[i],
ch_out=out_channels[i],
filter_size=3,
freeze_norm=freeze_norm,
act='relu',
name=name + '_layer_' + str(i + 1)))
else:
residual = self.add_sublayer(
"transition_{}_layer_{}".format(name, i + 1),
ConvNormLayer(
ch_in=in_channels[-1],
ch_out=out_channels[i],
filter_size=3,
stride=2,
freeze_norm=freeze_norm,
act='relu',
name=name + '_layer_' + str(i + 1)))
self.conv_bn_func_list.append(residual)
def forward(self, input):
outs = []
for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
if conv_bn_func is None:
outs.append(input[idx])
else:
if idx < len(input):
outs.append(conv_bn_func(input[idx]))
else:
outs.append(conv_bn_func(input[-1]))
return outs
class Branches(nn.Layer):
def __init__(self,
block_num,
in_channels,
out_channels,
has_se=False,
freeze_norm=True,
name=None):
super(Branches, self).__init__()
self.basic_block_list = []
for i in range(len(out_channels)):
self.basic_block_list.append([])
for j in range(block_num):
in_ch = in_channels[i] if j == 0 else out_channels[i]
basic_block_func = self.add_sublayer(
"bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
BasicBlock(
num_channels=in_ch,
num_filters=out_channels[i],
has_se=has_se,
freeze_norm=freeze_norm,
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.basic_block_list[i].append(basic_block_func)
def forward(self, inputs):
outs = []
for idx, input in enumerate(inputs):
conv = input
basic_block_list = self.basic_block_list[idx]
for basic_block_func in basic_block_list:
conv = basic_block_func(conv)
outs.append(conv)
return outs
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
has_se,
stride=1,
downsample=False,
freeze_norm=True,
name=None):
super(BottleneckBlock, self).__init__()
self.has_se = has_se
self.downsample = downsample
self.conv1 = ConvNormLayer(
ch_in=num_channels,
ch_out=num_filters,
filter_size=1,
freeze_norm=freeze_norm,
act="relu",
name=name + "_conv1")
self.conv2 = ConvNormLayer(
ch_in=num_filters,
ch_out=num_filters,
filter_size=3,
stride=stride,
freeze_norm=freeze_norm,
act="relu",
name=name + "_conv2")
self.conv3 = ConvNormLayer(
ch_in=num_filters,
ch_out=num_filters * 4,
filter_size=1,
freeze_norm=freeze_norm,
act=None,
name=name + "_conv3")
if self.downsample:
self.conv_down = ConvNormLayer(
ch_in=num_channels,
ch_out=num_filters * 4,
filter_size=1,
freeze_norm=freeze_norm,
act=None,
name=name + "_downsample")
if self.has_se:
self.se = SELayer(
num_channels=num_filters * 4,
num_filters=num_filters * 4,
reduction_ratio=16,
name='fc' + name)
def forward(self, input):
residual = input
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
if self.downsample:
residual = self.conv_down(input)
if self.has_se:
conv3 = self.se(conv3)
y = paddle.add(x=residual, y=conv3)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride=1,
has_se=False,
downsample=False,
freeze_norm=True,
name=None):
super(BasicBlock, self).__init__()
self.has_se = has_se
self.downsample = downsample
self.conv1 = ConvNormLayer(
ch_in=num_channels,
ch_out=num_filters,
filter_size=3,
freeze_norm=freeze_norm,
stride=stride,
act="relu",
name=name + "_conv1")
self.conv2 = ConvNormLayer(
ch_in=num_filters,
ch_out=num_filters,
filter_size=3,
freeze_norm=freeze_norm,
stride=1,
act=None,
name=name + "_conv2")
if self.downsample:
self.conv_down = ConvNormLayer(
ch_in=num_channels,
ch_out=num_filters * 4,
filter_size=1,
freeze_norm=freeze_norm,
act=None,
name=name + "_downsample")
if self.has_se:
self.se = SELayer(
num_channels=num_filters,
num_filters=num_filters,
reduction_ratio=16,
name='fc' + name)
def forward(self, input):
residual = input
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
if self.downsample:
residual = self.conv_down(input)
if self.has_se:
conv2 = self.se(conv2)
y = paddle.add(x=residual, y=conv2)
y = F.relu(y)
return y
class SELayer(nn.Layer):
def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
super(SELayer, self).__init__()
self.pool2d_gap = AdaptiveAvgPool2D(1)
self._num_channels = num_channels
med_ch = int(num_channels / reduction_ratio)
stdv = 1.0 / math.sqrt(num_channels * 1.0)
self.squeeze = Linear(
num_channels,
med_ch,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name=name + "_sqz_weights"),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(med_ch * 1.0)
self.excitation = Linear(
med_ch,
num_filters,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name=name + "_exc_weights"),
bias_attr=ParamAttr(name=name + '_exc_offset'))
def forward(self, input):
pool = self.pool2d_gap(input)
pool = paddle.squeeze(pool, axis=[2, 3])
squeeze = self.squeeze(pool)
squeeze = F.relu(squeeze)
excitation = self.excitation(squeeze)
excitation = F.sigmoid(excitation)
excitation = paddle.unsqueeze(excitation, axis=[2, 3])
out = input * excitation
return out
class Stage(nn.Layer):
def __init__(self,
num_channels,
num_modules,
num_filters,
has_se=False,
freeze_norm=True,
multi_scale_output=True,
name=None):
super(Stage, self).__init__()
self._num_modules = num_modules
self.stage_func_list = []
for i in range(num_modules):
if i == num_modules - 1 and not multi_scale_output:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters,
has_se=has_se,
freeze_norm=freeze_norm,
multi_scale_output=False,
name=name + '_' + str(i + 1)))
else:
stage_func = self.add_sublayer(
"stage_{}_{}".format(name, i + 1),
HighResolutionModule(
num_channels=num_channels,
num_filters=num_filters,
has_se=has_se,
freeze_norm=freeze_norm,
name=name + '_' + str(i + 1)))
self.stage_func_list.append(stage_func)
def forward(self, input):
out = input
for idx in range(self._num_modules):
out = self.stage_func_list[idx](out)
return out
class HighResolutionModule(nn.Layer):
def __init__(self,
num_channels,
num_filters,
has_se=False,
multi_scale_output=True,
freeze_norm=True,
name=None):
super(HighResolutionModule, self).__init__()
self.branches_func = Branches(
block_num=4,
in_channels=num_channels,
out_channels=num_filters,
has_se=has_se,
freeze_norm=freeze_norm,
name=name)
self.fuse_func = FuseLayers(
in_channels=num_filters,
out_channels=num_filters,
multi_scale_output=multi_scale_output,
freeze_norm=freeze_norm,
name=name)
def forward(self, input):
out = self.branches_func(input)
out = self.fuse_func(out)
return out
class FuseLayers(nn.Layer):
def __init__(self,
in_channels,
out_channels,
multi_scale_output=True,
freeze_norm=True,
name=None):
super(FuseLayers, self).__init__()
self._actual_ch = len(in_channels) if multi_scale_output else 1
self._in_channels = in_channels
self.residual_func_list = []
for i in range(self._actual_ch):
for j in range(len(in_channels)):
residual_func = None
if j > i:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
ConvNormLayer(
ch_in=in_channels[j],
ch_out=out_channels[i],
filter_size=1,
stride=1,
act=None,
freeze_norm=freeze_norm,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1)))
self.residual_func_list.append(residual_func)
elif j < i:
pre_num_filters = in_channels[j]
for k in range(i - j):
if k == i - j - 1:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}_{}".format(
name, i + 1, j + 1, k + 1),
ConvNormLayer(
ch_in=pre_num_filters,
ch_out=out_channels[i],
filter_size=3,
stride=2,
freeze_norm=freeze_norm,
act=None,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1)))
pre_num_filters = out_channels[i]
else:
residual_func = self.add_sublayer(
"residual_{}_layer_{}_{}_{}".format(
name, i + 1, j + 1, k + 1),
ConvNormLayer(
ch_in=pre_num_filters,
ch_out=out_channels[j],
filter_size=3,
stride=2,
freeze_norm=freeze_norm,
act="relu",
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1)))
pre_num_filters = out_channels[j]
self.residual_func_list.append(residual_func)
def forward(self, input):
outs = []
residual_func_idx = 0
for i in range(self._actual_ch):
residual = input[i]
for j in range(len(self._in_channels)):
if j > i:
y = self.residual_func_list[residual_func_idx](input[j])
residual_func_idx += 1
y = F.interpolate(y, scale_factor=2**(j - i))
residual = paddle.add(x=residual, y=y)
elif j < i:
y = input[j]
for k in range(i - j):
y = self.residual_func_list[residual_func_idx](y)
residual_func_idx += 1
residual = paddle.add(x=residual, y=y)
residual = F.relu(residual)
outs.append(residual)
return outs
@register
class HRNet(nn.Layer):
"""
HRNet, see https://arxiv.org/abs/1908.07919
Args:
width (int): the width of HRNet
has_se (bool): whether to add SE block for each stage
freeze_at (int): the stage to freeze
freeze_norm (bool): whether to freeze norm in HRNet
return_idx (List): the stage to return
"""
def __init__(self,
width=18,
has_se=False,
freeze_at=0,
freeze_norm=True,
norm_decay=0.,
return_idx=[0, 1, 2, 3]):
super(HRNet, self).__init__()
self.width = width
self.has_se = has_se
if isinstance(return_idx, Integral):
return_idx = [return_idx]
assert len(return_idx) > 0, "need one or more return index"
self.freeze_at = freeze_at
self.return_idx = return_idx
self.channels = {
18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
}
channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
self.conv_layer1_1 = ConvNormLayer(
ch_in=3,
ch_out=64,
filter_size=3,
stride=2,
freeze_norm=freeze_norm,
act='relu',
name="layer1_1")
self.conv_layer1_2 = ConvNormLayer(
ch_in=64,
ch_out=64,
filter_size=3,
stride=2,
freeze_norm=freeze_norm,
act='relu',
name="layer1_2")
self.la1 = Layer1(
num_channels=64,
has_se=has_se,
freeze_norm=freeze_norm,
name="layer2")
self.tr1 = TransitionLayer(
in_channels=[256],
out_channels=channels_2,
freeze_norm=freeze_norm,
name="tr1")
self.st2 = Stage(
num_channels=channels_2,
num_modules=num_modules_2,
num_filters=channels_2,
has_se=self.has_se,
freeze_norm=freeze_norm,
name="st2")
self.tr2 = TransitionLayer(
in_channels=channels_2,
out_channels=channels_3,
freeze_norm=freeze_norm,
name="tr2")
self.st3 = Stage(
num_channels=channels_3,
num_modules=num_modules_3,
num_filters=channels_3,
has_se=self.has_se,
freeze_norm=freeze_norm,
name="st3")
self.tr3 = TransitionLayer(
in_channels=channels_3,
out_channels=channels_4,
freeze_norm=freeze_norm,
name="tr3")
self.st4 = Stage(
num_channels=channels_4,
num_modules=num_modules_4,
num_filters=channels_4,
has_se=self.has_se,
freeze_norm=freeze_norm,
name="st4")
def forward(self, inputs):
x = inputs['image']
conv1 = self.conv_layer1_1(x)
conv2 = self.conv_layer1_2(conv1)
la1 = self.la1(conv2)
tr1 = self.tr1([la1])
st2 = self.st2(tr1)
tr2 = self.tr2(st2)
st3 = self.st3(tr2)
tr3 = self.tr3(st3)
st4 = self.st4(tr3)
res = []
for i, layer in enumerate(st4):
if i == self.freeze_at:
layer.stop_gradient = True
if i in self.return_idx:
res.append(layer)
return res
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
from . import fpn from . import fpn
from . import yolo_fpn from . import yolo_fpn
from . import hrfpn
from .fpn import * from .fpn import *
from .yolo_fpn import * from .yolo_fpn import *
from .hrfpn import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn.functional as F
from paddle import ParamAttr
import paddle.nn as nn
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
__all__ = ['HRFPN']
@register
class HRFPN(nn.Layer):
"""
Args:
in_channel (int): number of input feature channels from backbone
out_channel (int): number of output feature channels
share_conv (bool): whether to share conv for different layers' reduction
spatial_scale (list): feature map scaling factor
"""
def __init__(
self,
in_channel=270,
out_channel=256,
share_conv=False,
spatial_scale=[1. / 4, 1. / 8, 1. / 16, 1. / 32, 1. / 64], ):
super(HRFPN, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.share_conv = share_conv
self.spatial_scale = spatial_scale
self.reduction = nn.Conv2D(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=1,
weight_attr=ParamAttr(name='hrfpn_reduction_weights'),
bias_attr=False)
self.num_out = len(self.spatial_scale)
if share_conv:
self.fpn_conv = nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(name='fpn_conv_weights'),
bias_attr=False)
else:
self.fpn_conv = []
for i in range(self.num_out):
conv_name = "fpn_conv_" + str(i)
conv = self.add_sublayer(
conv_name,
nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(name=conv_name + "_weights"),
bias_attr=False))
self.fpn_conv.append(conv)
def forward(self, body_feats):
num_backbone_stages = len(body_feats)
outs = []
outs.append(body_feats[0])
# resize
for i in range(1, num_backbone_stages):
resized = F.interpolate(
body_feats[i], scale_factor=2**i, mode='bilinear')
outs.append(resized)
# concat
out = paddle.concat(outs, axis=1)
assert out.shape[
1] == self.in_channel, 'in_channel should be {}, be received {}'.format(
out.shape[1], self.in_channel)
# reduction
out = self.reduction(out)
# conv
outs = [out]
for i in range(1, self.num_out):
outs.append(F.avg_pool2d(out, kernel_size=2**i, stride=2**i))
outputs = []
for i in range(self.num_out):
conv_func = self.fpn_conv if self.share_conv else self.fpn_conv[i]
conv = conv_func(outs[i])
outputs.append(conv)
fpn_feat = [outputs[k] for k in range(self.num_out)]
return fpn_feat, self.spatial_scale
...@@ -95,7 +95,7 @@ def load_weight(model, weight, optimizer=None): ...@@ -95,7 +95,7 @@ def load_weight(model, weight, optimizer=None):
last_epoch = 0 last_epoch = 0
if optimizer is not None and os.path.exists(path + '.pdopt'): if optimizer is not None and os.path.exists(path + '.pdopt'):
optim_state_dict = paddle.load(path + '.pdopt') optim_state_dict = paddle.load(path + '.pdopt')
# to slove resume bug, will it be fixed in paddle 2.0 # to solve resume bug, will it be fixed in paddle 2.0
for key in optimizer.state_dict().keys(): for key in optimizer.state_dict().keys():
if not key in optim_state_dict.keys(): if not key in optim_state_dict.keys():
optim_state_dict[key] = optimizer.state_dict()[key] optim_state_dict[key] = optimizer.state_dict()[key]
...@@ -132,6 +132,9 @@ def load_pretrain_weight(model, ...@@ -132,6 +132,9 @@ def load_pretrain_weight(model,
weight_name, pre_state_dict[weight_name].shape)) weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name] param_state_dict[key] = pre_state_dict[weight_name]
else: else:
if 'backbone' in key:
logger.info('Lack weight: {}, structure name: {}'.format(
weight_name, key))
param_state_dict[key] = model_dict[key] param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册