未验证 提交 38d420bb 编写于 作者: S sunxl1988 提交者: GitHub

test=master add htc model (#1081)

add htc model 
上级 64a2a78e
# Hybrid Task Cascade for Instance Segmentation
## Introduction
We provide config files to reproduce the results in the CVPR 2019 paper for [Hybrid Task Cascade](https://arxiv.org/abs/1901.07518).
```
@inproceedings{chen2019hybrid,
title={Hybrid task cascade for instance segmentation},
author={Chen, Kai and Pang, Jiangmiao and Wang, Jiaqi and Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and Liu, Ziwei and Shi, Jianping and Ouyang, Wanli and Chen Change Loy and Dahua Lin},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```
## Dataset
HTC requires COCO and COCO-stuff dataset for training.
## Results and Models
The results on COCO 2017val are shown in the below table. (results on test-dev are usually slightly higher than val)
| Backbone | Lr schd | Inf time (fps) | box AP | mask AP | Download |
|:---------:|:-------:|:--------------:|:------:|:-------:|:--------:|
| R-50-FPN | 1x | 11 | 42.2 | 36.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/htc_r50_fpn_1x.pdparams ) |
architecture: HybridTaskCascade
use_gpu: true
max_iters: 100000
snapshot_iter: 10000
log_smooth_window: 50
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar
metric: COCO
weights: output/htc_r50_fpn_1x/model_final
num_classes: 81
HybridTaskCascade:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: HTCBBoxHead
bbox_assigner: CascadeBBoxAssigner
mask_assigner: MaskAssigner
mask_head: HTCMaskHead
fused_semantic_head: FusedSemanticHead
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: affine_channel
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 1000
# bbox roi extractor
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
sampling_ratio: 2
box_resolution: 7
mask_resolution: 14
# semantic roi extractor
RoIAlign:
resolution: 14
sampling_ratio: 2
HTCMaskHead:
dilation: 1
conv_dim: 256
num_convs: 4
resolution: 28
lr_ratio: 2.0
FusedSemanticHead:
semantic_num_class: 183
CascadeBBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [10, 20, 30]
bg_thresh_hi: [0.5, 0.6, 0.7]
bg_thresh_lo: [0.0, 0.0, 0.0]
fg_fraction: 0.25
fg_thresh: [0.5, 0.6, 0.7]
MaskAssigner:
resolution: 28
HTCBBoxHead:
head: CascadeTwoFCHead
nms: MultiClassSoftNMS
MultiClassSoftNMS:
score_threshold: 0.01
keep_top_k: 300
softnms_sigma: 0.5
CascadeTwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
TrainReader:
batch_size: 2
worker_num: 2
dataset:
!COCODataSet
dataset_dir: dataset/coco
anno_path: annotations/instances_train2017.json
image_dir: train2017
load_semantic: True
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_mask', 'semantic']
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
is_mask_flip: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
target_size: 800
max_size: 1333
interp: 1
use_cv2: true
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
sample_transforms:
- !DecodeImage
to_rgb: true
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: false
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
worker_num: 2
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
......@@ -42,7 +42,8 @@ class COCODataSet(DataSet):
anno_path=None,
dataset_dir=None,
sample_num=-1,
with_background=True):
with_background=True,
load_semantic=False):
super(COCODataSet, self).__init__(
image_dir=image_dir,
anno_path=anno_path,
......@@ -68,6 +69,7 @@ class COCODataSet(DataSet):
# a dict used to map category name to class id
self.cname2cid = None
self.load_image_only = False
self.load_semantic = load_semantic
def load_roidb_and_cname2cid(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
......@@ -104,11 +106,11 @@ class COCODataSet(DataSet):
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
im_fname = os.path.join(image_dir,
im_fname) if image_dir else im_fname
if not os.path.exists(im_fname):
im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname
if not os.path.exists(im_path):
logger.warn('Illegal image file: {}, and it will be '
'ignored'.format(im_fname))
'ignored'.format(im_path))
continue
if im_w < 0 or im_h < 0:
......@@ -118,7 +120,7 @@ class COCODataSet(DataSet):
continue
coco_rec = {
'im_file': im_fname,
'im_file': im_path,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
......@@ -168,8 +170,13 @@ class COCODataSet(DataSet):
'gt_poly': gt_poly,
})
if self.load_semantic:
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
'train2017', im_fname[:-3] + 'png')
coco_rec.update({'semantic': seg_path})
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_fname, img_id, im_h, im_w))
im_path, img_id, im_h, im_w))
records.append(coco_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
......
......@@ -82,6 +82,13 @@ class PadBatch(BaseOperator):
data['image'] = padding_im
if self.use_padded_im_info:
data['im_info'][:2] = max_shape[1:3]
if 'semantic' in data.keys() and data['semantic'] is not None:
semantic = data['semantic']
padding_sem = np.zeros(
(1, max_shape[1], max_shape[2]), dtype=np.float32)
padding_sem[:, :im_h, :im_w] = semantic
data['semantic'] = padding_sem
return samples
......
......@@ -106,8 +106,6 @@ class DecodeImage(BaseOperator):
raise TypeError("{}: input type is invalid.".format(self))
if not isinstance(self.with_mixup, bool):
raise TypeError("{}: input type is invalid.".format(self))
if not isinstance(self.with_cutmix, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
""" load image if 'im_file' field is not empty but 'image' is"""
......@@ -143,13 +141,21 @@ class DecodeImage(BaseOperator):
# make default im_info with [h, w, 1]
sample['im_info'] = np.array(
[im.shape[0], im.shape[1], 1.], dtype=np.float32)
# decode mixup image
if self.with_mixup and 'mixup' in sample:
self.__call__(sample['mixup'], context)
# decode cutmix image
if self.with_cutmix and 'cutmix' in sample:
self.__call__(sample['cutmix'], context)
# decode semantic label
if 'semantic' in sample.keys() and sample['semantic'] is not None:
sem_file = sample['semantic']
sem = cv2.imread(sem_file, cv2.IMREAD_GRAYSCALE)
sample['semantic'] = sem.astype('int32')
return sample
......@@ -342,6 +348,18 @@ class ResizeImage(BaseOperator):
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
if 'semantic' in sample.keys() and sample['semantic'] is not None:
semantic = sample['semantic']
semantic = cv2.resize(
semantic.astype('float32'),
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
semantic = np.asarray(semantic).astype('int32')
semantic = np.expand_dims(semantic, 0)
sample['semantic'] = semantic
else:
if self.max_size != 0:
raise TypeError(
......@@ -455,9 +473,15 @@ class RandomFlipImage(BaseOperator):
if self.is_mask_flip and len(sample['gt_poly']) != 0:
sample['gt_poly'] = self.flip_segms(sample['gt_poly'],
height, width)
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = self.flip_keypoint(
sample['gt_keypoint'], width)
if 'semantic' in sample.keys() and sample[
'semantic'] is not None:
sample['semantic'] = sample['semantic'][:, ::-1]
sample['flipped'] = True
sample['image'] = im
sample = samples if batch_input else samples[0]
......
......@@ -28,6 +28,7 @@ from . import faceboxes
from . import fcos
from . import cornernet_squeeze
from . import ttfnet
from . import htc
from .faster_rcnn import *
from .mask_rcnn import *
......@@ -43,3 +44,4 @@ from .faceboxes import *
from .fcos import *
from .cornernet_squeeze import *
from .ttfnet import *
from .htc import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import copy
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
from .input_helper import multiscale_def
__all__ = ['HybridTaskCascade']
@register
class HybridTaskCascade(object):
"""
Hybrid Task Cascade Mask R-CNN architecture, see https://arxiv.org/abs/1901.07518
Args:
backbone (object): backbone instance
rpn_head (object): `RPNhead` instance
bbox_assigner (object): `BBoxAssigner` instance
roi_extractor (object): ROI extractor instance
bbox_head (object): `HTCBBoxHead` instance
mask_assigner (object): `MaskAssigner` instance
mask_head (object): `HTCMaskHead` instance
fpn (object): feature pyramid network instance
semantic_roi_extractor(object): ROI extractor instance
fused_semantic_head (object): `FusedSemanticHead` instance
"""
__category__ = 'architecture'
__inject__ = [
'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head',
'mask_assigner', 'mask_head', 'fpn', 'semantic_roi_extractor',
'fused_semantic_head'
]
def __init__(self,
backbone,
rpn_head,
roi_extractor='FPNRoIAlign',
semantic_roi_extractor='RoIAlign',
fused_semantic_head='FusedSemanticHead',
bbox_head='HTCBBoxHead',
bbox_assigner='CascadeBBoxAssigner',
mask_assigner='MaskAssigner',
mask_head='HTCMaskHead',
rpn_only=False,
fpn='FPN'):
super(HybridTaskCascade, self).__init__()
assert fpn is not None, "HTC requires FPN"
self.backbone = backbone
self.fpn = fpn
self.rpn_head = rpn_head
self.bbox_assigner = bbox_assigner
self.roi_extractor = roi_extractor
self.semantic_roi_extractor = semantic_roi_extractor
self.fused_semantic_head = fused_semantic_head
self.bbox_head = bbox_head
self.mask_assigner = mask_assigner
self.mask_head = mask_head
self.rpn_only = rpn_only
# Cascade local cfg
self.cls_agnostic_bbox_reg = 2
(brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights
self.cascade_bbox_reg_weights = [
[1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0],
[1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1],
[1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2]
]
self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25]
self.num_stage = 3
self.with_mask = True
self.interleaved = True
self.mask_info_flow = True
self.with_semantic = True
self.use_bias_scalar = True
def build(self, feed_vars, mode='train'):
if mode == 'train':
required_fields = [
'gt_class', 'gt_bbox', 'gt_mask', 'is_crowd', 'im_info',
'semantic'
]
else:
required_fields = ['im_shape', 'im_info']
self._input_check(required_fields, feed_vars)
im = feed_vars['image']
if mode == 'train':
gt_bbox = feed_vars['gt_bbox']
is_crowd = feed_vars['is_crowd']
im_info = feed_vars['im_info']
# backbone
body_feats = self.backbone(im)
loss = {}
# FPN
if self.fpn is not None:
body_feats, spatial_scale = self.fpn.get_output(body_feats)
if self.with_semantic:
# TODO: use cfg
semantic_feat, seg_pred = self.fused_semantic_head.get_out(
body_feats)
if mode == 'train':
s_label = feed_vars['semantic']
semantic_loss = self.fused_semantic_head.get_loss(seg_pred,
s_label) * 0.2
loss.update({"semantic_loss": semantic_loss})
else:
semantic_feat = None
# rpn proposals
rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode)
if mode == 'train':
rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd)
loss.update(rpn_loss)
else:
if self.rpn_only:
im_scale = fluid.layers.slice(
im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, rpn_rois)
rois = rpn_rois / im_scale
return {'proposal': rois}
proposal_list = []
roi_feat_list = []
rcnn_pred_list = []
rcnn_target_list = []
mask_logits_list = []
mask_target_list = []
proposals = None
bbox_pred = None
outs = None
refined_bbox = rpn_rois
for i in range(self.num_stage):
# BBox Branch
if mode == 'train':
outs = self.bbox_assigner(
input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i)
proposals = outs[0]
rcnn_target_list.append(outs)
else:
proposals = refined_bbox
proposal_list.append(proposals)
# extract roi features
roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale)
if self.with_semantic:
semantic_roi_feat = self.semantic_roi_extractor(semantic_feat,
proposals)
if semantic_roi_feat is not None:
semantic_roi_feat = fluid.layers.pool2d(
semantic_roi_feat,
pool_size=2,
pool_stride=2,
pool_padding='SAME')
roi_feat = fluid.layers.sum([roi_feat, semantic_roi_feat])
roi_feat_list.append(roi_feat)
# bbox head
cls_score, bbox_pred = self.bbox_head.get_output(
roi_feat,
wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i],
name='_' + str(i))
rcnn_pred_list.append((cls_score, bbox_pred))
# Mask Branch
if self.with_mask:
if mode == 'train':
labels_int32 = outs[1]
if self.interleaved:
refined_bbox = self._decode_box(
proposals, bbox_pred, curr_stage=i)
proposals = refined_bbox
mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner(
rois=proposals,
gt_classes=feed_vars['gt_class'],
is_crowd=feed_vars['is_crowd'],
gt_segms=feed_vars['gt_mask'],
im_info=feed_vars['im_info'],
labels_int32=labels_int32)
mask_target_list.append(mask_int32)
mask_feat = self.roi_extractor(
body_feats, mask_rois, spatial_scale, is_mask=True)
if self.with_semantic:
semantic_roi_feat = self.semantic_roi_extractor(
semantic_feat, mask_rois)
if semantic_roi_feat is not None:
mask_feat = fluid.layers.sum(
[mask_feat, semantic_roi_feat])
if self.mask_info_flow:
last_feat = None
for j in range(i):
last_feat = self.mask_head.get_output(
mask_feat,
last_feat,
return_logits=False,
return_feat=True,
wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
if self.use_bias_scalar else 1.0,
name='_' + str(i) + '_' + str(j))
mask_logits = self.mask_head.get_output(
mask_feat,
last_feat,
return_logits=True,
return_feat=False,
wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
if self.use_bias_scalar else 1.0,
name='_' + str(i))
else:
mask_logits = self.mask_head.get_output(
mask_feat,
return_logits=True,
wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
if self.use_bias_scalar else 1.0,
name='_' + str(i))
mask_logits_list.append(mask_logits)
if i < self.num_stage - 1 and not self.interleaved:
refined_bbox = self._decode_box(
proposals, bbox_pred, curr_stage=i)
elif i < self.num_stage - 1 and mode != 'train':
refined_bbox = self._decode_box(
proposals, bbox_pred, curr_stage=i)
if mode == 'train':
bbox_loss = self.bbox_head.get_loss(
rcnn_pred_list, rcnn_target_list, self.cascade_rcnn_loss_weight)
loss.update(bbox_loss)
mask_loss = self.mask_head.get_loss(mask_logits_list,
mask_target_list,
self.cascade_rcnn_loss_weight)
loss.update(mask_loss)
total_loss = fluid.layers.sum(list(loss.values()))
loss.update({'loss': total_loss})
return loss
else:
mask_name = 'mask_pred'
mask_pred, bbox_pred = self.single_scale_eval(
body_feats,
spatial_scale,
im_info,
mask_name,
bbox_pred,
roi_feat_list,
rcnn_pred_list,
proposal_list,
feed_vars['im_shape'],
semantic_feat=semantic_feat if self.with_semantic else None)
return {'bbox': bbox_pred, 'mask': mask_pred}
def single_scale_eval(self,
body_feats,
spatial_scale,
im_info,
mask_name,
bbox_pred,
roi_feat_list=None,
rcnn_pred_list=None,
proposal_list=None,
im_shape=None,
use_multi_test=False,
semantic_feat=None):
if not use_multi_test:
bbox_pred = self.bbox_head.get_prediction(
im_info, im_shape, roi_feat_list, rcnn_pred_list, proposal_list,
self.cascade_bbox_reg_weights)
bbox_pred = bbox_pred['bbox']
# share weight
bbox_shape = fluid.layers.shape(bbox_pred)
bbox_size = fluid.layers.reduce_prod(bbox_shape)
bbox_size = fluid.layers.reshape(bbox_size, [1, 1])
size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32')
cond = fluid.layers.less_than(x=bbox_size, y=size)
mask_pred = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=False,
name=mask_name)
def noop():
fluid.layers.assign(input=bbox_pred, output=mask_pred)
def process_boxes():
bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6])
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale = fluid.layers.sequence_expand(im_scale, bbox)
bbox = fluid.layers.cast(bbox, dtype='float32')
im_scale = fluid.layers.cast(im_scale, dtype='float32')
mask_rois = bbox * im_scale
mask_feat = self.roi_extractor(
body_feats, mask_rois, spatial_scale, is_mask=True)
if self.with_semantic:
semantic_roi_feat = self.semantic_roi_extractor(semantic_feat,
mask_rois)
if semantic_roi_feat is not None:
mask_feat = fluid.layers.sum([mask_feat, semantic_roi_feat])
mask_logits_list = []
mask_pred_list = []
for i in range(self.num_stage):
if self.mask_info_flow:
last_feat = None
for j in range(i):
last_feat = self.mask_head.get_output(
mask_feat,
last_feat,
return_logits=False,
return_feat=True,
wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
if self.use_bias_scalar else 1.0,
name='_' + str(i) + '_' + str(j))
mask_logits = self.mask_head.get_output(
mask_feat,
last_feat,
return_logits=True,
return_feat=False,
wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i]
if self.use_bias_scalar else 1.0,
name='_' + str(i))
mask_logits_list.append(mask_logits)
else:
mask_logits = self.mask_head.get_output(
mask_feat,
return_logits=True,
return_feat=False,
name='_' + str(i))
mask_pred_out = self.mask_head.get_prediction(mask_logits, bbox)
mask_pred_list.append(mask_pred_out)
mask_pred_out = fluid.layers.sum(mask_pred_list) / float(
len(mask_pred_list))
fluid.layers.assign(input=mask_pred_out, output=mask_pred)
fluid.layers.cond(cond, noop, process_boxes)
return mask_pred, bbox_pred
def _input_check(self, require_fields, feed_vars):
for var in require_fields:
assert var in feed_vars, \
"{} has no {} field".format(feed_vars, var)
def _decode_box(self, proposals, bbox_pred, curr_stage):
rcnn_loc_delta_r = fluid.layers.reshape(
bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4))
# only use fg box delta to decode box
rcnn_loc_delta_s = fluid.layers.slice(
rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2])
refined_bbox = fluid.layers.box_coder(
prior_box=proposals,
prior_box_var=self.cascade_bbox_reg_weights[curr_stage],
target_box=rcnn_loc_delta_s,
code_type='decode_center_size',
box_normalized=False,
axis=1, )
refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4])
return refined_bbox
def _inputs_def(self, image_shape):
im_shape = [None] + image_shape
# yapf: disable
inputs_def = {
'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0},
'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1},
'gt_mask': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 3}, # polygon coordinates
'semantic': {'shape': [None, 1, None, None], 'dtype': 'int32', 'lod_level': 0},
}
# yapf: enable
return inputs_def
def build_inputs(self,
image_shape=[3, None, None],
fields=[
'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class',
'is_crowd', 'gt_mask', 'semantic'
],
multi_scale=False,
num_scales=-1,
use_flip=None,
use_dataloader=True,
iterable=False,
mask_branch=False):
inputs_def = self._inputs_def(image_shape)
fields = copy.deepcopy(fields)
if multi_scale:
ms_def, ms_fields = multiscale_def(image_shape, num_scales,
use_flip)
inputs_def.update(ms_def)
fields += ms_fields
self.im_info_names = ['image', 'im_info'] + ms_fields
if mask_branch:
box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox']
for key in box_fields:
inputs_def[key] = {
'shape': [6],
'dtype': 'float32',
'lod_level': 1
}
fields += box_fields
feed_vars = OrderedDict([(key, fluid.data(
name=key,
shape=inputs_def[key]['shape'],
dtype=inputs_def[key]['dtype'],
lod_level=inputs_def[key]['lod_level'])) for key in fields])
use_dataloader = use_dataloader and not mask_branch
loader = fluid.io.DataLoader.from_generator(
feed_list=list(feed_vars.values()),
capacity=64,
use_double_buffer=True,
iterable=iterable) if use_dataloader else None
return feed_vars, loader
def train(self, feed_vars):
return self.build(feed_vars, 'train')
def eval(self, feed_vars, multi_scale=None, mask_branch=False):
if multi_scale:
return self.build_multi_scale(feed_vars, mask_branch)
return self.build(feed_vars, 'test')
def test(self, feed_vars):
return self.build(feed_vars, 'test')
......@@ -17,7 +17,13 @@ from __future__ import absolute_import
from . import bbox_head
from . import mask_head
from . import cascade_head
from . import htc_bbox_head
from . import htc_mask_head
from . import htc_semantic_head
from .bbox_head import *
from .mask_head import *
from .cascade_head import *
from .htc_bbox_head import *
from .htc_mask_head import *
from .htc_semantic_head import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Xavier
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import MSRA
from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import ConvNorm
from ppdet.modeling.losses import SmoothL1Loss
from ppdet.core.workspace import register
__all__ = ['HTCBBoxHead']
@register
class HTCBBoxHead(object):
"""
HTC bbox head
Args:
head (object): the head module instance
nms (object): `MultiClassNMS` instance
num_classes: number of output classes
"""
__inject__ = ['head', 'nms', 'bbox_loss']
__shared__ = ['num_classes']
def __init__(self,
head,
nms=MultiClassNMS().__dict__,
bbox_loss=SmoothL1Loss().__dict__,
num_classes=81,
lr_ratio=2.0):
super(HTCBBoxHead, self).__init__()
self.head = head
self.nms = nms
self.bbox_loss = bbox_loss
self.num_classes = num_classes
self.lr_ratio = lr_ratio
if isinstance(nms, dict):
self.nms = MultiClassNMS(**nms)
if isinstance(bbox_loss, dict):
self.bbox_loss = SmoothL1Loss(**bbox_loss)
def get_output(self,
roi_feat,
cls_agnostic_bbox_reg=2,
wb_scalar=1.0,
name=''):
"""
Get bbox head output.
Args:
roi_feat (Variable): RoI feature from RoIExtractor.
cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic.
wb_scalar(Float): Weights and Bias's learning rate.
name(String): Layer's name
Returns:
cls_score(Variable): cls score.
bbox_pred(Variable): bbox regression.
"""
head_feat = self.head(roi_feat, wb_scalar, name)
cls_score = fluid.layers.fc(input=head_feat,
size=self.num_classes,
act=None,
name='cls_score' + name,
param_attr=ParamAttr(
name='cls_score%s_w' % name,
initializer=Normal(
loc=0.0, scale=0.01),
learning_rate=wb_scalar),
bias_attr=ParamAttr(
name='cls_score%s_b' % name,
learning_rate=wb_scalar * self.lr_ratio,
regularizer=L2Decay(0.)))
bbox_pred = fluid.layers.fc(input=head_feat,
size=4 * cls_agnostic_bbox_reg,
act=None,
name='bbox_pred' + name,
param_attr=ParamAttr(
name='bbox_pred%s_w' % name,
initializer=Normal(
loc=0.0, scale=0.001),
learning_rate=wb_scalar),
bias_attr=ParamAttr(
name='bbox_pred%s_b' % name,
learning_rate=wb_scalar * self.lr_ratio,
regularizer=L2Decay(0.)))
return cls_score, bbox_pred
def get_loss(self, rcnn_pred_list, rcnn_target_list, rcnn_loss_weight_list):
"""
Get bbox_head loss.
Args:
rcnn_pred_list(List): Cascade RCNN's head's output including
bbox_pred and cls_score
rcnn_target_list(List): Cascade rcnn's bbox and label target
rcnn_loss_weight_list(List): The weight of location and class loss
Return:
loss_cls(Variable): bbox_head loss.
loss_bbox(Variable): bbox_head loss.
"""
loss_dict = {}
for i, (rcnn_pred, rcnn_target
) in enumerate(zip(rcnn_pred_list, rcnn_target_list)):
labels_int64 = fluid.layers.cast(x=rcnn_target[1], dtype='int64')
labels_int64.stop_gradient = True
loss_cls = fluid.layers.softmax_with_cross_entropy(
logits=rcnn_pred[0],
label=labels_int64,
numeric_stable_mode=True, )
loss_cls = fluid.layers.reduce_mean(
loss_cls, name='loss_cls_' + str(i)) * rcnn_loss_weight_list[i]
loss_bbox = self.bbox_loss(
x=rcnn_pred[1],
y=rcnn_target[2],
inside_weight=rcnn_target[3],
outside_weight=rcnn_target[4])
loss_bbox = fluid.layers.reduce_mean(
loss_bbox,
name='loss_bbox_' + str(i)) * rcnn_loss_weight_list[i]
loss_dict['loss_cls_%d' % i] = loss_cls
loss_dict['loss_loc_%d' % i] = loss_bbox
return loss_dict
def get_prediction(self,
im_info,
im_shape,
roi_feat_list,
rcnn_pred_list,
proposal_list,
cascade_bbox_reg_weights,
cls_agnostic_bbox_reg=2,
return_box_score=False):
"""
Get prediction bounding box in test stage.
:
Args:
im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the
number of input images, each element consists
of im_height, im_width, im_scale.
im_shape (Variable): Actual shape of original image with shape
[B, 3]. B is the number of images, each element consists of
original_height, original_width, 1
rois_feat_list (List): RoI feature from RoIExtractor.
rcnn_pred_list (Variable): Cascade rcnn's head's output
including bbox_pred and cls_score
proposal_list (List): RPN proposal boxes.
cascade_bbox_reg_weights (List): BBox decode var.
cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic
Returns:
pred_result(Variable): Prediction result with shape [N, 6]. Each
row has 6 values: [label, confidence, xmin, ymin, xmax, ymax].
N is the total number of prediction.
"""
repeat_num = 3
# cls score
boxes_cls_prob_l = []
for i in range(repeat_num):
cls_score = rcnn_pred_list[i][0]
cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False)
boxes_cls_prob_l.append(cls_prob)
boxes_cls_prob_mean = fluid.layers.sum(boxes_cls_prob_l) / float(
len(boxes_cls_prob_l))
# bbox pred
im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
bbox_pred_l = []
for i in range(repeat_num):
if i < 2:
continue
bbox_reg_w = cascade_bbox_reg_weights[i]
proposals_boxes = proposal_list[i]
im_scale_lod = fluid.layers.sequence_expand(im_scale,
proposals_boxes)
proposals_boxes = proposals_boxes / im_scale_lod
bbox_pred = rcnn_pred_list[i][1]
bbox_pred_new = fluid.layers.reshape(bbox_pred,
(-1, cls_agnostic_bbox_reg, 4))
bbox_pred_l.append(bbox_pred_new)
bbox_pred_new = bbox_pred_l[-1]
if cls_agnostic_bbox_reg == 2:
# only use fg box delta to decode box
bbox_pred_new = fluid.layers.slice(
bbox_pred_new, axes=[1], starts=[1], ends=[2])
bbox_pred_new = fluid.layers.expand(bbox_pred_new,
[1, self.num_classes, 1])
decoded_box = fluid.layers.box_coder(
prior_box=proposals_boxes,
prior_box_var=bbox_reg_w,
target_box=bbox_pred_new,
code_type='decode_center_size',
box_normalized=False,
axis=1)
box_out = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
if return_box_score:
return {'bbox': box_out, 'score': boxes_cls_prob_mean}
pred_result = self.nms(bboxes=box_out, scores=boxes_cls_prob_mean)
return {"bbox": pred_result}
def get_prediction_cls_aware(self,
im_info,
im_shape,
cascade_cls_prob,
cascade_decoded_box,
cascade_bbox_reg_weights,
return_box_score=False):
'''
get_prediction_cls_aware: predict bbox for each class
'''
cascade_num_stage = 3
cascade_eval_weight = [0.2, 0.3, 0.5]
# merge 3 stages results
sum_cascade_cls_prob = sum([
prob * cascade_eval_weight[idx]
for idx, prob in enumerate(cascade_cls_prob)
])
sum_cascade_decoded_box = sum([
bbox * cascade_eval_weight[idx]
for idx, bbox in enumerate(cascade_decoded_box)
])
self.im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3])
im_scale_lod = fluid.layers.sequence_expand(self.im_scale,
sum_cascade_decoded_box)
sum_cascade_decoded_box = sum_cascade_decoded_box / im_scale_lod
decoded_bbox = sum_cascade_decoded_box
decoded_bbox = fluid.layers.reshape(
decoded_bbox, shape=(-1, self.num_classes, 4))
box_out = fluid.layers.box_clip(input=decoded_bbox, im_info=im_shape)
if return_box_score:
return {'bbox': box_out, 'score': sum_cascade_cls_prob}
pred_result = self.nms(bboxes=box_out, scores=sum_cascade_cls_prob)
return {"bbox": pred_result}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling.ops import ConvNorm
__all__ = ['HTCMaskHead']
@register
class HTCMaskHead(object):
"""
htc mask head
Args:
num_convs (int): num of convolutions, 4 for FPN, 1 otherwise
conv_dim (int): num of channels after first convolution
resolution (int): size of the output mask
dilation (int): dilation rate
num_classes (int): number of output classes
"""
__shared__ = ['num_classes']
def __init__(self,
num_convs=0,
conv_dim=256,
resolution=14,
dilation=1,
num_classes=81,
norm_type=None,
lr_ratio=2.0,
share_mask_conv=False):
super(HTCMaskHead, self).__init__()
self.num_convs = num_convs
self.conv_dim = conv_dim
self.resolution = resolution
self.dilation = dilation
self.num_classes = num_classes
self.norm_type = norm_type
self.lr_ratio = lr_ratio
self.share_mask_conv = share_mask_conv
def _mask_conv_head(self,
roi_feat,
num_convs,
norm_type,
wb_scalar=1.0,
name=''):
if norm_type == 'gn':
for i in range(num_convs):
layer_name = "mask_inter_feat_" + str(i + 1)
if not self.share_mask_conv:
layer_name += name
fan = self.conv_dim * 3 * 3
initializer = MSRA(uniform=False, fan_in=fan)
roi_feat = ConvNorm(
roi_feat,
self.conv_dim,
3,
act='relu',
dilation=self.dilation,
initializer=initializer,
norm_type=self.norm_type,
name=layer_name,
norm_name=layer_name)
else:
for i in range(num_convs):
layer_name = "mask_inter_feat_" + str(i + 1)
if not self.share_mask_conv:
layer_name += name
fan = self.conv_dim * 3 * 3
initializer = MSRA(uniform=False, fan_in=fan)
roi_feat = fluid.layers.conv2d(
input=roi_feat,
num_filters=self.conv_dim,
filter_size=3,
padding=1 * self.dilation,
act='relu',
stride=1,
dilation=self.dilation,
name=layer_name,
param_attr=ParamAttr(
name=layer_name + '_w', initializer=initializer),
bias_attr=ParamAttr(
name=layer_name + '_b',
learning_rate=wb_scalar * self.lr_ratio,
regularizer=L2Decay(0.)))
return roi_feat
def get_output(self,
roi_feat,
res_feat=None,
return_logits=True,
return_feat=False,
wb_scalar=1.0,
name=''):
class_num = self.num_classes
if res_feat is not None:
res_feat = fluid.layers.conv2d(
res_feat, roi_feat.shape[1], 1, name='res_net' + name)
roi_feat = fluid.layers.sum([roi_feat, res_feat])
# configure the conv number for FPN if necessary
head_feat = self._mask_conv_head(roi_feat, self.num_convs,
self.norm_type, wb_scalar, name)
if return_logits:
fan0 = roi_feat.shape[1] * 2 * 2
up_head_feat = fluid.layers.conv2d_transpose(
input=head_feat,
num_filters=self.conv_dim,
filter_size=2,
stride=2,
act='relu',
param_attr=ParamAttr(
name='conv5_mask_w' + name,
initializer=MSRA(
uniform=False, fan_in=fan0)),
bias_attr=ParamAttr(
name='conv5_mask_b' + name,
learning_rate=wb_scalar * self.lr_ratio,
regularizer=L2Decay(0.)))
fan = class_num
mask_logits = fluid.layers.conv2d(
input=up_head_feat,
num_filters=class_num,
filter_size=1,
act=None,
param_attr=ParamAttr(
name='mask_fcn_logits_w' + name,
initializer=MSRA(
uniform=False, fan_in=fan)),
bias_attr=ParamAttr(
name="mask_fcn_logits_b" + name,
learning_rate=wb_scalar * self.lr_ratio,
regularizer=L2Decay(0.)))
if return_feat:
return mask_logits, head_feat
else:
return mask_logits
if return_feat:
return head_feat
def get_loss(self,
mask_logits_list,
mask_int32_list,
cascade_loss_weights=[1.0, 0.5, 0.25]):
num_classes = self.num_classes
resolution = self.resolution
dim = num_classes * resolution * resolution
loss_mask_dict = {}
for i, (mask_logits, mask_int32
) in enumerate(zip(mask_logits_list, mask_int32_list)):
mask_logits = fluid.layers.reshape(mask_logits, (-1, dim))
mask_label = fluid.layers.cast(x=mask_int32, dtype='float32')
mask_label.stop_gradient = True
loss_name = 'loss_mask_' + str(i)
loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits(
x=mask_logits,
label=mask_label,
ignore_index=-1,
normalize=True,
name=loss_name)
loss_mask = fluid.layers.reduce_sum(
loss_mask) * cascade_loss_weights[i]
loss_mask_dict[loss_name] = loss_mask
return loss_mask_dict
def get_prediction(self, mask_logits, bbox_pred):
"""
Get prediction mask in test stage.
Args:
mask_logits (Variable): mask head output features.
bbox_pred (Variable): predicted bbox.
Returns:
mask_pred (Variable): Prediction mask with shape
[N, num_classes, resolution, resolution].
"""
mask_prob = fluid.layers.sigmoid(mask_logits)
mask_prob = fluid.layers.lod_reset(mask_prob, bbox_pred)
return mask_prob
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import MSRA
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling.ops import ConvNorm
__all__ = ['FusedSemanticHead']
@register
class FusedSemanticHead(object):
def __init__(self, semantic_num_class=183):
super(FusedSemanticHead, self).__init__()
self.semantic_num_class = semantic_num_class
def get_out(self,
fpn_feats,
out_c=256,
num_convs=4,
fusion_level='fpn_res3_sum'):
new_feat = fpn_feats[fusion_level]
new_feat_list = [new_feat, ]
target_shape = fluid.layers.shape(new_feat)[2:]
for k, v in fpn_feats.items():
if k != fusion_level:
v = fluid.layers.resize_bilinear(
v, target_shape, align_corners=True)
v = fluid.layers.conv2d(v, out_c, 1)
new_feat_list.append(v)
new_feat = fluid.layers.sum(new_feat_list)
for i in range(num_convs):
new_feat = fluid.layers.conv2d(new_feat, out_c, 3, padding=1)
# conv embedding
semantic_feat = fluid.layers.conv2d(new_feat, out_c, 1)
# conv logits
seg_pred = fluid.layers.conv2d(new_feat, self.semantic_num_class, 1)
return semantic_feat, seg_pred
def get_loss(self, logit, label, ignore_index=255):
label = fluid.layers.resize_nearest(label,
fluid.layers.shape(logit)[2:])
label = fluid.layers.reshape(label, [-1, 1])
label = fluid.layers.cast(label, 'int64')
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
logit = fluid.layers.reshape(logit, [-1, self.semantic_num_class])
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
soft_label=False,
ignore_index=ignore_index,
return_softmax=True)
ignore_mask = (label.astype('int32') != 255).astype('int32')
if ignore_mask is not None:
ignore_mask = fluid.layers.cast(ignore_mask, 'float32')
ignore_mask = fluid.layers.reshape(ignore_mask, [-1, 1])
loss = loss * ignore_mask
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask)
ignore_mask.stop_gradient = True
else:
avg_loss = fluid.layers.mean(loss)
label.stop_gradient = True
return avg_loss
......@@ -141,7 +141,7 @@ def main():
checkpoint.load_params(exe, startup_prog, cfg.weights)
resolution = None
if 'Mask' in cfg.architecture:
if 'Mask' in cfg.architecture or cfg.architecture == 'HybridTaskCascade':
resolution = model.mask_head.resolution
results = eval_run(exe, compile_program, loader, keys, values, cls, cfg,
sub_eval_prog, sub_keys, sub_values, resolution)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册