未验证 提交 2cda4b28 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add libra rcnn (#198)

add libra rcnn, including
  * libra loss
  * libra fpn(bfp)
  * libra sampling
上级 4b053712
......@@ -4,6 +4,7 @@
# Byte-compiled / optimized / DLL files
__pycache__/
.ipynb_checkpoints/
*.py[cod]
# C extensions
......
# Libra R-CNN: Towards Balanced Learning for Object Detection
## Introduction
- Libra R-CNN: Towards Balanced Learning for Object Detection
: [https://arxiv.org/abs/1904.02701](https://arxiv.org/abs/1904.02701)
```
@inproceedings{pang2019libra,
title={Libra R-CNN: Towards Balanced Learning for Object Detection},
author={Pang, Jiangmiao and Chen, Kai and Shi, Jianping and Feng, Huajun and Ouyang, Wanli and Dahua Lin},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```
## Model Zoo
| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download |
| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-BFP | Faster | 2 | 1x | 18.247 | 40.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/libra_rcnn_r50_vd_fpn_1x.tar) |
| ResNet101-vd-BFP | Faster | 2 | 1x | 14.865 | 42.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/libra_rcnn_r101_vd_fpn_1x.tar) |
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar
weights: output/libra_rcnn_r101_vd_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: BFP
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: LibraBBoxAssigner
ResNet:
depth: 101
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
BFP:
base_neck:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
refine_level: 2
refine_type: nonlocal
nonlocal_reduction: 1.0
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
LibraBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: BalancedL1Loss
BalancedL1Loss:
alpha: 0.5
gamma: 1.5
beta: 1.0
loss_weight: 1.0
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/libra_rcnn_r50_vd_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: BFP
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: LibraBBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
BFP:
base_neck:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
refine_level: 2
refine_type: nonlocal
nonlocal_reduction: 1.0
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
LibraBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: BalancedL1Loss
BalancedL1Loss:
alpha: 0.5
gamma: 1.5
beta: 1.0
loss_weight: 1.0
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
......@@ -27,6 +27,7 @@ from . import cb_resnet
from . import res2net
from . import hrnet
from . import hrfpn
from . import bfp
from .resnet import *
from .resnext import *
......@@ -40,4 +41,5 @@ from .faceboxnet import *
from .cb_resnet import *
from .res2net import *
from .hrnet import *
from .hrfpn import *
\ No newline at end of file
from .hrfpn import *
from .bfp import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from .nonlocal_helper import add_space_nonlocal
from .fpn import FPN
__all__ = ['BFP']
@register
class BFP(object):
"""
Libra R-CNN, see https://arxiv.org/abs/1904.02701
Args:
base_neck (dict): basic neck before balanced feature pyramid (bfp)
refine_level (int): index of integration and refine level of bfp
refine_type (str): refine type, None, conv or nonlocal
nonlocal_reduction (float): channel reduction level if refine_type is nonlocal
with_bias (bool): whether the nonlocal module contains bias
with_scale (bool): whether to scale feature in nonlocal module or not
"""
__inject__ = ['base_neck']
def __init__(self,
base_neck=FPN().__dict__,
refine_level=2,
refine_type="nonlocal",
nonlocal_reduction=1,
with_bias=True,
with_scale=False):
if isinstance(base_neck, dict):
self.base_neck = FPN(**base_neck)
self.refine_level = refine_level
self.refine_type = refine_type
self.nonlocal_reduction = nonlocal_reduction
self.with_bias = with_bias
self.with_scale = with_scale
def get_output(self, body_dict):
# top-down order
res_dict, spatial_scale = self.base_neck.get_output(body_dict)
res_dict = self.get_output_bfp(res_dict)
return res_dict, spatial_scale
def get_output_bfp(self, body_dict):
body_name_list = list(body_dict.keys())
num_backbone_stages = len(body_name_list)
self.num_levels = len(body_dict)
# step 1: gather multi-level features by resize and average
feats = []
refine_level_name = body_name_list[self.refine_level]
for i in range(self.num_levels):
curr_fpn_name = body_name_list[i]
pool_stride = 2**(i - self.refine_level)
pool_size = [
body_dict[refine_level_name].shape[2],
body_dict[refine_level_name].shape[3]
]
if i > self.refine_level:
gathered = fluid.layers.pool2d(
input=body_dict[curr_fpn_name],
pool_type='max',
pool_size=pool_stride,
pool_stride=pool_stride,
ceil_mode=True, )
else:
gathered = self._resize_input_tensor(
body_dict[curr_fpn_name], body_dict[refine_level_name],
1.0 / pool_stride)
feats.append(gathered)
bsf = sum(feats) / len(feats)
# step 2: refine gathered features
if self.refine_type == "conv":
bsf = fluid.layers.conv2d(
bsf,
bsf.shape[1],
filter_size=3,
padding=1,
param_attr=ParamAttr(name="bsf_w"),
bias_attr=ParamAttr(name="bsf_b"),
name="bsf")
elif self.refine_type == "nonlocal":
dim_in = bsf.shape[1]
nonlocal_name = "nonlocal_bsf"
bsf = add_space_nonlocal(
bsf,
bsf.shape[1],
bsf.shape[1],
nonlocal_name,
int(bsf.shape[1] / self.nonlocal_reduction),
with_bias=self.with_bias,
with_scale=self.with_scale)
# step 3: scatter refined features to multi-levels by a residual path
fpn_dict = {}
fpn_name_list = []
for i in range(self.num_levels):
curr_fpn_name = body_name_list[i]
pool_stride = 2**(self.refine_level - i)
if i >= self.refine_level:
residual = self._resize_input_tensor(
bsf, body_dict[curr_fpn_name], 1.0 / pool_stride)
else:
residual = fluid.layers.pool2d(
input=bsf,
pool_type='max',
pool_size=pool_stride,
pool_stride=pool_stride,
ceil_mode=True, )
fpn_dict[curr_fpn_name] = residual + body_dict[curr_fpn_name]
fpn_name_list.append(curr_fpn_name)
res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list])
return res_dict
def _resize_input_tensor(self, body_input, ref_output, scale):
shape = fluid.layers.shape(ref_output)
shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4])
out_shape_ = shape_hw
out_shape = fluid.layers.cast(out_shape_, dtype='int32')
out_shape.stop_gradient = True
body_output = fluid.layers.resize_nearest(
body_input, scale=scale, actual_shape=out_shape)
return body_output
......@@ -5,95 +5,72 @@ from __future__ import unicode_literals
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid.initializer import ConstantInitializer
nonlocal_params = {
"use_zero_init_conv" : False,
"conv_init_std" : 0.01,
"no_bias" : True,
"use_maxpool" : False,
"use_softmax" : True,
"use_bn" : False,
"use_scale" : True, # vital for the model prformance!!!
"use_affine" : False,
"bn_momentum" : 0.9,
"bn_epsilon" : 1.0000001e-5,
"bn_init_gamma" : 0.9,
"weight_decay_bn":1.e-4,
}
def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride = 2):
cur = input
theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr=ParamAttr(name = prefix + '_theta' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if not nonlocal_params["no_bias"] else False, \
name = prefix + '_theta')
def space_nonlocal(input,
dim_in,
dim_out,
prefix,
dim_inner,
with_bias=False,
with_scale=True):
theta = fluid.layers.conv2d(
input=input,
num_filters=dim_inner,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(name=prefix + '_theta_w'),
bias_attr=ParamAttr(
name=prefix + '_theta_b', initializer=ConstantInitializer(value=0.))
if with_bias else False)
theta_shape = theta.shape
theta_shape_op = fluid.layers.shape( theta )
theta_shape_op = fluid.layers.shape(theta)
theta_shape_op.stop_gradient = True
if nonlocal_params["use_maxpool"]:
max_pool = fluid.layers.pool2d(input = cur, \
pool_size = [max_pool_stride, max_pool_stride], \
pool_type = 'max', \
pool_stride = [max_pool_stride, max_pool_stride], \
pool_padding = [0, 0], \
name = prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_phi' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_g' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_g' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(theta, shape=(0, 0, -1) )
theta = fluid.layers.reshape(theta, shape=(0, 0, -1))
theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.conv2d(
input=input,
num_filters=dim_inner,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(name=prefix + '_phi_w'),
bias_attr=ParamAttr(
name=prefix + '_phi_b', initializer=ConstantInitializer(value=0.))
if with_bias else False,
name=prefix + '_phi')
phi = fluid.layers.reshape(phi, [0, 0, -1])
theta_phi = fluid.layers.matmul(theta, phi, name = prefix + '_affinity')
theta_phi = fluid.layers.matmul(theta, phi)
g = fluid.layers.conv2d(
input=input,
num_filters=dim_inner,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(name=prefix + '_g_w'),
bias_attr=ParamAttr(
name=prefix + '_g_b', initializer=ConstantInitializer(value=0.))
if with_bias else False,
name=prefix + '_g')
g = fluid.layers.reshape(g, [0, 0, -1])
if nonlocal_params["use_softmax"]:
if nonlocal_params["use_scale"]:
theta_phi_sc = fluid.layers.scale(theta_phi, scale = dim_inner**-.5)
else:
theta_phi_sc = theta_phi
p = fluid.layers.softmax(theta_phi_sc, name = prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# scale
if with_scale:
theta_phi = fluid.layers.scale(theta_phi, scale=dim_inner**-.5)
p = fluid.layers.softmax(theta_phi)
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name = prefix + '_y')
t = fluid.layers.matmul(g, p)
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
......@@ -104,56 +81,40 @@ def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride =
t_re = fluid.layers.reshape(t, shape=[n, ch, h, w])
blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_out' + "_w", \
initializer = fluid.initializer.Constant(value = 0.) \
if nonlocal_params["use_zero_init_conv"] \
else fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_out' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_out')
blob_out = fluid.layers.conv2d(
input=blob_out,
num_filters=dim_out,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(
name=prefix + '_out_w', initializer=ConstantInitializer(value=0.0)),
bias_attr=ParamAttr(
name=prefix + '_out_b', initializer=ConstantInitializer(value=0.0))
if with_bias else False,
name=prefix + '_out')
blob_out_shape = blob_out.shape
if nonlocal_params["use_bn"]:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(blob_out, \
# is_test = test_mode, \
momentum = nonlocal_params["bn_momentum"], \
epsilon = nonlocal_params["bn_epsilon"], \
name = bn_name, \
param_attr = ParamAttr(name = bn_name + "_s", \
initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
bias_attr = ParamAttr(name = bn_name + "_b", \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
moving_mean_name = bn_name + "_rm", \
moving_variance_name = bn_name + "_riv") # add bn
if nonlocal_params["use_affine"]:
affine_scale = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_s'), \
default_initializer = fluid.initializer.Constant(value = 1.))
affine_bias = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_b'), \
default_initializer = fluid.initializer.Constant(value = 0.))
blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \
bias = affine_bias, name = prefix + '_affine') # add affine
return blob_out
def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner ):
def add_space_nonlocal(input,
dim_in,
dim_out,
prefix,
dim_inner,
with_bias=False,
with_scale=True):
'''
add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
'''
conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner)
output = fluid.layers.elementwise_add(input, conv, name = prefix + '_sum')
conv = space_nonlocal(
input,
dim_in,
dim_out,
prefix,
dim_inner,
with_bias=with_bias,
with_scale=with_scale)
output = input + conv
return output
......@@ -19,9 +19,11 @@ from . import smooth_l1_loss
from . import giou_loss
from . import diou_loss
from . import iou_loss
from . import balanced_l1_loss
from .yolo_loss import *
from .smooth_l1_loss import *
from .giou_loss import *
from .diou_loss import *
from .iou_loss import *
from .balanced_l1_loss import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle import fluid
from ppdet.core.workspace import register, serializable
__all__ = ['BalancedL1Loss']
@register
@serializable
class BalancedL1Loss(object):
"""
Balanced L1 Loss, see https://arxiv.org/abs/1904.02701
Args:
alpha (float): hyper parameter of BalancedL1Loss, see more details in the paper
gamma (float): hyper parameter of BalancedL1Loss, see more details in the paper
beta (float): hyper parameter of BalancedL1Loss, see more details in the paper
loss_weights (float): loss weight
"""
def __init__(self, alpha=0.5, gamma=1.5, beta=1.0, loss_weight=1.0):
super(BalancedL1Loss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.beta = beta
self.loss_weight = loss_weight
def __call__(
self,
x,
y,
inside_weight=None,
outside_weight=None, ):
alpha = self.alpha
gamma = self.gamma
beta = self.beta
loss_weight = self.loss_weight
diff = fluid.layers.abs(x - y)
b = np.e**(gamma / alpha) - 1
less_beta = diff < beta
ge_beta = diff >= beta
less_beta = fluid.layers.cast(x=less_beta, dtype='float32')
ge_beta = fluid.layers.cast(x=ge_beta, dtype='float32')
less_beta.stop_gradient = True
ge_beta.stop_gradient = True
loss_1 = less_beta * (
alpha / b * (b * diff + 1) * fluid.layers.log(b * diff / beta + 1) -
alpha * diff)
loss_2 = ge_beta * (gamma * diff + gamma / b - alpha * beta)
iou_weights = 1.0
if inside_weight is not None and outside_weight is not None:
iou_weights = inside_weight * outside_weight
loss = (loss_1 + loss_2) * iou_weights
loss = fluid.layers.reduce_sum(loss, dim=-1) * loss_weight
return loss
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import paddle.fluid as fluid
__all__ = ["bbox_overlaps", "box_to_delta"]
logger = logging.getLogger(__name__)
def bbox_overlaps(boxes_1, boxes_2):
'''
bbox_overlaps
boxes_1: x1, y, x2, y2
boxes_2: x1, y, x2, y2
'''
assert boxes_1.shape[1] == 4 and boxes_2.shape[1] == 4
num_1 = boxes_1.shape[0]
num_2 = boxes_2.shape[0]
x1_1 = boxes_1[:, 0:1]
y1_1 = boxes_1[:, 1:2]
x2_1 = boxes_1[:, 2:3]
y2_1 = boxes_1[:, 3:4]
area_1 = (x2_1 - x1_1 + 1) * (y2_1 - y1_1 + 1)
x1_2 = boxes_2[:, 0].transpose()
y1_2 = boxes_2[:, 1].transpose()
x2_2 = boxes_2[:, 2].transpose()
y2_2 = boxes_2[:, 3].transpose()
area_2 = (x2_2 - x1_2 + 1) * (y2_2 - y1_2 + 1)
xx1 = np.maximum(x1_1, x1_2)
yy1 = np.maximum(y1_1, y1_2)
xx2 = np.minimum(x2_1, x2_2)
yy2 = np.minimum(y2_1, y2_2)
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (area_1 + area_2 - inter)
return ovr
def box_to_delta(ex_boxes, gt_boxes, weights):
""" box_to_delta """
ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1
ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1
ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w
ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h
gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1
gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1
gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w
gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h
dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]
dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]
dw = (np.log(gt_w / ex_w)) / weights[2]
dh = (np.log(gt_h / ex_h)) / weights[3]
targets = np.vstack([dx, dy, dw, dh]).transpose()
return targets
......@@ -87,6 +87,8 @@ def main():
reader = create_reader(cfg.EvalReader)
loader.set_sample_list_generator(reader, place)
dataset = cfg['EvalReader']['dataset']
# eval already exists json file
if FLAGS.json_eval:
logger.info(
......@@ -123,8 +125,6 @@ def main():
callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized()
dataset = cfg['EvalReader']['dataset']
sub_eval_prog = None
sub_keys = None
sub_values = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册