未验证 提交 c65eb1b5 编写于 作者: C CodesFarmer 提交者: GitHub

fcos initialization, with github user as committer (#405)

* fcos initialization

* delete the codes not used

* pre-commit the committed code

* delete the unused function

* modify the capacity in loader from 64 to 16
上级 b2c75899
architecture: FCOS
max_iters: 90000
use_gpu: true
snapshot_iter: 10000
log_smooth_window: 20
log_iter: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/fcos_r50_fpn_1x/model_final
num_classes: 81
FCOS:
backbone: ResNet
fpn: FPN
fcos_head: FCOSHead
ResNet:
norm_type: affine_channel
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
freeze_at: 2
FPN:
min_level: 3
max_level: 7
num_chan: 256
use_c5: false
spatial_scale: [0.03125, 0.0625, 0.125]
has_extra_convs: true
FCOSHead:
num_classes: 81
fpn_stride: [8, 16, 32, 64, 128]
num_convs: 4
norm_type: "gn"
fcos_loss: FCOSLoss
norm_reg_targets: True
centerness_on_reg: True
use_dcn_in_tower: False
nms: MultiClassNMS
MultiClassNMS:
score_threshold: 0.025
nms_top_k: 1000
keep_top_k: 100
nms_threshold: 0.6
background_label: -1
FCOSLoss:
loss_alpha: 0.25
loss_gamma: 2.0
iou_loss_type: "giou"
reg_weights: 1.0
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score', 'im_info']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: true
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: false
- !Gt2FCOSTarget
object_sizes_boundary: [64, 128, 256, 512]
center_sampling_radius: 1.5
downsample_ratios: [8, 16, 32, 64, 128]
norm_reg_targets: True
batch_size: 2
shuffle: true
worker_num: 16
use_process: false
EvalReader:
inputs_def:
fields: ['image', 'im_id', 'im_shape', 'im_info']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: true
batch_size: 8
shuffle: false
worker_num: 2
use_process: false
TestReader:
inputs_def:
# set image_shape if needed
fields: ['image', 'im_id', 'im_shape', 'im_info']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: true
batch_size: 1
shuffle: false
architecture: FCOS
max_iters: 180000
use_gpu: true
snapshot_iter: 20000
log_smooth_window: 20
log_iter: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/fcos_r50_fpn_multiscale_2x/model_final
num_classes: 81
FCOS:
backbone: ResNet
fpn: FPN
fcos_head: FCOSHead
ResNet:
norm_type: affine_channel
norm_decay: 0.
depth: 50
feature_maps: [3, 4, 5]
freeze_at: 2
FPN:
min_level: 3
max_level: 7
num_chan: 256
use_c5: false
spatial_scale: [0.03125, 0.0625, 0.125]
has_extra_convs: true
FCOSHead:
num_classes: 81
fpn_stride: [8, 16, 32, 64, 128]
num_convs: 4
norm_type: "gn"
fcos_loss: FCOSLoss
norm_reg_targets: True
centerness_on_reg: True
use_dcn_in_tower: False
nms: MultiClassNMS
MultiClassNMS:
score_threshold: 0.025
nms_top_k: 1000
keep_top_k: 100
nms_threshold: 0.6
background_label: -1
FCOSLoss:
loss_alpha: 0.25
loss_gamma: 2.0
iou_loss_type: "giou"
reg_weights: 1.0
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [120000, 160000]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score', 'im_info']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: true
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: [640, 672, 704, 736, 768, 800]
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: false
- !Gt2FCOSTarget
object_sizes_boundary: [64, 128, 256, 512]
center_sampling_radius: 1.5
downsample_ratios: [8, 16, 32, 64, 128]
norm_reg_targets: True
batch_size: 2
shuffle: true
worker_num: 16
use_process: false
EvalReader:
inputs_def:
fields: ['image', 'im_id', 'im_shape', 'im_info']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: true
batch_size: 8
shuffle: false
worker_num: 2
use_process: false
TestReader:
inputs_def:
# set image_shape if needed
fields: ['image', 'im_id', 'im_shape', 'im_info']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 128
use_padded_im_info: true
batch_size: 1
shuffle: false
......@@ -30,7 +30,10 @@ from .op_helper import jaccard_overlap
logger = logging.getLogger(__name__)
__all__ = ['PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget']
__all__ = [
'PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget',
'Gt2FCOSTarget'
]
@register_op
......@@ -245,3 +248,203 @@ class Gt2YoloTarget(BaseOperator):
target[best_n, 6 + cls, gj, gi] = 1.
sample['target{}'.format(i)] = target
return samples
@register_op
class Gt2FCOSTarget(BaseOperator):
"""
Generate FCOS targets by groud truth data
"""
def __init__(self,
object_sizes_boundary,
center_sampling_radius,
downsample_ratios,
norm_reg_targets=False):
super(Gt2FCOSTarget, self).__init__()
self.center_sampling_radius = center_sampling_radius
self.downsample_ratios = downsample_ratios
self.INF = np.inf
self.object_sizes_boundary = [-1] + object_sizes_boundary + [self.INF]
object_sizes_of_interest = []
for i in range(len(self.object_sizes_boundary) - 1):
object_sizes_of_interest.append([
self.object_sizes_boundary[i], self.object_sizes_boundary[i + 1]
])
self.object_sizes_of_interest = object_sizes_of_interest
self.norm_reg_targets = norm_reg_targets
def _compute_points(self, w, h):
"""
compute the corresponding points in each feature map
:param h: image height
:param w: image width
:return: points from all feature map
"""
locations = []
for stride in self.downsample_ratios:
shift_x = np.arange(0, w, stride).astype(np.float32)
shift_y = np.arange(0, h, stride).astype(np.float32)
shift_x, shift_y = np.meshgrid(shift_x, shift_y)
shift_x = shift_x.flatten()
shift_y = shift_y.flatten()
location = np.stack([shift_x, shift_y], axis=1) + stride // 2
locations.append(location)
num_points_each_level = [len(location) for location in locations]
locations = np.concatenate(locations, axis=0)
return locations, num_points_each_level
def _convert_xywh2xyxy(self, gt_bbox, w, h):
"""
convert the bounding box from style xywh to xyxy
:param gt_bbox: bounding boxes normalized into [0, 1]
:param w: image width
:param h: image height
:return: bounding boxes in xyxy style
"""
bboxes = gt_bbox.copy()
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * w
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * h
bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
return bboxes
def _check_inside_boxes_limited(self, gt_bbox, xs, ys,
num_points_each_level):
"""
check if points is within the clipped boxes
:param gt_bbox: bounding boxes
:param xs: horizontal coordinate of points
:param ys: vertical coordinate of points
:return: the mask of points is within gt_box or not
"""
bboxes = np.reshape(
gt_bbox, newshape=[1, gt_bbox.shape[0], gt_bbox.shape[1]])
bboxes = np.tile(bboxes, reps=[xs.shape[0], 1, 1])
ct_x = (bboxes[:, :, 0] + bboxes[:, :, 2]) / 2
ct_y = (bboxes[:, :, 1] + bboxes[:, :, 3]) / 2
beg = 0
clipped_box = bboxes.copy()
for lvl, stride in enumerate(self.downsample_ratios):
end = beg + num_points_each_level[lvl]
stride_exp = self.center_sampling_radius * stride
clipped_box[beg:end, :, 0] = np.maximum(
bboxes[beg:end, :, 0], ct_x[beg:end, :] - stride_exp)
clipped_box[beg:end, :, 1] = np.maximum(
bboxes[beg:end, :, 1], ct_y[beg:end, :] - stride_exp)
clipped_box[beg:end, :, 2] = np.minimum(
bboxes[beg:end, :, 2], ct_x[beg:end, :] + stride_exp)
clipped_box[beg:end, :, 3] = np.minimum(
bboxes[beg:end, :, 3], ct_y[beg:end, :] + stride_exp)
beg = end
l_res = xs - clipped_box[:, :, 0]
r_res = clipped_box[:, :, 2] - xs
t_res = ys - clipped_box[:, :, 1]
b_res = clipped_box[:, :, 3] - ys
clipped_box_reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
inside_gt_box = np.min(clipped_box_reg_targets, axis=2) > 0
return inside_gt_box
def __call__(self, samples, context=None):
assert len(self.object_sizes_of_interest) == len(self.downsample_ratios), \
"object_sizes_of_interest', and 'downsample_ratios' should have same length."
for sample in samples:
# im, gt_bbox, gt_class, gt_score = sample
im = sample['image']
im_info = sample['im_info']
bboxes = sample['gt_bbox']
gt_class = sample['gt_class']
gt_score = sample['gt_score']
bboxes[:, [0, 2]] = bboxes[:, [0, 2]] * np.floor(im_info[1]) / \
np.floor(im_info[1] / im_info[2])
bboxes[:, [1, 3]] = bboxes[:, [1, 3]] * np.floor(im_info[0]) / \
np.floor(im_info[0] / im_info[2])
# calculate the locations
h, w = sample['image'].shape[1:3]
points, num_points_each_level = self._compute_points(w, h)
object_scale_exp = []
for i, num_pts in enumerate(num_points_each_level):
object_scale_exp.append(
np.tile(
np.array([self.object_sizes_of_interest[i]]),
reps=[num_pts, 1]))
object_scale_exp = np.concatenate(object_scale_exp, axis=0)
gt_area = (bboxes[:, 2] - bboxes[:, 0]) * (
bboxes[:, 3] - bboxes[:, 1])
xs, ys = points[:, 0], points[:, 1]
xs = np.reshape(xs, newshape=[xs.shape[0], 1])
xs = np.tile(xs, reps=[1, bboxes.shape[0]])
ys = np.reshape(ys, newshape=[ys.shape[0], 1])
ys = np.tile(ys, reps=[1, bboxes.shape[0]])
l_res = xs - bboxes[:, 0]
r_res = bboxes[:, 2] - xs
t_res = ys - bboxes[:, 1]
b_res = bboxes[:, 3] - ys
reg_targets = np.stack([l_res, t_res, r_res, b_res], axis=2)
if self.center_sampling_radius > 0:
is_inside_box = self._check_inside_boxes_limited(
bboxes, xs, ys, num_points_each_level)
else:
is_inside_box = np.min(reg_targets, axis=2) > 0
# check if the targets is inside the corresponding level
max_reg_targets = np.max(reg_targets, axis=2)
lower_bound = np.tile(
np.expand_dims(
object_scale_exp[:, 0], axis=1),
reps=[1, max_reg_targets.shape[1]])
high_bound = np.tile(
np.expand_dims(
object_scale_exp[:, 1], axis=1),
reps=[1, max_reg_targets.shape[1]])
is_match_current_level = \
(max_reg_targets > lower_bound) & \
(max_reg_targets < high_bound)
points2gtarea = np.tile(
np.expand_dims(
gt_area, axis=0), reps=[xs.shape[0], 1])
points2gtarea[is_inside_box == 0] = self.INF
points2gtarea[is_match_current_level == 0] = self.INF
points2min_area = points2gtarea.min(axis=1)
points2min_area_ind = points2gtarea.argmin(axis=1)
labels = gt_class[points2min_area_ind]
labels[points2min_area == self.INF] = 0
reg_targets = reg_targets[range(xs.shape[0]), points2min_area_ind]
ctn_targets = np.sqrt((reg_targets[:, [0, 2]].min(axis=1) / \
reg_targets[:, [0, 2]].max(axis=1)) * \
(reg_targets[:, [1, 3]].min(axis=1) / \
reg_targets[:, [1, 3]].max(axis=1))).astype(np.float32)
ctn_targets = np.reshape(
ctn_targets, newshape=[ctn_targets.shape[0], 1])
ctn_targets[labels <= 0] = 0
pos_ind = np.nonzero(labels != 0)
reg_targets_pos = reg_targets[pos_ind[0], :]
split_sections = []
beg = 0
for lvl in range(len(num_points_each_level)):
end = beg + num_points_each_level[lvl]
split_sections.append(end)
beg = end
labels_by_level = np.split(labels, split_sections, axis=0)
reg_targets_by_level = np.split(reg_targets, split_sections, axis=0)
ctn_targets_by_level = np.split(ctn_targets, split_sections, axis=0)
for lvl in range(len(self.downsample_ratios)):
grid_w = int(np.ceil(w / self.downsample_ratios[lvl]))
grid_h = int(np.ceil(h / self.downsample_ratios[lvl]))
if self.norm_reg_targets:
sample['reg_target{}'.format(lvl)] = \
np.reshape(
reg_targets_by_level[lvl] / \
self.downsample_ratios[lvl],
newshape=[grid_h, grid_w, 4])
else:
sample['reg_target{}'.format(lvl)] = np.reshape(
reg_targets_by_level[lvl],
newshape=[grid_h, grid_w, 4])
sample['labels{}'.format(lvl)] = np.reshape(
labels_by_level[lvl], newshape=[grid_h, grid_w, 1])
sample['centerness{}'.format(lvl)] = np.reshape(
ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
return samples
......@@ -17,7 +17,9 @@ from __future__ import absolute_import
from . import rpn_head
from . import yolo_head
from . import retina_head
from . import fcos_head
from .rpn_head import *
from .yolo_head import *
from .retina_head import *
from .fcos_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.fluid.regularizer import L2Decay
from ppdet.modeling.ops import ConvNorm
from ppdet.modeling.ops import MultiClassNMS
from ppdet.core.workspace import register
__all__ = ['FCOSHead']
@register
class FCOSHead(object):
"""
FCOSHead
Args:
num_classes (int): Number of classes
fpn_stride (list): The stride of each FPN Layer
prior_prob (float): Used to set the bias init for the class prediction layer
num_convs (int): The layer number in fcos head
norm_type (str): Normalization type, 'bn'/'sync_bn'/'affine_channel'
fcos_loss (object): Instance of 'FCOSLoss'
norm_reg_targets (bool): Normalization the regression target if true
centerness_on_reg(bool): The prediction of centerness on regression or clssification branch
use_dcn_in_tower (bool): Ues deformable conv on FCOSHead if true
nms (object): Instance of 'MultiClassNMS'
"""
__inject__ = ['fcos_loss', 'nms']
__shared__ = ['num_classes']
def __init__(self,
num_classes=81,
fpn_stride=[8, 16, 32, 64, 128],
prior_prob=0.01,
num_convs=4,
norm_type="gn",
fcos_loss=None,
norm_reg_targets=False,
centerness_on_reg=False,
use_dcn_in_tower=False,
nms=MultiClassNMS(
score_threshold=0.01,
nms_top_k=1000,
keep_top_k=100,
nms_threshold=0.45,
background_label=-1).__dict__):
self.num_classes = num_classes - 1
self.fpn_stride = fpn_stride[::-1]
self.prior_prob = prior_prob
self.num_convs = num_convs
self.norm_reg_targets = norm_reg_targets
self.centerness_on_reg = centerness_on_reg
self.use_dcn_in_tower = use_dcn_in_tower
self.norm_type = norm_type
self.fcos_loss = fcos_loss
self.batch_size = 8
self.nms = nms
if isinstance(nms, dict):
self.nms = MultiClassNMS(**nms)
def _fcos_head(self, features, fpn_stride, fpn_scale, is_training=False):
"""
Args:
features (Variables): feature map from FPN
fpn_stride (int): the stride of current feature map
is_training (bool): whether is train or test mode
"""
subnet_blob_cls = features
subnet_blob_reg = features
in_channles = features.shape[1]
for lvl in range(0, self.num_convs):
conv_cls_name = 'fcos_head_cls_tower_conv_{}'.format(lvl)
subnet_blob_cls = ConvNorm(
input=subnet_blob_cls,
num_filters=in_channles,
filter_size=3,
stride=1,
norm_type=self.norm_type,
act='relu',
initializer=Normal(
loc=0., scale=0.01),
bias_attr=True,
norm_name=conv_cls_name + "_norm",
name=conv_cls_name)
conv_reg_name = 'fcos_head_reg_tower_conv_{}'.format(lvl)
subnet_blob_reg = ConvNorm(
input=subnet_blob_reg,
num_filters=in_channles,
filter_size=3,
stride=1,
norm_type=self.norm_type,
act='relu',
initializer=Normal(
loc=0., scale=0.01),
bias_attr=True,
norm_name=conv_reg_name + "_norm",
name=conv_reg_name)
conv_cls_name = "fcos_head_cls"
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
cls_logits = fluid.layers.conv2d(
input=subnet_blob_cls,
num_filters=self.num_classes,
filter_size=3,
stride=1,
padding=1,
param_attr=ParamAttr(
name=conv_cls_name + "_weights",
initializer=Normal(
loc=0., scale=0.01)),
bias_attr=ParamAttr(
name=conv_cls_name + "_bias",
initializer=Constant(value=bias_init_value)),
name=conv_cls_name)
conv_reg_name = "fcos_head_reg"
bbox_reg = fluid.layers.conv2d(
input=subnet_blob_reg,
num_filters=4,
filter_size=3,
stride=1,
padding=1,
param_attr=ParamAttr(
name=conv_reg_name + "_weights",
initializer=Normal(
loc=0., scale=0.01)),
bias_attr=ParamAttr(
name=conv_reg_name + "_bias", initializer=Constant(value=0)),
name=conv_reg_name)
bbox_reg = bbox_reg * fpn_scale
if self.norm_reg_targets:
bbox_reg = fluid.layers.relu(bbox_reg)
if not is_training:
bbox_reg = bbox_reg * fpn_stride
else:
bbox_reg = fluid.layers.exp(bbox_reg)
conv_centerness_name = "fcos_head_centerness"
if self.centerness_on_reg:
subnet_blob_ctn = subnet_blob_reg
else:
subnet_blob_ctn = subnet_blob_cls
centerness = fluid.layers.conv2d(
input=subnet_blob_ctn,
num_filters=1,
filter_size=3,
stride=1,
padding=1,
param_attr=ParamAttr(
name=conv_centerness_name + "_weights",
initializer=Normal(
loc=0., scale=0.01)),
bias_attr=ParamAttr(
name=conv_centerness_name + "_bias",
initializer=Constant(value=0)),
name=conv_centerness_name)
return cls_logits, bbox_reg, centerness
def _get_output(self, body_feats, is_training=False):
"""
Args:
body_feates (list): the list of fpn feature maps
is_training (bool): whether is train or test mode
Return:
cls_logits (Variables): prediction for classification
bboxes_reg (Variables): prediction for bounding box
centerness (Variables): prediction for ceterness
"""
cls_logits = []
bboxes_reg = []
centerness = []
assert len(body_feats) == len(self.fpn_stride), \
"The size of body_feats is not equal to size of fpn_stride"
for fpn_name, fpn_stride in zip(body_feats, self.fpn_stride):
features = body_feats[fpn_name]
scale = fluid.layers.create_parameter(
shape=[1, ],
dtype="float32",
name="%s_scale_on_reg" % fpn_name,
default_initializer=fluid.initializer.Constant(1.))
cls_pred, bbox_pred, ctn_pred = self._fcos_head(
features, fpn_stride, scale, is_training=is_training)
cls_logits.append(cls_pred)
bboxes_reg.append(bbox_pred)
centerness.append(ctn_pred)
return cls_logits, bboxes_reg, centerness
def _compute_locations(self, features):
"""
Args:
features (list): List of Variables for FPN feature maps
Return:
Anchor points for each feature map pixel
"""
locations = []
for lvl, fpn_name in enumerate(features):
feature = features[fpn_name]
shape_fm = fluid.layers.shape(feature)
shape_fm.stop_gradient = True
h = shape_fm[2]
w = shape_fm[3]
fpn_stride = self.fpn_stride[lvl]
shift_x = fluid.layers.range(
0, w * fpn_stride, fpn_stride, dtype='float32')
shift_y = fluid.layers.range(
0, h * fpn_stride, fpn_stride, dtype='float32')
shift_x = fluid.layers.unsqueeze(shift_x, axes=[0])
shift_y = fluid.layers.unsqueeze(shift_y, axes=[1])
shift_x = fluid.layers.expand_as(
shift_x, target_tensor=feature[0, 0, :, :])
shift_y = fluid.layers.expand_as(
shift_y, target_tensor=feature[0, 0, :, :])
shift_x.stop_gradient = True
shift_y.stop_gradient = True
shift_x = fluid.layers.reshape(shift_x, shape=[-1])
shift_y = fluid.layers.reshape(shift_y, shape=[-1])
location = fluid.layers.stack(
[shift_x, shift_y], axis=-1) + fpn_stride // 2
location.stop_gradient = True
locations.append(location)
return locations
def __merge_hw(self, input, ch_type="channel_first"):
"""
Args:
input (Variables): Feature map whose H and W will be merged into one dimension
ch_type (str): channel_first / channel_last
Return:
new_shape (Variables): The new shape after h and w merged into one dimension
"""
shape_ = fluid.layers.shape(input)
bs = shape_[0]
ch = shape_[1]
hi = shape_[2]
wi = shape_[3]
img_size = hi * wi
img_size.stop_gradient = True
if ch_type == "channel_first":
new_shape = fluid.layers.concat([bs, ch, img_size])
elif ch_type == "channel_last":
new_shape = fluid.layers.concat([bs, img_size, ch])
else:
raise KeyError("Wrong ch_type %s" % ch_type)
new_shape.stop_gradient = True
return new_shape
def _postprocessing_by_level(self, locations, box_cls, box_reg, box_ctn,
im_info):
"""
Args:
locations (Variables): anchor points for current layer
box_cls (Variables): categories prediction
box_reg (Variables): bounding box prediction
box_ctn (Variables): centerness prediction
im_info (Variables): [h, w, scale] for input images
Return:
box_cls_ch_last (Variables): score for each category, in [N, C, M]
C is the number of classes and M is the number of anchor points
box_reg_decoding (Variables): decoded bounding box, in [N, M, 4]
last dimension is [x1, y1, x2, y2]
"""
act_shape_cls = self.__merge_hw(box_cls)
box_cls_ch_last = fluid.layers.reshape(
x=box_cls,
shape=[self.batch_size, self.num_classes, -1],
actual_shape=act_shape_cls)
box_cls_ch_last = fluid.layers.sigmoid(box_cls_ch_last)
act_shape_reg = self.__merge_hw(box_reg, "channel_last")
box_reg_ch_last = fluid.layers.transpose(box_reg, perm=[0, 2, 3, 1])
box_reg_ch_last = fluid.layers.reshape(
x=box_reg_ch_last,
shape=[self.batch_size, -1, 4],
actual_shape=act_shape_reg)
act_shape_ctn = self.__merge_hw(box_ctn)
box_ctn_ch_last = fluid.layers.reshape(
x=box_ctn,
shape=[self.batch_size, 1, -1],
actual_shape=act_shape_ctn)
box_ctn_ch_last = fluid.layers.sigmoid(box_ctn_ch_last)
box_reg_decoding = fluid.layers.stack(
[
locations[:, 0] - box_reg_ch_last[:, :, 0],
locations[:, 1] - box_reg_ch_last[:, :, 1],
locations[:, 0] + box_reg_ch_last[:, :, 2],
locations[:, 1] + box_reg_ch_last[:, :, 3]
],
axis=1)
box_reg_decoding = fluid.layers.transpose(
box_reg_decoding, perm=[0, 2, 1])
# recover the location to original image
im_scale = im_info[:, 2]
box_reg_decoding = box_reg_decoding / im_scale
box_cls_ch_last = box_cls_ch_last * box_ctn_ch_last
return box_cls_ch_last, box_reg_decoding
def _post_processing(self, locations, cls_logits, bboxes_reg, centerness,
im_info):
"""
Args:
locations (list): List of Variables composed by center of each anchor point
cls_logits (list): List of Variables for class prediction
bboxes_reg (list): List of Variables for bounding box prediction
centerness (list): List of Variables for centerness prediction
im_info(Variables): [h, w, scale] for input images
Return:
pred (LoDTensor): predicted bounding box after nms,
the shape is n x 6, last dimension is [label, score, xmin, ymin, xmax, ymax]
"""
pred_boxes_ = []
pred_scores_ = []
for _, (
pts, cls, box, ctn
) in enumerate(zip(locations, cls_logits, bboxes_reg, centerness)):
pred_scores_lvl, pred_boxes_lvl = self._postprocessing_by_level(
pts, cls, box, ctn, im_info)
pred_boxes_.append(pred_boxes_lvl)
pred_scores_.append(pred_scores_lvl)
pred_boxes = fluid.layers.concat(pred_boxes_, axis=1)
pred_scores = fluid.layers.concat(pred_scores_, axis=2)
pred = self.nms(pred_boxes, pred_scores)
return pred
def get_loss(self, input, tag_labels, tag_bboxes, tag_centerness):
"""
Calculate the loss for FCOS
Args:
input (list): List of Variables for feature maps from FPN layers
tag_labels (Variables): category targets for each anchor point
tag_bboxes (Variables): bounding boxes targets for positive samples
tag_centerness (Variables): centerness targets for positive samples
Return:
loss (dict): loss composed by classification loss, bounding box
regression loss and centerness regression loss
"""
cls_logits, bboxes_reg, centerness = self._get_output(
input, is_training=True)
loss = self.fcos_loss(cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_centerness)
return loss
def get_prediction(self, input, im_info):
"""
Decode the prediction
Args:
input (list): List of Variables for feature maps from FPN layers
im_info(Variables): [h, w, scale] for input images
Return:
the bounding box prediction
"""
cls_logits, bboxes_reg, centerness = self._get_output(
input, is_training=False)
locations = self._compute_locations(input)
pred = self._post_processing(locations, cls_logits, bboxes_reg,
centerness, im_info)
return {"bbox": pred}
......@@ -24,6 +24,7 @@ from . import ssd
from . import retinanet
from . import blazeface
from . import faceboxes
from . import fcos
from .faster_rcnn import *
from .mask_rcnn import *
......@@ -35,3 +36,4 @@ from .ssd import *
from .retinanet import *
from .blazeface import *
from .faceboxes import *
from .fcos import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
__all__ = ['FCOS']
@register
class FCOS(object):
"""
FCOS architecture, see https://arxiv.org/abs/1904.01355
Args:
backbone (object): backbone instance
fpn (object): feature pyramid network instance
fcos_head (object): `FCOSHead` instance
"""
__category__ = 'architecture'
__inject__ = ['backbone', 'fpn', 'fcos_head']
def __init__(self, backbone, fpn, fcos_head):
super(FCOS, self).__init__()
self.backbone = backbone
self.fpn = fpn
self.fcos_head = fcos_head
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
im_info = feed_vars['im_info']
mixed_precision_enabled = mixed_precision_global_state() is not None
# cast inputs to FP16
if mixed_precision_enabled:
im = fluid.layers.cast(im, 'float16')
# backbone
body_feats = self.backbone(im)
# cast features back to FP32
if mixed_precision_enabled:
body_feats = OrderedDict((k, fluid.layers.cast(v, 'float32'))
for k, v in body_feats.items())
# FPN
body_feats, spatial_scale = self.fpn.get_output(body_feats)
# fcosnet head
if mode == 'train':
tag_labels = []
tag_bboxes = []
tag_centerness = []
for i in range(len(self.fcos_head.fpn_stride)):
# reg_target, labels, scores, centerness
k_lbl = 'labels{}'.format(i)
if k_lbl in feed_vars:
tag_labels.append(feed_vars[k_lbl])
k_box = 'reg_target{}'.format(i)
if k_box in feed_vars:
tag_bboxes.append(feed_vars[k_box])
k_ctn = 'centerness{}'.format(i)
if k_ctn in feed_vars:
tag_centerness.append(feed_vars[k_ctn])
# tag_labels, tag_bboxes, tag_centerness
loss = self.fcos_head.get_loss(body_feats, tag_labels, tag_bboxes,
tag_centerness)
total_loss = fluid.layers.sum(list(loss.values()))
loss.update({'loss': total_loss})
return loss
else:
pred = self.fcos_head.get_prediction(body_feats, im_info)
return pred
def _inputs_def(self, image_shape, fields):
im_shape = [None] + image_shape
# yapf: disable
inputs_def = {
'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
'gt_score': {'shape': [None, 1], 'dtype': 'float32', 'lod_level': 1},
'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}
}
# yapf: disable
if 'gt_bbox' in fields:
targets_def = {
'labels0': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
'reg_target0': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
'centerness0': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
'labels1': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
'reg_target1': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
'centerness1': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
'labels2': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
'reg_target2': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
'centerness2': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
'labels3': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
'reg_target3': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
'centerness3': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
'labels4': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
'reg_target4': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
'centerness4': {'shape': [None, None, None, 1], 'dtype': 'float32', 'lod_level': 0},
}
# yapf: enable
# downsample = 128
for k, stride in enumerate(self.fcos_head.fpn_stride):
k_lbl = 'labels{}'.format(k)
k_box = 'reg_target{}'.format(k)
k_ctn = 'centerness{}'.format(k)
grid_y = image_shape[-2] // stride if image_shape[-2] else None
grid_x = image_shape[-1] // stride if image_shape[-1] else None
if grid_x is not None:
num_pts = grid_x * grid_y
num_dim2 = 1
else:
num_pts = None
num_dim2 = None
targets_def[k_lbl]['shape'][1] = num_pts
targets_def[k_box]['shape'][1] = num_pts
targets_def[k_ctn]['shape'][1] = num_pts
targets_def[k_lbl]['shape'][2] = num_dim2
targets_def[k_box]['shape'][2] = num_dim2
targets_def[k_ctn]['shape'][2] = num_dim2
inputs_def.update(targets_def)
return inputs_def
def build_inputs(
self,
image_shape=[3, None, None],
fields=[
'image', 'im_shape', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd'
], # for-train
use_dataloader=True,
iterable=False):
inputs_def = self._inputs_def(image_shape, fields)
if "gt_bbox" in fields:
for i in range(len(self.fcos_head.fpn_stride)):
fields.extend(
['labels%d' % i, 'reg_target%d' % i, 'centerness%d' % i])
feed_vars = OrderedDict([(key, fluid.layers.data(
name=key,
shape=inputs_def[key]['shape'],
dtype=inputs_def[key]['dtype'],
lod_level=inputs_def[key]['lod_level'])) for key in fields])
loader = fluid.io.DataLoader.from_generator(
feed_list=list(feed_vars.values()),
capacity=16,
use_double_buffer=True,
iterable=iterable) if use_dataloader else None
return feed_vars, loader
def train(self, feed_vars):
return self.build(feed_vars, 'train')
def eval(self, feed_vars):
return self.build(feed_vars, 'test')
def test(self, feed_vars):
return self.build(feed_vars, 'test')
......@@ -51,7 +51,8 @@ class FPN(object):
spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.],
has_extra_convs=False,
norm_type=None,
freeze_norm=False):
freeze_norm=False,
use_c5=True):
self.freeze_norm = freeze_norm
self.num_chan = num_chan
self.min_level = min_level
......@@ -59,6 +60,7 @@ class FPN(object):
self.spatial_scale = spatial_scale
self.has_extra_convs = has_extra_convs
self.norm_type = norm_type
self.use_c5 = use_c5
def _add_topdown_lateral(self, body_name, body_input, upper_output):
lateral_name = 'fpn_inner_' + body_name + '_lateral'
......@@ -189,7 +191,10 @@ class FPN(object):
# Coarser FPN levels introduced for RetinaNet
highest_backbone_level = self.min_level + len(spatial_scale) - 1
if self.has_extra_convs and self.max_level > highest_backbone_level:
fpn_blob = body_dict[body_name_list[0]]
if self.use_c5:
fpn_blob = body_dict[body_name_list[0]]
else:
fpn_blob = fpn_dict[fpn_name_list[0]]
for i in range(highest_backbone_level + 1, self.max_level + 1):
fpn_blob_in = fpn_blob
fpn_name = 'fpn_' + str(i)
......
......@@ -20,6 +20,7 @@ from . import giou_loss
from . import diou_loss
from . import iou_loss
from . import balanced_l1_loss
from . import fcos_loss
from .yolo_loss import *
from .smooth_l1_loss import *
......@@ -27,3 +28,4 @@ from .giou_loss import *
from .diou_loss import *
from .iou_loss import *
from .balanced_l1_loss import *
from .fcos_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from ppdet.core.workspace import register, serializable
INF = 1e8
__all__ = ['FCOSLoss']
@register
@serializable
class FCOSLoss(object):
"""
FCOSLoss
Args:
loss_alpha (float): alpha in focal loss
loss_gamma (float): gamma in focal loss
iou_loss_type(str): location loss type, IoU/GIoU/LINEAR_IoU
reg_weights(float): weight for location loss
"""
def __init__(self,
loss_alpha=0.25,
loss_gamma=2.0,
iou_loss_type="IoU",
reg_weights=1.0):
self.loss_alpha = loss_alpha
self.loss_gamma = loss_gamma
self.iou_loss_type = iou_loss_type
self.reg_weights = reg_weights
def __flatten_tensor(self, input, channel_first=False):
"""
Flatten a Tensor
Args:
input (Variables): Input Tensor
channel_first(bool): if true the dimension order of
Tensor is [N, C, H, W], otherwise is [N, H, W, C]
Return:
input_channel_last (Variables): The flattened Tensor in channel_last style
"""
if channel_first:
input_channel_last = fluid.layers.transpose(
input, perm=[0, 2, 3, 1])
else:
input_channel_last = input
input_channel_last = fluid.layers.flatten(input_channel_last, axis=3)
return input_channel_last
def __iou_loss(self, pred, targets, positive_mask, weights=None):
"""
Calculate the loss for location prediction
Args:
pred (Variables): bounding boxes prediction
targets (Variables): targets for positive samples
positive_mask (Variables): mask of positive samples
weights (Variables): weights for each positive samples
Return:
loss (Varialbes): location loss
"""
plw = pred[:, 0] * positive_mask
pth = pred[:, 1] * positive_mask
prw = pred[:, 2] * positive_mask
pbh = pred[:, 3] * positive_mask
tlw = targets[:, 0] * positive_mask
tth = targets[:, 1] * positive_mask
trw = targets[:, 2] * positive_mask
tbh = targets[:, 3] * positive_mask
tlw.stop_gradient = True
trw.stop_gradient = True
tth.stop_gradient = True
tbh.stop_gradient = True
area_target = (tlw + trw) * (tth + tbh)
area_predict = (plw + prw) * (pth + pbh)
ilw = fluid.layers.elementwise_min(plw, tlw)
irw = fluid.layers.elementwise_min(prw, trw)
ith = fluid.layers.elementwise_min(pth, tth)
ibh = fluid.layers.elementwise_min(pbh, tbh)
clw = fluid.layers.elementwise_max(plw, tlw)
crw = fluid.layers.elementwise_max(prw, trw)
cth = fluid.layers.elementwise_max(pth, tth)
cbh = fluid.layers.elementwise_max(pbh, tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
ious = ious * positive_mask
if self.iou_loss_type.lower() == "linear_iou":
loss = 1.0 - ious
elif self.iou_loss_type.lower() == "giou":
area_uniou = area_predict + area_target - area_inter
area_circum = (clw + crw) * (cth + cbh) + 1e-7
giou = ious - (area_circum - area_uniou) / area_circum
loss = 1.0 - giou
elif self.iou_loss_type.lower() == "iou":
loss = 0.0 - fluid.layers.log(ious)
else:
raise KeyError
if weights is not None:
loss = loss * weights
return loss
def __call__(self, cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_center):
"""
Calculate the loss for classification, location and centerness
Args:
cls_logits (list): list of Variables, which is predicted
score for all anchor points with shape [N, M, C]
bboxes_reg (list): list of Variables, which is predicted
offsets for all anchor points with shape [N, M, 4]
centerness (list): list of Variables, which is predicted
centerness for all anchor points with shape [N, M, 1]
tag_labels (list): list of Variables, which is category
targets for each anchor point
tag_bboxes (list): list of Variables, which is bounding
boxes targets for positive samples
tag_center (list): list of Variables, which is centerness
targets for positive samples
Return:
loss (dict): loss composed by classification loss, bounding box
"""
cls_logits_flatten_list = []
bboxes_reg_flatten_list = []
centerness_flatten_list = []
tag_labels_flatten_list = []
tag_bboxes_flatten_list = []
tag_center_flatten_list = []
num_lvl = len(cls_logits)
for lvl in range(num_lvl):
cls_logits_flatten_list.append(
self.__flatten_tensor(cls_logits[num_lvl - 1 - lvl], True))
bboxes_reg_flatten_list.append(
self.__flatten_tensor(bboxes_reg[num_lvl - 1 - lvl], True))
centerness_flatten_list.append(
self.__flatten_tensor(centerness[num_lvl - 1 - lvl], True))
tag_labels_flatten_list.append(
self.__flatten_tensor(tag_labels[lvl], False))
tag_bboxes_flatten_list.append(
self.__flatten_tensor(tag_bboxes[lvl], False))
tag_center_flatten_list.append(
self.__flatten_tensor(tag_center[lvl], False))
cls_logits_flatten = fluid.layers.concat(
cls_logits_flatten_list, axis=0)
bboxes_reg_flatten = fluid.layers.concat(
bboxes_reg_flatten_list, axis=0)
centerness_flatten = fluid.layers.concat(
centerness_flatten_list, axis=0)
tag_labels_flatten = fluid.layers.concat(
tag_labels_flatten_list, axis=0)
tag_bboxes_flatten = fluid.layers.concat(
tag_bboxes_flatten_list, axis=0)
tag_center_flatten = fluid.layers.concat(
tag_center_flatten_list, axis=0)
tag_labels_flatten.stop_gradient = True
tag_bboxes_flatten.stop_gradient = True
tag_center_flatten.stop_gradient = True
mask_positive = tag_labels_flatten > 0
mask_positive.stop_gradient = True
mask_positive_float = fluid.layers.cast(mask_positive, dtype="float32")
mask_positive_float.stop_gradient = True
num_positive_fp32 = fluid.layers.reduce_sum(mask_positive_float)
num_positive_int32 = fluid.layers.cast(num_positive_fp32, dtype="int32")
num_positive_int32 = num_positive_int32 * 0 + 1
num_positive_fp32.stop_gradient = True
num_positive_int32.stop_gradient = True
normalize_sum = fluid.layers.sum(tag_center_flatten)
normalize_sum.stop_gradient = True
normalize_sum = fluid.layers.reduce_sum(mask_positive_float *
normalize_sum)
normalize_sum.stop_gradient = True
cls_loss = fluid.layers.sigmoid_focal_loss(
cls_logits_flatten, tag_labels_flatten,
num_positive_int32) / num_positive_fp32
reg_loss = self.__iou_loss(
bboxes_reg_flatten, tag_bboxes_flatten, mask_positive_float,
tag_center_flatten) * mask_positive_float / normalize_sum
ctn_loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=centerness_flatten,
label=tag_center_flatten) * mask_positive_float / num_positive_fp32
loss_all = {
"loss_centerness": fluid.layers.reduce_sum(ctn_loss),
"loss_cls": fluid.layers.reduce_sum(cls_loss),
"loss_box": fluid.layers.reduce_sum(reg_loss)
}
return loss_all
......@@ -77,11 +77,6 @@ class IouLoss(object):
xkis2 = fluid.layers.elementwise_min(x2, x2g)
ykis2 = fluid.layers.elementwise_min(y2, y2g)
xc1 = fluid.layers.elementwise_min(x1, x1g)
yc1 = fluid.layers.elementwise_min(y1, y1g)
xc2 = fluid.layers.elementwise_max(x2, x2g)
yc2 = fluid.layers.elementwise_max(y2, y2g)
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
......
......@@ -46,8 +46,16 @@ def ConvNorm(input,
act=None,
norm_name=None,
initializer=None,
bias_attr=False,
name=None):
fan = num_filters
if bias_attr:
bias_para = ParamAttr(
name=name + "_bias",
initializer=fluid.initializer.Constant(value=0),
learning_rate=lr_scale * 2)
else:
bias_para = False
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -61,7 +69,7 @@ def ConvNorm(input,
name=name + "_weights",
initializer=initializer,
learning_rate=lr_scale),
bias_attr=False,
bias_attr=bias_para,
name=name + '.conv2d.output.1')
norm_lr = 0. if freeze_norm else 1.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册