From d4a6d3240c71baf8ebe284977d87aa15aa1efc3e Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 9 Nov 2020 14:48:01 +0800 Subject: [PATCH] [Dygraph] update YOLOv3 config (#1668) * update YOLOv3 config --- configs/yolov3_darknet.yml | 24 ++-- ppdet/modeling/architecture/yolo.py | 27 ++-- ppdet/modeling/head/__init__.py | 14 +++ ppdet/modeling/head/yolo_head.py | 186 +++++++--------------------- ppdet/modeling/layers.py | 47 ------- ppdet/modeling/loss/__init__.py | 17 +++ ppdet/modeling/loss/yolo_loss.py | 62 ++++++++++ ppdet/modeling/neck/__init__.py | 16 +++ ppdet/modeling/neck/yolo_fpn.py | 108 ++++++++++++++++ 9 files changed, 285 insertions(+), 216 deletions(-) create mode 100644 ppdet/modeling/loss/__init__.py create mode 100644 ppdet/modeling/loss/yolo_loss.py create mode 100644 ppdet/modeling/neck/yolo_fpn.py diff --git a/configs/yolov3_darknet.yml b/configs/yolov3_darknet.yml index 47392330e..a7a3e7bc2 100644 --- a/configs/yolov3_darknet.yml +++ b/configs/yolov3_darknet.yml @@ -12,8 +12,8 @@ use_fine_grained_loss: false load_static_weights: True YOLOv3: - anchor: AnchorYOLO backbone: DarkNet + neck: YOLOv3FPN yolo_head: YOLOv3Head post_process: BBoxPostProcess @@ -21,21 +21,25 @@ DarkNet: depth: 53 return_idx: [2, 3, 4] +YOLOv3FPN: + feat_channels: [1024, 768, 384] + YOLOv3Head: - yolo_feat: - name: YOLOFeat - feat_in_list: [1024, 768, 384] + anchors: [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] + anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + loss: YOLOv3Loss + +YOLOv3Loss: ignore_thresh: 0.7 downsample: 32 label_smooth: true - anchor_per_position: 3 BBoxPostProcess: decode: name: YOLOBox conf_thresh: 0.005 downsample_ratio: 32 - clip_bbox: True + clip_bbox: true nms: name: MultiClassNMS keep_top_k: 100 @@ -45,14 +49,6 @@ BBoxPostProcess: normalized: false background_label: -1 - -AnchorYOLO: - anchor_generator: - name: AnchorGeneratorYOLO - anchors: [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326] - anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - - LearningRate: base_lr: 0.001 schedulers: diff --git a/ppdet/modeling/architecture/yolo.py b/ppdet/modeling/architecture/yolo.py index 7f274bc48..e09dd4f66 100644 --- a/ppdet/modeling/architecture/yolo.py +++ b/ppdet/modeling/architecture/yolo.py @@ -12,16 +12,20 @@ __all__ = ['YOLOv3'] class YOLOv3(BaseArch): __category__ = 'architecture' __inject__ = [ - 'anchor', 'backbone', + 'neck', 'yolo_head', 'post_process', ] - def __init__(self, anchor, backbone, yolo_head, post_process): + def __init__(self, + backbone='DarkNet', + neck='YOLOv3FPN', + yolo_head='YOLOv3Head', + post_process='BBoxPostProcess'): super(YOLOv3, self).__init__() - self.anchor = anchor self.backbone = backbone + self.neck = neck self.yolo_head = yolo_head self.post_process = post_process @@ -29,21 +33,20 @@ class YOLOv3(BaseArch): # Backbone body_feats = self.backbone(self.inputs) - # YOLO Head - self.yolo_head_out = self.yolo_head(body_feats) + # neck + body_feats = self.neck(body_feats) - # Anchor - self.anchors, self.anchor_masks, self.mask_anchors = self.anchor() + # YOLO Head + self.yolo_head_outs = self.yolo_head(body_feats) def loss(self, ): - yolo_loss = self.yolo_head.loss(self.inputs, self.yolo_head_out, - self.anchors, self.anchor_masks, - self.mask_anchors) + yolo_loss = self.yolo_head.loss(self.inputs, self.yolo_head_outs) return yolo_loss def infer(self, ): - bbox, bbox_num = self.post_process( - self.yolo_head_out, self.mask_anchors, self.inputs['im_size']) + bbox, bbox_num = self.post_process(self.yolo_head_outs, + self.yolo_head.mask_anchors, + self.inputs['im_size']) outs = { "bbox": bbox.numpy(), "bbox_num": bbox_num.numpy(), diff --git a/ppdet/modeling/head/__init__.py b/ppdet/modeling/head/__init__.py index 42324f0f4..619b3ccf2 100644 --- a/ppdet/modeling/head/__init__.py +++ b/ppdet/modeling/head/__init__.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from . import rpn_head from . import bbox_head from . import mask_head diff --git a/ppdet/modeling/head/yolo_head.py b/ppdet/modeling/head/yolo_head.py index 480cdec6a..aa4bf3a17 100644 --- a/ppdet/modeling/head/yolo_head.py +++ b/ppdet/modeling/head/yolo_head.py @@ -8,125 +8,33 @@ from ppdet.core.workspace import register from ..backbone.darknet import ConvBNLayer -class YoloDetBlock(nn.Layer): - def __init__(self, ch_in, channel, name): - super(YoloDetBlock, self).__init__() - self.ch_in = ch_in - self.channel = channel - assert channel % 2 == 0, \ - "channel {} cannot be divided by 2".format(channel) - conv_def = [ - ['conv0', ch_in, channel, 1, '.0.0'], - ['conv1', channel, channel * 2, 3, '.0.1'], - ['conv2', channel * 2, channel, 1, '.1.0'], - ['conv3', channel, channel * 2, 3, '.1.1'], - ['route', channel * 2, channel, 1, '.2'], - #['tip', channel, channel * 2, 3], - ] - - self.conv_module = nn.Sequential() - for idx, (conv_name, ch_in, ch_out, filter_size, - post_name) in enumerate(conv_def): - self.conv_module.add_sublayer( - conv_name, - ConvBNLayer( - ch_in=ch_in, - ch_out=ch_out, - filter_size=filter_size, - padding=(filter_size - 1) // 2, - name=name + post_name)) - - self.tip = ConvBNLayer( - ch_in=channel, - ch_out=channel * 2, - filter_size=3, - padding=1, - name=name + '.tip') - - def forward(self, inputs): - route = self.conv_module(inputs) - tip = self.tip(route) - return route, tip - - -@register -class YOLOFeat(nn.Layer): - __shared__ = ['num_levels'] - - def __init__(self, feat_in_list=[1024, 768, 384], num_levels=3): - super(YOLOFeat, self).__init__() - self.feat_in_list = feat_in_list - self.yolo_blocks = [] - self.route_blocks = [] - self.num_levels = num_levels - for i in range(self.num_levels): - name = 'yolo_block.{}'.format(i) - yolo_block = self.add_sublayer( - name, - YoloDetBlock( - feat_in_list[i], channel=512 // (2**i), name=name)) - self.yolo_blocks.append(yolo_block) - - if i < self.num_levels - 1: - name = 'yolo_transition.{}'.format(i) - route = self.add_sublayer( - name, - ConvBNLayer( - ch_in=512 // (2**i), - ch_out=256 // (2**i), - filter_size=1, - stride=1, - padding=0, - name=name)) - self.route_blocks.append(route) - - def forward(self, body_feats): - assert len(body_feats) == self.num_levels - body_feats = body_feats[::-1] - yolo_feats = [] - for i, block in enumerate(body_feats): - if i > 0: - block = paddle.concat([route, block], axis=1) - route, tip = self.yolo_blocks[i](block) - yolo_feats.append(tip) - - if i < self.num_levels - 1: - route = self.route_blocks[i](route) - route = F.interpolate(route, scale_factor=2.) - - return yolo_feats - - @register class YOLOv3Head(nn.Layer): - __shared__ = ['num_classes', 'num_levels', 'use_fine_grained_loss'] - __inject__ = ['yolo_feat'] + __shared__ = ['num_classes'] + __inject__ = ['loss'] def __init__(self, - yolo_feat, + anchors=[ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, + 156, 198, 373, 326 + ], + anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]], num_classes=80, - anchor_per_position=3, - num_levels=3, - use_fine_grained_loss=False, - ignore_thresh=0.7, - downsample=32, - label_smooth=True): + loss='YOLOv3Loss'): super(YOLOv3Head, self).__init__() + self.anchors = anchors + self.anchor_masks = anchor_masks self.num_classes = num_classes - self.anchor_per_position = anchor_per_position - self.yolo_feat = yolo_feat - self.num_levels = num_levels - self.use_fine_grained_loss = use_fine_grained_loss - self.ignore_thresh = ignore_thresh - self.downsample = downsample - self.label_smooth = label_smooth - self.yolo_out_list = [] - for i in range(num_levels): - # TODO: optim here - #num_filters = len(cfg.anchor_masks[i]) * (self.num_classes + 5) - num_filters = self.anchor_per_position * (self.num_classes + 5) + self.loss = loss + + self.mask_anchors = self.parse_anchor(self.anchors, self.anchor_masks) + self.num_outputs = len(self.mask_anchors) + + self.yolo_outputs = [] + for i in range(len(self.mask_anchors)): + num_filters = self.num_outputs * (self.num_classes + 5) name = 'yolo_output.{}'.format(i) - yolo_out = self.add_sublayer( + yolo_output = self.add_sublayer( name, nn.Conv2D( in_channels=1024 // (2**i), @@ -137,35 +45,27 @@ class YOLOv3Head(nn.Layer): weight_attr=ParamAttr(name=name + '.conv.weights'), bias_attr=ParamAttr( name=name + '.conv.bias', regularizer=L2Decay(0.)))) - self.yolo_out_list.append(yolo_out) - - def forward(self, body_feats): - assert len(body_feats) == self.num_levels - yolo_feats = self.yolo_feat(body_feats) - yolo_head_out = [] - for i, feat in enumerate(yolo_feats): - yolo_out = self.yolo_out_list[i](feat) - yolo_head_out.append(yolo_out) - return yolo_head_out - - def loss(self, inputs, head_out, anchors, anchor_masks, mask_anchors): - if self.use_fine_grained_loss: - raise NotImplementedError - - yolo_losses = [] - for i, out in enumerate(head_out): - loss = fluid.layers.yolov3_loss( - x=out, - gt_box=inputs['gt_bbox'], - gt_label=inputs['gt_class'], - gt_score=inputs['gt_score'], - anchors=anchors, - anchor_mask=anchor_masks[i], - class_num=self.num_classes, - ignore_thresh=self.ignore_thresh, - downsample_ratio=self.downsample // 2**i, - use_label_smooth=self.label_smooth, - name='yolo_loss_' + str(i)) - loss = fluid.layers.reduce_mean(loss) - yolo_losses.append(loss) - return {'loss': sum(yolo_losses)} + self.yolo_outputs.append(yolo_output) + + def parse_anchor(self, anchors, anchor_masks): + anchor_num = len(self.anchors) + mask_anchors = [] + for i in range(len(self.anchor_masks)): + mask_anchor = [] + for m in self.anchor_masks[i]: + assert m < anchor_num, "anchor mask index overflow" + mask_anchor.extend(self.anchors[2 * m:2 * m + 2]) + mask_anchors.append(mask_anchor) + + return mask_anchors + + def forward(self, feats): + assert len(feats) == len(self.mask_anchors) + yolo_outputs = [] + for i, feat in enumerate(feats): + yolo_output = self.yolo_outputs[i](feat) + yolo_outputs.append(yolo_output) + return yolo_outputs + + def loss(self, inputs, head_outputs): + return self.loss(inputs, head_outputs, anchors, anchor_masks) diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index d3da712a3..029309504 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -104,53 +104,6 @@ class AnchorTargetGeneratorRPN(object): return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights -@register -@serializable -class AnchorGeneratorYOLO(object): - def __init__(self, - anchors=[ - 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, - 156, 198, 373, 326 - ], - anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]): - super(AnchorGeneratorYOLO, self).__init__() - self.anchors = anchors - self.anchor_masks = anchor_masks - - def __call__(self): - anchor_num = len(self.anchors) - mask_anchors = [] - for i in range(len(self.anchor_masks)): - mask_anchor = [] - for m in self.anchor_masks[i]: - assert m < anchor_num, "anchor mask index overflow" - mask_anchor.extend(self.anchors[2 * m:2 * m + 2]) - mask_anchors.append(mask_anchor) - return self.anchors, self.anchor_masks, mask_anchors - - -@register -@serializable -class AnchorTargetGeneratorYOLO(object): - def __init__(self, - ignore_thresh=0.7, - downsample_ratio=32, - label_smooth=True): - super(AnchorTargetGeneratorYOLO, self).__init__() - self.ignore_thresh = ignore_thresh - self.downsample_ratio = downsample_ratio - self.label_smooth = label_smooth - - def __call__(self, ): - # TODO: split yolov3_loss into here - outs = { - 'ignore_thresh': self.ignore_thresh, - 'downsample_ratio': self.downsample_ratio, - 'label_smooth': self.label_smooth - } - return outs - - @register @serializable class ProposalGenerator(object): diff --git a/ppdet/modeling/loss/__init__.py b/ppdet/modeling/loss/__init__.py new file mode 100644 index 000000000..f18cadcc8 --- /dev/null +++ b/ppdet/modeling/loss/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import yolo_loss + +from .yolo_loss import * diff --git a/ppdet/modeling/loss/yolo_loss.py b/ppdet/modeling/loss/yolo_loss.py new file mode 100644 index 000000000..8fcc121aa --- /dev/null +++ b/ppdet/modeling/loss/yolo_loss.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.fluid.regularizer import L2Decay +from ppdet.core.workspace import register +from ..backbone.darknet import ConvBNLayer + + +@register +class YOLOv3Loss(nn.Layer): + __shared__ = ['num_classes'] + + def __init__(self, + num_classes=80, + ignore_thresh=0.7, + label_smooth=False, + downsample=32, + use_fine_grained_loss=False): + super(YOLOv3Loss, self).__init__() + self.ignore_thresh = ignore_thresh + self.label_smooth = label_smooth + self.downsample = downsample + self.use_fine_grained_loss = use_fine_grained_loss + + def forward(self, inputs, head_outputs, anchors, anchor_masks): + if self.use_fine_grained_loss: + raise NotImplementedError( + "fine grained loss not implement currently") + + yolo_losses = [] + for i, out in enumerate(head_outputs): + loss = fluid.layers.yolov3_loss( + x=out, + gt_box=inputs['gt_bbox'], + gt_label=inputs['gt_class'], + gt_score=inputs['gt_score'], + anchors=anchors, + anchor_mask=anchor_masks[i], + class_num=self.num_classes, + ignore_thresh=self.ignore_thresh, + downsample_ratio=self.downsample // 2**i, + use_label_smooth=self.label_smooth, + name='yolo_loss_' + str(i)) + loss = paddle.mean(loss) + yolo_losses.append(loss) + return {'loss': sum(yolo_losses)} diff --git a/ppdet/modeling/neck/__init__.py b/ppdet/modeling/neck/__init__.py index 4991079b2..0b61c3292 100644 --- a/ppdet/modeling/neck/__init__.py +++ b/ppdet/modeling/neck/__init__.py @@ -1,3 +1,19 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from . import fpn +from . import yolo_fpn from .fpn import * +from .yolo_fpn import * diff --git a/ppdet/modeling/neck/yolo_fpn.py b/ppdet/modeling/neck/yolo_fpn.py new file mode 100644 index 000000000..a7cfa349f --- /dev/null +++ b/ppdet/modeling/neck/yolo_fpn.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from ppdet.core.workspace import register, serializable +from ..backbone.darknet import ConvBNLayer + + +class YoloDetBlock(nn.Layer): + def __init__(self, ch_in, channel, name): + super(YoloDetBlock, self).__init__() + self.ch_in = ch_in + self.channel = channel + assert channel % 2 == 0, \ + "channel {} cannot be divided by 2".format(channel) + conv_def = [ + ['conv0', ch_in, channel, 1, '.0.0'], + ['conv1', channel, channel * 2, 3, '.0.1'], + ['conv2', channel * 2, channel, 1, '.1.0'], + ['conv3', channel, channel * 2, 3, '.1.1'], + ['route', channel * 2, channel, 1, '.2'], + ] + + self.conv_module = nn.Sequential() + for idx, (conv_name, ch_in, ch_out, filter_size, + post_name) in enumerate(conv_def): + self.conv_module.add_sublayer( + conv_name, + ConvBNLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=filter_size, + padding=(filter_size - 1) // 2, + name=name + post_name)) + + self.tip = ConvBNLayer( + ch_in=channel, + ch_out=channel * 2, + filter_size=3, + padding=1, + name=name + '.tip') + + def forward(self, inputs): + route = self.conv_module(inputs) + tip = self.tip(route) + return route, tip + + +@register +@serializable +class YOLOv3FPN(nn.Layer): + def __init__(self, feat_channels=[1024, 768, 384]): + super(YOLOv3FPN, self).__init__() + assert len(feat_channels) > 0, "feat_channels length should > 0" + self.feat_channels = feat_channels + self.num_blocks = len(feat_channels) + self.yolo_blocks = [] + self.routes = [] + for i in range(self.num_blocks): + name = 'yolo_block.{}'.format(i) + yolo_block = self.add_sublayer( + name, + YoloDetBlock( + feat_channels[i], channel=512 // (2**i), name=name)) + self.yolo_blocks.append(yolo_block) + + if i < self.num_blocks - 1: + name = 'yolo_transition.{}'.format(i) + route = self.add_sublayer( + name, + ConvBNLayer( + ch_in=512 // (2**i), + ch_out=256 // (2**i), + filter_size=1, + stride=1, + padding=0, + name=name)) + self.routes.append(route) + + def forward(self, blocks): + assert len(blocks) == self.num_blocks + blocks = blocks[::-1] + yolo_feats = [] + for i, block in enumerate(blocks): + if i > 0: + block = paddle.concat([route, block], axis=1) + route, tip = self.yolo_blocks[i](block) + yolo_feats.append(tip) + + if i < self.num_blocks - 1: + route = self.routes[i](route) + route = F.interpolate(route, scale_factor=2.) + + return yolo_feats -- GitLab