未验证 提交 b8fc8e6c 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

Add iou loss (#126)

* Add GIOU loss.
* Add DIOU/CIOU loss.
* Change loss box type in bbox_head.py.
* Remove fuse_elewise_add_act_ops flag in train.py (Error may happen for elem_divbased on the flag).
上级 841870f1
......@@ -94,6 +94,7 @@ PaddleDetection的目的是为工业界和学术界提供丰富、易用的目
### 12/2019
- 增加Res2Net模型。
- 增加HRNet模型。
- 增加GIOU loss和DIOU loss。
### 21/11/2019
......
......@@ -102,6 +102,7 @@ Advanced Features:
#### 12/2019
- Add Res2Net model.
- Add HRNet model.
- Add GIOU loss and DIOU loss.
#### 21/11/2019
- Add CascadeClsAware RCNN model.
......
# Improvements of IOU loss
## Introduction
- Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression: [https://arxiv.org/abs/1902.09630](https://arxiv.org/abs/1902.09630)
```
@article{DBLP:journals/corr/abs-1902-09630,
author = {Seyed Hamid Rezatofighi and
Nathan Tsoi and
JunYoung Gwak and
Amir Sadeghian and
Ian D. Reid and
Silvio Savarese},
title = {Generalized Intersection over Union: {A} Metric and {A} Loss for Bounding
Box Regression},
journal = {CoRR},
volume = {abs/1902.09630},
year = {2019},
url = {http://arxiv.org/abs/1902.09630},
archivePrefix = {arXiv},
eprint = {1902.09630},
timestamp = {Tue, 21 May 2019 18:03:36 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1902-09630},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
- Distance-IoU Loss: Faster and Better Learning for Bounding Box Regression: [https://arxiv.org/abs/1911.08287](https://arxiv.org/abs/1911.08287)
```
@article{Zheng2019DistanceIoULF,
title={Distance-IoU Loss: Faster and Better Learning for Bounding Box Regression},
author={Zhaohui Zheng and Ping Wang and Wei Liu and Jinze Li and Rongguang Ye and Dongwei Ren},
journal={ArXiv},
year={2019},
volume={abs/1911.08287}
}
```
## Model Zoo
| Backbone | Type | Loss Type | Loss Weight | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download |
| :---------------------- | :------------- | :---: | :---: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-FPN | Faster | GIOU | 10 | 2 | 1x | 22.94 | 39.4 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_giou_loss_1x.tar) |
| ResNet50-vd-FPN | Faster | DIOU | 12 | 2 | 1x | 22.94 | 39.2 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_diou_loss_1x.tar) |
| ResNet50-vd-FPN | Faster | CIOU | 12 | 2 | 1x | 22.95 | 39.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r50_vd_fpn_ciou_loss_1x.tar) |
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_diou_loss_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: DiouLoss
DiouLoss:
loss_weight: 10.0
is_cls_agnostic: false
use_complete_iou_loss: true
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_diou_loss_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: DiouLoss
DiouLoss:
loss_weight: 12.0
is_cls_agnostic: false
use_complete_iou_loss: false
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_giou_loss_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: GiouLoss
GiouLoss:
loss_weight: 10.0
is_cls_agnostic: false
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
......@@ -93,11 +93,14 @@ The backbone models pretrained on ImageNet are available. All backbone models ar
### HRNet
* See more details in [HRNet model zoo](../configs/hrnet/README.md)
* See more details in [HRNet model zoo](../configs/hrnet/README.md).
### Res2Net
* See more details in [Res2Net model zoo](../configs/res2net/README.md)
* See more details in [Res2Net model zoo](../configs/res2net/README.md).
### IOU loss
* GIOU loss and DIOU loss are included now. See more details in [IOU loss model zoo](../configs/iou_loss/README.md).
### Group Normalization
......
......@@ -54,6 +54,7 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
| ResNet101-vd-FPN | Faster | 1 | 1x | 17.011 | 40.5 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_1x.tar) |
| ResNet101-vd-FPN | Faster | 1 | 2x | 16.934 | 40.8 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_r101_vd_fpn_2x.tar) |
| ResNet101-vd-FPN | Mask | 1 | 1x | 13.105 | 41.4 | 36.8 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_r101_vd_fpn_1x.tar) |
| CBResNet101-vd-FPN | Faster | 2 | 1x | - | 42.7 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_cbr101_vd_dual_fpn_1x.tar) |
| ResNeXt101-vd-FPN | Faster | 1 | 1x | 8.815 | 42.2 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_x101_vd_64x4d_fpn_1x.tar) |
| ResNeXt101-vd-FPN | Faster | 1 | 2x | 8.809 | 41.7 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/faster_rcnn_x101_vd_64x4d_fpn_2x.tar) |
| ResNeXt101-vd-FPN | Mask | 1 | 1x | 7.689 | 42.9 | 37.9 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/mask_rcnn_x101_vd_64x4d_fpn_1x.tar) |
......@@ -95,6 +96,9 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型
### Res2Net
* 详情见[Res2Net模型库](../configs/res2net/README.md)
### IOU loss
* 目前模型库中包括GIOU loss和DIOU loss,详情加[IOU loss模型库](../configs/iou_loss/README.md).
### Group Normalization
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 | Box AP | Mask AP | 下载 |
......
......@@ -15,5 +15,11 @@
from __future__ import absolute_import
from . import yolo_loss
from . import smooth_l1_loss
from . import giou_loss
from . import diou_loss
from .yolo_loss import *
from .smooth_l1_loss import *
from .giou_loss import *
from .diou_loss import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle import fluid
from ppdet.core.workspace import register, serializable
from .giou_loss import GiouLoss
__all__ = ['DiouLoss']
@register
@serializable
class DiouLoss(GiouLoss):
"""
Distance-IoU Loss, see https://arxiv.org/abs/1911.08287
Args:
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
use_complete_iou_loss (bool): whether to use complete iou loss
"""
def __init__(self,
loss_weight=10.,
is_cls_agnostic=False,
num_classes=81,
use_complete_iou_loss=True):
super(DiouLoss, self).__init__(
loss_weight=loss_weight,
is_cls_agnostic=is_cls_agnostic,
num_classes=num_classes)
self.use_complete_iou_loss = use_complete_iou_loss
def __call__(self,
x,
y,
inside_weight=None,
outside_weight=None,
bbox_reg_weight=[0.1, 0.1, 0.2, 0.2]):
eps = 1.e-10
x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight)
x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight)
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
cxg = (x1g + x2g) / 2
cyg = (y1g + y2g) / 2
wg = x2g - x1g
hg = y2g - y1g
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
# A and B
xkis1 = fluid.layers.elementwise_max(x1, x1g)
ykis1 = fluid.layers.elementwise_max(y1, y1g)
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
# A or B
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
ciou_term = 0
if self.use_complete_iou_loss:
dist_intersection = (cx - cxg) * (cx - cxg) + (cy - cyg) * (cy - cyg
)
dist_union = (xc2 - xc1) * (xc2 - xc1) + (yc2 - yc1) * (yc2 - yc1)
d = (dist_intersection + eps) / (dist_union + eps)
ar_gt = wg / hg
ar_pred = w / h
arctan = fluid.layers.atan(ar_gt) - fluid.layers.atan(ar_pred)
ar_loss = 4. / np.pi / np.pi * arctan * arctan
alpha = ar_loss / (1 - iouk + ar_loss + eps)
alpha.stop_gradient = True
ciou_term = d + alpha * ar_loss
iou_weights = 1
if inside_weight is not None and outside_weight is not None:
inside_weight = fluid.layers.reshape(inside_weight, shape=(-1, 4))
outside_weight = fluid.layers.reshape(outside_weight, shape=(-1, 4))
inside_weight = fluid.layers.reduce_mean(inside_weight, dim=1)
outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1)
iou_weights = inside_weight * outside_weight
class_weight = 2 if self.is_cls_agnostic else self.num_classes
diou = fluid.layers.reduce_mean(
(1 - iouk + ciou_term) * iou_weights) * class_weight
return diou * self.loss_weight
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle import fluid
from ppdet.core.workspace import register, serializable
__all__ = ['GiouLoss']
@register
@serializable
class GiouLoss(object):
'''
Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630
Args:
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
'''
__shared__ = ['num_classes']
def __init__(self, loss_weight=10., is_cls_agnostic=False, num_classes=81):
super(GiouLoss, self).__init__()
self.loss_weight = loss_weight
self.is_cls_agnostic = is_cls_agnostic
self.num_classes = num_classes
# deltas: NxMx4
def bbox_transform(self, deltas, weights):
wx, wy, ww, wh = weights
deltas = fluid.layers.reshape(deltas, shape=(0, -1, 4))
dx = fluid.layers.slice(deltas, axes=[2], starts=[0], ends=[1]) * wx
dy = fluid.layers.slice(deltas, axes=[2], starts=[1], ends=[2]) * wy
dw = fluid.layers.slice(deltas, axes=[2], starts=[2], ends=[3]) * ww
dh = fluid.layers.slice(deltas, axes=[2], starts=[3], ends=[4]) * wh
dw = fluid.layers.clip(dw, -1.e10, np.log(1000. / 16))
dh = fluid.layers.clip(dh, -1.e10, np.log(1000. / 16))
pred_ctr_x = dx
pred_ctr_y = dy
pred_w = fluid.layers.exp(dw)
pred_h = fluid.layers.exp(dh)
x1 = pred_ctr_x - 0.5 * pred_w
y1 = pred_ctr_y - 0.5 * pred_h
x2 = pred_ctr_x + 0.5 * pred_w
y2 = pred_ctr_y + 0.5 * pred_h
x1 = fluid.layers.reshape(x1, shape=(-1, ))
y1 = fluid.layers.reshape(y1, shape=(-1, ))
x2 = fluid.layers.reshape(x2, shape=(-1, ))
y2 = fluid.layers.reshape(y2, shape=(-1, ))
return x1, y1, x2, y2
def __call__(self,
x,
y,
inside_weight=None,
outside_weight=None,
bbox_reg_weight=[0.1, 0.1, 0.2, 0.2]):
eps = 1.e-10
x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight)
x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight)
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
xkis1 = fluid.layers.elementwise_max(x1, x1g)
ykis1 = fluid.layers.elementwise_max(y1, y1g)
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
area_c = (xc2 - xc1) * (yc2 - yc1) + eps
miouk = iouk - ((area_c - unionk) / area_c)
iou_weights = 1
if inside_weight is not None and outside_weight is not None:
inside_weight = fluid.layers.reshape(inside_weight, shape=(-1, 4))
outside_weight = fluid.layers.reshape(outside_weight, shape=(-1, 4))
inside_weight = fluid.layers.reduce_mean(inside_weight, dim=1)
outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1)
iou_weights = inside_weight * outside_weight
class_weight = 2 if self.is_cls_agnostic else self.num_classes
iouk = fluid.layers.reduce_mean((1 - iouk) * iou_weights) * class_weight
miouk = fluid.layers.reduce_mean(
(1 - miouk) * iou_weights) * class_weight
return miouk * self.loss_weight
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from ppdet.core.workspace import register, serializable
__all__ = ['SmoothL1Loss']
@register
@serializable
class SmoothL1Loss(object):
'''
Smooth L1 loss
Args:
sigma (float): hyper param in smooth l1 loss
'''
def __init__(self, sigma=1.0):
super(SmoothL1Loss, self).__init__()
self.sigma = sigma
def __call__(self, x, y, inside_weight=None, outside_weight=None):
return fluid.layers.smooth_l1(
x,
y,
inside_weight=inside_weight,
outside_weight=outside_weight,
sigma=self.sigma)
......@@ -26,6 +26,7 @@ from paddle.fluid.initializer import MSRA
from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import ConvNorm
from ppdet.modeling.losses import SmoothL1Loss
from ppdet.core.workspace import register, serializable
from ppdet.experimental import mixed_precision_global_state
......@@ -166,23 +167,27 @@ class BBoxHead(object):
nms (object): `MultiClassNMS` instance
num_classes: number of output classes
"""
__inject__ = ['head', 'box_coder', 'nms']
__inject__ = ['head', 'box_coder', 'nms', 'bbox_loss']
__shared__ = ['num_classes']
def __init__(self,
head,
box_coder=BoxCoder().__dict__,
nms=MultiClassNMS().__dict__,
bbox_loss=SmoothL1Loss().__dict__,
num_classes=81):
super(BBoxHead, self).__init__()
self.head = head
self.num_classes = num_classes
self.box_coder = box_coder
self.nms = nms
self.bbox_loss = bbox_loss
if isinstance(box_coder, dict):
self.box_coder = BoxCoder(**box_coder)
if isinstance(nms, dict):
self.nms = MultiClassNMS(**nms)
if isinstance(bbox_loss, dict):
self.bbox_loss = SmoothL1Loss(**bbox_loss)
self.head_feat = None
def get_head_feat(self, input=None):
......@@ -271,12 +276,11 @@ class BBoxHead(object):
loss_cls = fluid.layers.softmax_with_cross_entropy(
logits=cls_score, label=labels_int64, numeric_stable_mode=True)
loss_cls = fluid.layers.reduce_mean(loss_cls)
loss_bbox = fluid.layers.smooth_l1(
loss_bbox = self.bbox_loss(
x=bbox_pred,
y=bbox_targets,
inside_weight=bbox_inside_weights,
outside_weight=bbox_outside_weights,
sigma=1.0)
outside_weight=bbox_outside_weights)
loss_bbox = fluid.layers.reduce_mean(loss_bbox)
return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
......
......@@ -154,7 +154,6 @@ def main():
# compile program for multi-devices
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_optimizer_ops = False
build_strategy.fuse_elewise_add_act_ops = True
# only enable sync_bn in multi GPU devices
sync_bn = getattr(model.backbone, 'norm_type', None) == 'sync_bn'
build_strategy.sync_batch_norm = sync_bn and devices_num > 1 \
......@@ -248,6 +247,7 @@ def main():
it, np.mean(outs[-1]), logs, time_cost, eta)
logger.info(strs)
if (it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1) \
and (not FLAGS.dist or trainer_id == 0):
save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册