未验证 提交 b83e39ef 编写于 作者: F Feng Ni 提交者: GitHub

[dygraph] from_config for hrnet, test=dygraph (#2218)

* from_config hrnet, test=dygraph

* update hrnet modelzoo, test=dygraph
上级 28e5a3ab
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
## Model Zoo ## Model Zoo
| Backbone | Type | deformable Conv | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs | | Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs |
| :---------------------- | :------------- | :---: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: | | :---------------------- | :------------- | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: |
| HRNetV2p_W18 | Faster | False | 2 | 1x | - | 35.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml) | | HRNetV2p_W18 | Faster | 1 | 1x | - | 36.8 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml) |
| HRNetV2p_W18 | Faster | False | 2 | 2x | - | 37.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml) | | HRNetV2p_W18 | Faster | 1 | 2x | - | 39.0 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml) |
architecture: FasterRCNN architecture: FasterRCNN
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar
weights: output/faster_rcnn_hrnetv2p_w18_1x_coco/model_final
load_static_weights: True load_static_weights: True
# Model Achitecture
FasterRCNN: FasterRCNN:
# model anchor info flow
anchor: Anchor
proposal: Proposal
# model feat info flow
backbone: HRNet backbone: HRNet
neck: HRFPN neck: HRFPN
rpn_head: RPNHead rpn_head: RPNHead
...@@ -26,65 +20,51 @@ HRFPN: ...@@ -26,65 +20,51 @@ HRFPN:
share_conv: false share_conv: false
RPNHead: RPNHead:
rpn_feat:
name: RPNFeat
feat_in: 256
feat_out: 256
anchor_per_position: 3
rpn_channel: 256
Anchor:
anchor_generator: anchor_generator:
name: AnchorGeneratorRPN
aspect_ratios: [0.5, 1.0, 2.0] aspect_ratios: [0.5, 1.0, 2.0]
anchor_start_size: 32 anchor_sizes: [[32], [64], [128], [256], [512]]
stride: [4., 4.] strides: [4, 8, 16, 32, 64]
anchor_target_generator: rpn_target_assign:
name: AnchorTargetGeneratorRPN
batch_size_per_im: 256 batch_size_per_im: 256
fg_fraction: 0.5 fg_fraction: 0.5
negative_overlap: 0.3 negative_overlap: 0.3
positive_overlap: 0.7 positive_overlap: 0.7
straddle_thresh: 0.0 use_random: True
train_proposal:
Proposal: min_size: 0.0
proposal_generator: nms_thresh: 0.7
name: ProposalGenerator pre_nms_top_n: 2000
post_nms_top_n: 2000
topk_after_collect: True
test_proposal:
min_size: 0.0 min_size: 0.0
nms_thresh: 0.7 nms_thresh: 0.7
train_pre_nms_top_n: 2000 pre_nms_top_n: 1000
train_post_nms_top_n: 2000 post_nms_top_n: 1000
infer_pre_nms_top_n: 1000
infer_post_nms_top_n: 1000
proposal_target_generator:
name: ProposalTargetGenerator
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: [0.5,]
bg_thresh_lo: [0.0,]
fg_thresh: [0.5,]
fg_fraction: 0.25
BBoxHead: BBoxHead:
bbox_feat: head: TwoFCHead
name: BBoxFeat roi_extractor:
roi_extractor: resolution: 7
name: RoIAlign sampling_ratio: 0
resolution: 7 aligned: True
sampling_ratio: 2 bbox_assigner: BBoxAssigner
head_feat:
name: TwoFCHead BBoxAssigner:
in_dim: 256 batch_size_per_im: 512
mlp_dim: 1024 bg_thresh: 0.5
in_feat: 1024 fg_thresh: 0.5
fg_fraction: 0.25
use_random: True
TwoFCHead:
mlp_dim: 1024
BBoxPostProcess: BBoxPostProcess:
decode: decode: RCNNBox
name: RCNNBox
num_classes: 81
batch_size: 1
nms: nms:
name: MultiClassNMS name: MultiClassNMS
keep_top_k: 100 keep_top_k: 100
score_threshold: 0.05 score_threshold: 0.05
nms_threshold: 0.5 nms_threshold: 0.5
normalized: true
...@@ -6,6 +6,9 @@ _BASE_: [ ...@@ -6,6 +6,9 @@ _BASE_: [
'../runtime.yml', '../runtime.yml',
] ]
weights: output/faster_rcnn_hrnetv2p_w18_1x_coco/model_final
epoch: 12
LearningRate: LearningRate:
base_lr: 0.02 base_lr: 0.02
schedulers: schedulers:
......
...@@ -22,6 +22,7 @@ from numbers import Integral ...@@ -22,6 +22,7 @@ from numbers import Integral
import math import math
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ..shape_spec import ShapeSpec
__all__ = ['HRNet'] __all__ = ['HRNet']
...@@ -577,6 +578,8 @@ class HRNet(nn.Layer): ...@@ -577,6 +578,8 @@ class HRNet(nn.Layer):
channels_2, channels_3, channels_4 = self.channels[width] channels_2, channels_3, channels_4 = self.channels[width]
num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
self._out_channels = channels_4
self._out_strides = [4, 8, 16, 32]
self.conv_layer1_1 = ConvNormLayer( self.conv_layer1_1 = ConvNormLayer(
ch_in=3, ch_in=3,
...@@ -666,3 +669,11 @@ class HRNet(nn.Layer): ...@@ -666,3 +669,11 @@ class HRNet(nn.Layer):
res.append(layer) res.append(layer)
return res return res
@property
def out_shape(self):
return [
ShapeSpec(
channels=self._out_channels[i], stride=self._out_strides[i])
for i in self.return_idx
]
...@@ -18,6 +18,7 @@ from paddle import ParamAttr ...@@ -18,6 +18,7 @@ from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ..shape_spec import ShapeSpec
__all__ = ['HRFPN'] __all__ = ['HRFPN']
...@@ -26,23 +27,28 @@ __all__ = ['HRFPN'] ...@@ -26,23 +27,28 @@ __all__ = ['HRFPN']
class HRFPN(nn.Layer): class HRFPN(nn.Layer):
""" """
Args: Args:
in_channel (int): number of input feature channels from backbone in_channels (list): number of input feature channels from backbone
out_channel (int): number of output feature channels out_channel (int): number of output feature channels
share_conv (bool): whether to share conv for different layers' reduction share_conv (bool): whether to share conv for different layers' reduction
spatial_scale (list): feature map scaling factor spatial_scales (list): feature map scaling factor
extra_stage (int): add extra stage for returning HRFPN fpn_feats
""" """
def __init__( def __init__(self,
self, in_channels=[18, 36, 72, 144],
in_channel=270, out_channel=256,
out_channel=256, share_conv=False,
share_conv=False, extra_stage=1,
spatial_scale=[1. / 4, 1. / 8, 1. / 16, 1. / 32, 1. / 64], ): spatial_scales=[1. / 4, 1. / 8, 1. / 16, 1. / 32]):
super(HRFPN, self).__init__() super(HRFPN, self).__init__()
in_channel = sum(in_channels)
self.in_channel = in_channel self.in_channel = in_channel
self.out_channel = out_channel self.out_channel = out_channel
self.share_conv = share_conv self.share_conv = share_conv
self.spatial_scale = spatial_scale for i in range(extra_stage):
spatial_scales = spatial_scales + [spatial_scales[-1] / 2.]
self.spatial_scales = spatial_scales
self.num_out = len(self.spatial_scales)
self.reduction = nn.Conv2D( self.reduction = nn.Conv2D(
in_channels=in_channel, in_channels=in_channel,
...@@ -50,7 +56,7 @@ class HRFPN(nn.Layer): ...@@ -50,7 +56,7 @@ class HRFPN(nn.Layer):
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr(name='hrfpn_reduction_weights'), weight_attr=ParamAttr(name='hrfpn_reduction_weights'),
bias_attr=False) bias_attr=False)
self.num_out = len(self.spatial_scale)
if share_conv: if share_conv:
self.fpn_conv = nn.Conv2D( self.fpn_conv = nn.Conv2D(
in_channels=out_channel, in_channels=out_channel,
...@@ -106,5 +112,20 @@ class HRFPN(nn.Layer): ...@@ -106,5 +112,20 @@ class HRFPN(nn.Layer):
conv = conv_func(outs[i]) conv = conv_func(outs[i])
outputs.append(conv) outputs.append(conv)
fpn_feat = [outputs[k] for k in range(self.num_out)] fpn_feats = [outputs[k] for k in range(self.num_out)]
return fpn_feat, self.spatial_scale return fpn_feats
@classmethod
def from_config(cls, cfg, input_shape):
return {
'in_channels': [i.channels for i in input_shape],
'spatial_scales': [1.0 / i.stride for i in input_shape],
}
@property
def out_shape(self):
return [
ShapeSpec(
channels=self.out_channel, stride=1. / s)
for s in self.spatial_scales
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册