diff --git a/configs/datasets/dota.yml b/configs/datasets/dota.yml new file mode 100644 index 0000000000000000000000000000000000000000..07602d41ce032f9b90e37cec76c5a1b97d6c2067 --- /dev/null +++ b/configs/datasets/dota.yml @@ -0,0 +1,20 @@ +metric: COCO +num_classes: 15 + +TrainDataset: + !COCODataSet + image_dir: trainval_split/images + anno_path: trainval_split/s2anet_trainval_paddle_coco.json + dataset_dir: /paddle/dataset/DOTA_1024_s2anet + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_rbox'] + +EvalDataset: + !COCODataSet + image_dir: trainval_split/images + anno_path: trainval_split/s2anet_trainval_paddle_coco.json + dataset_dir: /paddle/dataset/DOTA_1024_s2anet/ + +TestDataset: + !ImageFolder + anno_path: trainval_split/s2anet_trainval_paddle_coco.json + dataset_dir: /paddle/dataset/DOTA_1024_s2anet/ diff --git a/configs/dota/README.md b/configs/dota/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7d33bc7270cb032a8230a16456493c7165b0568e --- /dev/null +++ b/configs/dota/README.md @@ -0,0 +1,120 @@ +# S2ANet模型 + +## 内容 +- [简介](#简介) +- [DOTA数据集](#DOTA数据集) +- [模型库](#模型库) +- [训练说明](#训练说明) + +## 简介 + +[S2ANet](https://arxiv.org/pdf/2008.09397.pdf)是用于检测旋转框的模型,要求使用PaddlePaddle 2.0.1(可使用pip安装) 或适当的[develop版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/install/Tables.html#whl-release)。 + + +## DOTA数据集 +[DOTA Dataset]是航空影像中物体检测的数据集,包含2806张图像,每张图像4000*4000分辨率。 + +| 数据版本 | 类别数 | 图像数 | 图像尺寸 | 实例数 | 标注方式 | +|:--------:|:-------:|:---------:|:---------:| :---------:| :------------: | +| v1.0 | 15 | 2806 | 800~4000 | 118282 | OBB + HBB | +| v1.5 | 16 | 2806 | 800~4000 | 400000 | OBB + HBB | + +注:OBB标注方式是指标注任意四边形;顶点按顺时针顺序排列。HBB标注方式是指标注示例的外接矩形。 + +DOTA数据集中总共有2806张图像,其中1411张图像作为训练集,458张图像作为评估集,剩余937张图像作为测试集。 + +如果需要切割图像数据,请参考[DOTA_devkit](https://github.com/CAPTAIN-WHU/DOTA_devkit) 。 + +设置`crop_size=1024, stride=824, gap=200`参数切割数据后,训练集15749张图像,评估集5297张图像,测试集10833张图像。 + +## 模型库 + +### S2ANet模型 + +| 模型 | GPU个数 | Conv类型 | mAP | 模型下载 | 配置文件 | +|:-----------:|:-------:|:----------:|:--------:| :----------:| :---------: | +| S2ANet | 8 | Conv | 71.42 | [model](https://paddledet.bj.bcebos.com/models/s2anet_conv_1x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/s2anet_conv_1x_dota.yml) | + +**注意:**这里使用`multiclass_nms`,与原作者使用nms略有不同,精度相比原始论文中高0.15 (71.27-->71.42)。 + +## 训练说明 + +### 1. 旋转框IOU计算OP + +旋转框IOU计算OP[ext_op](../../ppdet/ext_op)是参考Paddle[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/new_custom_op.html) 的方式开发。 + +若使用旋转框IOU计算OP,需要环境满足: +- PaddlePaddle >= 2.0.1 +- GCC == 8.2 + +推荐使用docker镜像[paddle:2.0.1-gpu-cuda10.1-cudnn7](registry.baidubce.com/paddlepaddle/paddle:2.0.1-gpu-cuda10.1-cudnn7)。 + +执行如下命令下载镜像并启动容器: +``` +sudo nvidia-docker run -it --name paddle_s2anet -v $PWD:/paddle --network=host registry.baidubce.com/paddlepaddle/paddle:2.0.1-gpu-cuda10.1-cudnn7 /bin/bash +``` + +进入容器后,安装必要的python包: +``` +python3.7 -m pip install Cython wheel tqdm opencv-python==4.2.0.32 scipy PyYAML shapely pycocotools +``` + +镜像中paddle2.0.1已安装好,进入python3.7,执行如下代码检查paddle安装是否正常: +``` +import paddle +print(paddle.__version__) +paddle.utils.run_check() +``` + +进入到`ext_op`文件夹,安装: +``` +python3.7 setup.py install +``` + +安装完成后,测试自定义op是否可以正常编译以及计算结果: +``` +cd PaddleDetecetion/ppdet/ext_op +python3.7 test.py +``` + +### 2. 数据格式 +DOTA 数据集中实例是按照任意四边形标注,在进行训练模型前,需要参考[DOTA2COCO](https://github.com/CAPTAIN-WHU/DOTA_devkit/blob/master/DOTA2COCO.py) 转换成`[xc, yc, bow_w, bow_h, angle]`格式,并以coco数据格式存储。 + +## 评估 + +执行如下命令,会在`output_dir`文件夹下将每个图像预测结果保存到同文件夹名的txt文本中。 +``` +python3.7 tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=./weights/s2anet_1x_dota.pdparams --infer_dir=dota_test_images --draw_threshold=0.05 --save_txt=True --output_dir=output +``` + + +请参考[DOTA_devkit](https://github.com/CAPTAIN-WHU/DOTA_devkit) 生成评估文件,评估文件格式请参考[DOTA Test](http://captain.whu.edu.cn/DOTAweb/tasks.html) ,生成zip文件,每个类一个txt文件,txt文件中每行格式为:`image_id score x1 y1 x2 y2 x3 y3 x4 y4`,提交服务器进行评估。 + +## 预测部署 + +Paddle中`multiclass_nms`算子的输入支持四边形输入,因此部署时可以不不需要依赖旋转框IOU计算算子。 + +```bash +# 预测 +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=model.pdparams --infer_img=demo/P0072__1.0__0___0.png --use_gpu=True +``` + + +## Citations +``` +@article{han2021align, + author={J. {Han} and J. {Ding} and J. {Li} and G. -S. {Xia}}, + journal={IEEE Transactions on Geoscience and Remote Sensing}, + title={Align Deep Features for Oriented Object Detection}, + year={2021}, + pages={1-11}, + doi={10.1109/TGRS.2021.3062048}} + +@inproceedings{xia2018dota, + title={DOTA: A large-scale dataset for object detection in aerial images}, + author={Xia, Gui-Song and Bai, Xiang and Ding, Jian and Zhu, Zhen and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei}, + booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3974--3983}, + year={2018} +} +``` diff --git a/configs/dota/_base_/s2anet.yml b/configs/dota/_base_/s2anet.yml new file mode 100644 index 0000000000000000000000000000000000000000..1761d3542827afeaaae756fdd390dc9d079ef491 --- /dev/null +++ b/configs/dota/_base_/s2anet.yml @@ -0,0 +1,53 @@ +architecture: S2ANet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams +weights: output/s2anet_r50_fpn_1x_dota/model_final.pdparams + + +# Model Achitecture +S2ANet: + backbone: ResNet + neck: FPN + s2anet_head: S2ANetHead + s2anet_bbox_post_process: S2ANetBBoxPostProcess + +ResNet: + depth: 50 + norm_type: bn + return_idx: [1,2,3] + num_stages: 4 + +FPN: + in_channels: [256, 512, 1024] + out_channel: 256 + spatial_scales: [0.25, 0.125, 0.0625] + has_extra_convs: True + extra_stage: 2 + relu_before_extra_convs: False + +S2ANetHead: + anchor_strides: [8, 16, 32, 64, 128] + anchor_scales: [4] + anchor_ratios: [1.0] + anchor_assign: RBoxAssigner + stacked_convs: 2 + feat_in: 256 + feat_out: 256 + num_classes: 15 + align_conv_type: 'AlignConv' # AlignConv Conv + align_conv_size: 3 + use_sigmoid_cls: True + +RBoxAssigner: + pos_iou_thr: 0.5 + neg_iou_thr: 0.4 + min_iou_thr: 0.0 + ignore_iof_thr: -2 + +S2ANetBBoxPostProcess: + nms_pre: 2000 + min_bbox_size: 0.0 + nms: + name: MultiClassNMS + keep_top_k: -1 + score_threshold: 0.05 + nms_threshold: 0.1 diff --git a/configs/dota/_base_/s2anet_optimizer_1x.yml b/configs/dota/_base_/s2anet_optimizer_1x.yml new file mode 100644 index 0000000000000000000000000000000000000000..65f794dc34c55f5d597b94eb1b305b28a28707f7 --- /dev/null +++ b/configs/dota/_base_/s2anet_optimizer_1x.yml @@ -0,0 +1,20 @@ +epoch: 12 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [7, 10] + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 + clip_grad_by_norm: 35 diff --git a/configs/dota/_base_/s2anet_reader.yml b/configs/dota/_base_/s2anet_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..c3df7a089ae0c3fa0bc9336330a06c3edfa94788 --- /dev/null +++ b/configs/dota/_base_/s2anet_reader.yml @@ -0,0 +1,42 @@ +worker_num: 0 +TrainReader: + sample_transforms: + - Decode: {} + - Rbox2Poly: {} + # Resize can process rbox + - Resize: {target_size: [1024, 1024], interp: 2, keep_ratio: False} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - RboxPadBatch: {pad_to_stride: 32, pad_gt: true} + batch_size: 1 + shuffle: true + drop_last: true + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - RboxPadBatch: {pad_to_stride: 32, pad_gt: false} + batch_size: 1 + shuffle: false + drop_last: false + drop_empty: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - RboxPadBatch: {pad_to_stride: 32, pad_gt: false} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/dota/s2anet_1x_dota.yml b/configs/dota/s2anet_1x_dota.yml new file mode 100644 index 0000000000000000000000000000000000000000..d480c1c8669402727d16cfb1c3fbdd0d1d7464af --- /dev/null +++ b/configs/dota/s2anet_1x_dota.yml @@ -0,0 +1,8 @@ +_BASE_: [ + '../datasets/dota.yml', + '../runtime.yml', + '_base_/s2anet_optimizer_1x.yml', + '_base_/s2anet.yml', + '_base_/s2anet_reader.yml', +] +weights: output/s2anet_1x_dota/model_final diff --git a/configs/dota/s2anet_conv_1x_dota.yml b/configs/dota/s2anet_conv_1x_dota.yml new file mode 100644 index 0000000000000000000000000000000000000000..feab8c04ae2858d2887a06f656be2d854b74e832 --- /dev/null +++ b/configs/dota/s2anet_conv_1x_dota.yml @@ -0,0 +1,20 @@ +_BASE_: [ + '../datasets/dota_debug.yml', + '../runtime.yml', + '_base_/s2anet_optimizer_1x.yml', + '_base_/s2anet.yml', + '_base_/s2anet_reader.yml', +] +weights: output/s2anet_1x_dota/model_final +S2ANetHead: + anchor_strides: [ 8, 16, 32, 64, 128 ] + anchor_scales: [ 4 ] + anchor_ratios: [ 1.0 ] + anchor_assign: RBoxAssigner + stacked_convs: 2 + feat_in: 256 + feat_out: 256 + num_classes: 15 + align_conv_type: 'Conv' # AlignConv Conv + align_conv_size: 3 + use_sigmoid_cls: True diff --git a/demo/P0072__1.0__0___0.png b/demo/P0072__1.0__0___0.png new file mode 100644 index 0000000000000000000000000000000000000000..aaf9c59bc18f09a342b13e88d966c590b7c16024 Binary files /dev/null and b/demo/P0072__1.0__0___0.png differ diff --git a/demo/P0861__1.0__1154___824.png b/demo/P0861__1.0__1154___824.png new file mode 100644 index 0000000000000000000000000000000000000000..47ab7ae3b4698c18a70513e262274f2bfeb98622 Binary files /dev/null and b/demo/P0861__1.0__1154___824.png differ diff --git a/docs/tutorials/GETTING_STARTED_cn.md b/docs/tutorials/GETTING_STARTED_cn.md index f67671d7d4c1192ddba9d063d2ae793db6af2c8b..d04cac9f37578d48cde2ec5b539a451d06abf446 100644 --- a/docs/tutorials/GETTING_STARTED_cn.md +++ b/docs/tutorials/GETTING_STARTED_cn.md @@ -52,6 +52,7 @@ python tools/infer.py -c configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.yml --i | --draw_threshold | infer | 可视化时分数阈值 | 0.5 | 例如`--draw_threshold=0.7` | | --infer_dir | infer | 用于预测的图片文件夹路径 | None | `--infer_img`和`--infer_dir`必须至少设置一个 | | --infer_img | infer | 用于预测的图片路径 | None | `--infer_img`和`--infer_dir`必须至少设置一个,`infer_img`具有更高优先级 | +| --save_txt | infer | 是否在文件夹下将图片的预测结果保存到文本文件中 | False | 可选 | ## 使用示例 diff --git a/ppdet/data/source/coco.py b/ppdet/data/source/coco.py index d5691c7be86d82db8fdc85eaa2987bbbfb4c1535..f896502b3622e3ef61611b24c27965d2a065c13a 100644 --- a/ppdet/data/source/coco.py +++ b/ppdet/data/source/coco.py @@ -102,14 +102,26 @@ class COCODataSet(DetDataset): else: if not any(np.array(inst['bbox'])): continue - x1, y1, box_w, box_h = inst['bbox'] - x2 = x1 + box_w - y2 = y1 + box_h + + # read rbox anno or not + is_rbox_anno = True if len(inst['bbox']) == 5 else False + if is_rbox_anno: + xc, yc, box_w, box_h, angle = inst['bbox'] + x1 = xc - box_w / 2.0 + y1 = yc - box_h / 2.0 + x2 = x1 + box_w + y2 = y1 + box_h + else: + x1, y1, box_w, box_h = inst['bbox'] + x2 = x1 + box_w + y2 = y1 + box_h eps = 1e-5 if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps: inst['clean_bbox'] = [ round(float(x), 3) for x in [x1, y1, x2, y2] ] + if is_rbox_anno: + inst['clean_rbox'] = [xc, yc, box_w, box_h, angle] bboxes.append(inst) else: logger.warning( @@ -122,6 +134,9 @@ class COCODataSet(DetDataset): continue gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + if is_rbox_anno: + gt_rbox = np.zeros((num_bbox, 5), dtype=np.float32) + gt_theta = np.zeros((num_bbox, 1), dtype=np.int32) gt_class = np.zeros((num_bbox, 1), dtype=np.int32) is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) difficult = np.zeros((num_bbox, 1), dtype=np.int32) @@ -132,6 +147,9 @@ class COCODataSet(DetDataset): catid = box['category_id'] gt_class[i][0] = self.catid2clsid[catid] gt_bbox[i, :] = box['clean_bbox'] + # xc, yc, w, h, theta + if is_rbox_anno: + gt_rbox[i, :] = box['clean_rbox'] is_crowd[i][0] = box['iscrowd'] # check RLE format if 'segmentation' in box and box['iscrowd'] == 1: @@ -150,12 +168,22 @@ class COCODataSet(DetDataset): 'w': im_w, } if 'image' in self.data_fields else {} - gt_rec = { - 'is_crowd': is_crowd, - 'gt_class': gt_class, - 'gt_bbox': gt_bbox, - 'gt_poly': gt_poly, - } + if is_rbox_anno: + gt_rec = { + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + 'gt_rbox': gt_rbox, + 'gt_poly': gt_poly, + } + else: + gt_rec = { + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + 'gt_poly': gt_poly, + } + for k, v in gt_rec.items(): if k in self.data_fields: coco_rec[k] = v diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index b749111d3bd424a4192d2a3ea5cd64aae955e173..8ca42bfa9e1a64a216d9d12aca43508ca2f06600 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -31,12 +31,8 @@ from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) __all__ = [ - 'PadBatch', - 'BatchRandomResize', - 'Gt2YoloTarget', - 'Gt2FCOSTarget', - 'Gt2TTFTarget', - 'Gt2Solov2Target', + 'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget', + 'Gt2TTFTarget', 'Gt2Solov2Target', 'RboxPadBatch' ] @@ -739,3 +735,155 @@ class Gt2Solov2Target(BaseOperator): data['grid_order{}'.format(idx)] = gt_grid_order return samples + + +@register_op +class RboxPadBatch(BaseOperator): + """ + Pad a batch of samples so they can be divisible by a stride. + The layout of each image should be 'CHW'. And convert poly to rbox. + Args: + pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure + height and width is divisible by `pad_to_stride`. + """ + + def __init__(self, pad_to_stride=0, pad_gt=False): + super(RboxPadBatch, self).__init__() + self.pad_to_stride = pad_to_stride + self.pad_gt = pad_gt + + def poly_to_rbox(self, polys): + """ + poly:[x0,y0,x1,y1,x2,y2,x3,y3] + to + rotated_boxes:[x_ctr,y_ctr,w,h,angle] + """ + rotated_boxes = [] + for poly in polys: + poly = np.array(poly[:8], dtype=np.float32) + + pt1 = (poly[0], poly[1]) + pt2 = (poly[2], poly[3]) + pt3 = (poly[4], poly[5]) + pt4 = (poly[6], poly[7]) + + edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[ + 1] - pt2[1]) * (pt1[1] - pt2[1])) + edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[ + 1] - pt3[1]) * (pt2[1] - pt3[1])) + + width = max(edge1, edge2) + height = min(edge1, edge2) + + angle = 0 + if edge1 > edge2: + angle = np.arctan2( + np.float(pt2[1] - pt1[1]), np.float(pt2[0] - pt1[0])) + elif edge2 >= edge1: + angle = np.arctan2( + np.float(pt4[1] - pt1[1]), np.float(pt4[0] - pt1[0])) + + def norm_angle(angle, range=[-np.pi / 4, np.pi]): + return (angle - range[0]) % range[1] + range[0] + + angle = norm_angle(angle) + + x_ctr = np.float(pt1[0] + pt3[0]) / 2.0 + y_ctr = np.float(pt1[1] + pt3[1]) / 2.0 + rotated_box = np.array([x_ctr, y_ctr, width, height, angle]) + rotated_boxes.append(rotated_box) + ret_rotated_boxes = np.array(rotated_boxes) + assert ret_rotated_boxes.shape[1] == 5 + return ret_rotated_boxes + + def __call__(self, samples, context=None): + """ + Args: + samples (list): a batch of sample, each is dict. + """ + coarsest_stride = self.pad_to_stride + + max_shape = np.array([data['image'].shape for data in samples]).max( + axis=0) + if coarsest_stride > 0: + max_shape[1] = int( + np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride) + max_shape[2] = int( + np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride) + + for data in samples: + im = data['image'] + im_c, im_h, im_w = im.shape[:] + padding_im = np.zeros( + (im_c, max_shape[1], max_shape[2]), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + data['image'] = padding_im + if 'semantic' in data and data['semantic'] is not None: + semantic = data['semantic'] + padding_sem = np.zeros( + (1, max_shape[1], max_shape[2]), dtype=np.float32) + padding_sem[:, :im_h, :im_w] = semantic + data['semantic'] = padding_sem + if 'gt_segm' in data and data['gt_segm'] is not None: + gt_segm = data['gt_segm'] + padding_segm = np.zeros( + (gt_segm.shape[0], max_shape[1], max_shape[2]), + dtype=np.uint8) + padding_segm[:, :im_h, :im_w] = gt_segm + data['gt_segm'] = padding_segm + if self.pad_gt: + gt_num = [] + if 'gt_poly' in data and data['gt_poly'] is not None and len(data[ + 'gt_poly']) > 0: + pad_mask = True + else: + pad_mask = False + + if pad_mask: + poly_num = [] + poly_part_num = [] + point_num = [] + for data in samples: + gt_num.append(data['gt_bbox'].shape[0]) + if pad_mask: + poly_num.append(len(data['gt_poly'])) + for poly in data['gt_poly']: + poly_part_num.append(int(len(poly))) + for p_p in poly: + point_num.append(int(len(p_p) / 2)) + gt_num_max = max(gt_num) + + for i, sample in enumerate(samples): + assert 'gt_rbox' in sample + assert 'gt_rbox2poly' in sample + gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32) + gt_class_data = -np.ones([gt_num_max], dtype=np.int32) + is_crowd_data = np.ones([gt_num_max], dtype=np.int32) + + if pad_mask: + poly_num_max = max(poly_num) + poly_part_num_max = max(poly_part_num) + point_num_max = max(point_num) + gt_masks_data = -np.ones( + [poly_num_max, poly_part_num_max, point_num_max, 2], + dtype=np.float32) + + gt_num = sample['gt_bbox'].shape[0] + gt_box_data[0:gt_num, :] = sample['gt_bbox'] + gt_class_data[0:gt_num] = np.squeeze(sample['gt_class']) + is_crowd_data[0:gt_num] = np.squeeze(sample['is_crowd']) + if pad_mask: + for j, poly in enumerate(sample['gt_poly']): + for k, p_p in enumerate(poly): + pp_np = np.array(p_p).reshape(-1, 2) + gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np + sample['gt_poly'] = gt_masks_data + sample['gt_bbox'] = gt_box_data + sample['gt_class'] = gt_class_data + sample['is_crowd'] = is_crowd_data + # ploy to rbox + polys = sample['gt_rbox2poly'] + rbox = self.poly_to_rbox(polys) + sample['gt_rbox'] = rbox + + return samples diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index b07ee0cba1dfa22b7114f0054d3844df614a500e..06cc6165b0b13159515b9caaaee48364dca27f48 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -536,6 +536,17 @@ class RandomFlip(BaseOperator): bbox[:, 2] = width - oldx1 return bbox + def apply_rbox(self, bbox, width): + oldx1 = bbox[:, 0].copy() + oldx2 = bbox[:, 2].copy() + oldx3 = bbox[:, 4].copy() + oldx4 = bbox[:, 6].copy() + bbox[:, 0] = width - oldx2 + bbox[:, 2] = width - oldx1 + bbox[:, 4] = width - oldx3 + bbox[:, 6] = width - oldx4 + return bbox + def apply(self, sample, context=None): """Filp the image and bounding box. Operators: @@ -567,6 +578,10 @@ class RandomFlip(BaseOperator): if 'gt_segm' in sample and sample['gt_segm'].any(): sample['gt_segm'] = sample['gt_segm'][:, :, ::-1] + if 'gt_rbox2poly' in sample and sample['gt_rbox2poly'].any(): + sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'], + width) + sample['flipped'] = True sample['image'] = im return sample @@ -704,6 +719,16 @@ class Resize(BaseOperator): [im_scale_x, im_scale_y], [resize_w, resize_h]) + # apply rbox + if 'gt_rbox2poly' in sample: + if np.array(sample['gt_rbox2poly']).shape[1] != 8: + logger.warn( + "gt_rbox2poly's length shoule be 8, but actually is {}". + format(len(sample['gt_rbox2poly']))) + sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'], + [im_scale_x, im_scale_y], + [resize_w, resize_h]) + # apply polygon if 'gt_poly' in sample and len(sample['gt_poly']) > 0: sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2], @@ -1933,3 +1958,113 @@ class Poly2Mask(BaseOperator): ] sample['gt_segm'] = np.asarray(masks).astype(np.uint8) return sample + + +@register_op +class Rbox2Poly(BaseOperator): + """ + Convert rbbox format to poly format. + """ + + def __init__(self): + super(Rbox2Poly, self).__init__() + + def apply(self, sample, context=None): + assert 'gt_rbox' in sample + assert sample['gt_rbox'].shape[1] == 5 + rrect = sample['gt_rbox'] + bbox_num = rrect.shape[0] + x_ctr = rrect[:, 0] + y_ctr = rrect[:, 1] + width = rrect[:, 2] + height = rrect[:, 3] + angle = rrect[:, 4] + tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2 + # rect 2x4 + rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]]) + R = np.array([[np.cos(angle), -np.sin(angle)], + [np.sin(angle), np.cos(angle)]]) + + poly = [] + for i in range(R.shape[2]): + tmp_r = R[:, :, i].reshape(2, 2) + poly.append(tmp_r.dot(rect[:, :, i])) + + # poly:[M, 2, 4] + poly = np.array(poly) + coor_x = poly[:, 0, :4] + x_ctr.reshape(bbox_num, 1) + coor_y = poly[:, 1, :4] + y_ctr.reshape(bbox_num, 1) + poly = np.stack( + [ + coor_x[:, 0], coor_y[:, 0], coor_x[:, 1], coor_y[:, 1], + coor_x[:, 2], coor_y[:, 2], coor_x[:, 3], coor_y[:, 3] + ], + axis=1) + x1 = x_ctr - width / 2.0 + y1 = y_ctr - height / 2.0 + x2 = x_ctr + width / 2.0 + y2 = y_ctr + height / 2.0 + sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1) + sample['gt_rbox2poly'] = poly + return sample + + +@register_op +class Poly2Rbox(BaseOperator): + """ + Convert poly format to rbbox format. + """ + + def __init__(self): + super(Poly2Rbox, self).__init__() + + def poly_to_rbox(self, polys): + """ + poly:[x0,y0,x1,y1,x2,y2,x3,y3] + to + rotated_boxes:[x_ctr,y_ctr,w,h,angle] + """ + rotated_boxes = [] + for poly in polys: + poly = np.array(poly[:8], dtype=np.float32) + + pt1 = (poly[0], poly[1]) + pt2 = (poly[2], poly[3]) + pt3 = (poly[4], poly[5]) + pt4 = (poly[6], poly[7]) + + edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[ + 1] - pt2[1]) * (pt1[1] - pt2[1])) + edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[ + 1] - pt3[1]) * (pt2[1] - pt3[1])) + + width = max(edge1, edge2) + height = min(edge1, edge2) + + angle = 0 + if edge1 > edge2: + angle = np.arctan2( + np.float(pt2[1] - pt1[1]), np.float(pt2[0] - pt1[0])) + elif edge2 >= edge1: + angle = np.arctan2( + np.float(pt4[1] - pt1[1]), np.float(pt4[0] - pt1[0])) + + def norm_angle(angle, range=[-np.pi / 4, np.pi]): + return (angle - range[0]) % range[1] + range[0] + + angle = norm_angle(angle) + + x_ctr = np.float(pt1[0] + pt3[0]) / 2 + y_ctr = np.float(pt1[1] + pt3[1]) / 2 + rotated_box = np.array([x_ctr, y_ctr, width, height, angle]) + rotated_boxes.append(rotated_box) + ret_rotated_boxes = np.array(rotated_boxes) + assert ret_rotated_boxes.shape[1] == 5 + return ret_rotated_boxes + + def apply(self, sample, context=None): + assert 'gt_rbox2poly' in sample + poly = sample['gt_rbox2poly'] + rbox = self.poly_to_rbox(poly) + sample['gt_rbox'] = rbox + return sample diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index ab626084fb108f4d3454da42d1f046b44b0251cc..29a5c7513998e528fe504c7dd20eac9e5c4247ec 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -31,6 +31,7 @@ TRT_MIN_SUBGRAPH = { 'SSD': 60, 'RCNN': 40, 'RetinaNet': 40, + 'S2ANet': 40, 'EfficientDet': 40, 'Face': 3, 'TTFNet': 3, diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index be2c14e68f6d7c0eefd0a084c3b7faf4c2fc73c2..c64bbe5b83de08f33d10edbdb34e15c139621038 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -31,7 +31,7 @@ from paddle.static import InputSpec from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight -from ppdet.utils.visualizer import visualize_results +from ppdet.utils.visualizer import visualize_results, save_result from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats @@ -333,7 +333,11 @@ class Trainer(object): def evaluate(self): self._eval_with_loader(self.loader) - def predict(self, images, draw_threshold=0.5, output_dir='output'): + def predict(self, + images, + draw_threshold=0.5, + output_dir='output', + save_txt=False): self.dataset.set_images(images) loader = create('TestReader')(self.dataset, 0) @@ -369,6 +373,7 @@ class Trainer(object): if 'mask' in batch_res else None segm_res = batch_res['segm'][start:end] \ if 'segm' in batch_res else None + image = visualize_results(image, bbox_res, mask_res, segm_res, int(outs['im_id']), catid2name, draw_threshold) @@ -380,6 +385,9 @@ class Trainer(object): logger.info("Detection bbox results save in {}".format( save_name)) image.save(save_name, quality=95) + if save_txt: + save_path = os.path.splitext(save_name)[0] + '.txt' + save_result(save_path, bbox_res, catid2name, draw_threshold) start = end def _get_save_image_name(self, output_dir, image_path): diff --git a/ppdet/ext_op/README.md b/ppdet/ext_op/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7ada0acf7fd75266fed6c66a9a010debc645bee8 --- /dev/null +++ b/ppdet/ext_op/README.md @@ -0,0 +1,38 @@ +# 自定义OP编译 +旋转框IOU计算OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/new_custom_op.html) 。 + +## 1. 环境依赖 +- Paddle >= 2.0.1 +- gcc 8.2 + +## 2. 安装 +``` +python3.7 setup.py install +``` + +按照如下方式使用 +``` +# 引入自定义op +from rbox_iou_ops import rbox_iou + +paddle.set_device('gpu:0') +paddle.disable_static() + +rbox1 = np.random.rand(13000, 5) +rbox2 = np.random.rand(7, 5) + +pd_rbox1 = paddle.to_tensor(rbox1) +pd_rbox2 = paddle.to_tensor(rbox2) + +iou = rbox_iou(pd_rbox1, pd_rbox2) +print('iou', iou) +``` + +## 3. 单元测试 +单元测试`test.py`文件中,通过对比python实现的结果和测试自定义op结果。 + +由于python计算细节与cpp计算细节略有区别,误差区间设置为0.02。 +``` +python3.7 test.py +``` +提示`rbox_iou OP compute right!`说明OP测试通过。 diff --git a/ppdet/ext_op/rbox_iou_op.cc b/ppdet/ext_op/rbox_iou_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..05890fd2bf7fe2e10299e608a7fb852b175f3507 --- /dev/null +++ b/ppdet/ext_op/rbox_iou_op.cc @@ -0,0 +1,46 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +#include + +std::vector RboxIouCPUForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2); +std::vector RboxIouCUDAForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2); + + +#define CHECK_INPUT_SAME(x1, x2) PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.") +std::vector RboxIouForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) { + CHECK_INPUT_SAME(rbox1, rbox2); + if (rbox1.place() == paddle::PlaceType::kCPU) { + return RboxIouCPUForward(rbox1, rbox2); + } + else if (rbox1.place() == paddle::PlaceType::kGPU) { + return RboxIouCUDAForward(rbox1, rbox2); + } +} + +std::vector> InferShape(std::vector rbox1_shape, std::vector rbox2_shape) { + return {{rbox1_shape[0], rbox2_shape[0]}}; +} + +std::vector InferDtype(paddle::DataType t1, paddle::DataType t2) { + return {t1}; +} + +PD_BUILD_OP(rbox_iou) + .Inputs({"RBOX1", "RBOX2"}) + .Outputs({"Output"}) + .SetKernelFn(PD_KERNEL(RboxIouForward)) + .SetInferShapeFn(PD_INFER_SHAPE(InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(InferDtype)); diff --git a/ppdet/ext_op/rbox_iou_op.cu b/ppdet/ext_op/rbox_iou_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..0581f782f3ed1e2e8c7be81f58adc88a431387d8 --- /dev/null +++ b/ppdet/ext_op/rbox_iou_op.cu @@ -0,0 +1,507 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include +#include + +#ifdef __CUDACC__ +// Designates functions callable from the host (CPU) and the device (GPU) +#define HOST_DEVICE __host__ __device__ +#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__ +#else +#include +#define HOST_DEVICE +#define HOST_DEVICE_INLINE HOST_DEVICE inline +#endif + +#include "paddle/extension.h" + +#include + +namespace { + +template +struct RotatedBox { + T x_ctr, y_ctr, w, h, a; +}; + +template +struct Point { + T x, y; + HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {} + HOST_DEVICE_INLINE Point operator+(const Point& p) const { + return Point(x + p.x, y + p.y); + } + HOST_DEVICE_INLINE Point& operator+=(const Point& p) { + x += p.x; + y += p.y; + return *this; + } + HOST_DEVICE_INLINE Point operator-(const Point& p) const { + return Point(x - p.x, y - p.y); + } + HOST_DEVICE_INLINE Point operator*(const T coeff) const { + return Point(x * coeff, y * coeff); + } +}; + +template +HOST_DEVICE_INLINE T dot_2d(const Point& A, const Point& B) { + return A.x * B.x + A.y * B.y; +} + +template +HOST_DEVICE_INLINE T cross_2d(const Point& A, const Point& B) { + return A.x * B.y - B.x * A.y; +} + +template +HOST_DEVICE_INLINE void get_rotated_vertices( + const RotatedBox& box, + Point (&pts)[4]) { + // M_PI / 180. == 0.01745329251 + //double theta = box.a * 0.01745329251; + //MODIFIED + double theta = box.a; + T cosTheta2 = (T)cos(theta) * 0.5f; + T sinTheta2 = (T)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; +} + +template +HOST_DEVICE_INLINE int get_intersection_points( + const Point (&pts1)[4], + const Point (&pts2)[4], + Point (&intersections)[24]) { + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; + } + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + // Solve for 2x2 Ax=b + T det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + T t1 = cross_2d(vec2[j], vec12) / det; + T t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } + } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && + (APdotAD <= ADdotAD)) { + intersections[num++] = pts1[i]; + } + } + } + + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && + (APdotAD <= ADdotAD)) { + intersections[num++] = pts2[i]; + } + } + } + + return num; +} + +template +HOST_DEVICE_INLINE int convex_hull_graham( + const Point (&p)[24], + const int& num_in, + Point (&q)[24], + bool shift_to_zero = false) { + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { + t = i; + } + } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) { + q[i] = p[i] - start; + } + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + T dist[24]; + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + } + +#ifdef __CUDACC__ + // CUDA version + // In the future, we can potentially use thrust + // for sorting here to improve speed (though not guaranteed) + for (int i = 1; i < num_in - 1; i++) { + for (int j = i + 1; j < num_in; j++) { + T crossProduct = cross_2d(q[i], q[j]); + if ((crossProduct < -1e-6) || + (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) { + auto q_tmp = q[i]; + q[i] = q[j]; + q[j] = q_tmp; + auto dist_tmp = dist[i]; + dist[i] = dist[j]; + dist[j] = dist_tmp; + } + } + } +#else + // CPU version + std::sort( + q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { + T temp = cross_2d(A, B); + if (fabs(temp) < 1e-6) { + return dot_2d(A, A) < dot_2d(B, B); + } else { + return temp > 0; + } + }); +#endif + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) { + if (dist[k] > 1e-8) { + break; + } + } + if (k == num_in) { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) { + while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) { + m--; + } + q[m++] = q[i]; + } + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) { + for (int i = 0; i < m; i++) { + q[i] += start; + } + } + + return m; +} + +template +HOST_DEVICE_INLINE T polygon_area(const Point (&q)[24], const int& m) { + if (m <= 2) { + return 0; + } + + T area = 0; + for (int i = 1; i < m - 1; i++) { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + } + + return area / 2.0; +} + +template +HOST_DEVICE_INLINE T rboxes_intersection( + const RotatedBox& box1, + const RotatedBox& box2) { + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; + + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); + + int num = get_intersection_points(pts1, pts2, intersectPts); + + if (num <= 2) { + return 0.0; + } + + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); +} + +} // namespace + +template +HOST_DEVICE_INLINE T +rbox_iou_single(T const* const box1_raw, T const* const box2_raw) { + // shift center to the middle point to achieve higher precision in result + RotatedBox box1, box2; + auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; + auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0; + box1.x_ctr = box1_raw[0] - center_shift_x; + box1.y_ctr = box1_raw[1] - center_shift_y; + box1.w = box1_raw[2]; + box1.h = box1_raw[3]; + box1.a = box1_raw[4]; + box2.x_ctr = box2_raw[0] - center_shift_x; + box2.y_ctr = box2_raw[1] - center_shift_y; + box2.w = box2_raw[2]; + box2.h = box2_raw[3]; + box2.a = box2_raw[4]; + + const T area1 = box1.w * box1.h; + const T area2 = box2.w * box2.h; + if (area1 < 1e-14 || area2 < 1e-14) { + return 0.f; + } + + const T intersection = rboxes_intersection(box1, box2); + const T iou = intersection / (area1 + area2 - intersection); + return iou; +} + + +// 2D block with 32 * 16 = 512 threads per block +const int BLOCK_DIM_X = 32; +const int BLOCK_DIM_Y = 16; + +/** + Computes ceil(a / b) +*/ +template +__host__ __device__ __forceinline__ T CeilDiv0(T a, T b) { + return (a + b - 1) / b; +} + +static inline int CeilDiv(const int a, const int b) { + return (a + b -1) / b; +} + +template +__global__ void rbox_iou_cuda_kernel( + const int rbox1_num, + const int rbox2_num, + const T* rbox1_data_ptr, + const T* rbox2_data_ptr, + T* output_data_ptr) { + + // get row_start and col_start + const int rbox1_block_idx = blockIdx.x * blockDim.x; + const int rbox2_block_idx = blockIdx.y * blockDim.y; + + const int rbox1_thread_num = min(rbox1_num - rbox1_block_idx, blockDim.x); + const int rbox2_thread_num = min(rbox2_num - rbox2_block_idx, blockDim.y); + + __shared__ T block_boxes1[BLOCK_DIM_X * 5]; + __shared__ T block_boxes2[BLOCK_DIM_Y * 5]; + + + // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y + if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) { + block_boxes1[threadIdx.x * 5 + 0] = + rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 0]; + block_boxes1[threadIdx.x * 5 + 1] = + rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 1]; + block_boxes1[threadIdx.x * 5 + 2] = + rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 2]; + block_boxes1[threadIdx.x * 5 + 3] = + rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 3]; + block_boxes1[threadIdx.x * 5 + 4] = + rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4]; + } + + // threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as above: threadIdx.y == 0 + if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) { + block_boxes2[threadIdx.x * 5 + 0] = + rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0]; + block_boxes2[threadIdx.x * 5 + 1] = + rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 1]; + block_boxes2[threadIdx.x * 5 + 2] = + rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 2]; + block_boxes2[threadIdx.x * 5 + 3] = + rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 3]; + block_boxes2[threadIdx.x * 5 + 4] = + rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 4]; + } + + // sync + __syncthreads(); + + if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) { + int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx + threadIdx.y; + output_data_ptr[offset] = rbox_iou_single(block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); + } +} + +#define CHECK_INPUT_GPU(x) PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.") + +std::vector RboxIouCUDAForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) { + CHECK_INPUT_GPU(rbox1); + CHECK_INPUT_GPU(rbox2); + + auto rbox1_num = rbox1.shape()[0]; + auto rbox2_num = rbox2.shape()[0]; + + auto output = paddle::Tensor(paddle::PlaceType::kGPU); + output.reshape({rbox1_num, rbox2_num}); + + const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X); + const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y); + + dim3 blocks(blocks_x, blocks_y); + dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); + + PD_DISPATCH_FLOATING_TYPES( + rbox1.type(), + "rbox_iou_cuda_kernel", + ([&] { + rbox_iou_cuda_kernel<<>>( + rbox1_num, + rbox2_num, + rbox1.data(), + rbox2.data(), + output.mutable_data()); + })); + + return {output}; +} + + +template +void rbox_iou_cpu_kernel( + const int rbox1_num, + const int rbox2_num, + const T* rbox1_data_ptr, + const T* rbox2_data_ptr, + T* output_data_ptr) { + + int i, j; + for (i = 0; i < rbox1_num; i++) { + for (j = 0; j < rbox2_num; j++) { + int offset = i * rbox2_num + j; + output_data_ptr[offset] = rbox_iou_single(rbox1_data_ptr + i * 5, rbox2_data_ptr + j * 5); + } + } +} + + +#define CHECK_INPUT_CPU(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.") + +std::vector RboxIouCPUForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) { + CHECK_INPUT_CPU(rbox1); + CHECK_INPUT_CPU(rbox2); + + auto rbox1_num = rbox1.shape()[0]; + auto rbox2_num = rbox2.shape()[0]; + + auto output = paddle::Tensor(paddle::PlaceType::kCPU); + output.reshape({rbox1_num, rbox2_num}); + + PD_DISPATCH_FLOATING_TYPES( + rbox1.type(), + "rbox_iou_cpu_kernel", + ([&] { + rbox_iou_cpu_kernel( + rbox1_num, + rbox2_num, + rbox1.data(), + rbox2.data(), + output.mutable_data()); + })); + + return {output}; +} diff --git a/ppdet/ext_op/setup.py b/ppdet/ext_op/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6859f0cc29b80a171534eb385654f24f92a60921 --- /dev/null +++ b/ppdet/ext_op/setup.py @@ -0,0 +1,6 @@ +from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup + +if __name__ == "__main__": + setup( + name='rbox_iou_ops', + ext_modules=CUDAExtension(sources=['rbox_iou_op.cc', 'rbox_iou_op.cu'])) diff --git a/ppdet/ext_op/test.py b/ppdet/ext_op/test.py new file mode 100644 index 0000000000000000000000000000000000000000..83403edd3a9e6a34accd386aac26d0bdb1d77b20 --- /dev/null +++ b/ppdet/ext_op/test.py @@ -0,0 +1,154 @@ +import numpy as np +import os +import sys +import cv2 +import time +import shapely +from shapely.geometry import Polygon +import paddle + +paddle.set_device('gpu:0') +paddle.disable_static() + +try: + from rbox_iou_ops import rbox_iou +except Exception as e: + print('import custom_ops error', e) + sys.exit(-1) + +# generate random data +rbox1 = np.random.rand(13000, 5) +rbox2 = np.random.rand(7, 5) + +# x1 y1 w h [0, 0.5] +rbox1[:, 0:4] = rbox1[:, 0:4] * 0.45 + 0.001 +rbox2[:, 0:4] = rbox2[:, 0:4] * 0.45 + 0.001 + +# generate rbox +rbox1[:, 4] = rbox1[:, 4] - 0.5 +rbox2[:, 4] = rbox2[:, 4] - 0.5 + +print('rbox1', rbox1.shape, 'rbox2', rbox2.shape) + +# to paddle tensor +pd_rbox1 = paddle.to_tensor(rbox1) +pd_rbox2 = paddle.to_tensor(rbox2) + +iou = rbox_iou(pd_rbox1, pd_rbox2) +start_time = time.time() +print('paddle time:', time.time() - start_time) +print('iou is', iou.cpu().shape) + + +# get gt +def rbox2poly_single(rrect, get_best_begin_point=False): + """ + rrect:[x_ctr,y_ctr,w,h,angle] + to + poly:[x0,y0,x1,y1,x2,y2,x3,y3] + """ + x_ctr, y_ctr, width, height, angle = rrect[:5] + tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2 + # rect 2x4 + rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]]) + R = np.array([[np.cos(angle), -np.sin(angle)], + [np.sin(angle), np.cos(angle)]]) + # poly + poly = R.dot(rect) + x0, x1, x2, x3 = poly[0, :4] + x_ctr + y0, y1, y2, y3 = poly[1, :4] + y_ctr + poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32) + return poly + + +def intersection(g, p): + """ + Intersection. + """ + + g = g[:8].reshape((4, 2)) + p = p[:8].reshape((4, 2)) + + a = g + b = p + + use_filter = True + if use_filter: + # step1: + inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0])) + inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0])) + inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1])) + inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1])) + if inter_x1 >= inter_x2 or inter_y1 >= inter_y2: + return 0. + x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0])) + x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0])) + y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1])) + y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1])) + if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2: + return 0. + + g = Polygon(g) + p = Polygon(p) + #g = g.buffer(0) + #p = p.buffer(0) + if not g.is_valid or not p.is_valid: + return 0 + + inter = Polygon(g).intersection(Polygon(p)).area + union = g.area + p.area - inter + if union == 0: + return 0 + else: + return inter / union + + +# rbox_iou by python +def rbox_overlaps(anchors, gt_bboxes, use_cv2=False): + """ + + Args: + anchors: [NA, 5] x1,y1,x2,y2,angle + gt_bboxes: [M, 5] x1,y1,x2,y2,angle + + Returns: + + """ + assert anchors.shape[1] == 5 + assert gt_bboxes.shape[1] == 5 + + gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes] + anchors_ploy = [rbox2poly_single(e) for e in anchors] + + num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy) + iou = np.zeros((num_gt, num_anchors), dtype=np.float32) + + start_time = time.time() + for i in range(num_gt): + for j in range(num_anchors): + try: + iou[i, j] = intersection(gt_bboxes_ploy[i], anchors_ploy[j]) + except Exception as e: + print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i], + 'anchors_ploy[j]', anchors_ploy[j], e) + iou = iou.T + print('intersection all sp_time', time.time() - start_time) + return iou + + +# make coor as int +ploy_rbox1 = rbox1 +ploy_rbox2 = rbox2 +ploy_rbox1[:, 0:4] = rbox1[:, 0:4] * 1024 +ploy_rbox2[:, 0:4] = rbox2[:, 0:4] * 1024 + +start_time = time.time() +iou_py = rbox_overlaps(ploy_rbox1, ploy_rbox2, use_cv2=False) +print('rbox time', time.time() - start_time) +print(iou_py.shape) + +iou_pd = iou.cpu().numpy() +sum_abs_diff = np.sum(np.abs(iou_pd - iou_py)) +print('sum of abs diff', sum_abs_diff) +if sum_abs_diff < 0.02: + print("rbox_iou OP compute right!") diff --git a/ppdet/metrics/coco_utils.py b/ppdet/metrics/coco_utils.py index 82d5cc85628a4d1a98507bbee0d1b3c5b6c13d2e..a7ac32226566d56f9c993a6b02f6c21397c36be5 100644 --- a/ppdet/metrics/coco_utils.py +++ b/ppdet/metrics/coco_utils.py @@ -21,7 +21,7 @@ import sys import numpy as np import itertools -from ppdet.metrics.json_results import get_det_res, get_seg_res, get_solov2_segm_res +from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res from ppdet.metrics.map_utils import draw_pr_curve from ppdet.utils.logger import setup_logger @@ -45,8 +45,12 @@ def get_infer_results(outs, catid, bias=0): infer_res = {} if 'bbox' in outs: - infer_res['bbox'] = get_det_res( - outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias) + if len(outs['bbox']) > 0 and len(outs['bbox'][0]) > 6: + infer_res['bbox'] = get_det_poly_res( + outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias) + else: + infer_res['bbox'] = get_det_res( + outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias) if 'mask' in outs: # mask post process diff --git a/ppdet/metrics/json_results.py b/ppdet/metrics/json_results.py index 0c02cdba317290829422b8723de570490b098b0d..f5607666103e804bc4cd42ca87797c032e3b0a97 100755 --- a/ppdet/metrics/json_results.py +++ b/ppdet/metrics/json_results.py @@ -43,6 +43,54 @@ def get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): return det_res +def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): + det_res = [] + k = 0 + for i in range(len(bbox_nums)): + cur_image_id = int(image_id[i][0]) + det_nums = bbox_nums[i] + for j in range(det_nums): + dt = bboxes[k] + k = k + 1 + num_id, score, x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist() + if int(num_id) < 0: + continue + category_id = int(num_id) + rbox = [x1, y1, x2, y2, x3, y3, x4, y4] + dt_res = { + 'image_id': cur_image_id, + 'category_id': category_id, + 'bbox': rbox, + 'score': score + } + det_res.append(dt_res) + return det_res + + +def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0): + det_res = [] + k = 0 + for i in range(len(bbox_nums)): + cur_image_id = int(image_id[i][0]) + det_nums = bbox_nums[i] + for j in range(det_nums): + dt = bboxes[k] + k = k + 1 + num_id, score, x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist() + if int(num_id) < 0: + continue + category_id = int(num_id) + rbox = [x1, y1, x2, y2, x3, y3, x4, y4] + dt_res = { + 'image_id': cur_image_id, + 'category_id': category_id, + 'bbox': rbox, + 'score': score + } + det_res.append(dt_res) + return det_res + + def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map): import pycocotools.mask as mask_util seg_res = [] diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 6ffb47115548fc513245e517ebb776fbb9b72fc7..ae881607c6544709dbfdc7f6e73ae4ae30bbe48b 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -14,6 +14,7 @@ from . import ssd from . import fcos from . import solov2 from . import ttfnet +from . import s2anet from .meta_arch import * from .faster_rcnn import * @@ -24,3 +25,4 @@ from .ssd import * from .fcos import * from .solov2 import * from .ttfnet import * +from .s2anet import * diff --git a/ppdet/modeling/architectures/s2anet.py b/ppdet/modeling/architectures/s2anet.py new file mode 100644 index 0000000000000000000000000000000000000000..72e9e820adcf230c5dd4a0d6c51c0496779e424a --- /dev/null +++ b/ppdet/modeling/architectures/s2anet.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch +import numpy as np + +__all__ = ['S2ANet'] + + +@register +class S2ANet(BaseArch): + __category__ = 'architecture' + __inject__ = [ + 's2anet_head', + 's2anet_bbox_post_process', + ] + + def __init__(self, backbone, neck, s2anet_head, s2anet_bbox_post_process): + """ + S2ANet, see https://arxiv.org/pdf/2008.09397.pdf + + Args: + backbone (object): backbone instance + neck (object): `FPN` instance + s2anet_head (object): `S2ANetHead` instance + s2anet_bbox_post_process (object): `S2ANetBBoxPostProcess` instance + """ + super(S2ANet, self).__init__() + self.backbone = backbone + self.neck = neck + self.s2anet_head = s2anet_head + self.s2anet_bbox_post_process = s2anet_bbox_post_process + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + kwargs = {'input_shape': backbone.out_shape} + neck = cfg['neck'] and create(cfg['neck'], **kwargs) + + out_shape = neck and neck.out_shape or backbone.out_shape + kwargs = {'input_shape': out_shape} + s2anet_head = create(cfg['s2anet_head'], **kwargs) + s2anet_bbox_post_process = create(cfg['s2anet_bbox_post_process'], + **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + "s2anet_head": s2anet_head, + "s2anet_bbox_post_process": s2anet_bbox_post_process, + } + + def _forward(self): + body_feats = self.backbone(self.inputs) + if self.neck is not None: + body_feats = self.neck(body_feats) + self.s2anet_head(body_feats) + if self.training: + loss = self.s2anet_head.get_loss(self.inputs) + total_loss = paddle.add_n(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + im_shape = self.inputs['im_shape'] + scale_factor = self.inputs['scale_factor'] + nms_pre = self.s2anet_bbox_post_process.nms_pre + pred_scores, pred_bboxes = self.s2anet_head.get_prediction(nms_pre) + + # post_process + pred_cls_score_bbox, bbox_num, index = self.s2anet_bbox_post_process.get_prediction( + pred_scores, pred_bboxes, im_shape, scale_factor) + + # output + output = {'bbox': pred_cls_score_bbox, 'bbox_num': bbox_num} + return output + + def get_loss(self, ): + loss = self._forward() + return loss + + def get_pred(self): + output = self._forward() + return output diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index 3308684aa6c483d407e70852e2fa4fdb7a8609f8..8db8b2345047490b8ceba64fe788d503eed293ec 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -16,6 +16,7 @@ import math import paddle import paddle.nn.functional as F import math +import numpy as np def bbox2delta(src_boxes, tgt_boxes, weights): @@ -260,3 +261,147 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9): return iou - (rho2 / c2 + v * alpha) else: return iou + + +def rect2rbox(bboxes): + """ + :param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax) + :return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle) + """ + bboxes = bboxes.reshape(-1, 4) + num_boxes = bboxes.shape[0] + + x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0 + y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0 + edges1 = np.abs(bboxes[:, 2] - bboxes[:, 0]) + edges2 = np.abs(bboxes[:, 3] - bboxes[:, 1]) + angles = np.zeros([num_boxes], dtype=bboxes.dtype) + + inds = edges1 < edges2 + + rboxes = np.stack((x_ctr, y_ctr, edges1, edges2, angles), axis=1) + rboxes[inds, 2] = edges2[inds] + rboxes[inds, 3] = edges1[inds] + rboxes[inds, 4] = np.pi / 2.0 + return rboxes + + +def delta2rbox(Rrois, + deltas, + means=[0, 0, 0, 0, 0], + stds=[1, 1, 1, 1, 1], + wh_ratio_clip=1e-6): + """ + :param Rrois: (cx, cy, w, h, theta) + :param deltas: (dx, dy, dw, dh, dtheta) + :param means: + :param stds: + :param wh_ratio_clip: + :return: + """ + means = paddle.to_tensor(means) + stds = paddle.to_tensor(stds) + deltas = paddle.reshape(deltas, [-1, deltas.shape[-1]]) + denorm_deltas = deltas * stds + means + + dx = denorm_deltas[:, 0] + dy = denorm_deltas[:, 1] + dw = denorm_deltas[:, 2] + dh = denorm_deltas[:, 3] + dangle = denorm_deltas[:, 4] + + max_ratio = np.abs(np.log(wh_ratio_clip)) + dw = paddle.clip(dw, min=-max_ratio, max=max_ratio) + dh = paddle.clip(dh, min=-max_ratio, max=max_ratio) + + Rroi_x = Rrois[:, 0] + Rroi_y = Rrois[:, 1] + Rroi_w = Rrois[:, 2] + Rroi_h = Rrois[:, 3] + Rroi_angle = Rrois[:, 4] + + gx = dx * Rroi_w * paddle.cos(Rroi_angle) - dy * Rroi_h * paddle.sin( + Rroi_angle) + Rroi_x + gy = dx * Rroi_w * paddle.sin(Rroi_angle) + dy * Rroi_h * paddle.cos( + Rroi_angle) + Rroi_y + gw = Rroi_w * dw.exp() + gh = Rroi_h * dh.exp() + ga = np.pi * dangle + Rroi_angle + ga = (ga + np.pi / 4) % np.pi - np.pi / 4 + ga = paddle.to_tensor(ga) + + gw = paddle.to_tensor(gw, dtype='float32') + gh = paddle.to_tensor(gh, dtype='float32') + bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1) + return bboxes + + +def rbox2delta(proposals, gt, means=[0, 0, 0, 0, 0], stds=[1, 1, 1, 1, 1]): + """ + + Args: + proposals: + gt: + means: 1x5 + stds: 1x5 + + Returns: + + """ + proposals = proposals.astype(np.float64) + + PI = np.pi + + gt_widths = gt[..., 2] + gt_heights = gt[..., 3] + gt_angle = gt[..., 4] + + proposals_widths = proposals[..., 2] + proposals_heights = proposals[..., 3] + proposals_angle = proposals[..., 4] + + coord = gt[..., 0:2] - proposals[..., 0:2] + dx = (np.cos(proposals[..., 4]) * coord[..., 0] + np.sin(proposals[..., 4]) + * coord[..., 1]) / proposals_widths + dy = (-np.sin(proposals[..., 4]) * coord[..., 0] + np.cos(proposals[..., 4]) + * coord[..., 1]) / proposals_heights + dw = np.log(gt_widths / proposals_widths) + dh = np.log(gt_heights / proposals_heights) + da = (gt_angle - proposals_angle) + + da = (da + PI / 4) % PI - PI / 4 + da /= PI + + deltas = np.stack([dx, dy, dw, dh, da], axis=-1) + means = np.array(means, dtype=deltas.dtype) + stds = np.array(stds, dtype=deltas.dtype) + deltas = (deltas - means) / stds + deltas = deltas.astype(np.float32) + return deltas + + +def bbox_decode(bbox_preds, + anchors, + means=[0, 0, 0, 0, 0], + stds=[1, 1, 1, 1, 1]): + """decode bbox from deltas + Args: + bbox_preds: [N,H,W,5] + anchors: [H*W,5] + return: + bboxes: [N,H,W,5] + """ + means = paddle.to_tensor(means) + stds = paddle.to_tensor(stds) + num_imgs, H, W, _ = bbox_preds.shape + bboxes_list = [] + for img_id in range(num_imgs): + bbox_pred = bbox_preds[img_id] + # bbox_pred.shape=[5,H,W] + bbox_delta = bbox_pred + anchors = paddle.to_tensor(anchors) + bboxes = delta2rbox( + anchors, bbox_delta, means, stds, wh_ratio_clip=1e-6) + bboxes = paddle.reshape(bboxes, [H, W, 5]) + bboxes_list.append(bboxes) + return paddle.stack(bboxes_list, axis=0) \ No newline at end of file diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 9ed5fc2e3e5e8b8989084d1f1b3eb6dac93d24ff..9263aa812f7c112f21615649be1f7a946a16f83d 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -22,6 +22,7 @@ from . import solov2_head from . import ttf_head from . import cascade_head from . import face_head +from . import s2anet_head from .bbox_head import * from .mask_head import * @@ -33,3 +34,4 @@ from .solov2_head import * from .ttf_head import * from .cascade_head import * from .face_head import * +from .s2anet_head import * diff --git a/ppdet/modeling/heads/s2anet_head.py b/ppdet/modeling/heads/s2anet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..3a64205a405c90474c03daa78548123f5d33b2bd --- /dev/null +++ b/ppdet/modeling/heads/s2anet_head.py @@ -0,0 +1,872 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.initializer import Normal, Constant +from ppdet.core.workspace import register +from ppdet.modeling import ops +from ppdet.modeling import bbox_utils +from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner +import numpy as np + + +class S2ANetAnchorGenerator(object): + """ + S2ANetAnchorGenerator by np + """ + + def __init__(self, + base_size=8, + scales=1.0, + ratios=1.0, + scale_major=True, + ctr=None): + self.base_size = base_size + self.scales = scales + self.ratios = ratios + self.scale_major = scale_major + self.ctr = ctr + self.base_anchors = self.gen_base_anchors() + + @property + def num_base_anchors(self): + return self.base_anchors.shape[0] + + def gen_base_anchors(self): + w = self.base_size + h = self.base_size + if self.ctr is None: + x_ctr = 0.5 * (w - 1) + y_ctr = 0.5 * (h - 1) + else: + x_ctr, y_ctr = self.ctr + + h_ratios = np.sqrt(self.ratios) + w_ratios = 1 / h_ratios + if self.scale_major: + ws = (w * w_ratios[:] * self.scales[:]).reshape([-1]) + hs = (h * h_ratios[:] * self.scales[:]).reshape([-1]) + else: + ws = (w * self.scales[:] * w_ratios[:]).reshape([-1]) + hs = (h * self.scales[:] * h_ratios[:]).reshape([-1]) + + # yapf: disable + base_anchors = np.stack( + [ + x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) + ], + axis=-1) + base_anchors = np.round(base_anchors) + # yapf: enable + + return base_anchors + + def _meshgrid(self, x, y, row_major=True): + xx, yy = np.meshgrid(x, y) + xx = xx.reshape(-1) + yy = yy.reshape(-1) + if row_major: + return xx, yy + else: + return yy, xx + + def grid_anchors(self, featmap_size, stride=16): + # featmap_size*stride project it to original area + base_anchors = self.base_anchors + feat_h, feat_w = featmap_size + shift_x = np.arange(0, feat_w, 1, 'int32') * stride + shift_y = np.arange(0, feat_h, 1, 'int32') * stride + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) + # shifts = shifts.type_as(base_anchors) + # first feat_w elements correspond to the first row of shifts + # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get + # shifted anchors (K, A, 4), reshape to (K*A, 4) + + #all_anchors = base_anchors[:, :] + shifts[:, :] + all_anchors = base_anchors[None, :, :] + shifts[:, None, :] + # all_anchors = all_anchors.reshape([-1, 4]) + # first A rows correspond to A anchors of (0, 0) in feature map, + # then (0, 1), (0, 2), ... + return all_anchors + + def valid_flags(self, featmap_size, valid_size): + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = np.zeros([feat_w], dtype='uint8') + valid_y = np.zeros([feat_h], dtype='uint8') + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + valid = valid.reshape([-1]) + + # valid = valid[:, None].expand( + # [valid.size(0), self.num_base_anchors]).reshape([-1]) + return valid + + +class AlignConv(nn.Layer): + def __init__(self, in_channels, out_channels, kernel_size=3, groups=1): + super(AlignConv, self).__init__() + self.kernel_size = kernel_size + self.align_conv = paddle.vision.ops.DeformConv2D( + in_channels, + out_channels, + kernel_size=self.kernel_size, + padding=(self.kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=Normal(0, 0.01)), + bias_attr=None) + + @paddle.no_grad() + def get_offset(self, anchors, featmap_size, stride): + """ + Args: + anchors: [M,5] xc,yc,w,h,angle + featmap_size: (feat_h, feat_w) + stride: 8 + Returns: + + """ + anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5) + dtype = anchors.dtype + feat_h, feat_w = featmap_size + pad = (self.kernel_size - 1) // 2 + idx = paddle.arange(-pad, pad + 1, dtype=dtype) + + yy, xx = paddle.meshgrid(idx, idx) + xx = paddle.reshape(xx, [-1]) + yy = paddle.reshape(yy, [-1]) + + # get sampling locations of default conv + xc = paddle.arange(0, feat_w, dtype=dtype) + yc = paddle.arange(0, feat_h, dtype=dtype) + yc, xc = paddle.meshgrid(yc, xc) + + xc = paddle.reshape(xc, [-1, 1]) + yc = paddle.reshape(yc, [-1, 1]) + x_conv = xc + xx + y_conv = yc + yy + + # get sampling locations of anchors + # x_ctr, y_ctr, w, h, a = np.unbind(anchors, dim=1) + x_ctr = anchors[:, 0] + y_ctr = anchors[:, 1] + w = anchors[:, 2] + h = anchors[:, 3] + a = anchors[:, 4] + + x_ctr = paddle.reshape(x_ctr, [x_ctr.shape[0], 1]) + y_ctr = paddle.reshape(y_ctr, [y_ctr.shape[0], 1]) + w = paddle.reshape(w, [w.shape[0], 1]) + h = paddle.reshape(h, [h.shape[0], 1]) + a = paddle.reshape(a, [a.shape[0], 1]) + + x_ctr = x_ctr / stride + y_ctr = y_ctr / stride + w_s = w / stride + h_s = h / stride + cos, sin = paddle.cos(a), paddle.sin(a) + dw, dh = w_s / self.kernel_size, h_s / self.kernel_size + x, y = dw * xx, dh * yy + xr = cos * x - sin * y + yr = sin * x + cos * y + x_anchor, y_anchor = xr + x_ctr, yr + y_ctr + # get offset filed + offset_x = x_anchor - x_conv + offset_y = y_anchor - y_conv + # x, y in anchors is opposite in image coordinates, + # so we stack them with y, x other than x, y + offset = paddle.stack([offset_y, offset_x], axis=-1) + # NA,ks*ks*2 + # [NA, ks, ks, 2] --> [NA, ks*ks*2] + offset = paddle.reshape(offset, [offset.shape[0], -1]) + # [NA, ks*ks*2] --> [ks*ks*2, NA] + offset = paddle.transpose(offset, [1, 0]) + # [NA, ks*ks*2] --> [1, ks*ks*2, H, W] + offset = paddle.reshape(offset, [1, -1, feat_h, feat_w]) + return offset + + def forward(self, x, refine_anchors, stride): + featmap_size = (x.shape[2], x.shape[3]) + offset = self.get_offset(refine_anchors, featmap_size, stride) + x = F.relu(self.align_conv(x, offset)) + return x + + +@register +class S2ANetHead(nn.Layer): + """ + S2Anet head + Args: + stacked_convs (int): number of stacked_convs + feat_in (int): input channels of feat + feat_out (int): output channels of feat + num_classes (int): num_classes + anchor_strides (list): stride of anchors + anchor_scales (list): scale of anchors + anchor_ratios (list): ratios of anchors + target_means (list): target_means + target_stds (list): target_stds + align_conv_type (str): align_conv_type ['Conv', 'AlignConv'] + align_conv_size (int): kernel size of align_conv + use_sigmoid_cls (bool): use sigmoid_cls or not + reg_loss_weight (list): reg loss weight + """ + __shared__ = ['num_classes'] + __inject__ = ['anchor_assign'] + + def __init__(self, + stacked_convs=2, + feat_in=256, + feat_out=256, + num_classes=15, + anchor_strides=[8, 16, 32, 64, 128], + anchor_scales=[4], + anchor_ratios=[1.0], + target_means=(.0, .0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0, 1.0), + align_conv_type='AlignConv', + align_conv_size=3, + use_sigmoid_cls=True, + anchor_assign=RBoxAssigner().__dict__, + reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.0]): + super(S2ANetHead, self).__init__() + self.stacked_convs = stacked_convs + self.feat_in = feat_in + self.feat_out = feat_out + self.anchor_list = None + self.anchor_scales = anchor_scales + self.anchor_ratios = anchor_ratios + self.anchor_strides = anchor_strides + self.anchor_base_sizes = list(anchor_strides) + self.target_means = target_means + self.target_stds = target_stds + assert align_conv_type in ['AlignConv', 'Conv'] + self.align_conv_type = align_conv_type + self.align_conv_size = align_conv_size + + self.use_sigmoid_cls = use_sigmoid_cls + self.cls_out_channels = num_classes if self.use_sigmoid_cls else 1 + self.sampling = False + self.anchor_assign = anchor_assign + self.reg_loss_weight = reg_loss_weight + + self.s2anet_head_out = None + + # anchor + self.anchor_generators = [] + for anchor_base in self.anchor_base_sizes: + self.anchor_generators.append( + S2ANetAnchorGenerator(anchor_base, anchor_scales, + anchor_ratios)) + + # featmap_sizes + self.featmap_sizes = [] + self.base_anchors = [] + self.rbox_anchors = [] + self.refine_anchor_list = [] + + self.fam_cls_convs = nn.Sequential() + self.fam_reg_convs = nn.Sequential() + + for i in range(self.stacked_convs): + chan_in = self.feat_in if i == 0 else self.feat_out + + self.fam_cls_convs.add_sublayer( + 'fam_cls_conv_{}'.format(i), + nn.Conv2D( + in_channels=chan_in, + out_channels=self.feat_out, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(0)))) + + self.fam_cls_convs.add_sublayer('fam_cls_conv_{}_act'.format(i), + nn.ReLU()) + + self.fam_reg_convs.add_sublayer( + 'fam_reg_conv_{}'.format(i), + nn.Conv2D( + in_channels=chan_in, + out_channels=self.feat_out, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(0)))) + + self.fam_reg_convs.add_sublayer('fam_reg_conv_{}_act'.format(i), + nn.ReLU()) + + self.fam_reg = nn.Conv2D( + self.feat_out, + 5, + 1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(0))) + prior_prob = 0.01 + bias_init = float(-np.log((1 - prior_prob) / prior_prob)) + self.fam_cls = nn.Conv2D( + self.feat_out, + self.cls_out_channels, + 1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(bias_init))) + + if self.align_conv_type == "AlignConv": + self.align_conv = AlignConv(self.feat_out, self.feat_out, + self.align_conv_size) + elif self.align_conv_type == "Conv": + self.align_conv = nn.Conv2D( + self.feat_out, + self.feat_out, + self.align_conv_size, + padding=(self.align_conv_size - 1) // 2, + bias_attr=ParamAttr(initializer=Constant(0))) + + self.or_conv = nn.Conv2D( + self.feat_out, + self.feat_out, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(0))) + + # ODM + self.odm_cls_convs = nn.Sequential() + self.odm_reg_convs = nn.Sequential() + + for i in range(self.stacked_convs): + ch_in = self.feat_out + # ch_in = int(self.feat_out / 8) if i == 0 else self.feat_out + + self.odm_cls_convs.add_sublayer( + 'odm_cls_conv_{}'.format(i), + nn.Conv2D( + in_channels=ch_in, + out_channels=self.feat_out, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(0)))) + + self.odm_cls_convs.add_sublayer('odm_cls_conv_{}_act'.format(i), + nn.ReLU()) + + self.odm_reg_convs.add_sublayer( + 'odm_reg_conv_{}'.format(i), + nn.Conv2D( + in_channels=self.feat_out, + out_channels=self.feat_out, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(0)))) + + self.odm_reg_convs.add_sublayer('odm_reg_conv_{}_act'.format(i), + nn.ReLU()) + + self.odm_cls = nn.Conv2D( + self.feat_out, + self.cls_out_channels, + 3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(bias_init))) + self.odm_reg = nn.Conv2D( + self.feat_out, + 5, + 3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)), + bias_attr=ParamAttr(initializer=Constant(0))) + + def forward(self, feats): + fam_reg_branch_list = [] + fam_cls_branch_list = [] + + odm_reg_branch_list = [] + odm_cls_branch_list = [] + + self.featmap_sizes = dict() + self.base_anchors = dict() + self.refine_anchor_list = [] + + for i, feat in enumerate(feats): + fam_cls_feat = self.fam_cls_convs(feat) + + fam_cls = self.fam_cls(fam_cls_feat) + # [N, CLS, H, W] --> [N, H, W, CLS] + fam_cls = fam_cls.transpose([0, 2, 3, 1]) + fam_cls_reshape = paddle.reshape( + fam_cls, [fam_cls.shape[0], -1, self.cls_out_channels]) + fam_cls_branch_list.append(fam_cls_reshape) + + fam_reg_feat = self.fam_reg_convs(feat) + + fam_reg = self.fam_reg(fam_reg_feat) + # [N, 5, H, W] --> [N, H, W, 5] + fam_reg = fam_reg.transpose([0, 2, 3, 1]) + fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5]) + fam_reg_branch_list.append(fam_reg_reshape) + + # prepare anchor + featmap_size = feat.shape[-2:] + self.featmap_sizes[i] = featmap_size + init_anchors = self.anchor_generators[i].grid_anchors( + featmap_size, self.anchor_strides[i]) + + init_anchors = bbox_utils.rect2rbox(init_anchors) + self.base_anchors[(i, featmap_size[0])] = init_anchors + + #fam_reg1 = fam_reg + #fam_reg1.stop_gradient = True + refine_anchor = bbox_utils.bbox_decode( + fam_reg.detach(), init_anchors, self.target_means, + self.target_stds) + + self.refine_anchor_list.append(refine_anchor) + + if self.align_conv_type == 'AlignConv': + align_feat = self.align_conv(feat, + refine_anchor.clone(), + self.anchor_strides[i]) + elif self.align_conv_type == 'DCN': + align_offset = self.align_conv_offset(feat) + align_feat = self.align_conv(feat, align_offset) + elif self.align_conv_type == 'GA_DCN': + align_offset = self.align_conv_offset(feat) + align_feat = self.align_conv(feat, align_offset) + elif self.align_conv_type == 'Conv': + align_feat = self.align_conv(feat) + + or_feat = self.or_conv(align_feat) + odm_reg_feat = or_feat + odm_cls_feat = or_feat + + odm_reg_feat = self.odm_reg_convs(odm_reg_feat) + odm_cls_feat = self.odm_cls_convs(odm_cls_feat) + + odm_cls_score = self.odm_cls(odm_cls_feat) + # [N, CLS, H, W] --> [N, H, W, CLS] + odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1]) + odm_cls_score_reshape = paddle.reshape( + odm_cls_score, + [odm_cls_score.shape[0], -1, self.cls_out_channels]) + + odm_cls_branch_list.append(odm_cls_score_reshape) + + odm_bbox_pred = self.odm_reg(odm_reg_feat) + # [N, 5, H, W] --> [N, H, W, 5] + odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1]) + odm_bbox_pred_reshape = paddle.reshape( + odm_bbox_pred, [odm_bbox_pred.shape[0], -1, 5]) + odm_reg_branch_list.append(odm_bbox_pred_reshape) + + self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list, + odm_cls_branch_list, odm_reg_branch_list) + return self.s2anet_head_out + + def get_prediction(self, nms_pre): + refine_anchors = self.refine_anchor_list + fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = self.s2anet_head_out + pred_scores, pred_bboxes = self.get_bboxes( + odm_cls_branch_list, + odm_reg_branch_list, + refine_anchors, + nms_pre, + cls_out_channels=self.cls_out_channels, + use_sigmoid_cls=self.use_sigmoid_cls) + return pred_scores, pred_bboxes + + def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0): + """ + Args: + pred: pred score + label: label + delta: delta + Returns: loss + """ + assert pred.shape == label.shape and label.numel() > 0 + assert delta > 0 + diff = paddle.abs(pred - label) + loss = paddle.where(diff < delta, 0.5 * diff * diff / delta, + diff - 0.5 * delta) + return loss + + def get_fam_loss(self, fam_target, s2anet_head_out): + (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) = fam_target + fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out + + fam_cls_losses = [] + fam_bbox_losses = [] + st_idx = 0 + featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes] + num_total_samples = len(pos_inds) + len( + neg_inds) if self.sampling else len(pos_inds) + num_total_samples = max(1, num_total_samples) + + for idx, feat_size in enumerate(featmap_sizes): + feat_anchor_num = feat_size[0] * feat_size[1] + + # step1: get data + feat_labels = labels[st_idx:st_idx + feat_anchor_num] + feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num] + + feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :] + feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :] + st_idx += feat_anchor_num + + # step2: calc cls loss + feat_labels = feat_labels.reshape(-1) + feat_label_weights = feat_label_weights.reshape(-1) + + fam_cls_score = fam_cls_branch_list[idx] + fam_cls_score = paddle.squeeze(fam_cls_score, axis=0) + fam_cls_score1 = fam_cls_score + + # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1 + feat_labels = paddle.to_tensor(feat_labels) + feat_labels_one_hot = paddle.nn.functional.one_hot( + feat_labels, self.cls_out_channels + 1) + feat_labels_one_hot = feat_labels_one_hot[:, 1:] + feat_labels_one_hot.stop_gradient = True + + num_total_samples = paddle.to_tensor( + num_total_samples, dtype='float32', stop_gradient=True) + + fam_cls = F.sigmoid_focal_loss( + fam_cls_score1, + feat_labels_one_hot, + normalizer=num_total_samples, + reduction='none') + + feat_label_weights = feat_label_weights.reshape( + feat_label_weights.shape[0], 1) + feat_label_weights = np.repeat( + feat_label_weights, self.cls_out_channels, axis=1) + feat_label_weights = paddle.to_tensor( + feat_label_weights, stop_gradient=True) + + fam_cls = fam_cls * feat_label_weights + fam_cls_total = paddle.sum(fam_cls) + fam_cls_losses.append(fam_cls_total) + + # step3: regression loss + fam_bbox_pred = fam_reg_branch_list[idx] + feat_bbox_targets = paddle.to_tensor( + feat_bbox_targets, dtype='float32', stop_gradient=True) + feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5]) + + fam_bbox_pred = fam_reg_branch_list[idx] + fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0) + fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5]) + fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets) + feat_bbox_weights = paddle.to_tensor( + feat_bbox_weights, stop_gradient=True) + fam_bbox = fam_bbox * feat_bbox_weights + fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples + + fam_bbox_losses.append(fam_bbox_total) + + fam_cls_loss = paddle.add_n(fam_cls_losses) + fam_reg_loss = paddle.add_n(fam_bbox_losses) + return fam_cls_loss, fam_reg_loss + + def get_odm_loss(self, odm_target, s2anet_head_out): + (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) = odm_target + fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out + + odm_cls_losses = [] + odm_bbox_losses = [] + st_idx = 0 + featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes] + num_total_samples = len(pos_inds) + len( + neg_inds) if self.sampling else len(pos_inds) + num_total_samples = max(1, num_total_samples) + for idx, feat_size in enumerate(featmap_sizes): + feat_anchor_num = feat_size[0] * feat_size[1] + + # step1: get data + feat_labels = labels[st_idx:st_idx + feat_anchor_num] + feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num] + + feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :] + feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :] + st_idx += feat_anchor_num + + # step2: calc cls loss + feat_labels = feat_labels.reshape(-1) + feat_label_weights = feat_label_weights.reshape(-1) + + odm_cls_score = odm_cls_branch_list[idx] + odm_cls_score = paddle.squeeze(odm_cls_score, axis=0) + odm_cls_score1 = odm_cls_score + + # gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1 + feat_labels = paddle.to_tensor(feat_labels) + feat_labels_one_hot = paddle.nn.functional.one_hot( + feat_labels, self.cls_out_channels + 1) + feat_labels_one_hot = feat_labels_one_hot[:, 1:] + feat_labels_one_hot.stop_gradient = True + + num_total_samples = paddle.to_tensor( + num_total_samples, dtype='float32', stop_gradient=True) + odm_cls = F.sigmoid_focal_loss( + odm_cls_score1, + feat_labels_one_hot, + normalizer=num_total_samples, + reduction='none') + + feat_label_weights = feat_label_weights.reshape( + feat_label_weights.shape[0], 1) + feat_label_weights = np.repeat( + feat_label_weights, self.cls_out_channels, axis=1) + feat_label_weights = paddle.to_tensor(feat_label_weights) + feat_label_weights.stop_gradient = True + + odm_cls = odm_cls * feat_label_weights + odm_cls_total = paddle.sum(odm_cls) + odm_cls_losses.append(odm_cls_total) + + # # step3: regression loss + feat_bbox_targets = paddle.to_tensor( + feat_bbox_targets, dtype='float32') + feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5]) + feat_bbox_targets.stop_gradient = True + + odm_bbox_pred = odm_reg_branch_list[idx] + odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0) + odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5]) + odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets) + feat_bbox_weights = paddle.to_tensor( + feat_bbox_weights, stop_gradient=True) + odm_bbox = odm_bbox * feat_bbox_weights + odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples + odm_bbox_losses.append(odm_bbox_total) + + odm_cls_loss = paddle.add_n(odm_cls_losses) + odm_reg_loss = paddle.add_n(odm_bbox_losses) + return odm_cls_loss, odm_reg_loss + + def get_loss(self, inputs): + # inputs: im_id image im_shape scale_factor gt_bbox gt_class is_crowd + + # compute loss + fam_cls_loss_lst = [] + fam_reg_loss_lst = [] + odm_cls_loss_lst = [] + odm_reg_loss_lst = [] + + im_shape = inputs['im_shape'] + for im_id in range(im_shape.shape[0]): + np_im_shape = inputs['im_shape'][im_id].numpy() + np_scale_factor = inputs['scale_factor'][im_id].numpy() + # data_format: (xc, yc, w, h, theta) + gt_bboxes = inputs['gt_rbox'][im_id].numpy() + gt_labels = inputs['gt_class'][im_id].numpy() + is_crowd = inputs['is_crowd'][im_id].numpy() + gt_labels = gt_labels + 1 + + # featmap_sizes + featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes] + anchors_list, valid_flag_list = self.get_init_anchors(featmap_sizes, + np_im_shape) + anchors_list_all = [] + for ii, anchor in enumerate(anchors_list): + anchor = anchor.reshape(-1, 4) + anchor = bbox_utils.rect2rbox(anchor) + anchors_list_all.extend(anchor) + anchors_list_all = np.array(anchors_list_all) + + # get im_feat + fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]] + fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]] + odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]] + odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]] + im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list, + odm_cls_feats_list, odm_reg_feats_list) + + # FAM + im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes, + gt_labels, is_crowd) + if im_fam_target is not None: + im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss( + im_fam_target, im_s2anet_head_out) + fam_cls_loss_lst.append(im_fam_cls_loss) + fam_reg_loss_lst.append(im_fam_reg_loss) + + # ODM + refine_anchors_list, valid_flag_list = self.get_refine_anchors( + featmap_sizes, image_shape=np_im_shape) + refine_anchors_list = np.array(refine_anchors_list) + im_odm_target = self.anchor_assign(refine_anchors_list, gt_bboxes, + gt_labels, is_crowd) + + if im_odm_target is not None: + im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss( + im_odm_target, im_s2anet_head_out) + odm_cls_loss_lst.append(im_odm_cls_loss) + odm_reg_loss_lst.append(im_odm_reg_loss) + fam_cls_loss = paddle.add_n(fam_cls_loss_lst) + fam_reg_loss = paddle.add_n(fam_reg_loss_lst) + odm_cls_loss = paddle.add_n(odm_cls_loss_lst) + odm_reg_loss = paddle.add_n(odm_reg_loss_lst) + return { + 'fam_cls_loss': fam_cls_loss, + 'fam_reg_loss': fam_reg_loss, + 'odm_cls_loss': odm_cls_loss, + 'odm_reg_loss': odm_reg_loss + } + + def get_init_anchors(self, featmap_sizes, image_shape): + """Get anchors according to feature map sizes. + + Args: + featmap_sizes (list[tuple]): Multi-level feature map sizes. + image_shape (list[dict]): Image meta info. + Returns: + tuple: anchors of each image, valid flags of each image + """ + num_levels = len(featmap_sizes) + + # since feature map sizes of all images are the same, we only compute + # anchors for one time + anchor_list = [] + for i in range(num_levels): + anchors = self.anchor_generators[i].grid_anchors( + featmap_sizes[i], self.anchor_strides[i]) + anchor_list.append(anchors) + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for i in range(num_levels): + anchor_stride = self.anchor_strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = image_shape + valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) + valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) + flags = self.anchor_generators[i].valid_flags( + (feat_h, feat_w), (valid_feat_h, valid_feat_w)) + valid_flag_list.append(flags) + + return anchor_list, valid_flag_list + + def get_refine_anchors(self, featmap_sizes, image_shape): + num_levels = len(featmap_sizes) + + refine_anchors_list = [] + for i in range(num_levels): + refine_anchor = self.refine_anchor_list[i] + refine_anchor = paddle.squeeze(refine_anchor, axis=0) + refine_anchor = refine_anchor.numpy() + refine_anchor = np.reshape(refine_anchor, + [-1, refine_anchor.shape[-1]]) + refine_anchors_list.extend(refine_anchor) + + # for each image, we compute valid flags of multi level anchors + valid_flag_list = [] + for i in range(num_levels): + anchor_stride = self.anchor_strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = image_shape + valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h) + valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w) + flags = self.anchor_generators[i].valid_flags( + (feat_h, feat_w), (valid_feat_h, valid_feat_w)) + valid_flag_list.append(flags) + + return refine_anchors_list, valid_flag_list + + def rbox2poly_single(self, rrect, get_best_begin_point=False): + """ + rrect:[x_ctr,y_ctr,w,h,angle] + to + poly:[x0,y0,x1,y1,x2,y2,x3,y3] + """ + x_ctr, y_ctr, width, height, angle = rrect[:5] + tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2 + # rect 2x4 + rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]]) + R = np.array([[np.cos(angle), -np.sin(angle)], + [np.sin(angle), np.cos(angle)]]) + # poly + poly = R.dot(rect) + x0, x1, x2, x3 = poly[0, :4] + x_ctr + y0, y1, y2, y3 = poly[1, :4] + y_ctr + poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32) + return poly + + def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre, + cls_out_channels, use_sigmoid_cls): + assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) + + mlvl_bboxes = [] + mlvl_scores = [] + + idx = 0 + for cls_score, bbox_pred, anchors in zip(cls_score_list, bbox_pred_list, + mlvl_anchors): + cls_score = paddle.reshape(cls_score, [-1, cls_out_channels]) + if use_sigmoid_cls: + scores = F.sigmoid(cls_score) + else: + scores = F.softmax(cls_score, axis=-1) + + # bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5) + bbox_pred = paddle.transpose(bbox_pred, [1, 2, 0]) + bbox_pred = paddle.reshape(bbox_pred, [-1, 5]) + anchors = paddle.reshape(anchors, [-1, 5]) + + if nms_pre > 0 and scores.shape[0] > nms_pre: + # Get maximum scores for foreground classes. + if use_sigmoid_cls: + max_scores = paddle.max(scores, axis=1) + else: + max_scores = paddle.max(scores[:, 1:], axis=1) + + topk_val, topk_inds = paddle.topk(max_scores, nms_pre) + anchors = paddle.gather(anchors, topk_inds) + bbox_pred = paddle.gather(bbox_pred, topk_inds) + scores = paddle.gather(scores, topk_inds) + + target_means = (.0, .0, .0, .0, .0) + target_stds = (1.0, 1.0, 1.0, 1.0, 1.0) + bboxes = bbox_utils.delta2rbox(anchors, bbox_pred, target_means, + target_stds) + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + + idx += 1 + + mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0) + mlvl_scores = paddle.concat(mlvl_scores) + if use_sigmoid_cls: + # Add a dummy background class to the front when using sigmoid + padding = paddle.zeros( + [mlvl_scores.shape[0], 1], dtype=mlvl_scores.dtype) + mlvl_scores = paddle.concat([padding, mlvl_scores], axis=1) + + return mlvl_scores, mlvl_bboxes diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 00f30ff6ff4efdba85aa1e338497c937ee0527eb..d48d9aa6a8041f2d0c8221eeea21b8ad3ff550d8 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -218,3 +218,128 @@ class FCOSPostProcess(object): centerness, scale_factor) bbox_pred, bbox_num, _ = self.nms(bboxes, score) return bbox_pred, bbox_num + + +@register +class S2ANetBBoxPostProcess(object): + __inject__ = ['nms'] + + def __init__(self, nms_pre=2000, min_bbox_size=0, nms=None): + super(S2ANetBBoxPostProcess, self).__init__() + self.nms_pre = nms_pre + self.min_bbox_size = min_bbox_size + self.nms = nms + self.origin_shape_list = [] + + def rbox2poly(self, rrect, get_best_begin_point=True): + """ + rrect: [N, 5] [x_ctr,y_ctr,w,h,angle] + to + poly:[x0,y0,x1,y1,x2,y2,x3,y3] + """ + bbox_num = rrect.shape[0] + x_ctr = rrect[:, 0] + y_ctr = rrect[:, 1] + width = rrect[:, 2] + height = rrect[:, 3] + angle = rrect[:, 4] + + tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2 + # rect 2x4 + rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]]) + R = np.array([[np.cos(angle), -np.sin(angle)], + [np.sin(angle), np.cos(angle)]]) + + # R:[2,2,M] rect:[2,4,M] + #poly = R.dot(rect) + poly = [] + for i in range(R.shape[2]): + poly.append(R[:, :, i].dot(rect[:, :, i])) + # poly:[M, 2, 4] + poly = np.array(poly) + coor_x = poly[:, 0, :4] + x_ctr.reshape(bbox_num, 1) + coor_y = poly[:, 1, :4] + y_ctr.reshape(bbox_num, 1) + poly = np.stack( + [ + coor_x[:, 0], coor_y[:, 0], coor_x[:, 1], coor_y[:, 1], + coor_x[:, 2], coor_y[:, 2], coor_x[:, 3], coor_y[:, 3] + ], + axis=1) + if get_best_begin_point: + poly_lst = [get_best_begin_point_single(e) for e in poly] + poly = np.array(poly_lst) + return poly + + def get_prediction(self, pred_scores, pred_bboxes, im_shape, scale_factor): + """ + pred_scores : [N, M] score + pred_bboxes : [N, 5] xc, yc, w, h, a + im_shape : [N, 2] im_shape + scale_factor : [N, 2] scale_factor + """ + # TODO: support bs>1 + pred_ploys = self.rbox2poly(pred_bboxes.numpy(), False) + pred_ploys = paddle.to_tensor(pred_ploys) + pred_ploys = paddle.reshape( + pred_ploys, [1, pred_ploys.shape[0], pred_ploys.shape[1]]) + + pred_scores = paddle.to_tensor(pred_scores) + # pred_scores [NA, 16] --> [16, NA] + pred_scores = paddle.transpose(pred_scores, [1, 0]) + pred_scores = paddle.reshape( + pred_scores, [1, pred_scores.shape[0], pred_scores.shape[1]]) + pred_cls_score_bbox, bbox_num, index = self.nms(pred_ploys, pred_scores) + + # post process scale + # result [n, 10] + if bbox_num > 0: + pred_bbox, bbox_num = self.post_process(pred_cls_score_bbox[:, 2:], + bbox_num, im_shape[0], + scale_factor[0]) + + pred_cls_score_bbox = paddle.concat( + [pred_cls_score_bbox[:, 0:2], pred_bbox], axis=1) + else: + pred_cls_score_bbox = paddle.to_tensor( + np.array( + [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], + dtype='float32')) + bbox_num = paddle.to_tensor(np.array([1], dtype='int32')) + return pred_cls_score_bbox, bbox_num, index + + def post_process(self, bboxes, bbox_num, im_shape, scale_factor): + """ + Rescale, clip and filter the bbox from the output of NMS to + get final prediction. + + Args: + bboxes(Tensor): bboxes [N, 8] + bbox_num(Tensor): bbox_num + im_shape(Tensor): [1 2] + scale_factor(Tensor): [1 2] + Returns: + bbox_pred(Tensor): The output is the prediction with shape [N, 8] + including labels, scores and bboxes. The size of + bboxes are corresponding to the original image. + """ + + origin_shape = paddle.floor(im_shape / scale_factor + 0.5) + + origin_h = origin_shape[0] + origin_w = origin_shape[1] + + bboxes[:, 0::2] = bboxes[:, 0::2] / scale_factor[0] + bboxes[:, 1::2] = bboxes[:, 1::2] / scale_factor[1] + + zeros = paddle.zeros_like(origin_h) + x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w), zeros) + y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h), zeros) + x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w), zeros) + y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h), zeros) + x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w), zeros) + y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h), zeros) + x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w), zeros) + y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h), zeros) + bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1) + bboxes = (bbox, bbox_num) + return bboxes diff --git a/ppdet/modeling/proposal_generator/target_layer.py b/ppdet/modeling/proposal_generator/target_layer.py index 6ad82dad156a7dba23b6a8638d9f99077a8b4010..54aed4f85b9fd20712d27445701861c0bc0c7f19 100644 --- a/ppdet/modeling/proposal_generator/target_layer.py +++ b/ppdet/modeling/proposal_generator/target_layer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import sys import paddle - from ppdet.core.workspace import register, serializable - from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target +from ppdet.modeling import bbox_utils +import numpy as np @register @@ -176,3 +176,170 @@ class MaskAssigner(object): # mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights return outs + + +@register +class RBoxAssigner(object): + """ + assigner of rbox + Args: + pos_iou_thr (float): threshold of pos samples + neg_iou_thr (float): threshold of neg samples + min_iou_thr (float): the min threshold of samples + ignore_iof_thr (int): the ignored threshold + """ + + def __init__(self, + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_iou_thr=0.0, + ignore_iof_thr=-2): + super(RBoxAssigner, self).__init__() + + self.pos_iou_thr = pos_iou_thr + self.neg_iou_thr = neg_iou_thr + self.min_iou_thr = min_iou_thr + self.ignore_iof_thr = ignore_iof_thr + + def anchor_valid(self, anchors): + """ + + Args: + anchor: M x 4 + + Returns: + + """ + if anchors.ndim == 3: + anchors = anchors.reshape(-1, anchor.shape[-1]) + assert anchors.ndim == 2 + anchor_num = anchors.shape[0] + anchor_valid = np.ones((anchor_num), np.uint8) + anchor_inds = np.arange(anchor_num) + return anchor_inds + + def assign_anchor(self, + anchors, + gt_bboxes, + gt_lables, + pos_iou_thr, + neg_iou_thr, + min_iou_thr=0.0, + ignore_iof_thr=-2): + """ + + Args: + anchors: + gt_bboxes:[M, 5] rc,yc,w,h,angle + gt_lables: + + Returns: + + """ + assert anchors.shape[1] == 4 or anchors.shape[1] == 5 + assert gt_bboxes.shape[1] == 4 or gt_bboxes.shape[1] == 5 + anchors_xc_yc = anchors + gt_bboxes_xc_yc = gt_bboxes + + # calc rbox iou + anchors_xc_yc = anchors_xc_yc.astype(np.float32) + gt_bboxes_xc_yc = gt_bboxes_xc_yc.astype(np.float32) + anchors_xc_yc = paddle.to_tensor(anchors_xc_yc, place=paddle.CPUPlace()) + gt_bboxes_xc_yc = paddle.to_tensor( + gt_bboxes_xc_yc, place=paddle.CPUPlace()) + + try: + from rbox_iou_ops import rbox_iou + except Exception as e: + print('import custom_ops error', e) + sys.exit(-1) + + iou = rbox_iou(gt_bboxes_xc_yc, anchors_xc_yc) + iou = iou.numpy() + iou = iou.T + + # every gt's anchor's index + gt_bbox_anchor_inds = iou.argmax(axis=0) + gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])] + gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0] + + # every anchor's gt bbox's index + anchor_gt_bbox_inds = iou.argmax(axis=1) + anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds] + + # (1) set labels=-2 as default + labels = np.ones((iou.shape[0], ), dtype=np.int32) * ignore_iof_thr + + # (2) assign ignore + labels[anchor_gt_bbox_iou < min_iou_thr] = ignore_iof_thr + + # (3) assign neg_ids -1 + assign_neg_ids1 = anchor_gt_bbox_iou >= min_iou_thr + assign_neg_ids2 = anchor_gt_bbox_iou < neg_iou_thr + assign_neg_ids = np.logical_and(assign_neg_ids1, assign_neg_ids2) + labels[assign_neg_ids] = -1 + + # anchor_gt_bbox_iou_inds + # (4) assign max_iou as pos_ids >=0 + anchor_gt_bbox_iou_inds = anchor_gt_bbox_inds[gt_bbox_anchor_iou_inds] + # gt_bbox_anchor_iou_inds = np.logical_and(gt_bbox_anchor_iou_inds, anchor_gt_bbox_iou >= min_iou_thr) + labels[gt_bbox_anchor_iou_inds] = gt_lables[anchor_gt_bbox_iou_inds] + + # (5) assign >= pos_iou_thr as pos_ids + iou_pos_iou_thr_ids = anchor_gt_bbox_iou >= pos_iou_thr + iou_pos_iou_thr_ids_box_inds = anchor_gt_bbox_inds[iou_pos_iou_thr_ids] + labels[iou_pos_iou_thr_ids] = gt_lables[iou_pos_iou_thr_ids_box_inds] + return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels + + def __call__(self, anchors, gt_bboxes, gt_labels, is_crowd): + + assert anchors.ndim == 2 + assert anchors.shape[1] == 5 + assert gt_bboxes.ndim == 2 + assert gt_bboxes.shape[1] == 5 + + pos_iou_thr = self.pos_iou_thr + neg_iou_thr = self.neg_iou_thr + min_iou_thr = self.min_iou_thr + ignore_iof_thr = self.ignore_iof_thr + + anchor_num = anchors.shape[0] + anchors_inds = self.anchor_valid(anchors) + anchors = anchors[anchors_inds] + gt_bboxes = gt_bboxes + is_crowd_slice = is_crowd + not_crowd_inds = np.where(is_crowd_slice == 0) + + # Step1: match anchor and gt_bbox + anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = self.assign_anchor( + anchors, gt_bboxes, + gt_labels.reshape(-1), pos_iou_thr, neg_iou_thr, min_iou_thr, + ignore_iof_thr) + + # Step2: sample anchor + pos_inds = np.where(labels >= 0)[0] + neg_inds = np.where(labels == -1)[0] + + # Step3: make output + anchors_num = anchors.shape[0] + bbox_targets = np.zeros_like(anchors) + bbox_weights = np.zeros_like(anchors) + pos_labels = np.ones(anchors_num, dtype=np.int32) * -1 + pos_labels_weights = np.zeros(anchors_num, dtype=np.float32) + + pos_sampled_anchors = anchors[pos_inds] + #print('ancho target pos_inds', pos_inds, len(pos_inds)) + pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]] + if len(pos_inds) > 0: + pos_bbox_targets = bbox_utils.rbox2delta(pos_sampled_anchors, + pos_sampled_gt_boxes) + bbox_targets[pos_inds, :] = pos_bbox_targets + bbox_weights[pos_inds, :] = 1.0 + + pos_labels[pos_inds] = labels[pos_inds] + pos_labels_weights[pos_inds] = 1.0 + + if len(neg_inds) > 0: + pos_labels_weights[neg_inds] = 1.0 + return (pos_labels, pos_labels_weights, bbox_targets, bbox_weights, + pos_inds, neg_inds) diff --git a/ppdet/utils/visualizer.py b/ppdet/utils/visualizer.py index 5327fef1d2cc92910347ae96f014d79453f70802..ecf95954107ebad2f43f37d135c7f381c50bd7cc 100644 --- a/ppdet/utils/visualizer.py +++ b/ppdet/utils/visualizer.py @@ -20,8 +20,9 @@ from __future__ import unicode_literals import numpy as np from PIL import Image, ImageDraw import cv2 - from .colormap import colormap +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) __all__ = ['visualize_results'] @@ -86,21 +87,32 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold): if score < threshold: continue - xmin, ymin, w, h = bbox - xmax = xmin + w - ymax = ymin + h - if catid not in catid2color: idx = np.random.randint(len(color_list)) catid2color[catid] = color_list[idx] color = tuple(catid2color[catid]) # draw bbox - draw.line( - [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), - (xmin, ymin)], - width=2, - fill=color) + if len(bbox) == 4: + # draw bbox + xmin, ymin, w, h = bbox + xmax = xmin + w + ymax = ymin + h + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill=color) + elif len(bbox) == 8: + x1, y1, x2, y2, x3, y3, x4, y4 = bbox + draw.line( + [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)], + width=2, + fill=color) + xmin = min(x1, x2, x3, x4) + ymin = min(y1, y2, y3, y4) + else: + logger.error('the shape of bbox must be [M, 4] or [M, 8]!') # draw label text = "{} {:.2f}".format(catid2name[catid], score) @@ -112,6 +124,23 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold): return image +def save_result(save_path, bbox_res, catid2name, threshold): + """ + save result as txt + """ + with open(save_path, 'w') as f: + for dt in bbox_res: + catid, bbox, score = dt['category_id'], dt['bbox'], dt['score'] + if score < threshold: + continue + # each bbox result as a line + # for rbox: classname score x1 y1 x2 y2 x3 y3 x4 y4 + # for bbox: classname score x1 y1 w h + bbox_pred = '{} {} '.format(catid2name[catid], score) + ' '.join( + [str(e) for e in bbox]) + f.write(bbox_pred + '\n') + + def draw_segm(image, im_id, catid2name, diff --git a/tools/infer.py b/tools/infer.py index c33d7a438b832edc714a13660a50c3fda2736952..a73e14a67cb3ee9705477bba449fa15418fa8092 100755 --- a/tools/infer.py +++ b/tools/infer.py @@ -71,6 +71,11 @@ def parse_args(): type=str, default="vdl_log_dir/image", help='VisualDL logging directory for image.') + parser.add_argument( + "--save_txt", + type=bool, + default=False, + help="whether to record the data to VisualDL.") args = parser.parse_args() return args @@ -120,7 +125,8 @@ def run(FLAGS, cfg): trainer.predict( images, draw_threshold=FLAGS.draw_threshold, - output_dir=FLAGS.output_dir) + output_dir=FLAGS.output_dir, + save_txt=FLAGS.save_txt) def main():