未验证 提交 4b7917cd 编写于 作者: K Kaipeng Deng 提交者: GitHub

From config for single stage model YOLOv3/PPYOLO/SSD (#2112)

* fit for YOLOvd/PPYOLO/SSD
上级 c82274bb
...@@ -22,7 +22,6 @@ ResNet: ...@@ -22,7 +22,6 @@ ResNet:
norm_decay: 0. norm_decay: 0.
PPYOLOFPN: PPYOLOFPN:
feat_channels: [2048, 1280, 640]
coord_conv: true coord_conv: true
drop_block: true drop_block: true
block_size: 3 block_size: 3
......
...@@ -16,22 +16,19 @@ MobileNet: ...@@ -16,22 +16,19 @@ MobileNet:
feature_maps: [11, 13, 14, 15, 16, 17] feature_maps: [11, 13, 14, 15, 16, 17]
SSDHead: SSDHead:
in_channels: [512, 1024, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
kernel_size: 1 kernel_size: 1
padding: 0 padding: 0
anchor_generator:
AnchorGeneratorSSD: steps: [0, 0, 0, 0, 0, 0]
steps: [0, 0, 0, 0, 0, 0] aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]] min_ratio: 20
min_ratio: 20 max_ratio: 90
max_ratio: 90 base_size: 300
base_size: 300 min_sizes: [60.0, 105.0, 150.0, 195.0, 240.0, 285.0]
min_sizes: [60.0, 105.0, 150.0, 195.0, 240.0, 285.0] max_sizes: [[], 150.0, 195.0, 240.0, 285.0, 300.0]
max_sizes: [[], 150.0, 195.0, 240.0, 285.0, 300.0] offset: 0.5
offset: 0.5 flip: true
flip: true min_max_aspect_ratios_order: false
min_max_aspect_ratios_order: false
BBoxPostProcess: BBoxPostProcess:
decode: decode:
......
architecture: SSD architecture: SSD
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/VGG16_caffe_pretrained.pdparams pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/VGG16_caffe_pretrained.pdparams
load_static_weights: True
# Model Achitecture # Model Achitecture
SSD: SSD:
...@@ -15,19 +14,16 @@ VGG: ...@@ -15,19 +14,16 @@ VGG:
normalizations: [20., -1, -1, -1, -1, -1] normalizations: [20., -1, -1, -1, -1, -1]
SSDHead: SSDHead:
in_channels: [512, 1024, 512, 256, 256, 256] anchor_generator:
anchor_generator: AnchorGeneratorSSD steps: [8, 16, 32, 64, 100, 300]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]]
AnchorGeneratorSSD: min_ratio: 20
steps: [8, 16, 32, 64, 100, 300] max_ratio: 90
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]] min_sizes: [30.0, 60.0, 111.0, 162.0, 213.0, 264.0]
min_ratio: 20 max_sizes: [60.0, 111.0, 162.0, 213.0, 264.0, 315.0]
max_ratio: 90 offset: 0.5
min_sizes: [30.0, 60.0, 111.0, 162.0, 213.0, 264.0] flip: true
max_sizes: [60.0, 111.0, 162.0, 213.0, 264.0, 315.0] min_max_aspect_ratios_order: true
offset: 0.5
flip: true
min_max_aspect_ratios_order: true
BBoxPostProcess: BBoxPostProcess:
decode: decode:
......
...@@ -15,23 +15,20 @@ MobileNet: ...@@ -15,23 +15,20 @@ MobileNet:
feature_maps: [11, 13, 14, 15, 16, 17] feature_maps: [11, 13, 14, 15, 16, 17]
SSDHead: SSDHead:
in_channels: [512, 1024, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
use_sepconv: True use_sepconv: True
conv_decay: 0.00004 conv_decay: 0.00004
anchor_generator:
AnchorGeneratorSSD: steps: [16, 32, 64, 100, 150, 300]
steps: [16, 32, 64, 100, 150, 300] aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]] min_ratio: 20
min_ratio: 20 max_ratio: 95
max_ratio: 95 base_size: 300
base_size: 300 min_sizes: []
min_sizes: [] max_sizes: []
max_sizes: [] offset: 0.5
offset: 0.5 flip: true
flip: true clip: true
clip: true min_max_aspect_ratios_order: False
min_max_aspect_ratios_order: False
BBoxPostProcess: BBoxPostProcess:
decode: decode:
......
...@@ -18,23 +18,20 @@ MobileNetV3: ...@@ -18,23 +18,20 @@ MobileNetV3:
multiplier: 0.5 multiplier: 0.5
SSDHead: SSDHead:
in_channels: [672, 480, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
use_sepconv: True use_sepconv: True
conv_decay: 0.00004 conv_decay: 0.00004
anchor_generator:
AnchorGeneratorSSD: steps: [16, 32, 64, 107, 160, 320]
steps: [16, 32, 64, 107, 160, 320] aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]] min_ratio: 20
min_ratio: 20 max_ratio: 95
max_ratio: 95 base_size: 320
base_size: 320 min_sizes: []
min_sizes: [] max_sizes: []
max_sizes: [] offset: 0.5
offset: 0.5 flip: true
flip: true clip: true
clip: true min_max_aspect_ratios_order: false
min_max_aspect_ratios_order: false
BBoxPostProcess: BBoxPostProcess:
decode: decode:
......
...@@ -18,23 +18,20 @@ MobileNetV3: ...@@ -18,23 +18,20 @@ MobileNetV3:
multiplier: 0.5 multiplier: 0.5
SSDHead: SSDHead:
in_channels: [288, 288, 512, 256, 256, 128]
anchor_generator: AnchorGeneratorSSD
use_sepconv: True use_sepconv: True
conv_decay: 0.00004 conv_decay: 0.00004
anchor_generator:
AnchorGeneratorSSD: steps: [16, 32, 64, 107, 160, 320]
steps: [16, 32, 64, 107, 160, 320] aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]]
aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]] min_ratio: 20
min_ratio: 20 max_ratio: 95
max_ratio: 95 base_size: 320
base_size: 320 min_sizes: []
min_sizes: [] max_sizes: []
max_sizes: [] offset: 0.5
offset: 0.5 flip: true
flip: true clip: true
clip: true min_max_aspect_ratios_order: false
min_max_aspect_ratios_order: false
BBoxPostProcess: BBoxPostProcess:
decode: decode:
......
...@@ -14,8 +14,8 @@ DarkNet: ...@@ -14,8 +14,8 @@ DarkNet:
depth: 53 depth: 53
return_idx: [2, 3, 4] return_idx: [2, 3, 4]
YOLOv3FPN: # use default config
feat_channels: [1024, 768, 384] # YOLOv3FPN:
YOLOv3Head: YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23], anchors: [[10, 13], [16, 30], [33, 23],
......
...@@ -15,8 +15,8 @@ MobileNet: ...@@ -15,8 +15,8 @@ MobileNet:
with_extra_blocks: false with_extra_blocks: false
extra_block_filters: [] extra_block_filters: []
YOLOv3FPN: # use default config
feat_channels: [1024, 768, 384] # YOLOv3FPN:
YOLOv3Head: YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23], anchors: [[10, 13], [16, 30], [33, 23],
......
...@@ -16,8 +16,8 @@ MobileNetV3: ...@@ -16,8 +16,8 @@ MobileNetV3:
extra_block_filters: [] extra_block_filters: []
feature_maps: [7, 13, 16] feature_maps: [7, 13, 16]
YOLOv3FPN: # use default config
feat_channels: [160, 368, 168] # YOLOv3FPN:
YOLOv3Head: YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23], anchors: [[10, 13], [16, 30], [33, 23],
......
...@@ -16,8 +16,8 @@ MobileNetV3: ...@@ -16,8 +16,8 @@ MobileNetV3:
extra_block_filters: [] extra_block_filters: []
feature_maps: [4, 9, 12] feature_maps: [4, 9, 12]
YOLOv3FPN: # use default config
feat_channels: [96, 304, 152] # YOLOv3FPN:
YOLOv3Head: YOLOv3Head:
anchors: [[10, 13], [16, 30], [33, 23], anchors: [[10, 13], [16, 30], [33, 23],
......
...@@ -2,7 +2,7 @@ from __future__ import absolute_import ...@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from ppdet.core.workspace import register from ppdet.core.workspace import register, create
from .meta_arch import BaseArch from .meta_arch import BaseArch
__all__ = ['SSD'] __all__ = ['SSD']
...@@ -11,38 +11,47 @@ __all__ = ['SSD'] ...@@ -11,38 +11,47 @@ __all__ = ['SSD']
@register @register
class SSD(BaseArch): class SSD(BaseArch):
__category__ = 'architecture' __category__ = 'architecture'
__inject__ = ['backbone', 'neck', 'ssd_head', 'post_process'] __inject__ = ['post_process']
def __init__(self, backbone, ssd_head, post_process, neck=None): def __init__(self, backbone, ssd_head, post_process):
super(SSD, self).__init__() super(SSD, self).__init__()
self.backbone = backbone self.backbone = backbone
self.neck = neck
self.ssd_head = ssd_head self.ssd_head = ssd_head
self.post_process = post_process self.post_process = post_process
def model_arch(self): @classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# head
kwargs = {'input_shape': backbone.out_shape}
ssd_head = create(cfg['ssd_head'], **kwargs)
return {
'backbone': backbone,
"ssd_head": ssd_head,
}
def _forward(self):
# Backbone # Backbone
body_feats = self.backbone(self.inputs) body_feats = self.backbone(self.inputs)
# Neck
if self.neck is not None:
body_feats, spatial_scale = self.neck(body_feats)
# SSD Head # SSD Head
self.ssd_head_outs, self.anchors = self.ssd_head(body_feats, if self.training:
self.inputs['image']) return self.ssd_head(body_feats, self.inputs['image'],
self.inputs['gt_bbox'],
self.inputs['gt_class'])
else:
boxes, scores, anchors = self.ssd_head(body_feats,
self.inputs['image'])
bbox, bbox_num = self.post_process((boxes, scores), anchors,
self.inputs['im_shape'],
self.inputs['scale_factor'])
return bbox, bbox_num
def get_loss(self, ): def get_loss(self, ):
loss = self.ssd_head.get_loss(self.ssd_head_outs, self.inputs, return {"loss": self._forward()}
self.anchors)
return {"loss": loss}
def get_pred(self): def get_pred(self):
bbox, bbox_num = self.post_process(self.ssd_head_outs, self.anchors, return dict(zip(['bbox', 'bbox_num'], self._forward()))
self.inputs['im_shape'],
self.inputs['scale_factor'])
outs = {
"bbox": bbox,
"bbox_num": bbox_num,
}
return outs
...@@ -2,7 +2,7 @@ from __future__ import absolute_import ...@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from ppdet.core.workspace import register from ppdet.core.workspace import register, create
from .meta_arch import BaseArch from .meta_arch import BaseArch
__all__ = ['YOLOv3'] __all__ = ['YOLOv3']
...@@ -11,12 +11,7 @@ __all__ = ['YOLOv3'] ...@@ -11,12 +11,7 @@ __all__ = ['YOLOv3']
@register @register
class YOLOv3(BaseArch): class YOLOv3(BaseArch):
__category__ = 'architecture' __category__ = 'architecture'
__inject__ = [ __inject__ = ['post_process']
'backbone',
'neck',
'yolo_head',
'post_process',
]
def __init__(self, def __init__(self,
backbone='DarkNet', backbone='DarkNet',
...@@ -29,27 +24,50 @@ class YOLOv3(BaseArch): ...@@ -29,27 +24,50 @@ class YOLOv3(BaseArch):
self.yolo_head = yolo_head self.yolo_head = yolo_head
self.post_process = post_process self.post_process = post_process
def model_arch(self, ): @classmethod
# Backbone def from_config(cls, cfg, *args, **kwargs):
body_feats = self.backbone(self.inputs) # backbone
backbone = create(cfg['backbone'])
# fpn
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
# head
kwargs = {'input_shape': neck.out_shape}
yolo_head = create(cfg['yolo_head'], **kwargs)
# neck return {
'backbone': backbone,
'neck': neck,
"yolo_head": yolo_head,
}
def _forward(self):
body_feats = self.backbone(self.inputs)
body_feats = self.neck(body_feats) body_feats = self.neck(body_feats)
# YOLO Head if self.training:
self.yolo_head_outs = self.yolo_head(body_feats) return self.yolo_head(body_feats, self.inputs)
else:
yolo_head_outs = self.yolo_head(body_feats)
bbox, bbox_num = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
return bbox, bbox_num
def get_loss(self, ): def get_loss(self):
loss = self.yolo_head.get_loss(self.yolo_head_outs, self.inputs) return self._forward()
return loss
def get_pred(self): def get_pred(self):
yolo_head_outs = self.yolo_head.get_outputs(self.yolo_head_outs) bbox_pred, bbox_num = self._forward()
bbox, bbox_num = self.post_process( label = bbox_pred[:, 0]
yolo_head_outs, self.yolo_head.mask_anchors, score = bbox_pred[:, 1]
self.inputs['im_shape'], self.inputs['scale_factor']) bbox = bbox_pred[:, 2:]
outs = { output = {
"bbox": bbox, 'bbox': bbox,
"bbox_num": bbox_num, 'score': score,
'label': label,
'bbox_num': bbox_num
} }
return outs return output
...@@ -19,6 +19,7 @@ from paddle import ParamAttr ...@@ -19,6 +19,7 @@ from paddle import ParamAttr
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ppdet.modeling.ops import batch_norm from ppdet.modeling.ops import batch_norm
from ..shape_spec import ShapeSpec
__all__ = ['DarkNet', 'ConvBNLayer'] __all__ = ['DarkNet', 'ConvBNLayer']
...@@ -193,6 +194,7 @@ class DarkNet(nn.Layer): ...@@ -193,6 +194,7 @@ class DarkNet(nn.Layer):
norm_decay=norm_decay, norm_decay=norm_decay,
name='yolo_input.downsample') name='yolo_input.downsample')
self._out_channels = []
self.darknet_conv_block_list = [] self.darknet_conv_block_list = []
self.downsample_list = [] self.downsample_list = []
ch_in = [64, 128, 256, 512, 1024] ch_in = [64, 128, 256, 512, 1024]
...@@ -208,6 +210,8 @@ class DarkNet(nn.Layer): ...@@ -208,6 +210,8 @@ class DarkNet(nn.Layer):
norm_decay=norm_decay, norm_decay=norm_decay,
name=name)) name=name))
self.darknet_conv_block_list.append(conv_block) self.darknet_conv_block_list.append(conv_block)
if i in return_idx:
self._out_channels.append(64 * (2**i))
for i in range(num_stages - 1): for i in range(num_stages - 1):
down_name = 'stage.{}.downsample'.format(i) down_name = 'stage.{}.downsample'.format(i)
downsample = self.add_sublayer( downsample = self.add_sublayer(
...@@ -235,3 +239,7 @@ class DarkNet(nn.Layer): ...@@ -235,3 +239,7 @@ class DarkNet(nn.Layer):
if i < self.num_stages - 1: if i < self.num_stages - 1:
out = self.downsample_list[i](out) out = self.downsample_list[i](out)
return blocks return blocks
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
...@@ -24,6 +24,7 @@ from paddle.regularizer import L2Decay ...@@ -24,6 +24,7 @@ from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal from paddle.nn.initializer import KaimingNormal
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from numbers import Integral from numbers import Integral
from ..shape_spec import ShapeSpec
__all__ = ['MobileNet'] __all__ = ['MobileNet']
...@@ -201,6 +202,8 @@ class MobileNet(nn.Layer): ...@@ -201,6 +202,8 @@ class MobileNet(nn.Layer):
self.with_extra_blocks = with_extra_blocks self.with_extra_blocks = with_extra_blocks
self.extra_block_filters = extra_block_filters self.extra_block_filters = extra_block_filters
self._out_channels = []
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
in_channels=3, in_channels=3,
out_channels=int(32 * scale), out_channels=int(32 * scale),
...@@ -229,6 +232,7 @@ class MobileNet(nn.Layer): ...@@ -229,6 +232,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv2_1")) name="conv2_1"))
self.dwsl.append(dws21) self.dwsl.append(dws21)
self._update_out_channels(64, len(self.dwsl), feature_maps)
dws22 = self.add_sublayer( dws22 = self.add_sublayer(
"conv2_2", "conv2_2",
sublayer=DepthwiseSeparable( sublayer=DepthwiseSeparable(
...@@ -244,6 +248,7 @@ class MobileNet(nn.Layer): ...@@ -244,6 +248,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv2_2")) name="conv2_2"))
self.dwsl.append(dws22) self.dwsl.append(dws22)
self._update_out_channels(128, len(self.dwsl), feature_maps)
# 1/4 # 1/4
dws31 = self.add_sublayer( dws31 = self.add_sublayer(
"conv3_1", "conv3_1",
...@@ -260,6 +265,7 @@ class MobileNet(nn.Layer): ...@@ -260,6 +265,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv3_1")) name="conv3_1"))
self.dwsl.append(dws31) self.dwsl.append(dws31)
self._update_out_channels(128, len(self.dwsl), feature_maps)
dws32 = self.add_sublayer( dws32 = self.add_sublayer(
"conv3_2", "conv3_2",
sublayer=DepthwiseSeparable( sublayer=DepthwiseSeparable(
...@@ -275,6 +281,7 @@ class MobileNet(nn.Layer): ...@@ -275,6 +281,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv3_2")) name="conv3_2"))
self.dwsl.append(dws32) self.dwsl.append(dws32)
self._update_out_channels(256, len(self.dwsl), feature_maps)
# 1/8 # 1/8
dws41 = self.add_sublayer( dws41 = self.add_sublayer(
"conv4_1", "conv4_1",
...@@ -291,6 +298,7 @@ class MobileNet(nn.Layer): ...@@ -291,6 +298,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv4_1")) name="conv4_1"))
self.dwsl.append(dws41) self.dwsl.append(dws41)
self._update_out_channels(256, len(self.dwsl), feature_maps)
dws42 = self.add_sublayer( dws42 = self.add_sublayer(
"conv4_2", "conv4_2",
sublayer=DepthwiseSeparable( sublayer=DepthwiseSeparable(
...@@ -306,6 +314,7 @@ class MobileNet(nn.Layer): ...@@ -306,6 +314,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv4_2")) name="conv4_2"))
self.dwsl.append(dws42) self.dwsl.append(dws42)
self._update_out_channels(512, len(self.dwsl), feature_maps)
# 1/16 # 1/16
for i in range(5): for i in range(5):
tmp = self.add_sublayer( tmp = self.add_sublayer(
...@@ -323,6 +332,7 @@ class MobileNet(nn.Layer): ...@@ -323,6 +332,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv5_" + str(i + 1))) name="conv5_" + str(i + 1)))
self.dwsl.append(tmp) self.dwsl.append(tmp)
self._update_out_channels(512, len(self.dwsl), feature_maps)
dws56 = self.add_sublayer( dws56 = self.add_sublayer(
"conv5_6", "conv5_6",
sublayer=DepthwiseSeparable( sublayer=DepthwiseSeparable(
...@@ -338,6 +348,7 @@ class MobileNet(nn.Layer): ...@@ -338,6 +348,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv5_6")) name="conv5_6"))
self.dwsl.append(dws56) self.dwsl.append(dws56)
self._update_out_channels(1024, len(self.dwsl), feature_maps)
# 1/32 # 1/32
dws6 = self.add_sublayer( dws6 = self.add_sublayer(
"conv6", "conv6",
...@@ -354,6 +365,7 @@ class MobileNet(nn.Layer): ...@@ -354,6 +365,7 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv6")) name="conv6"))
self.dwsl.append(dws6) self.dwsl.append(dws6)
self._update_out_channels(1024, len(self.dwsl), feature_maps)
if self.with_extra_blocks: if self.with_extra_blocks:
self.extra_blocks = [] self.extra_blocks = []
...@@ -371,6 +383,13 @@ class MobileNet(nn.Layer): ...@@ -371,6 +383,13 @@ class MobileNet(nn.Layer):
norm_type=norm_type, norm_type=norm_type,
name="conv7_" + str(i + 1))) name="conv7_" + str(i + 1)))
self.extra_blocks.append(conv_extra) self.extra_blocks.append(conv_extra)
self._update_out_channels(
block_filter[1],
len(self.dwsl) + len(self.extra_blocks), feature_maps)
def _update_out_channels(self, channel, feature_idx, feature_maps):
if feature_idx in feature_maps:
self._out_channels.append(channel)
def forward(self, inputs): def forward(self, inputs):
outs = [] outs = []
...@@ -390,3 +409,7 @@ class MobileNet(nn.Layer): ...@@ -390,3 +409,7 @@ class MobileNet(nn.Layer):
if idx + 1 in self.feature_maps: if idx + 1 in self.feature_maps:
outs.append(y) outs.append(y)
return outs return outs
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
...@@ -23,6 +23,7 @@ from paddle import ParamAttr ...@@ -23,6 +23,7 @@ from paddle import ParamAttr
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from numbers import Integral from numbers import Integral
from ..shape_spec import ShapeSpec
__all__ = ['MobileNetV3'] __all__ = ['MobileNetV3']
...@@ -383,6 +384,7 @@ class MobileNetV3(nn.Layer): ...@@ -383,6 +384,7 @@ class MobileNetV3(nn.Layer):
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
name="conv1") name="conv1")
self._out_channels = []
self.block_list = [] self.block_list = []
i = 0 i = 0
inplanes = make_divisible(inplanes * scale) inplanes = make_divisible(inplanes * scale)
...@@ -413,6 +415,9 @@ class MobileNetV3(nn.Layer): ...@@ -413,6 +415,9 @@ class MobileNetV3(nn.Layer):
self.block_list.append(block) self.block_list.append(block)
inplanes = make_divisible(scale * c) inplanes = make_divisible(scale * c)
i += 1 i += 1
self._update_out_channels(
make_divisible(scale * exp)
if return_list else inplanes, i + 1, feature_maps)
if self.with_extra_blocks: if self.with_extra_blocks:
self.extra_block_list = [] self.extra_block_list = []
...@@ -438,6 +443,7 @@ class MobileNetV3(nn.Layer): ...@@ -438,6 +443,7 @@ class MobileNetV3(nn.Layer):
name="conv" + str(i + 2))) name="conv" + str(i + 2)))
self.extra_block_list.append(conv_extra) self.extra_block_list.append(conv_extra)
i += 1 i += 1
self._update_out_channels(extra_out_c, i + 1, feature_maps)
for j, block_filter in enumerate(self.extra_block_filters): for j, block_filter in enumerate(self.extra_block_filters):
in_c = extra_out_c if j == 0 else self.extra_block_filters[j - in_c = extra_out_c if j == 0 else self.extra_block_filters[j -
...@@ -457,6 +463,11 @@ class MobileNetV3(nn.Layer): ...@@ -457,6 +463,11 @@ class MobileNetV3(nn.Layer):
name='conv' + str(i + 2))) name='conv' + str(i + 2)))
self.extra_block_list.append(conv_extra) self.extra_block_list.append(conv_extra)
i += 1 i += 1
self._update_out_channels(block_filter[1], i + 1, feature_maps)
def _update_out_channels(self, channel, feature_idx, feature_maps):
if feature_idx in feature_maps:
self._out_channels.append(channel)
def forward(self, inputs): def forward(self, inputs):
x = self.conv1(inputs['image']) x = self.conv1(inputs['image'])
...@@ -479,3 +490,7 @@ class MobileNetV3(nn.Layer): ...@@ -479,3 +490,7 @@ class MobileNetV3(nn.Layer):
if idx + 2 in self.feature_maps: if idx + 2 in self.feature_maps:
outs.append(x) outs.append(x)
return outs return outs
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
...@@ -7,6 +7,7 @@ from paddle import ParamAttr ...@@ -7,6 +7,7 @@ from paddle import ParamAttr
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle.nn import Conv2D, MaxPool2D from paddle.nn import Conv2D, MaxPool2D
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ..shape_spec import ShapeSpec
__all__ = ['VGG'] __all__ = ['VGG']
...@@ -129,6 +130,8 @@ class VGG(nn.Layer): ...@@ -129,6 +130,8 @@ class VGG(nn.Layer):
self.normalizations = normalizations self.normalizations = normalizations
self.extra_block_filters = extra_block_filters self.extra_block_filters = extra_block_filters
self._out_channels = []
self.conv_block_0 = ConvBlock( self.conv_block_0 = ConvBlock(
3, 64, self.groups[0], 2, 2, 0, name="conv1_") 3, 64, self.groups[0], 2, 2, 0, name="conv1_")
self.conv_block_1 = ConvBlock( self.conv_block_1 = ConvBlock(
...@@ -139,6 +142,7 @@ class VGG(nn.Layer): ...@@ -139,6 +142,7 @@ class VGG(nn.Layer):
256, 512, self.groups[3], 2, 2, 0, name="conv4_") 256, 512, self.groups[3], 2, 2, 0, name="conv4_")
self.conv_block_4 = ConvBlock( self.conv_block_4 = ConvBlock(
512, 512, self.groups[4], 3, 1, 1, name="conv5_") 512, 512, self.groups[4], 3, 1, 1, name="conv5_")
self._out_channels.append(512)
self.fc6 = Conv2D( self.fc6 = Conv2D(
in_channels=512, in_channels=512,
...@@ -153,6 +157,7 @@ class VGG(nn.Layer): ...@@ -153,6 +157,7 @@ class VGG(nn.Layer):
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0) padding=0)
self._out_channels.append(1024)
# extra block # extra block
self.extra_convs = [] self.extra_convs = []
...@@ -164,6 +169,7 @@ class VGG(nn.Layer): ...@@ -164,6 +169,7 @@ class VGG(nn.Layer):
v[2], v[3], v[4])) v[2], v[3], v[4]))
last_channels = v[1] last_channels = v[1]
self.extra_convs.append(extra_conv) self.extra_convs.append(extra_conv)
self._out_channels.append(last_channels)
self.norms = [] self.norms = []
for i, n in enumerate(self.normalizations): for i, n in enumerate(self.normalizations):
...@@ -192,7 +198,7 @@ class VGG(nn.Layer): ...@@ -192,7 +198,7 @@ class VGG(nn.Layer):
outputs.append(out) outputs.append(out)
if not self.extra_block_filters: if not self.extra_block_filters:
return out return outputs
# extra block # extra block
for extra_conv in self.extra_convs: for extra_conv in self.extra_convs:
...@@ -204,3 +210,7 @@ class VGG(nn.Layer): ...@@ -204,3 +210,7 @@ class VGG(nn.Layer):
outputs[i] = self.norms[i](outputs[i]) outputs[i] = self.norms[i](outputs[i])
return outputs return outputs
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
...@@ -5,6 +5,8 @@ from ppdet.core.workspace import register ...@@ -5,6 +5,8 @@ from ppdet.core.workspace import register
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle import ParamAttr from paddle import ParamAttr
from ..layers import AnchorGeneratorSSD
class SepConvLayer(nn.Layer): class SepConvLayer(nn.Layer):
def __init__(self, def __init__(self,
...@@ -58,7 +60,7 @@ class SSDHead(nn.Layer): ...@@ -58,7 +60,7 @@ class SSDHead(nn.Layer):
def __init__(self, def __init__(self,
num_classes=81, num_classes=81,
in_channels=(512, 1024, 512, 256, 256, 256), in_channels=(512, 1024, 512, 256, 256, 256),
anchor_generator='AnchorGeneratorSSD', anchor_generator=AnchorGeneratorSSD().__dict__,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
use_sepconv=False, use_sepconv=False,
...@@ -69,8 +71,11 @@ class SSDHead(nn.Layer): ...@@ -69,8 +71,11 @@ class SSDHead(nn.Layer):
self.in_channels = in_channels self.in_channels = in_channels
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
self.loss = loss self.loss = loss
self.num_priors = self.anchor_generator.num_priors
if isinstance(anchor_generator, dict):
self.anchor_generator = AnchorGeneratorSSD(**anchor_generator)
self.num_priors = self.anchor_generator.num_priors
self.box_convs = [] self.box_convs = []
self.score_convs = [] self.score_convs = []
for i, num_prior in enumerate(self.num_priors): for i, num_prior in enumerate(self.num_priors):
...@@ -116,7 +121,11 @@ class SSDHead(nn.Layer): ...@@ -116,7 +121,11 @@ class SSDHead(nn.Layer):
name=score_conv_name)) name=score_conv_name))
self.score_convs.append(score_conv) self.score_convs.append(score_conv)
def forward(self, feats, image): @classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
def forward(self, feats, image, gt_bbox=None, gt_class=None):
box_preds = [] box_preds = []
cls_scores = [] cls_scores = []
prior_boxes = [] prior_boxes = []
...@@ -134,10 +143,11 @@ class SSDHead(nn.Layer): ...@@ -134,10 +143,11 @@ class SSDHead(nn.Layer):
prior_boxes = self.anchor_generator(feats, image) prior_boxes = self.anchor_generator(feats, image)
outputs = {} if self.training:
outputs['boxes'] = box_preds return self.get_loss(box_preds, cls_scores, gt_bbox, gt_class,
outputs['scores'] = cls_scores prior_boxes)
return outputs, prior_boxes else:
return box_preds, cls_scores, prior_boxes
def get_loss(self, inputs, targets, prior_boxes): def get_loss(self, boxes, scores, gt_bbox, gt_class, prior_boxes):
return self.loss(inputs, targets, prior_boxes) return self.loss(boxes, scores, gt_bbox, gt_class, prior_boxes)
...@@ -67,38 +67,36 @@ class YOLOv3Head(nn.Layer): ...@@ -67,38 +67,36 @@ class YOLOv3Head(nn.Layer):
assert mask < anchor_num, "anchor mask index overflow" assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask]) self.mask_anchors[-1].extend(anchors[mask])
def forward(self, feats): def forward(self, feats, targets=None):
assert len(feats) == len(self.anchors) assert len(feats) == len(self.anchors)
yolo_outputs = [] yolo_outputs = []
for i, feat in enumerate(feats): for i, feat in enumerate(feats):
yolo_output = self.yolo_outputs[i](feat) yolo_output = self.yolo_outputs[i](feat)
yolo_outputs.append(yolo_output) yolo_outputs.append(yolo_output)
return yolo_outputs
def get_loss(self, inputs, targets): if self.training:
return self.loss(inputs, targets, self.anchors) return self.loss(yolo_outputs, targets, self.anchors)
def get_outputs(self, outputs):
if self.iou_aware:
y = []
for i, out in enumerate(outputs):
na = len(self.anchors[i])
ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
b, c, h, w = x.shape
no = c // na
x = x.reshape((b, na, no, h * w))
ioup = ioup.reshape((b, na, 1, h * w))
obj = x[:, :, 4:5, :]
ioup = F.sigmoid(ioup)
obj = F.sigmoid(obj)
obj_t = (obj**(1 - self.iou_aware_factor)) * (
ioup**self.iou_aware_factor)
obj_t = _de_sigmoid(obj_t)
loc_t = x[:, :, :4, :]
cls_t = x[:, :, 5:, :]
y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
y_t = y_t.reshape((b, c, h, w))
y.append(y_t)
return y
else: else:
return outputs if self.iou_aware:
y = []
for i, out in enumerate(yolo_outputs):
na = len(self.anchors[i])
ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
b, c, h, w = x.shape
no = c // na
x = x.reshape((b, na, no, h * w))
ioup = ioup.reshape((b, na, 1, h * w))
obj = x[:, :, 4:5, :]
ioup = F.sigmoid(ioup)
obj = F.sigmoid(obj)
obj_t = (obj**(1 - self.iou_aware_factor)) * (
ioup**self.iou_aware_factor)
obj_t = _de_sigmoid(obj_t)
loc_t = x[:, :, :4, :]
cls_t = x[:, :, 5:, :]
y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
y_t = y_t.reshape((b, c, h, w))
y.append(y_t)
return y
else:
return yolo_outputs
...@@ -403,7 +403,7 @@ class MatrixNMS(object): ...@@ -403,7 +403,7 @@ class MatrixNMS(object):
self.gaussian_sigma = gaussian_sigma self.gaussian_sigma = gaussian_sigma
self.background_label = background_label self.background_label = background_label
def __call__(self, bbox, score): def __call__(self, bbox, score, *args):
return ops.matrix_nms( return ops.matrix_nms(
bboxes=bbox, bboxes=bbox,
scores=score, scores=score,
...@@ -469,7 +469,7 @@ class SSDBox(object): ...@@ -469,7 +469,7 @@ class SSDBox(object):
im_shape, im_shape,
scale_factor, scale_factor,
var_weight=None): var_weight=None):
boxes, scores = preds['boxes'], preds['scores'] boxes, scores = preds
outputs = [] outputs = []
for box, score, prior_box in zip(boxes, scores, prior_boxes): for box, score, prior_box in zip(boxes, scores, prior_boxes):
pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta
......
...@@ -109,12 +109,11 @@ class SSDLoss(nn.Layer): ...@@ -109,12 +109,11 @@ class SSDLoss(nn.Layer):
neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype) neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype)
return neg_mask return neg_mask
def forward(self, inputs, targets, anchors): def forward(self, boxes, scores, gt_box, gt_class, anchors):
boxes = paddle.concat(inputs['boxes'], axis=1) boxes = paddle.concat(boxes, axis=1)
scores = paddle.concat(inputs['scores'], axis=1) scores = paddle.concat(scores, axis=1)
prior_boxes = paddle.concat(anchors, axis=0) prior_boxes = paddle.concat(anchors, axis=0)
gt_box = targets['gt_bbox'] gt_label = gt_class.unsqueeze(-1)
gt_label = targets['gt_class'].unsqueeze(-1)
batch_size, num_priors, num_classes = scores.shape batch_size, num_priors, num_classes = scores.shape
def _reshape_to_2d(x): def _reshape_to_2d(x):
......
...@@ -23,6 +23,8 @@ from paddle.regularizer import L2Decay ...@@ -23,6 +23,8 @@ from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ..shape_spec import ShapeSpec from ..shape_spec import ShapeSpec
__all__ = ['FPN']
@register @register
@serializable @serializable
......
...@@ -25,6 +25,8 @@ from ppdet.modeling.layers import DeformableConvV2 ...@@ -25,6 +25,8 @@ from ppdet.modeling.layers import DeformableConvV2
import math import math
from ppdet.modeling.ops import batch_norm from ppdet.modeling.ops import batch_norm
__all__ = ['TTFFPN']
class Upsample(nn.Layer): class Upsample(nn.Layer):
def __init__(self, ch_in, ch_out, name=None): def __init__(self, ch_in, ch_out, name=None):
......
...@@ -20,6 +20,10 @@ from ppdet.core.workspace import register, serializable ...@@ -20,6 +20,10 @@ from ppdet.core.workspace import register, serializable
from ..backbones.darknet import ConvBNLayer from ..backbones.darknet import ConvBNLayer
import numpy as np import numpy as np
from ..shape_spec import ShapeSpec
__all__ = ['YOLOv3FPN', 'PPYOLOFPN']
class YoloDetBlock(nn.Layer): class YoloDetBlock(nn.Layer):
def __init__(self, ch_in, channel, norm_type, name): def __init__(self, ch_in, channel, norm_type, name):
...@@ -163,23 +167,30 @@ class PPYOLODetBlock(nn.Layer): ...@@ -163,23 +167,30 @@ class PPYOLODetBlock(nn.Layer):
class YOLOv3FPN(nn.Layer): class YOLOv3FPN(nn.Layer):
__shared__ = ['norm_type'] __shared__ = ['norm_type']
def __init__(self, feat_channels=[1024, 768, 384], norm_type='bn'): def __init__(self, in_channels=[256, 512, 1024], norm_type='bn'):
super(YOLOv3FPN, self).__init__() super(YOLOv3FPN, self).__init__()
assert len(feat_channels) > 0, "feat_channels length should > 0" assert len(in_channels) > 0, "in_channels length should > 0"
self.feat_channels = feat_channels self.in_channels = in_channels
self.num_blocks = len(feat_channels) self.num_blocks = len(in_channels)
self._out_channels = []
self.yolo_blocks = [] self.yolo_blocks = []
self.routes = [] self.routes = []
for i in range(self.num_blocks): for i in range(self.num_blocks):
name = 'yolo_block.{}'.format(i) name = 'yolo_block.{}'.format(i)
in_channel = in_channels[-i - 1]
if i > 0:
in_channel += 512 // (2**i)
yolo_block = self.add_sublayer( yolo_block = self.add_sublayer(
name, name,
YoloDetBlock( YoloDetBlock(
feat_channels[i], in_channel,
channel=512 // (2**i), channel=512 // (2**i),
norm_type=norm_type, norm_type=norm_type,
name=name)) name=name))
self.yolo_blocks.append(yolo_block) self.yolo_blocks.append(yolo_block)
# tip layer output channel doubled
self._out_channels.append(1024 // (2**i))
if i < self.num_blocks - 1: if i < self.num_blocks - 1:
name = 'yolo_transition.{}'.format(i) name = 'yolo_transition.{}'.format(i)
...@@ -211,20 +222,25 @@ class YOLOv3FPN(nn.Layer): ...@@ -211,20 +222,25 @@ class YOLOv3FPN(nn.Layer):
return yolo_feats return yolo_feats
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
@register @register
@serializable @serializable
class PPYOLOFPN(nn.Layer): class PPYOLOFPN(nn.Layer):
__shared__ = ['norm_type'] __shared__ = ['norm_type']
def __init__(self, def __init__(self, in_channels=[512, 1024, 2048], norm_type='bn', **kwargs):
feat_channels=[2048, 1280, 640],
norm_type='bn',
**kwargs):
super(PPYOLOFPN, self).__init__() super(PPYOLOFPN, self).__init__()
assert len(feat_channels) > 0, "feat_channels length should > 0" assert len(in_channels) > 0, "in_channels length should > 0"
self.feat_channels = feat_channels self.in_channels = in_channels
self.num_blocks = len(feat_channels) self.num_blocks = len(in_channels)
# parse kwargs # parse kwargs
self.coord_conv = kwargs.get('coord_conv', False) self.coord_conv = kwargs.get('coord_conv', False)
self.drop_block = kwargs.get('drop_block', False) self.drop_block = kwargs.get('drop_block', False)
...@@ -246,9 +262,12 @@ class PPYOLOFPN(nn.Layer): ...@@ -246,9 +262,12 @@ class PPYOLOFPN(nn.Layer):
else: else:
dropblock_cfg = [] dropblock_cfg = []
self._out_channels = []
self.yolo_blocks = [] self.yolo_blocks = []
self.routes = [] self.routes = []
for i, ch_in in enumerate(self.feat_channels): for i, ch_in in enumerate(self.in_channels[::-1]):
if i > 0:
ch_in += 512 // (2**i)
channel = 64 * (2**self.num_blocks) // (2**i) channel = 64 * (2**self.num_blocks) // (2**i)
base_cfg = [ base_cfg = [
# name of layer, Layer, args # name of layer, Layer, args
...@@ -279,6 +298,7 @@ class PPYOLOFPN(nn.Layer): ...@@ -279,6 +298,7 @@ class PPYOLOFPN(nn.Layer):
name = 'yolo_block.{}'.format(i) name = 'yolo_block.{}'.format(i)
yolo_block = self.add_sublayer(name, PPYOLODetBlock(cfg, name)) yolo_block = self.add_sublayer(name, PPYOLODetBlock(cfg, name))
self.yolo_blocks.append(yolo_block) self.yolo_blocks.append(yolo_block)
self._out_channels.append(channel * 2)
if i < self.num_blocks - 1: if i < self.num_blocks - 1:
name = 'yolo_transition.{}'.format(i) name = 'yolo_transition.{}'.format(i)
route = self.add_sublayer( route = self.add_sublayer(
...@@ -307,4 +327,12 @@ class PPYOLOFPN(nn.Layer): ...@@ -307,4 +327,12 @@ class PPYOLOFPN(nn.Layer):
route = self.routes[i](route) route = self.routes[i](route)
route = F.interpolate(route, scale_factor=2.) route = F.interpolate(route, scale_factor=2.)
return yolo_feats return yolo_feats
\ No newline at end of file
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册