未验证 提交 2cda4b28 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add libra rcnn (#198)

add libra rcnn, including
  * libra loss
  * libra fpn(bfp)
  * libra sampling
上级 4b053712
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
.ipynb_checkpoints/
*.py[cod] *.py[cod]
# C extensions # C extensions
......
# Libra R-CNN: Towards Balanced Learning for Object Detection
## Introduction
- Libra R-CNN: Towards Balanced Learning for Object Detection
: [https://arxiv.org/abs/1904.02701](https://arxiv.org/abs/1904.02701)
```
@inproceedings{pang2019libra,
title={Libra R-CNN: Towards Balanced Learning for Object Detection},
author={Pang, Jiangmiao and Chen, Kai and Shi, Jianping and Feng, Huajun and Ouyang, Wanli and Dahua Lin},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```
## Model Zoo
| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download |
| :---------------------- | :-------------: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: |
| ResNet50-vd-BFP | Faster | 2 | 1x | 18.247 | 40.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/libra_rcnn_r50_vd_fpn_1x.tar) |
| ResNet101-vd-BFP | Faster | 2 | 1x | 14.865 | 42.5 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/libra_rcnn_r101_vd_fpn_1x.tar) |
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet101_vd_pretrained.tar
weights: output/libra_rcnn_r101_vd_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: BFP
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: LibraBBoxAssigner
ResNet:
depth: 101
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
BFP:
base_neck:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
refine_level: 2
refine_type: nonlocal
nonlocal_reduction: 1.0
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
LibraBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: BalancedL1Loss
BalancedL1Loss:
alpha: 0.5
gamma: 1.5
beta: 1.0
loss_weight: 1.0
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/libra_rcnn_r50_vd_fpn_1x/model_final
metric: COCO
num_classes: 81
FasterRCNN:
backbone: ResNet
fpn: BFP
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: LibraBBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
BFP:
base_neck:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
refine_level: 2
refine_type: nonlocal
nonlocal_reduction: 1.0
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
LibraBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
bbox_loss: BalancedL1Loss
BalancedL1Loss:
alpha: 0.5
gamma: 1.5
beta: 1.0
loss_weight: 1.0
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
...@@ -27,6 +27,7 @@ from . import cb_resnet ...@@ -27,6 +27,7 @@ from . import cb_resnet
from . import res2net from . import res2net
from . import hrnet from . import hrnet
from . import hrfpn from . import hrfpn
from . import bfp
from .resnet import * from .resnet import *
from .resnext import * from .resnext import *
...@@ -40,4 +41,5 @@ from .faceboxnet import * ...@@ -40,4 +41,5 @@ from .faceboxnet import *
from .cb_resnet import * from .cb_resnet import *
from .res2net import * from .res2net import *
from .hrnet import * from .hrnet import *
from .hrfpn import * from .hrfpn import *
\ No newline at end of file from .bfp import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from .nonlocal_helper import add_space_nonlocal
from .fpn import FPN
__all__ = ['BFP']
@register
class BFP(object):
"""
Libra R-CNN, see https://arxiv.org/abs/1904.02701
Args:
base_neck (dict): basic neck before balanced feature pyramid (bfp)
refine_level (int): index of integration and refine level of bfp
refine_type (str): refine type, None, conv or nonlocal
nonlocal_reduction (float): channel reduction level if refine_type is nonlocal
with_bias (bool): whether the nonlocal module contains bias
with_scale (bool): whether to scale feature in nonlocal module or not
"""
__inject__ = ['base_neck']
def __init__(self,
base_neck=FPN().__dict__,
refine_level=2,
refine_type="nonlocal",
nonlocal_reduction=1,
with_bias=True,
with_scale=False):
if isinstance(base_neck, dict):
self.base_neck = FPN(**base_neck)
self.refine_level = refine_level
self.refine_type = refine_type
self.nonlocal_reduction = nonlocal_reduction
self.with_bias = with_bias
self.with_scale = with_scale
def get_output(self, body_dict):
# top-down order
res_dict, spatial_scale = self.base_neck.get_output(body_dict)
res_dict = self.get_output_bfp(res_dict)
return res_dict, spatial_scale
def get_output_bfp(self, body_dict):
body_name_list = list(body_dict.keys())
num_backbone_stages = len(body_name_list)
self.num_levels = len(body_dict)
# step 1: gather multi-level features by resize and average
feats = []
refine_level_name = body_name_list[self.refine_level]
for i in range(self.num_levels):
curr_fpn_name = body_name_list[i]
pool_stride = 2**(i - self.refine_level)
pool_size = [
body_dict[refine_level_name].shape[2],
body_dict[refine_level_name].shape[3]
]
if i > self.refine_level:
gathered = fluid.layers.pool2d(
input=body_dict[curr_fpn_name],
pool_type='max',
pool_size=pool_stride,
pool_stride=pool_stride,
ceil_mode=True, )
else:
gathered = self._resize_input_tensor(
body_dict[curr_fpn_name], body_dict[refine_level_name],
1.0 / pool_stride)
feats.append(gathered)
bsf = sum(feats) / len(feats)
# step 2: refine gathered features
if self.refine_type == "conv":
bsf = fluid.layers.conv2d(
bsf,
bsf.shape[1],
filter_size=3,
padding=1,
param_attr=ParamAttr(name="bsf_w"),
bias_attr=ParamAttr(name="bsf_b"),
name="bsf")
elif self.refine_type == "nonlocal":
dim_in = bsf.shape[1]
nonlocal_name = "nonlocal_bsf"
bsf = add_space_nonlocal(
bsf,
bsf.shape[1],
bsf.shape[1],
nonlocal_name,
int(bsf.shape[1] / self.nonlocal_reduction),
with_bias=self.with_bias,
with_scale=self.with_scale)
# step 3: scatter refined features to multi-levels by a residual path
fpn_dict = {}
fpn_name_list = []
for i in range(self.num_levels):
curr_fpn_name = body_name_list[i]
pool_stride = 2**(self.refine_level - i)
if i >= self.refine_level:
residual = self._resize_input_tensor(
bsf, body_dict[curr_fpn_name], 1.0 / pool_stride)
else:
residual = fluid.layers.pool2d(
input=bsf,
pool_type='max',
pool_size=pool_stride,
pool_stride=pool_stride,
ceil_mode=True, )
fpn_dict[curr_fpn_name] = residual + body_dict[curr_fpn_name]
fpn_name_list.append(curr_fpn_name)
res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list])
return res_dict
def _resize_input_tensor(self, body_input, ref_output, scale):
shape = fluid.layers.shape(ref_output)
shape_hw = fluid.layers.slice(shape, axes=[0], starts=[2], ends=[4])
out_shape_ = shape_hw
out_shape = fluid.layers.cast(out_shape_, dtype='int32')
out_shape.stop_gradient = True
body_output = fluid.layers.resize_nearest(
body_input, scale=scale, actual_shape=out_shape)
return body_output
...@@ -5,95 +5,72 @@ from __future__ import unicode_literals ...@@ -5,95 +5,72 @@ from __future__ import unicode_literals
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import ParamAttr from paddle.fluid import ParamAttr
from paddle.fluid.initializer import ConstantInitializer
nonlocal_params = {
"use_zero_init_conv" : False,
"conv_init_std" : 0.01,
"no_bias" : True,
"use_maxpool" : False,
"use_softmax" : True,
"use_bn" : False,
"use_scale" : True, # vital for the model prformance!!!
"use_affine" : False,
"bn_momentum" : 0.9,
"bn_epsilon" : 1.0000001e-5,
"bn_init_gamma" : 0.9,
"weight_decay_bn":1.e-4,
}
def space_nonlocal(input,
def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride = 2): dim_in,
cur = input dim_out,
theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \ prefix,
filter_size = [1, 1], stride = [1, 1], \ dim_inner,
padding = [0, 0], \ with_bias=False,
param_attr=ParamAttr(name = prefix + '_theta' + "_w", \ with_scale=True):
initializer = fluid.initializer.Normal(loc = 0.0, theta = fluid.layers.conv2d(
scale = nonlocal_params["conv_init_std"])), \ input=input,
bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \ num_filters=dim_inner,
initializer = fluid.initializer.Constant(value = 0.)) \ filter_size=1,
if not nonlocal_params["no_bias"] else False, \ stride=1,
name = prefix + '_theta') padding=0,
param_attr=ParamAttr(name=prefix + '_theta_w'),
bias_attr=ParamAttr(
name=prefix + '_theta_b', initializer=ConstantInitializer(value=0.))
if with_bias else False)
theta_shape = theta.shape theta_shape = theta.shape
theta_shape_op = fluid.layers.shape( theta ) theta_shape_op = fluid.layers.shape(theta)
theta_shape_op.stop_gradient = True theta_shape_op.stop_gradient = True
if nonlocal_params["use_maxpool"]:
max_pool = fluid.layers.pool2d(input = cur, \
pool_size = [max_pool_stride, max_pool_stride], \
pool_type = 'max', \
pool_stride = [max_pool_stride, max_pool_stride], \
pool_padding = [0, 0], \
name = prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_phi' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_g' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_g' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size) # we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784) # e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(theta, shape=(0, 0, -1) ) theta = fluid.layers.reshape(theta, shape=(0, 0, -1))
theta = fluid.layers.transpose(theta, [0, 2, 1]) theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.conv2d(
input=input,
num_filters=dim_inner,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(name=prefix + '_phi_w'),
bias_attr=ParamAttr(
name=prefix + '_phi_b', initializer=ConstantInitializer(value=0.))
if with_bias else False,
name=prefix + '_phi')
phi = fluid.layers.reshape(phi, [0, 0, -1]) phi = fluid.layers.reshape(phi, [0, 0, -1])
theta_phi = fluid.layers.matmul(theta, phi, name = prefix + '_affinity')
theta_phi = fluid.layers.matmul(theta, phi)
g = fluid.layers.conv2d(
input=input,
num_filters=dim_inner,
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(name=prefix + '_g_w'),
bias_attr=ParamAttr(
name=prefix + '_g_b', initializer=ConstantInitializer(value=0.))
if with_bias else False,
name=prefix + '_g')
g = fluid.layers.reshape(g, [0, 0, -1]) g = fluid.layers.reshape(g, [0, 0, -1])
if nonlocal_params["use_softmax"]: # scale
if nonlocal_params["use_scale"]: if with_scale:
theta_phi_sc = fluid.layers.scale(theta_phi, scale = dim_inner**-.5) theta_phi = fluid.layers.scale(theta_phi, scale=dim_inner**-.5)
else: p = fluid.layers.softmax(theta_phi)
theta_phi_sc = theta_phi
p = fluid.layers.softmax(theta_phi_sc, name = prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2] # note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1) # e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1]) p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name = prefix + '_y') t = fluid.layers.matmul(g, p)
# reshape back # reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14) # e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
...@@ -104,56 +81,40 @@ def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride = ...@@ -104,56 +81,40 @@ def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride =
t_re = fluid.layers.reshape(t, shape=[n, ch, h, w]) t_re = fluid.layers.reshape(t, shape=[n, ch, h, w])
blob_out = t_re blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \ blob_out = fluid.layers.conv2d(
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \ input=blob_out,
param_attr = ParamAttr(name = prefix + '_out' + "_w", \ num_filters=dim_out,
initializer = fluid.initializer.Constant(value = 0.) \ filter_size=1,
if nonlocal_params["use_zero_init_conv"] \ stride=1,
else fluid.initializer.Normal(loc = 0.0, padding=0,
scale = nonlocal_params["conv_init_std"])), \ param_attr=ParamAttr(
bias_attr = ParamAttr(name = prefix + '_out' + "_b", \ name=prefix + '_out_w', initializer=ConstantInitializer(value=0.0)),
initializer = fluid.initializer.Constant(value = 0.)) \ bias_attr=ParamAttr(
if (nonlocal_params["no_bias"] == 0) else False, \ name=prefix + '_out_b', initializer=ConstantInitializer(value=0.0))
name = prefix + '_out') if with_bias else False,
name=prefix + '_out')
blob_out_shape = blob_out.shape blob_out_shape = blob_out.shape
if nonlocal_params["use_bn"]:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(blob_out, \
# is_test = test_mode, \
momentum = nonlocal_params["bn_momentum"], \
epsilon = nonlocal_params["bn_epsilon"], \
name = bn_name, \
param_attr = ParamAttr(name = bn_name + "_s", \
initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
bias_attr = ParamAttr(name = bn_name + "_b", \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
moving_mean_name = bn_name + "_rm", \
moving_variance_name = bn_name + "_riv") # add bn
if nonlocal_params["use_affine"]:
affine_scale = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_s'), \
default_initializer = fluid.initializer.Constant(value = 1.))
affine_bias = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_b'), \
default_initializer = fluid.initializer.Constant(value = 0.))
blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \
bias = affine_bias, name = prefix + '_affine') # add affine
return blob_out return blob_out
def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner ): def add_space_nonlocal(input,
dim_in,
dim_out,
prefix,
dim_inner,
with_bias=False,
with_scale=True):
''' '''
add_space_nonlocal: add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971 Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
''' '''
conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner) conv = space_nonlocal(
output = fluid.layers.elementwise_add(input, conv, name = prefix + '_sum') input,
dim_in,
dim_out,
prefix,
dim_inner,
with_bias=with_bias,
with_scale=with_scale)
output = input + conv
return output return output
...@@ -19,9 +19,11 @@ from . import smooth_l1_loss ...@@ -19,9 +19,11 @@ from . import smooth_l1_loss
from . import giou_loss from . import giou_loss
from . import diou_loss from . import diou_loss
from . import iou_loss from . import iou_loss
from . import balanced_l1_loss
from .yolo_loss import * from .yolo_loss import *
from .smooth_l1_loss import * from .smooth_l1_loss import *
from .giou_loss import * from .giou_loss import *
from .diou_loss import * from .diou_loss import *
from .iou_loss import * from .iou_loss import *
from .balanced_l1_loss import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from paddle import fluid
from ppdet.core.workspace import register, serializable
__all__ = ['BalancedL1Loss']
@register
@serializable
class BalancedL1Loss(object):
"""
Balanced L1 Loss, see https://arxiv.org/abs/1904.02701
Args:
alpha (float): hyper parameter of BalancedL1Loss, see more details in the paper
gamma (float): hyper parameter of BalancedL1Loss, see more details in the paper
beta (float): hyper parameter of BalancedL1Loss, see more details in the paper
loss_weights (float): loss weight
"""
def __init__(self, alpha=0.5, gamma=1.5, beta=1.0, loss_weight=1.0):
super(BalancedL1Loss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.beta = beta
self.loss_weight = loss_weight
def __call__(
self,
x,
y,
inside_weight=None,
outside_weight=None, ):
alpha = self.alpha
gamma = self.gamma
beta = self.beta
loss_weight = self.loss_weight
diff = fluid.layers.abs(x - y)
b = np.e**(gamma / alpha) - 1
less_beta = diff < beta
ge_beta = diff >= beta
less_beta = fluid.layers.cast(x=less_beta, dtype='float32')
ge_beta = fluid.layers.cast(x=ge_beta, dtype='float32')
less_beta.stop_gradient = True
ge_beta.stop_gradient = True
loss_1 = less_beta * (
alpha / b * (b * diff + 1) * fluid.layers.log(b * diff / beta + 1) -
alpha * diff)
loss_2 = ge_beta * (gamma * diff + gamma / b - alpha * beta)
iou_weights = 1.0
if inside_weight is not None and outside_weight is not None:
iou_weights = inside_weight * outside_weight
loss = (loss_1 + loss_2) * iou_weights
loss = fluid.layers.reduce_sum(loss, dim=-1) * loss_weight
return loss
...@@ -19,12 +19,13 @@ from paddle import fluid ...@@ -19,12 +19,13 @@ from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ppdet.utils.bbox_utils import bbox_overlaps, box_to_delta
__all__ = [ __all__ = [
'AnchorGenerator', 'DropBlock', 'RPNTargetAssign', 'GenerateProposals', 'AnchorGenerator', 'DropBlock', 'RPNTargetAssign', 'GenerateProposals',
'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool',
'MultiBoxHead', 'SSDOutputDecoder', 'RetinaTargetAssign', 'MultiBoxHead', 'SSDOutputDecoder', 'RetinaTargetAssign',
'RetinaOutputDecoder', 'ConvNorm', 'MultiClassSoftNMS' 'RetinaOutputDecoder', 'ConvNorm', 'MultiClassSoftNMS', 'LibraBBoxAssigner'
] ]
...@@ -546,6 +547,460 @@ class BBoxAssigner(object): ...@@ -546,6 +547,460 @@ class BBoxAssigner(object):
self.use_random = shuffle_before_sample self.use_random = shuffle_before_sample
@register
class LibraBBoxAssigner(object):
def __init__(self,
batch_size_per_im=512,
fg_fraction=.25,
fg_thresh=.5,
bg_thresh_hi=.5,
bg_thresh_lo=0.,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
num_classes=81,
shuffle_before_sample=True,
is_cls_agnostic=False,
num_bins=3):
super(LibraBBoxAssigner, self).__init__()
self.batch_size_per_im = batch_size_per_im
self.fg_fraction = fg_fraction
self.fg_thresh = fg_thresh
self.bg_thresh_hi = bg_thresh_hi
self.bg_thresh_lo = bg_thresh_lo
self.bbox_reg_weights = bbox_reg_weights
self.class_nums = num_classes
self.use_random = shuffle_before_sample
self.is_cls_agnostic = is_cls_agnostic
self.num_bins = num_bins
def __call__(
self,
rpn_rois,
gt_classes,
is_crowd,
gt_boxes,
im_info, ):
return self.generate_proposal_label_libra(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
is_crowd=is_crowd,
gt_boxes=gt_boxes,
im_info=im_info,
batch_size_per_im=self.batch_size_per_im,
fg_fraction=self.fg_fraction,
fg_thresh=self.fg_thresh,
bg_thresh_hi=self.bg_thresh_hi,
bg_thresh_lo=self.bg_thresh_lo,
bbox_reg_weights=self.bbox_reg_weights,
class_nums=self.class_nums,
use_random=self.use_random,
is_cls_agnostic=self.is_cls_agnostic,
is_cascade_rcnn=False)
def generate_proposal_label_libra(
self, rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums, use_random,
is_cls_agnostic, is_cascade_rcnn):
num_bins = self.num_bins
def create_tmp_var(program, name, dtype, shape, lod_level=None):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
def _sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
if len(pos_inds) <= num_expected:
return pos_inds
else:
unique_gt_inds = np.unique(max_classes[pos_inds])
num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = []
for i in unique_gt_inds:
inds = np.nonzero(max_classes == i)[0]
before_len = len(inds)
inds = list(set(inds) & set(pos_inds))
after_len = len(inds)
if len(inds) > num_per_gt:
inds = np.random.choice(
inds, size=num_per_gt, replace=False)
sampled_inds.extend(list(inds)) # combine as a new sampler
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(
list(set(pos_inds) - set(sampled_inds)))
assert len(sampled_inds)+len(extra_inds) == len(pos_inds), \
"sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
len(sampled_inds), len(extra_inds), len(pos_inds))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(
extra_inds, size=num_extra, replace=False)
sampled_inds.extend(extra_inds.tolist())
elif len(sampled_inds) > num_expected:
sampled_inds = np.random.choice(
sampled_inds, size=num_expected, replace=False)
return sampled_inds
def sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
num_bins, bg_thresh_hi):
max_iou = max_overlaps.max()
iou_interval = (max_iou - floor_thr) / num_bins
per_num_expected = int(num_expected / num_bins)
sampled_inds = []
for i in range(num_bins):
start_iou = floor_thr + i * iou_interval
end_iou = floor_thr + (i + 1) * iou_interval
tmp_set = set(
np.where(
np.logical_and(max_overlaps >= start_iou, max_overlaps <
end_iou))[0])
tmp_inds = list(tmp_set & full_set)
if len(tmp_inds) > per_num_expected:
tmp_sampled_set = np.random.choice(
tmp_inds, size=per_num_expected, replace=False)
else:
tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
sampled_inds.append(tmp_sampled_set)
sampled_inds = np.concatenate(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(full_set - set(sampled_inds)))
assert len(sampled_inds)+len(extra_inds) == len(full_set), \
"sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
len(sampled_inds), len(extra_inds), len(full_set))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(
extra_inds, num_extra, replace=False)
sampled_inds = np.concatenate([sampled_inds, extra_inds])
return sampled_inds
def _sample_neg(max_overlaps,
max_classes,
neg_inds,
num_expected,
floor_thr=-1,
floor_fraction=0,
num_bins=3,
bg_thresh_hi=0.5):
if len(neg_inds) <= num_expected:
return neg_inds
else:
# balance sampling for negative samples
neg_set = set(neg_inds)
if floor_thr > 0:
floor_set = set(
np.where(
np.logical_and(max_overlaps >= 0, max_overlaps <
floor_thr))[0])
iou_sampling_set = set(
np.where(max_overlaps >= floor_thr)[0])
elif floor_thr == 0:
floor_set = set(np.where(max_overlaps == 0)[0])
iou_sampling_set = set(
np.where(max_overlaps > floor_thr)[0])
else:
floor_set = set()
iou_sampling_set = set(
np.where(max_overlaps > floor_thr)[0])
floor_thr = 0
floor_neg_inds = list(floor_set & neg_set)
iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
num_expected_iou_sampling = int(num_expected *
(1 - floor_fraction))
if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
if num_bins >= 2:
iou_sampled_inds = sample_via_interval(
max_overlaps,
set(iou_sampling_neg_inds),
num_expected_iou_sampling, floor_thr, num_bins,
bg_thresh_hi)
else:
iou_sampled_inds = np.random.choice(
iou_sampling_neg_inds,
size=num_expected_iou_sampling,
replace=False)
else:
iou_sampled_inds = np.array(
iou_sampling_neg_inds, dtype=np.int)
num_expected_floor = num_expected - len(iou_sampled_inds)
if len(floor_neg_inds) > num_expected_floor:
sampled_floor_inds = np.random.choice(
floor_neg_inds, size=num_expected_floor, replace=False)
else:
sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
sampled_inds = np.concatenate(
(sampled_floor_inds, iou_sampled_inds))
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(list(neg_set - set(sampled_inds)))
if len(extra_inds) > num_extra:
extra_inds = np.random.choice(
extra_inds, size=num_extra, replace=False)
sampled_inds = np.concatenate((sampled_inds, extra_inds))
return sampled_inds
def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
class_nums, use_random, is_cls_agnostic,
is_cascade_rcnn):
rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
# Roidb
im_scale = im_info[2]
inv_im_scale = 1. / im_scale
rpn_rois = rpn_rois * inv_im_scale
if is_cascade_rcnn:
rpn_rois = rpn_rois[gt_boxes.shape[0]:, :]
boxes = np.vstack([gt_boxes, rpn_rois])
gt_overlaps = np.zeros((boxes.shape[0], class_nums))
box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32)
if len(gt_boxes) > 0:
proposal_to_gt_overlaps = bbox_overlaps(boxes, gt_boxes)
overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
overlaps_max = proposal_to_gt_overlaps.max(axis=1)
# Boxes which with non-zero overlap with gt boxes
overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
overlapped_boxes_ind]]
for idx in range(len(overlapped_boxes_ind)):
gt_overlaps[overlapped_boxes_ind[
idx], overlapped_boxes_gt_classes[idx]] = overlaps_max[
overlapped_boxes_ind[idx]]
box_to_gt_ind_map[overlapped_boxes_ind[
idx]] = overlaps_argmax[overlapped_boxes_ind[idx]]
crowd_ind = np.where(is_crowd)[0]
gt_overlaps[crowd_ind] = -1
max_overlaps = gt_overlaps.max(axis=1)
max_classes = gt_overlaps.argmax(axis=1)
# Cascade RCNN Decode Filter
if is_cascade_rcnn:
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
keep = np.where((ws > 0) & (hs > 0))[0]
boxes = boxes[keep]
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (
max_overlaps >= bg_thresh_lo))[0]
fg_rois_per_this_image = fg_inds.shape[0]
bg_rois_per_this_image = bg_inds.shape[0]
else:
# Foreground
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_im,
fg_inds.shape[0])
# Sample foreground if there are too many
if fg_inds.shape[0] > fg_rois_per_this_image:
if use_random:
fg_inds = _sample_pos(max_overlaps, max_classes,
fg_inds, fg_rois_per_this_image)
fg_inds = fg_inds[:fg_rois_per_this_image]
# Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (
max_overlaps >= bg_thresh_lo))[0]
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0])
assert bg_rois_per_this_image >= 0, "bg_rois_per_this_image must be >= 0 but got {}".format(
bg_rois_per_this_image)
# Sample background if there are too many
if bg_inds.shape[0] > bg_rois_per_this_image:
if use_random:
# libra neg sample
bg_inds = _sample_neg(
max_overlaps,
max_classes,
bg_inds,
bg_rois_per_this_image,
num_bins=num_bins,
bg_thresh_hi=bg_thresh_hi)
bg_inds = bg_inds[:bg_rois_per_this_image]
keep_inds = np.append(fg_inds, bg_inds)
sampled_labels = max_classes[keep_inds] # N x 1
sampled_labels[fg_rois_per_this_image:] = 0
sampled_boxes = boxes[keep_inds] # N x 324
sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]]
sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0]
bbox_label_targets = _compute_targets(
sampled_boxes, sampled_gts, sampled_labels, bbox_reg_weights)
bbox_targets, bbox_inside_weights = _expand_bbox_targets(
bbox_label_targets, class_nums, is_cls_agnostic)
bbox_outside_weights = np.array(
bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype)
# Scale rois
sampled_rois = sampled_boxes * im_scale
# Faster RCNN blobs
frcn_blobs = dict(
rois=sampled_rois,
labels_int32=sampled_labels,
bbox_targets=bbox_targets,
bbox_inside_weights=bbox_inside_weights,
bbox_outside_weights=bbox_outside_weights)
return frcn_blobs
def _compute_targets(roi_boxes, gt_boxes, labels, bbox_reg_weights):
assert roi_boxes.shape[0] == gt_boxes.shape[0]
assert roi_boxes.shape[1] == 4
assert gt_boxes.shape[1] == 4
targets = np.zeros(roi_boxes.shape)
bbox_reg_weights = np.asarray(bbox_reg_weights)
targets = box_to_delta(
ex_boxes=roi_boxes, gt_boxes=gt_boxes, weights=bbox_reg_weights)
return np.hstack([labels[:, np.newaxis], targets]).astype(
np.float32, copy=False)
def _expand_bbox_targets(bbox_targets_input, class_nums,
is_cls_agnostic):
class_labels = bbox_targets_input[:, 0]
fg_inds = np.where(class_labels > 0)[0]
bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums
if not is_cls_agnostic else 4 * 2))
bbox_inside_weights = np.zeros(bbox_targets.shape)
for ind in fg_inds:
class_label = int(class_labels[
ind]) if not is_cls_agnostic else 1
start_ind = class_label * 4
end_ind = class_label * 4 + 4
bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind,
1:]
bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0,
1.0)
return bbox_targets, bbox_inside_weights
def generate_func(
rpn_rois,
gt_classes,
is_crowd,
gt_boxes,
im_info, ):
rpn_rois_lod = rpn_rois.lod()[0]
gt_classes_lod = gt_classes.lod()[0]
# convert
rpn_rois = np.array(rpn_rois)
gt_classes = np.array(gt_classes)
is_crowd = np.array(is_crowd)
gt_boxes = np.array(gt_boxes)
im_info = np.array(im_info)
rois = []
labels_int32 = []
bbox_targets = []
bbox_inside_weights = []
bbox_outside_weights = []
lod = [0]
for idx in range(len(rpn_rois_lod) - 1):
rois_si = rpn_rois_lod[idx]
rois_ei = rpn_rois_lod[idx + 1]
gt_si = gt_classes_lod[idx]
gt_ei = gt_classes_lod[idx + 1]
frcn_blobs = _sample_rois(
rpn_rois[rois_si:rois_ei], gt_classes[gt_si:gt_ei],
is_crowd[gt_si:gt_ei], gt_boxes[gt_si:gt_ei], im_info[idx],
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums, use_random,
is_cls_agnostic, is_cascade_rcnn)
lod.append(frcn_blobs['rois'].shape[0] + lod[-1])
rois.append(frcn_blobs['rois'])
labels_int32.append(frcn_blobs['labels_int32'].reshape(-1, 1))
bbox_targets.append(frcn_blobs['bbox_targets'])
bbox_inside_weights.append(frcn_blobs['bbox_inside_weights'])
bbox_outside_weights.append(frcn_blobs['bbox_outside_weights'])
rois = np.vstack(rois)
labels_int32 = np.vstack(labels_int32)
bbox_targets = np.vstack(bbox_targets)
bbox_inside_weights = np.vstack(bbox_inside_weights)
bbox_outside_weights = np.vstack(bbox_outside_weights)
# create lod-tensor for return
# notice that the func create_lod_tensor does not work well here
ret_rois = fluid.LoDTensor()
ret_rois.set_lod([lod])
ret_rois.set(rois.astype("float32"), fluid.CPUPlace())
ret_labels_int32 = fluid.LoDTensor()
ret_labels_int32.set_lod([lod])
ret_labels_int32.set(labels_int32.astype("int32"), fluid.CPUPlace())
ret_bbox_targets = fluid.LoDTensor()
ret_bbox_targets.set_lod([lod])
ret_bbox_targets.set(
bbox_targets.astype("float32"), fluid.CPUPlace())
ret_bbox_inside_weights = fluid.LoDTensor()
ret_bbox_inside_weights.set_lod([lod])
ret_bbox_inside_weights.set(
bbox_inside_weights.astype("float32"), fluid.CPUPlace())
ret_bbox_outside_weights = fluid.LoDTensor()
ret_bbox_outside_weights.set_lod([lod])
ret_bbox_outside_weights.set(
bbox_outside_weights.astype("float32"), fluid.CPUPlace())
return ret_rois, ret_labels_int32, ret_bbox_targets, ret_bbox_inside_weights, ret_bbox_outside_weights
rois = create_tmp_var(
fluid.default_main_program(),
name=None, #'rois',
dtype='float32',
shape=[-1, 4], )
bbox_inside_weights = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_inside_weights',
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
bbox_outside_weights = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_outside_weights',
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
bbox_targets = create_tmp_var(
fluid.default_main_program(),
name=None, #'bbox_targets',
dtype='float32',
shape=[-1, 8 if self.is_cls_agnostic else self.class_nums * 4], )
labels_int32 = create_tmp_var(
fluid.default_main_program(),
name=None, #'labels_int32',
dtype='int32',
shape=[-1, 1], )
outs = [
rois, labels_int32, bbox_targets, bbox_inside_weights,
bbox_outside_weights
]
fluid.layers.py_func(
func=generate_func,
x=[rpn_rois, gt_classes, is_crowd, gt_boxes, im_info],
out=outs)
return outs
@register @register
class RoIAlign(object): class RoIAlign(object):
__op__ = fluid.layers.roi_align __op__ = fluid.layers.roi_align
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import paddle.fluid as fluid
__all__ = ["bbox_overlaps", "box_to_delta"]
logger = logging.getLogger(__name__)
def bbox_overlaps(boxes_1, boxes_2):
'''
bbox_overlaps
boxes_1: x1, y, x2, y2
boxes_2: x1, y, x2, y2
'''
assert boxes_1.shape[1] == 4 and boxes_2.shape[1] == 4
num_1 = boxes_1.shape[0]
num_2 = boxes_2.shape[0]
x1_1 = boxes_1[:, 0:1]
y1_1 = boxes_1[:, 1:2]
x2_1 = boxes_1[:, 2:3]
y2_1 = boxes_1[:, 3:4]
area_1 = (x2_1 - x1_1 + 1) * (y2_1 - y1_1 + 1)
x1_2 = boxes_2[:, 0].transpose()
y1_2 = boxes_2[:, 1].transpose()
x2_2 = boxes_2[:, 2].transpose()
y2_2 = boxes_2[:, 3].transpose()
area_2 = (x2_2 - x1_2 + 1) * (y2_2 - y1_2 + 1)
xx1 = np.maximum(x1_1, x1_2)
yy1 = np.maximum(y1_1, y1_2)
xx2 = np.minimum(x2_1, x2_2)
yy2 = np.minimum(y2_1, y2_2)
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (area_1 + area_2 - inter)
return ovr
def box_to_delta(ex_boxes, gt_boxes, weights):
""" box_to_delta """
ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1
ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1
ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w
ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h
gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1
gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1
gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w
gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h
dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]
dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]
dw = (np.log(gt_w / ex_w)) / weights[2]
dh = (np.log(gt_h / ex_h)) / weights[3]
targets = np.vstack([dx, dy, dw, dh]).transpose()
return targets
...@@ -87,6 +87,8 @@ def main(): ...@@ -87,6 +87,8 @@ def main():
reader = create_reader(cfg.EvalReader) reader = create_reader(cfg.EvalReader)
loader.set_sample_list_generator(reader, place) loader.set_sample_list_generator(reader, place)
dataset = cfg['EvalReader']['dataset']
# eval already exists json file # eval already exists json file
if FLAGS.json_eval: if FLAGS.json_eval:
logger.info( logger.info(
...@@ -123,8 +125,6 @@ def main(): ...@@ -123,8 +125,6 @@ def main():
callable(model.is_bbox_normalized): callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized() is_bbox_normalized = model.is_bbox_normalized()
dataset = cfg['EvalReader']['dataset']
sub_eval_prog = None sub_eval_prog = None
sub_keys = None sub_keys = None
sub_values = None sub_values = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册