From 38d420bbf133317dfd54f8f475e96191eb29e8c0 Mon Sep 17 00:00:00 2001 From: sunxl1988 <47514455+sunxl1988@users.noreply.github.com> Date: Mon, 20 Jul 2020 14:04:21 +0800 Subject: [PATCH] test=master add htc model (#1081) add htc model --- configs/htc/README.md | 26 + configs/htc/htc_r50_fpn_1x.yml | 212 ++++++++ ppdet/data/source/coco.py | 21 +- ppdet/data/transform/batch_operators.py | 7 + ppdet/data/transform/operators.py | 28 +- ppdet/modeling/architectures/__init__.py | 2 + ppdet/modeling/architectures/htc.py | 466 ++++++++++++++++++ ppdet/modeling/roi_heads/__init__.py | 6 + ppdet/modeling/roi_heads/htc_bbox_head.py | 265 ++++++++++ ppdet/modeling/roi_heads/htc_mask_head.py | 205 ++++++++ ppdet/modeling/roi_heads/htc_semantic_head.py | 88 ++++ tools/eval.py | 2 +- 12 files changed, 1318 insertions(+), 10 deletions(-) create mode 100644 configs/htc/README.md create mode 100644 configs/htc/htc_r50_fpn_1x.yml create mode 100644 ppdet/modeling/architectures/htc.py create mode 100644 ppdet/modeling/roi_heads/htc_bbox_head.py create mode 100644 ppdet/modeling/roi_heads/htc_mask_head.py create mode 100644 ppdet/modeling/roi_heads/htc_semantic_head.py diff --git a/configs/htc/README.md b/configs/htc/README.md new file mode 100644 index 000000000..e34b1b233 --- /dev/null +++ b/configs/htc/README.md @@ -0,0 +1,26 @@ +# Hybrid Task Cascade for Instance Segmentation + +## Introduction + +We provide config files to reproduce the results in the CVPR 2019 paper for [Hybrid Task Cascade](https://arxiv.org/abs/1901.07518). + +``` +@inproceedings{chen2019hybrid, + title={Hybrid task cascade for instance segmentation}, + author={Chen, Kai and Pang, Jiangmiao and Wang, Jiaqi and Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and Liu, Ziwei and Shi, Jianping and Ouyang, Wanli and Chen Change Loy and Dahua Lin}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + year={2019} +} +``` + +## Dataset + +HTC requires COCO and COCO-stuff dataset for training. + +## Results and Models + +The results on COCO 2017val are shown in the below table. (results on test-dev are usually slightly higher than val) + + | Backbone | Lr schd | Inf time (fps) | box AP | mask AP | Download | + |:---------:|:-------:|:--------------:|:------:|:-------:|:--------:| + | R-50-FPN | 1x | 11 | 42.2 | 36.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/htc_r50_fpn_1x.pdparams ) | diff --git a/configs/htc/htc_r50_fpn_1x.yml b/configs/htc/htc_r50_fpn_1x.yml new file mode 100644 index 000000000..0601704ce --- /dev/null +++ b/configs/htc/htc_r50_fpn_1x.yml @@ -0,0 +1,212 @@ +architecture: HybridTaskCascade +use_gpu: true +max_iters: 100000 +snapshot_iter: 10000 +log_smooth_window: 50 +save_dir: output +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar +metric: COCO +weights: output/htc_r50_fpn_1x/model_final +num_classes: 81 + +HybridTaskCascade: + backbone: ResNet + fpn: FPN + rpn_head: FPNRPNHead + roi_extractor: FPNRoIAlign + bbox_head: HTCBBoxHead + bbox_assigner: CascadeBBoxAssigner + mask_assigner: MaskAssigner + mask_head: HTCMaskHead + fused_semantic_head: FusedSemanticHead + +ResNet: + depth: 50 + feature_maps: [2, 3, 4, 5] + freeze_at: 2 + norm_type: affine_channel + +FPN: + max_level: 6 + min_level: 2 + num_chan: 256 + spatial_scale: [0.03125, 0.0625, 0.125, 0.25] + +FPNRPNHead: + anchor_generator: + aspect_ratios: [0.5, 1.0, 2.0] + variance: [1.0, 1.0, 1.0, 1.0] + anchor_start_size: 32 + max_level: 6 + min_level: 2 + num_chan: 256 + rpn_target_assign: + rpn_batch_size_per_im: 256 + rpn_fg_fraction: 0.5 + rpn_negative_overlap: 0.3 + rpn_positive_overlap: 0.7 + rpn_straddle_thresh: 0.0 + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 1000 + +# bbox roi extractor +FPNRoIAlign: + canconical_level: 4 + canonical_size: 224 + max_level: 5 + min_level: 2 + sampling_ratio: 2 + box_resolution: 7 + mask_resolution: 14 + +# semantic roi extractor +RoIAlign: + resolution: 14 + sampling_ratio: 2 + +HTCMaskHead: + dilation: 1 + conv_dim: 256 + num_convs: 4 + resolution: 28 + lr_ratio: 2.0 + +FusedSemanticHead: + semantic_num_class: 183 + +CascadeBBoxAssigner: + batch_size_per_im: 512 + bbox_reg_weights: [10, 20, 30] + bg_thresh_hi: [0.5, 0.6, 0.7] + bg_thresh_lo: [0.0, 0.0, 0.0] + fg_fraction: 0.25 + fg_thresh: [0.5, 0.6, 0.7] + +MaskAssigner: + resolution: 28 + +HTCBBoxHead: + head: CascadeTwoFCHead + nms: MultiClassSoftNMS + +MultiClassSoftNMS: + score_threshold: 0.01 + keep_top_k: 300 + softnms_sigma: 0.5 + +CascadeTwoFCHead: + mlp_dim: 1024 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [60000, 80000] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + +TrainReader: + batch_size: 2 + worker_num: 2 + dataset: + !COCODataSet + dataset_dir: dataset/coco + anno_path: annotations/instances_train2017.json + image_dir: train2017 + load_semantic: True + inputs_def: + fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_mask', 'semantic'] + sample_transforms: + - !DecodeImage + to_rgb: true + - !RandomFlipImage + prob: 0.5 + is_mask_flip: true + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + target_size: 800 + max_size: 1333 + interp: 1 + use_cv2: true + - !Permute + to_bgr: false + channel_first: true + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + +EvalReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'im_shape'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + sample_transforms: + - !DecodeImage + to_rgb: true + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 + target_size: 800 + use_cv2: true + - !Permute + channel_first: true + to_bgr: false + batch_transforms: + - !PadBatch + pad_to_stride: 32 + use_padded_im_info: false + batch_size: 1 + shuffle: false + drop_last: false + drop_empty: false + worker_num: 2 + +TestReader: + inputs_def: + fields: ['image', 'im_info', 'im_id', 'im_shape'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + sample_transforms: + - !DecodeImage + to_rgb: true + with_mixup: false + - !NormalizeImage + is_channel_first: false + is_scale: true + mean: [0.485,0.456,0.406] + std: [0.229, 0.224,0.225] + - !ResizeImage + interp: 1 + max_size: 1333 diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index aaeaed58e..6b31ccb35 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -42,7 +42,8 @@ class COCODataSet(DataSet): anno_path=None, dataset_dir=None, sample_num=-1, - with_background=True): + with_background=True, + load_semantic=False): super(COCODataSet, self).__init__( image_dir=image_dir, anno_path=anno_path, @@ -68,6 +69,7 @@ class COCODataSet(DataSet): # a dict used to map category name to class id self.cname2cid = None self.load_image_only = False + self.load_semantic = load_semantic def load_roidb_and_cname2cid(self): anno_path = os.path.join(self.dataset_dir, self.anno_path) @@ -104,11 +106,11 @@ class COCODataSet(DataSet): im_w = float(img_anno['width']) im_h = float(img_anno['height']) - im_fname = os.path.join(image_dir, - im_fname) if image_dir else im_fname - if not os.path.exists(im_fname): + im_path = os.path.join(image_dir, + im_fname) if image_dir else im_fname + if not os.path.exists(im_path): logger.warn('Illegal image file: {}, and it will be ' - 'ignored'.format(im_fname)) + 'ignored'.format(im_path)) continue if im_w < 0 or im_h < 0: @@ -118,7 +120,7 @@ class COCODataSet(DataSet): continue coco_rec = { - 'im_file': im_fname, + 'im_file': im_path, 'im_id': np.array([img_id]), 'h': im_h, 'w': im_w, @@ -168,8 +170,13 @@ class COCODataSet(DataSet): 'gt_poly': gt_poly, }) + if self.load_semantic: + seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps', + 'train2017', im_fname[:-3] + 'png') + coco_rec.update({'semantic': seg_path}) + logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format( - im_fname, img_id, im_h, im_w)) + im_path, img_id, im_h, im_w)) records.append(coco_rec) ct += 1 if self.sample_num > 0 and ct >= self.sample_num: diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 1bed5edaf..331752d77 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -82,6 +82,13 @@ class PadBatch(BaseOperator): data['image'] = padding_im if self.use_padded_im_info: data['im_info'][:2] = max_shape[1:3] + if 'semantic' in data.keys() and data['semantic'] is not None: + semantic = data['semantic'] + padding_sem = np.zeros( + (1, max_shape[1], max_shape[2]), dtype=np.float32) + padding_sem[:, :im_h, :im_w] = semantic + data['semantic'] = padding_sem + return samples diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index db73e4174..d18b42a5c 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -106,8 +106,6 @@ class DecodeImage(BaseOperator): raise TypeError("{}: input type is invalid.".format(self)) if not isinstance(self.with_mixup, bool): raise TypeError("{}: input type is invalid.".format(self)) - if not isinstance(self.with_cutmix, bool): - raise TypeError("{}: input type is invalid.".format(self)) def __call__(self, sample, context=None): """ load image if 'im_file' field is not empty but 'image' is""" @@ -143,13 +141,21 @@ class DecodeImage(BaseOperator): # make default im_info with [h, w, 1] sample['im_info'] = np.array( [im.shape[0], im.shape[1], 1.], dtype=np.float32) + # decode mixup image if self.with_mixup and 'mixup' in sample: self.__call__(sample['mixup'], context) + # decode cutmix image if self.with_cutmix and 'cutmix' in sample: self.__call__(sample['cutmix'], context) + # decode semantic label + if 'semantic' in sample.keys() and sample['semantic'] is not None: + sem_file = sample['semantic'] + sem = cv2.imread(sem_file, cv2.IMREAD_GRAYSCALE) + sample['semantic'] = sem.astype('int32') + return sample @@ -342,6 +348,18 @@ class ResizeImage(BaseOperator): fx=im_scale_x, fy=im_scale_y, interpolation=self.interp) + if 'semantic' in sample.keys() and sample['semantic'] is not None: + semantic = sample['semantic'] + semantic = cv2.resize( + semantic.astype('float32'), + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + semantic = np.asarray(semantic).astype('int32') + semantic = np.expand_dims(semantic, 0) + sample['semantic'] = semantic else: if self.max_size != 0: raise TypeError( @@ -455,9 +473,15 @@ class RandomFlipImage(BaseOperator): if self.is_mask_flip and len(sample['gt_poly']) != 0: sample['gt_poly'] = self.flip_segms(sample['gt_poly'], height, width) + if 'gt_keypoint' in sample.keys(): sample['gt_keypoint'] = self.flip_keypoint( sample['gt_keypoint'], width) + + if 'semantic' in sample.keys() and sample[ + 'semantic'] is not None: + sample['semantic'] = sample['semantic'][:, ::-1] + sample['flipped'] = True sample['image'] = im sample = samples if batch_input else samples[0] diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index f29be6cbd..a8a77bb61 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -28,6 +28,7 @@ from . import faceboxes from . import fcos from . import cornernet_squeeze from . import ttfnet +from . import htc from .faster_rcnn import * from .mask_rcnn import * @@ -43,3 +44,4 @@ from .faceboxes import * from .fcos import * from .cornernet_squeeze import * from .ttfnet import * +from .htc import * diff --git a/ppdet/modeling/architectures/htc.py b/ppdet/modeling/architectures/htc.py new file mode 100644 index 000000000..9c0f72749 --- /dev/null +++ b/ppdet/modeling/architectures/htc.py @@ -0,0 +1,466 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import copy +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.regularizer import L2Decay +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +from .input_helper import multiscale_def + +__all__ = ['HybridTaskCascade'] + + +@register +class HybridTaskCascade(object): + """ + Hybrid Task Cascade Mask R-CNN architecture, see https://arxiv.org/abs/1901.07518 + + Args: + backbone (object): backbone instance + rpn_head (object): `RPNhead` instance + bbox_assigner (object): `BBoxAssigner` instance + roi_extractor (object): ROI extractor instance + bbox_head (object): `HTCBBoxHead` instance + mask_assigner (object): `MaskAssigner` instance + mask_head (object): `HTCMaskHead` instance + fpn (object): feature pyramid network instance + semantic_roi_extractor(object): ROI extractor instance + fused_semantic_head (object): `FusedSemanticHead` instance + """ + + __category__ = 'architecture' + __inject__ = [ + 'backbone', 'rpn_head', 'bbox_assigner', 'roi_extractor', 'bbox_head', + 'mask_assigner', 'mask_head', 'fpn', 'semantic_roi_extractor', + 'fused_semantic_head' + ] + + def __init__(self, + backbone, + rpn_head, + roi_extractor='FPNRoIAlign', + semantic_roi_extractor='RoIAlign', + fused_semantic_head='FusedSemanticHead', + bbox_head='HTCBBoxHead', + bbox_assigner='CascadeBBoxAssigner', + mask_assigner='MaskAssigner', + mask_head='HTCMaskHead', + rpn_only=False, + fpn='FPN'): + super(HybridTaskCascade, self).__init__() + assert fpn is not None, "HTC requires FPN" + self.backbone = backbone + self.fpn = fpn + self.rpn_head = rpn_head + self.bbox_assigner = bbox_assigner + self.roi_extractor = roi_extractor + self.semantic_roi_extractor = semantic_roi_extractor + self.fused_semantic_head = fused_semantic_head + self.bbox_head = bbox_head + self.mask_assigner = mask_assigner + self.mask_head = mask_head + self.rpn_only = rpn_only + # Cascade local cfg + self.cls_agnostic_bbox_reg = 2 + (brw0, brw1, brw2) = self.bbox_assigner.bbox_reg_weights + self.cascade_bbox_reg_weights = [ + [1. / brw0, 1. / brw0, 2. / brw0, 2. / brw0], + [1. / brw1, 1. / brw1, 2. / brw1, 2. / brw1], + [1. / brw2, 1. / brw2, 2. / brw2, 2. / brw2] + ] + self.cascade_rcnn_loss_weight = [1.0, 0.5, 0.25] + self.num_stage = 3 + self.with_mask = True + self.interleaved = True + self.mask_info_flow = True + self.with_semantic = True + self.use_bias_scalar = True + + def build(self, feed_vars, mode='train'): + if mode == 'train': + required_fields = [ + 'gt_class', 'gt_bbox', 'gt_mask', 'is_crowd', 'im_info', + 'semantic' + ] + else: + required_fields = ['im_shape', 'im_info'] + self._input_check(required_fields, feed_vars) + + im = feed_vars['image'] + if mode == 'train': + gt_bbox = feed_vars['gt_bbox'] + is_crowd = feed_vars['is_crowd'] + + im_info = feed_vars['im_info'] + + # backbone + body_feats = self.backbone(im) + + loss = {} + # FPN + if self.fpn is not None: + body_feats, spatial_scale = self.fpn.get_output(body_feats) + + if self.with_semantic: + # TODO: use cfg + semantic_feat, seg_pred = self.fused_semantic_head.get_out( + body_feats) + if mode == 'train': + s_label = feed_vars['semantic'] + semantic_loss = self.fused_semantic_head.get_loss(seg_pred, + s_label) * 0.2 + loss.update({"semantic_loss": semantic_loss}) + else: + semantic_feat = None + + # rpn proposals + rpn_rois = self.rpn_head.get_proposals(body_feats, im_info, mode=mode) + if mode == 'train': + rpn_loss = self.rpn_head.get_loss(im_info, gt_bbox, is_crowd) + loss.update(rpn_loss) + else: + if self.rpn_only: + im_scale = fluid.layers.slice( + im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, rpn_rois) + rois = rpn_rois / im_scale + return {'proposal': rois} + + proposal_list = [] + roi_feat_list = [] + rcnn_pred_list = [] + rcnn_target_list = [] + mask_logits_list = [] + mask_target_list = [] + proposals = None + bbox_pred = None + outs = None + refined_bbox = rpn_rois + for i in range(self.num_stage): + # BBox Branch + if mode == 'train': + outs = self.bbox_assigner( + input_rois=refined_bbox, feed_vars=feed_vars, curr_stage=i) + proposals = outs[0] + rcnn_target_list.append(outs) + else: + proposals = refined_bbox + proposal_list.append(proposals) + + # extract roi features + roi_feat = self.roi_extractor(body_feats, proposals, spatial_scale) + if self.with_semantic: + semantic_roi_feat = self.semantic_roi_extractor(semantic_feat, + proposals) + if semantic_roi_feat is not None: + semantic_roi_feat = fluid.layers.pool2d( + semantic_roi_feat, + pool_size=2, + pool_stride=2, + pool_padding='SAME') + roi_feat = fluid.layers.sum([roi_feat, semantic_roi_feat]) + roi_feat_list.append(roi_feat) + + # bbox head + cls_score, bbox_pred = self.bbox_head.get_output( + roi_feat, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i], + name='_' + str(i)) + rcnn_pred_list.append((cls_score, bbox_pred)) + + # Mask Branch + if self.with_mask: + if mode == 'train': + labels_int32 = outs[1] + if self.interleaved: + refined_bbox = self._decode_box( + proposals, bbox_pred, curr_stage=i) + proposals = refined_bbox + + mask_rois, roi_has_mask_int32, mask_int32 = self.mask_assigner( + rois=proposals, + gt_classes=feed_vars['gt_class'], + is_crowd=feed_vars['is_crowd'], + gt_segms=feed_vars['gt_mask'], + im_info=feed_vars['im_info'], + labels_int32=labels_int32) + mask_target_list.append(mask_int32) + + mask_feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + if self.with_semantic: + semantic_roi_feat = self.semantic_roi_extractor( + semantic_feat, mask_rois) + if semantic_roi_feat is not None: + mask_feat = fluid.layers.sum( + [mask_feat, semantic_roi_feat]) + + if self.mask_info_flow: + last_feat = None + for j in range(i): + last_feat = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=False, + return_feat=True, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i) + '_' + str(j)) + mask_logits = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=True, + return_feat=False, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i)) + else: + mask_logits = self.mask_head.get_output( + mask_feat, + return_logits=True, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i)) + mask_logits_list.append(mask_logits) + + if i < self.num_stage - 1 and not self.interleaved: + refined_bbox = self._decode_box( + proposals, bbox_pred, curr_stage=i) + elif i < self.num_stage - 1 and mode != 'train': + refined_bbox = self._decode_box( + proposals, bbox_pred, curr_stage=i) + + if mode == 'train': + bbox_loss = self.bbox_head.get_loss( + rcnn_pred_list, rcnn_target_list, self.cascade_rcnn_loss_weight) + loss.update(bbox_loss) + mask_loss = self.mask_head.get_loss(mask_logits_list, + mask_target_list, + self.cascade_rcnn_loss_weight) + loss.update(mask_loss) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + mask_name = 'mask_pred' + mask_pred, bbox_pred = self.single_scale_eval( + body_feats, + spatial_scale, + im_info, + mask_name, + bbox_pred, + roi_feat_list, + rcnn_pred_list, + proposal_list, + feed_vars['im_shape'], + semantic_feat=semantic_feat if self.with_semantic else None) + return {'bbox': bbox_pred, 'mask': mask_pred} + + def single_scale_eval(self, + body_feats, + spatial_scale, + im_info, + mask_name, + bbox_pred, + roi_feat_list=None, + rcnn_pred_list=None, + proposal_list=None, + im_shape=None, + use_multi_test=False, + semantic_feat=None): + + if not use_multi_test: + bbox_pred = self.bbox_head.get_prediction( + im_info, im_shape, roi_feat_list, rcnn_pred_list, proposal_list, + self.cascade_bbox_reg_weights) + bbox_pred = bbox_pred['bbox'] + + # share weight + bbox_shape = fluid.layers.shape(bbox_pred) + bbox_size = fluid.layers.reduce_prod(bbox_shape) + bbox_size = fluid.layers.reshape(bbox_size, [1, 1]) + size = fluid.layers.fill_constant([1, 1], value=6, dtype='int32') + cond = fluid.layers.less_than(x=bbox_size, y=size) + + mask_pred = fluid.layers.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=False, + name=mask_name) + + def noop(): + fluid.layers.assign(input=bbox_pred, output=mask_pred) + + def process_boxes(): + bbox = fluid.layers.slice(bbox_pred, [1], starts=[2], ends=[6]) + + im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale = fluid.layers.sequence_expand(im_scale, bbox) + + bbox = fluid.layers.cast(bbox, dtype='float32') + im_scale = fluid.layers.cast(im_scale, dtype='float32') + mask_rois = bbox * im_scale + + mask_feat = self.roi_extractor( + body_feats, mask_rois, spatial_scale, is_mask=True) + + if self.with_semantic: + semantic_roi_feat = self.semantic_roi_extractor(semantic_feat, + mask_rois) + if semantic_roi_feat is not None: + mask_feat = fluid.layers.sum([mask_feat, semantic_roi_feat]) + + mask_logits_list = [] + mask_pred_list = [] + for i in range(self.num_stage): + if self.mask_info_flow: + last_feat = None + for j in range(i): + last_feat = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=False, + return_feat=True, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i) + '_' + str(j)) + mask_logits = self.mask_head.get_output( + mask_feat, + last_feat, + return_logits=True, + return_feat=False, + wb_scalar=1.0 / self.cascade_rcnn_loss_weight[i] + if self.use_bias_scalar else 1.0, + name='_' + str(i)) + mask_logits_list.append(mask_logits) + else: + mask_logits = self.mask_head.get_output( + mask_feat, + return_logits=True, + return_feat=False, + name='_' + str(i)) + mask_pred_out = self.mask_head.get_prediction(mask_logits, bbox) + mask_pred_list.append(mask_pred_out) + + mask_pred_out = fluid.layers.sum(mask_pred_list) / float( + len(mask_pred_list)) + fluid.layers.assign(input=mask_pred_out, output=mask_pred) + + fluid.layers.cond(cond, noop, process_boxes) + return mask_pred, bbox_pred + + def _input_check(self, require_fields, feed_vars): + for var in require_fields: + assert var in feed_vars, \ + "{} has no {} field".format(feed_vars, var) + + def _decode_box(self, proposals, bbox_pred, curr_stage): + rcnn_loc_delta_r = fluid.layers.reshape( + bbox_pred, (-1, self.cls_agnostic_bbox_reg, 4)) + # only use fg box delta to decode box + rcnn_loc_delta_s = fluid.layers.slice( + rcnn_loc_delta_r, axes=[1], starts=[1], ends=[2]) + refined_bbox = fluid.layers.box_coder( + prior_box=proposals, + prior_box_var=self.cascade_bbox_reg_weights[curr_stage], + target_box=rcnn_loc_delta_s, + code_type='decode_center_size', + box_normalized=False, + axis=1, ) + refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4]) + + return refined_bbox + + def _inputs_def(self, image_shape): + im_shape = [None] + image_shape + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'im_info': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'im_shape': {'shape': [None, 3], 'dtype': 'float32', 'lod_level': 0}, + 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, + 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'is_crowd': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, + 'gt_mask': {'shape': [None, 2], 'dtype': 'float32', 'lod_level': 3}, # polygon coordinates + 'semantic': {'shape': [None, 1, None, None], 'dtype': 'int32', 'lod_level': 0}, + } + # yapf: enable + return inputs_def + + def build_inputs(self, + image_shape=[3, None, None], + fields=[ + 'image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', + 'is_crowd', 'gt_mask', 'semantic' + ], + multi_scale=False, + num_scales=-1, + use_flip=None, + use_dataloader=True, + iterable=False, + mask_branch=False): + inputs_def = self._inputs_def(image_shape) + fields = copy.deepcopy(fields) + if multi_scale: + ms_def, ms_fields = multiscale_def(image_shape, num_scales, + use_flip) + inputs_def.update(ms_def) + fields += ms_fields + self.im_info_names = ['image', 'im_info'] + ms_fields + if mask_branch: + box_fields = ['bbox', 'bbox_flip'] if use_flip else ['bbox'] + for key in box_fields: + inputs_def[key] = { + 'shape': [6], + 'dtype': 'float32', + 'lod_level': 1 + } + fields += box_fields + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + use_dataloader = use_dataloader and not mask_branch + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=64, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, 'train') + + def eval(self, feed_vars, multi_scale=None, mask_branch=False): + if multi_scale: + return self.build_multi_scale(feed_vars, mask_branch) + return self.build(feed_vars, 'test') + + def test(self, feed_vars): + return self.build(feed_vars, 'test') diff --git a/ppdet/modeling/roi_heads/__init__.py b/ppdet/modeling/roi_heads/__init__.py index 345a0eb3e..bb5f47d6f 100644 --- a/ppdet/modeling/roi_heads/__init__.py +++ b/ppdet/modeling/roi_heads/__init__.py @@ -17,7 +17,13 @@ from __future__ import absolute_import from . import bbox_head from . import mask_head from . import cascade_head +from . import htc_bbox_head +from . import htc_mask_head +from . import htc_semantic_head from .bbox_head import * from .mask_head import * from .cascade_head import * +from .htc_bbox_head import * +from .htc_mask_head import * +from .htc_semantic_head import * diff --git a/ppdet/modeling/roi_heads/htc_bbox_head.py b/ppdet/modeling/roi_heads/htc_bbox_head.py new file mode 100644 index 000000000..d43c7d9b8 --- /dev/null +++ b/ppdet/modeling/roi_heads/htc_bbox_head.py @@ -0,0 +1,265 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Xavier +from paddle.fluid.regularizer import L2Decay +from paddle.fluid.initializer import MSRA + +from ppdet.modeling.ops import MultiClassNMS +from ppdet.modeling.ops import ConvNorm +from ppdet.modeling.losses import SmoothL1Loss +from ppdet.core.workspace import register + +__all__ = ['HTCBBoxHead'] + + +@register +class HTCBBoxHead(object): + """ + HTC bbox head + + Args: + head (object): the head module instance + nms (object): `MultiClassNMS` instance + num_classes: number of output classes + """ + __inject__ = ['head', 'nms', 'bbox_loss'] + __shared__ = ['num_classes'] + + def __init__(self, + head, + nms=MultiClassNMS().__dict__, + bbox_loss=SmoothL1Loss().__dict__, + num_classes=81, + lr_ratio=2.0): + super(HTCBBoxHead, self).__init__() + self.head = head + self.nms = nms + self.bbox_loss = bbox_loss + self.num_classes = num_classes + self.lr_ratio = lr_ratio + + if isinstance(nms, dict): + self.nms = MultiClassNMS(**nms) + if isinstance(bbox_loss, dict): + self.bbox_loss = SmoothL1Loss(**bbox_loss) + + def get_output(self, + roi_feat, + cls_agnostic_bbox_reg=2, + wb_scalar=1.0, + name=''): + """ + Get bbox head output. + + Args: + roi_feat (Variable): RoI feature from RoIExtractor. + cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic. + wb_scalar(Float): Weights and Bias's learning rate. + name(String): Layer's name + + Returns: + cls_score(Variable): cls score. + bbox_pred(Variable): bbox regression. + """ + head_feat = self.head(roi_feat, wb_scalar, name) + cls_score = fluid.layers.fc(input=head_feat, + size=self.num_classes, + act=None, + name='cls_score' + name, + param_attr=ParamAttr( + name='cls_score%s_w' % name, + initializer=Normal( + loc=0.0, scale=0.01), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='cls_score%s_b' % name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + bbox_pred = fluid.layers.fc(input=head_feat, + size=4 * cls_agnostic_bbox_reg, + act=None, + name='bbox_pred' + name, + param_attr=ParamAttr( + name='bbox_pred%s_w' % name, + initializer=Normal( + loc=0.0, scale=0.001), + learning_rate=wb_scalar), + bias_attr=ParamAttr( + name='bbox_pred%s_b' % name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + return cls_score, bbox_pred + + def get_loss(self, rcnn_pred_list, rcnn_target_list, rcnn_loss_weight_list): + """ + Get bbox_head loss. + + Args: + rcnn_pred_list(List): Cascade RCNN's head's output including + bbox_pred and cls_score + rcnn_target_list(List): Cascade rcnn's bbox and label target + rcnn_loss_weight_list(List): The weight of location and class loss + + Return: + loss_cls(Variable): bbox_head loss. + loss_bbox(Variable): bbox_head loss. + """ + loss_dict = {} + for i, (rcnn_pred, rcnn_target + ) in enumerate(zip(rcnn_pred_list, rcnn_target_list)): + labels_int64 = fluid.layers.cast(x=rcnn_target[1], dtype='int64') + labels_int64.stop_gradient = True + + loss_cls = fluid.layers.softmax_with_cross_entropy( + logits=rcnn_pred[0], + label=labels_int64, + numeric_stable_mode=True, ) + loss_cls = fluid.layers.reduce_mean( + loss_cls, name='loss_cls_' + str(i)) * rcnn_loss_weight_list[i] + + loss_bbox = self.bbox_loss( + x=rcnn_pred[1], + y=rcnn_target[2], + inside_weight=rcnn_target[3], + outside_weight=rcnn_target[4]) + loss_bbox = fluid.layers.reduce_mean( + loss_bbox, + name='loss_bbox_' + str(i)) * rcnn_loss_weight_list[i] + + loss_dict['loss_cls_%d' % i] = loss_cls + loss_dict['loss_loc_%d' % i] = loss_bbox + + return loss_dict + + def get_prediction(self, + im_info, + im_shape, + roi_feat_list, + rcnn_pred_list, + proposal_list, + cascade_bbox_reg_weights, + cls_agnostic_bbox_reg=2, + return_box_score=False): + """ + Get prediction bounding box in test stage. + : + Args: + im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the + number of input images, each element consists + of im_height, im_width, im_scale. + im_shape (Variable): Actual shape of original image with shape + [B, 3]. B is the number of images, each element consists of + original_height, original_width, 1 + rois_feat_list (List): RoI feature from RoIExtractor. + rcnn_pred_list (Variable): Cascade rcnn's head's output + including bbox_pred and cls_score + proposal_list (List): RPN proposal boxes. + cascade_bbox_reg_weights (List): BBox decode var. + cls_agnostic_bbox_reg(Int): BBox regressor are class agnostic + + Returns: + pred_result(Variable): Prediction result with shape [N, 6]. Each + row has 6 values: [label, confidence, xmin, ymin, xmax, ymax]. + N is the total number of prediction. + """ + repeat_num = 3 + # cls score + boxes_cls_prob_l = [] + for i in range(repeat_num): + cls_score = rcnn_pred_list[i][0] + cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False) + boxes_cls_prob_l.append(cls_prob) + + boxes_cls_prob_mean = fluid.layers.sum(boxes_cls_prob_l) / float( + len(boxes_cls_prob_l)) + + # bbox pred + im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + bbox_pred_l = [] + for i in range(repeat_num): + if i < 2: + continue + bbox_reg_w = cascade_bbox_reg_weights[i] + proposals_boxes = proposal_list[i] + im_scale_lod = fluid.layers.sequence_expand(im_scale, + proposals_boxes) + proposals_boxes = proposals_boxes / im_scale_lod + bbox_pred = rcnn_pred_list[i][1] + bbox_pred_new = fluid.layers.reshape(bbox_pred, + (-1, cls_agnostic_bbox_reg, 4)) + bbox_pred_l.append(bbox_pred_new) + + bbox_pred_new = bbox_pred_l[-1] + if cls_agnostic_bbox_reg == 2: + # only use fg box delta to decode box + bbox_pred_new = fluid.layers.slice( + bbox_pred_new, axes=[1], starts=[1], ends=[2]) + bbox_pred_new = fluid.layers.expand(bbox_pred_new, + [1, self.num_classes, 1]) + decoded_box = fluid.layers.box_coder( + prior_box=proposals_boxes, + prior_box_var=bbox_reg_w, + target_box=bbox_pred_new, + code_type='decode_center_size', + box_normalized=False, + axis=1) + + box_out = fluid.layers.box_clip(input=decoded_box, im_info=im_shape) + if return_box_score: + return {'bbox': box_out, 'score': boxes_cls_prob_mean} + pred_result = self.nms(bboxes=box_out, scores=boxes_cls_prob_mean) + return {"bbox": pred_result} + + def get_prediction_cls_aware(self, + im_info, + im_shape, + cascade_cls_prob, + cascade_decoded_box, + cascade_bbox_reg_weights, + return_box_score=False): + ''' + get_prediction_cls_aware: predict bbox for each class + ''' + cascade_num_stage = 3 + cascade_eval_weight = [0.2, 0.3, 0.5] + # merge 3 stages results + sum_cascade_cls_prob = sum([ + prob * cascade_eval_weight[idx] + for idx, prob in enumerate(cascade_cls_prob) + ]) + sum_cascade_decoded_box = sum([ + bbox * cascade_eval_weight[idx] + for idx, bbox in enumerate(cascade_decoded_box) + ]) + self.im_scale = fluid.layers.slice(im_info, [1], starts=[2], ends=[3]) + im_scale_lod = fluid.layers.sequence_expand(self.im_scale, + sum_cascade_decoded_box) + + sum_cascade_decoded_box = sum_cascade_decoded_box / im_scale_lod + + decoded_bbox = sum_cascade_decoded_box + decoded_bbox = fluid.layers.reshape( + decoded_bbox, shape=(-1, self.num_classes, 4)) + + box_out = fluid.layers.box_clip(input=decoded_bbox, im_info=im_shape) + if return_box_score: + return {'bbox': box_out, 'score': sum_cascade_cls_prob} + pred_result = self.nms(bboxes=box_out, scores=sum_cascade_cls_prob) + return {"bbox": pred_result} diff --git a/ppdet/modeling/roi_heads/htc_mask_head.py b/ppdet/modeling/roi_heads/htc_mask_head.py new file mode 100644 index 000000000..bf4816164 --- /dev/null +++ b/ppdet/modeling/roi_heads/htc_mask_head.py @@ -0,0 +1,205 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm + +__all__ = ['HTCMaskHead'] + + +@register +class HTCMaskHead(object): + """ + htc mask head + Args: + num_convs (int): num of convolutions, 4 for FPN, 1 otherwise + conv_dim (int): num of channels after first convolution + resolution (int): size of the output mask + dilation (int): dilation rate + num_classes (int): number of output classes + """ + + __shared__ = ['num_classes'] + + def __init__(self, + num_convs=0, + conv_dim=256, + resolution=14, + dilation=1, + num_classes=81, + norm_type=None, + lr_ratio=2.0, + share_mask_conv=False): + super(HTCMaskHead, self).__init__() + self.num_convs = num_convs + self.conv_dim = conv_dim + self.resolution = resolution + self.dilation = dilation + self.num_classes = num_classes + self.norm_type = norm_type + self.lr_ratio = lr_ratio + self.share_mask_conv = share_mask_conv + + def _mask_conv_head(self, + roi_feat, + num_convs, + norm_type, + wb_scalar=1.0, + name=''): + if norm_type == 'gn': + for i in range(num_convs): + layer_name = "mask_inter_feat_" + str(i + 1) + if not self.share_mask_conv: + layer_name += name + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + roi_feat = ConvNorm( + roi_feat, + self.conv_dim, + 3, + act='relu', + dilation=self.dilation, + initializer=initializer, + norm_type=self.norm_type, + name=layer_name, + norm_name=layer_name) + else: + for i in range(num_convs): + layer_name = "mask_inter_feat_" + str(i + 1) + if not self.share_mask_conv: + layer_name += name + fan = self.conv_dim * 3 * 3 + initializer = MSRA(uniform=False, fan_in=fan) + roi_feat = fluid.layers.conv2d( + input=roi_feat, + num_filters=self.conv_dim, + filter_size=3, + padding=1 * self.dilation, + act='relu', + stride=1, + dilation=self.dilation, + name=layer_name, + param_attr=ParamAttr( + name=layer_name + '_w', initializer=initializer), + bias_attr=ParamAttr( + name=layer_name + '_b', + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + return roi_feat + + def get_output(self, + roi_feat, + res_feat=None, + return_logits=True, + return_feat=False, + wb_scalar=1.0, + name=''): + class_num = self.num_classes + if res_feat is not None: + res_feat = fluid.layers.conv2d( + res_feat, roi_feat.shape[1], 1, name='res_net' + name) + roi_feat = fluid.layers.sum([roi_feat, res_feat]) + # configure the conv number for FPN if necessary + head_feat = self._mask_conv_head(roi_feat, self.num_convs, + self.norm_type, wb_scalar, name) + + if return_logits: + fan0 = roi_feat.shape[1] * 2 * 2 + up_head_feat = fluid.layers.conv2d_transpose( + input=head_feat, + num_filters=self.conv_dim, + filter_size=2, + stride=2, + act='relu', + param_attr=ParamAttr( + name='conv5_mask_w' + name, + initializer=MSRA( + uniform=False, fan_in=fan0)), + bias_attr=ParamAttr( + name='conv5_mask_b' + name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + + fan = class_num + mask_logits = fluid.layers.conv2d( + input=up_head_feat, + num_filters=class_num, + filter_size=1, + act=None, + param_attr=ParamAttr( + name='mask_fcn_logits_w' + name, + initializer=MSRA( + uniform=False, fan_in=fan)), + bias_attr=ParamAttr( + name="mask_fcn_logits_b" + name, + learning_rate=wb_scalar * self.lr_ratio, + regularizer=L2Decay(0.))) + if return_feat: + return mask_logits, head_feat + else: + return mask_logits + + if return_feat: + return head_feat + + def get_loss(self, + mask_logits_list, + mask_int32_list, + cascade_loss_weights=[1.0, 0.5, 0.25]): + num_classes = self.num_classes + resolution = self.resolution + dim = num_classes * resolution * resolution + loss_mask_dict = {} + for i, (mask_logits, mask_int32 + ) in enumerate(zip(mask_logits_list, mask_int32_list)): + + mask_logits = fluid.layers.reshape(mask_logits, (-1, dim)) + mask_label = fluid.layers.cast(x=mask_int32, dtype='float32') + mask_label.stop_gradient = True + loss_name = 'loss_mask_' + str(i) + loss_mask = fluid.layers.sigmoid_cross_entropy_with_logits( + x=mask_logits, + label=mask_label, + ignore_index=-1, + normalize=True, + name=loss_name) + loss_mask = fluid.layers.reduce_sum( + loss_mask) * cascade_loss_weights[i] + loss_mask_dict[loss_name] = loss_mask + return loss_mask_dict + + def get_prediction(self, mask_logits, bbox_pred): + """ + Get prediction mask in test stage. + + Args: + mask_logits (Variable): mask head output features. + bbox_pred (Variable): predicted bbox. + + Returns: + mask_pred (Variable): Prediction mask with shape + [N, num_classes, resolution, resolution]. + """ + mask_prob = fluid.layers.sigmoid(mask_logits) + mask_prob = fluid.layers.lod_reset(mask_prob, bbox_pred) + return mask_prob diff --git a/ppdet/modeling/roi_heads/htc_semantic_head.py b/ppdet/modeling/roi_heads/htc_semantic_head.py new file mode 100644 index 000000000..227889885 --- /dev/null +++ b/ppdet/modeling/roi_heads/htc_semantic_head.py @@ -0,0 +1,88 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import MSRA +from paddle.fluid.regularizer import L2Decay + +from ppdet.core.workspace import register +from ppdet.modeling.ops import ConvNorm + +__all__ = ['FusedSemanticHead'] + + +@register +class FusedSemanticHead(object): + def __init__(self, semantic_num_class=183): + super(FusedSemanticHead, self).__init__() + self.semantic_num_class = semantic_num_class + + def get_out(self, + fpn_feats, + out_c=256, + num_convs=4, + fusion_level='fpn_res3_sum'): + new_feat = fpn_feats[fusion_level] + new_feat_list = [new_feat, ] + target_shape = fluid.layers.shape(new_feat)[2:] + for k, v in fpn_feats.items(): + if k != fusion_level: + v = fluid.layers.resize_bilinear( + v, target_shape, align_corners=True) + v = fluid.layers.conv2d(v, out_c, 1) + new_feat_list.append(v) + new_feat = fluid.layers.sum(new_feat_list) + + for i in range(num_convs): + new_feat = fluid.layers.conv2d(new_feat, out_c, 3, padding=1) + + # conv embedding + semantic_feat = fluid.layers.conv2d(new_feat, out_c, 1) + # conv logits + seg_pred = fluid.layers.conv2d(new_feat, self.semantic_num_class, 1) + return semantic_feat, seg_pred + + def get_loss(self, logit, label, ignore_index=255): + label = fluid.layers.resize_nearest(label, + fluid.layers.shape(logit)[2:]) + label = fluid.layers.reshape(label, [-1, 1]) + label = fluid.layers.cast(label, 'int64') + + logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) + logit = fluid.layers.reshape(logit, [-1, self.semantic_num_class]) + + loss, probs = fluid.layers.softmax_with_cross_entropy( + logit, + label, + soft_label=False, + ignore_index=ignore_index, + return_softmax=True) + + ignore_mask = (label.astype('int32') != 255).astype('int32') + if ignore_mask is not None: + ignore_mask = fluid.layers.cast(ignore_mask, 'float32') + ignore_mask = fluid.layers.reshape(ignore_mask, [-1, 1]) + loss = loss * ignore_mask + avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask) + ignore_mask.stop_gradient = True + else: + avg_loss = fluid.layers.mean(loss) + label.stop_gradient = True + + return avg_loss diff --git a/tools/eval.py b/tools/eval.py index 2c822638f..84da9b49a 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -141,7 +141,7 @@ def main(): checkpoint.load_params(exe, startup_prog, cfg.weights) resolution = None - if 'Mask' in cfg.architecture: + if 'Mask' in cfg.architecture or cfg.architecture == 'HybridTaskCascade': resolution = model.mask_head.resolution results = eval_run(exe, compile_program, loader, keys, values, cls, cfg, sub_eval_prog, sub_keys, sub_values, resolution) -- GitLab