未验证 提交 3b564170 编写于 作者: S shangliang Xu 提交者: GitHub

[ssd] add MLPerf ssd model (#4055)

上级 bd0527b8
epoch: 70
LearningRate:
base_lr: 0.05
schedulers:
- !PiecewiseDecay
milestones: [48, 60]
gamma: [0.1, 0.1]
use_warmup: false
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
architecture: SSD
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet34_pretrained.pdparams
SSD:
backbone: ResNet
ssd_head: SSDHead
post_process: BBoxPostProcess
r34_backbone: True
ResNet:
# index 0 stands for res2
depth: 34
norm_type: bn
freeze_norm: False
freeze_at: -1
return_idx: [2]
num_stages: 3
SSDHead:
anchor_generator:
steps: [8, 16, 32, 64, 100, 300]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]]
min_sizes: [21.0, 45.0, 99.0, 153.0, 207.0, 261.0]
max_sizes: [45.0, 99.0, 153.0, 207.0, 261.0, 315.0]
offset: 0.5
clip: True
min_max_aspect_ratios_order: True
use_extra_head: True
BBoxPostProcess:
decode:
name: SSDBox
nms:
name: MultiClassNMS
keep_top_k: 200
score_threshold: 0.05
nms_threshold: 0.5
nms_top_k: 400
worker_num: 3
TrainReader:
inputs_def:
num_max_boxes: 90
sample_transforms:
- Decode: {}
- RandomCrop: {num_attempts: 1}
- RandomFlip: {}
- Resize: {target_size: [300, 300], keep_ratio: False, interp: 1}
- RandomDistort: {brightness: [0.875, 1.125, 0.5], random_apply: False}
- NormalizeBox: {}
- PadBox: {num_max_boxes: 90}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {}
batch_size: 64
shuffle: true
drop_last: true
use_shared_memory: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {}
batch_size: 1
TestReader:
inputs_def:
image_shape: [3, 300, 300]
sample_transforms:
- Decode: {}
- Resize: {target_size: [300, 300], keep_ratio: False, interp: 1}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_70e.yml',
'_base_/ssd_r34_300.yml',
'_base_/ssd_r34_reader.yml',
]
weights: output/ssd_r34_70e_coco/model_final
log_iter: 100
snapshot_epoch: 5
......@@ -36,11 +36,19 @@ class SSD(BaseArch):
__category__ = 'architecture'
__inject__ = ['post_process']
def __init__(self, backbone, ssd_head, post_process):
def __init__(self, backbone, ssd_head, post_process, r34_backbone=False):
super(SSD, self).__init__()
self.backbone = backbone
self.ssd_head = ssd_head
self.post_process = post_process
self.r34_backbone = r34_backbone
if self.r34_backbone:
from ppdet.modeling.backbones.resnet import ResNet
assert isinstance(self.backbone, ResNet) and \
self.backbone.depth == 34, \
"If you set r34_backbone=True, please use ResNet-34 as backbone."
self.backbone.res_layers[2].blocks[0].branch2a.conv._stride = [1, 1]
self.backbone.res_layers[2].blocks[0].short.conv._stride = [1, 1]
@classmethod
def from_config(cls, cfg, *args, **kwargs):
......
......@@ -28,7 +28,7 @@ class SepConvLayer(nn.Layer):
out_channels,
kernel_size=3,
padding=1,
conv_decay=0):
conv_decay=0.):
super(SepConvLayer, self).__init__()
self.dw_conv = nn.Conv2D(
in_channels=in_channels,
......@@ -61,6 +61,35 @@ class SepConvLayer(nn.Layer):
return x
class SSDExtraHead(nn.Layer):
def __init__(self,
in_channels=256,
out_channels=([256, 512], [256, 512], [128, 256], [128, 256],
[128, 256]),
strides=(2, 2, 2, 1, 1),
paddings=(1, 1, 1, 0, 0)):
super(SSDExtraHead, self).__init__()
self.convs = nn.LayerList()
for out_channel, stride, padding in zip(out_channels, strides,
paddings):
self.convs.append(
self._make_layers(in_channels, out_channel[0], out_channel[1],
stride, padding))
in_channels = out_channel[-1]
def _make_layers(self, c_in, c_hidden, c_out, stride_3x3, padding_3x3):
return nn.Sequential(
nn.Conv2D(c_in, c_hidden, 1),
nn.ReLU(),
nn.Conv2D(c_hidden, c_out, 3, stride_3x3, padding_3x3), nn.ReLU())
def forward(self, x):
out = [x]
for conv_layer in self.convs:
out.append(conv_layer(out[-1]))
return out
@register
class SSDHead(nn.Layer):
"""
......@@ -75,6 +104,7 @@ class SSDHead(nn.Layer):
use_sepconv (bool): Use SepConvLayer if true
conv_decay (float): Conv regularization coeff
loss (object): 'SSDLoss' instance
use_extra_head (bool): If use ResNet34 as baskbone, you should set `use_extra_head`=True
"""
__shared__ = ['num_classes']
......@@ -88,13 +118,19 @@ class SSDHead(nn.Layer):
padding=1,
use_sepconv=False,
conv_decay=0.,
loss='SSDLoss'):
loss='SSDLoss',
use_extra_head=False):
super(SSDHead, self).__init__()
# add background class
self.num_classes = num_classes + 1
self.in_channels = in_channels
self.anchor_generator = anchor_generator
self.loss = loss
self.use_extra_head = use_extra_head
if self.use_extra_head:
self.ssd_extra_head = SSDExtraHead()
self.in_channels = [256, 512, 512, 256, 256, 256]
if isinstance(anchor_generator, dict):
self.anchor_generator = AnchorGeneratorSSD(**anchor_generator)
......@@ -108,7 +144,7 @@ class SSDHead(nn.Layer):
box_conv = self.add_sublayer(
box_conv_name,
nn.Conv2D(
in_channels=in_channels[i],
in_channels=self.in_channels[i],
out_channels=num_prior * 4,
kernel_size=kernel_size,
padding=padding))
......@@ -116,7 +152,7 @@ class SSDHead(nn.Layer):
box_conv = self.add_sublayer(
box_conv_name,
SepConvLayer(
in_channels=in_channels[i],
in_channels=self.in_channels[i],
out_channels=num_prior * 4,
kernel_size=kernel_size,
padding=padding,
......@@ -128,7 +164,7 @@ class SSDHead(nn.Layer):
score_conv = self.add_sublayer(
score_conv_name,
nn.Conv2D(
in_channels=in_channels[i],
in_channels=self.in_channels[i],
out_channels=num_prior * self.num_classes,
kernel_size=kernel_size,
padding=padding))
......@@ -136,7 +172,7 @@ class SSDHead(nn.Layer):
score_conv = self.add_sublayer(
score_conv_name,
SepConvLayer(
in_channels=in_channels[i],
in_channels=self.in_channels[i],
out_channels=num_prior * self.num_classes,
kernel_size=kernel_size,
padding=padding,
......@@ -148,9 +184,13 @@ class SSDHead(nn.Layer):
return {'in_channels': [i.channels for i in input_shape], }
def forward(self, feats, image, gt_bbox=None, gt_class=None):
if self.use_extra_head:
assert len(feats) == 1, \
("If you set use_extra_head=True, backbone feature "
"list length should be 1.")
feats = self.ssd_extra_head(feats[0])
box_preds = []
cls_scores = []
prior_boxes = []
for feat, box_conv, score_conv in zip(feats, self.box_convs,
self.score_convs):
box_pred = box_conv(feat)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册