未验证 提交 bc8b0d44 编写于 作者: C cnn 提交者: GitHub

[dev] add s2anet (#2432)

* Add S2ANet model for Oriented Object Detection.
上级 4d566f1c
metric: COCO
num_classes: 15
TrainDataset:
!COCODataSet
image_dir: trainval_split/images
anno_path: trainval_split/s2anet_trainval_paddle_coco.json
dataset_dir: /paddle/dataset/DOTA_1024_s2anet
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_rbox']
EvalDataset:
!COCODataSet
image_dir: trainval_split/images
anno_path: trainval_split/s2anet_trainval_paddle_coco.json
dataset_dir: /paddle/dataset/DOTA_1024_s2anet/
TestDataset:
!ImageFolder
anno_path: trainval_split/s2anet_trainval_paddle_coco.json
dataset_dir: /paddle/dataset/DOTA_1024_s2anet/
# S2ANet模型
## 内容
- [简介](#简介)
- [DOTA数据集](#DOTA数据集)
- [模型库](#模型库)
- [训练说明](#训练说明)
## 简介
[S2ANet](https://arxiv.org/pdf/2008.09397.pdf)是用于检测旋转框的模型,要求使用PaddlePaddle 2.0.1(可使用pip安装) 或适当的[develop版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/install/Tables.html#whl-release)
## DOTA数据集
[DOTA Dataset]是航空影像中物体检测的数据集,包含2806张图像,每张图像4000*4000分辨率。
| 数据版本 | 类别数 | 图像数 | 图像尺寸 | 实例数 | 标注方式 |
|:--------:|:-------:|:---------:|:---------:| :---------:| :------------: |
| v1.0 | 15 | 2806 | 800~4000 | 118282 | OBB + HBB |
| v1.5 | 16 | 2806 | 800~4000 | 400000 | OBB + HBB |
注:OBB标注方式是指标注任意四边形;顶点按顺时针顺序排列。HBB标注方式是指标注示例的外接矩形。
DOTA数据集中总共有2806张图像,其中1411张图像作为训练集,458张图像作为评估集,剩余937张图像作为测试集。
如果需要切割图像数据,请参考[DOTA_devkit](https://github.com/CAPTAIN-WHU/DOTA_devkit)
设置`crop_size=1024, stride=824, gap=200`参数切割数据后,训练集15749张图像,评估集5297张图像,测试集10833张图像。
## 模型库
### S2ANet模型
| 模型 | GPU个数 | Conv类型 | mAP | 模型下载 | 配置文件 |
|:-----------:|:-------:|:----------:|:--------:| :----------:| :---------: |
| S2ANet | 8 | Conv | 71.42 | [model](https://paddledet.bj.bcebos.com/models/s2anet_conv_1x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/s2anet_conv_1x_dota.yml) |
**注意:**这里使用`multiclass_nms`,与原作者使用nms略有不同,精度相比原始论文中高0.15 (71.27-->71.42)。
## 训练说明
### 1. 旋转框IOU计算OP
旋转框IOU计算OP[ext_op](../../ppdet/ext_op)是参考Paddle[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/new_custom_op.html) 的方式开发。
若使用旋转框IOU计算OP,需要环境满足:
- PaddlePaddle >= 2.0.1
- GCC == 8.2
推荐使用docker镜像[paddle:2.0.1-gpu-cuda10.1-cudnn7](registry.baidubce.com/paddlepaddle/paddle:2.0.1-gpu-cuda10.1-cudnn7)
执行如下命令下载镜像并启动容器:
```
sudo nvidia-docker run -it --name paddle_s2anet -v $PWD:/paddle --network=host registry.baidubce.com/paddlepaddle/paddle:2.0.1-gpu-cuda10.1-cudnn7 /bin/bash
```
进入容器后,安装必要的python包:
```
python3.7 -m pip install Cython wheel tqdm opencv-python==4.2.0.32 scipy PyYAML shapely pycocotools
```
镜像中paddle2.0.1已安装好,进入python3.7,执行如下代码检查paddle安装是否正常:
```
import paddle
print(paddle.__version__)
paddle.utils.run_check()
```
进入到`ext_op`文件夹,安装:
```
python3.7 setup.py install
```
安装完成后,测试自定义op是否可以正常编译以及计算结果:
```
cd PaddleDetecetion/ppdet/ext_op
python3.7 test.py
```
### 2. 数据格式
DOTA 数据集中实例是按照任意四边形标注,在进行训练模型前,需要参考[DOTA2COCO](https://github.com/CAPTAIN-WHU/DOTA_devkit/blob/master/DOTA2COCO.py) 转换成`[xc, yc, bow_w, bow_h, angle]`格式,并以coco数据格式存储。
## 评估
执行如下命令,会在`output_dir`文件夹下将每个图像预测结果保存到同文件夹名的txt文本中。
```
python3.7 tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=./weights/s2anet_1x_dota.pdparams --infer_dir=dota_test_images --draw_threshold=0.05 --save_txt=True --output_dir=output
```
请参考[DOTA_devkit](https://github.com/CAPTAIN-WHU/DOTA_devkit) 生成评估文件,评估文件格式请参考[DOTA Test](http://captain.whu.edu.cn/DOTAweb/tasks.html) ,生成zip文件,每个类一个txt文件,txt文件中每行格式为:`image_id score x1 y1 x2 y2 x3 y3 x4 y4`,提交服务器进行评估。
## 预测部署
Paddle中`multiclass_nms`算子的输入支持四边形输入,因此部署时可以不不需要依赖旋转框IOU计算算子。
```bash
# 预测
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=model.pdparams --infer_img=demo/P0072__1.0__0___0.png --use_gpu=True
```
## Citations
```
@article{han2021align,
author={J. {Han} and J. {Ding} and J. {Li} and G. -S. {Xia}},
journal={IEEE Transactions on Geoscience and Remote Sensing},
title={Align Deep Features for Oriented Object Detection},
year={2021},
pages={1-11},
doi={10.1109/TGRS.2021.3062048}}
@inproceedings{xia2018dota,
title={DOTA: A large-scale dataset for object detection in aerial images},
author={Xia, Gui-Song and Bai, Xiang and Ding, Jian and Zhu, Zhen and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={3974--3983},
year={2018}
}
```
architecture: S2ANet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
weights: output/s2anet_r50_fpn_1x_dota/model_final.pdparams
# Model Achitecture
S2ANet:
backbone: ResNet
neck: FPN
s2anet_head: S2ANetHead
s2anet_bbox_post_process: S2ANetBBoxPostProcess
ResNet:
depth: 50
norm_type: bn
return_idx: [1,2,3]
num_stages: 4
FPN:
in_channels: [256, 512, 1024]
out_channel: 256
spatial_scales: [0.25, 0.125, 0.0625]
has_extra_convs: True
extra_stage: 2
relu_before_extra_convs: False
S2ANetHead:
anchor_strides: [8, 16, 32, 64, 128]
anchor_scales: [4]
anchor_ratios: [1.0]
anchor_assign: RBoxAssigner
stacked_convs: 2
feat_in: 256
feat_out: 256
num_classes: 15
align_conv_type: 'AlignConv' # AlignConv Conv
align_conv_size: 3
use_sigmoid_cls: True
RBoxAssigner:
pos_iou_thr: 0.5
neg_iou_thr: 0.4
min_iou_thr: 0.0
ignore_iof_thr: -2
S2ANetBBoxPostProcess:
nms_pre: 2000
min_bbox_size: 0.0
nms:
name: MultiClassNMS
keep_top_k: -1
score_threshold: 0.05
nms_threshold: 0.1
epoch: 12
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [7, 10]
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
clip_grad_by_norm: 35
worker_num: 0
TrainReader:
sample_transforms:
- Decode: {}
- Rbox2Poly: {}
# Resize can process rbox
- Resize: {target_size: [1024, 1024], interp: 2, keep_ratio: False}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- RboxPadBatch: {pad_to_stride: 32, pad_gt: true}
batch_size: 1
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- RboxPadBatch: {pad_to_stride: 32, pad_gt: false}
batch_size: 1
shuffle: false
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [1024, 1024], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- RboxPadBatch: {pad_to_stride: 32, pad_gt: false}
batch_size: 1
shuffle: false
drop_last: false
_BASE_: [
'../datasets/dota.yml',
'../runtime.yml',
'_base_/s2anet_optimizer_1x.yml',
'_base_/s2anet.yml',
'_base_/s2anet_reader.yml',
]
weights: output/s2anet_1x_dota/model_final
_BASE_: [
'../datasets/dota_debug.yml',
'../runtime.yml',
'_base_/s2anet_optimizer_1x.yml',
'_base_/s2anet.yml',
'_base_/s2anet_reader.yml',
]
weights: output/s2anet_1x_dota/model_final
S2ANetHead:
anchor_strides: [ 8, 16, 32, 64, 128 ]
anchor_scales: [ 4 ]
anchor_ratios: [ 1.0 ]
anchor_assign: RBoxAssigner
stacked_convs: 2
feat_in: 256
feat_out: 256
num_classes: 15
align_conv_type: 'Conv' # AlignConv Conv
align_conv_size: 3
use_sigmoid_cls: True
......@@ -52,6 +52,7 @@ python tools/infer.py -c configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.yml --i
| --draw_threshold | infer | 可视化时分数阈值 | 0.5 | 例如`--draw_threshold=0.7` |
| --infer_dir | infer | 用于预测的图片文件夹路径 | None | `--infer_img``--infer_dir`必须至少设置一个 |
| --infer_img | infer | 用于预测的图片路径 | None | `--infer_img``--infer_dir`必须至少设置一个,`infer_img`具有更高优先级 |
| --save_txt | infer | 是否在文件夹下将图片的预测结果保存到文本文件中 | False | 可选 |
## 使用示例
......
......@@ -102,6 +102,16 @@ class COCODataSet(DetDataset):
else:
if not any(np.array(inst['bbox'])):
continue
# read rbox anno or not
is_rbox_anno = True if len(inst['bbox']) == 5 else False
if is_rbox_anno:
xc, yc, box_w, box_h, angle = inst['bbox']
x1 = xc - box_w / 2.0
y1 = yc - box_h / 2.0
x2 = x1 + box_w
y2 = y1 + box_h
else:
x1, y1, box_w, box_h = inst['bbox']
x2 = x1 + box_w
y2 = y1 + box_h
......@@ -110,6 +120,8 @@ class COCODataSet(DetDataset):
inst['clean_bbox'] = [
round(float(x), 3) for x in [x1, y1, x2, y2]
]
if is_rbox_anno:
inst['clean_rbox'] = [xc, yc, box_w, box_h, angle]
bboxes.append(inst)
else:
logger.warning(
......@@ -122,6 +134,9 @@ class COCODataSet(DetDataset):
continue
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
if is_rbox_anno:
gt_rbox = np.zeros((num_bbox, 5), dtype=np.float32)
gt_theta = np.zeros((num_bbox, 1), dtype=np.int32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
......@@ -132,6 +147,9 @@ class COCODataSet(DetDataset):
catid = box['category_id']
gt_class[i][0] = self.catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
# xc, yc, w, h, theta
if is_rbox_anno:
gt_rbox[i, :] = box['clean_rbox']
is_crowd[i][0] = box['iscrowd']
# check RLE format
if 'segmentation' in box and box['iscrowd'] == 1:
......@@ -150,12 +168,22 @@ class COCODataSet(DetDataset):
'w': im_w,
} if 'image' in self.data_fields else {}
if is_rbox_anno:
gt_rec = {
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_rbox': gt_rbox,
'gt_poly': gt_poly,
}
else:
gt_rec = {
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
for k, v in gt_rec.items():
if k in self.data_fields:
coco_rec[k] = v
......
......@@ -31,12 +31,8 @@ from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'PadBatch',
'BatchRandomResize',
'Gt2YoloTarget',
'Gt2FCOSTarget',
'Gt2TTFTarget',
'Gt2Solov2Target',
'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
'Gt2TTFTarget', 'Gt2Solov2Target', 'RboxPadBatch'
]
......@@ -739,3 +735,155 @@ class Gt2Solov2Target(BaseOperator):
data['grid_order{}'.format(idx)] = gt_grid_order
return samples
@register_op
class RboxPadBatch(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'. And convert poly to rbox.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`.
"""
def __init__(self, pad_to_stride=0, pad_gt=False):
super(RboxPadBatch, self).__init__()
self.pad_to_stride = pad_to_stride
self.pad_gt = pad_gt
def poly_to_rbox(self, polys):
"""
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
to
rotated_boxes:[x_ctr,y_ctr,w,h,angle]
"""
rotated_boxes = []
for poly in polys:
poly = np.array(poly[:8], dtype=np.float32)
pt1 = (poly[0], poly[1])
pt2 = (poly[2], poly[3])
pt3 = (poly[4], poly[5])
pt4 = (poly[6], poly[7])
edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[
1] - pt2[1]) * (pt1[1] - pt2[1]))
edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[
1] - pt3[1]) * (pt2[1] - pt3[1]))
width = max(edge1, edge2)
height = min(edge1, edge2)
angle = 0
if edge1 > edge2:
angle = np.arctan2(
np.float(pt2[1] - pt1[1]), np.float(pt2[0] - pt1[0]))
elif edge2 >= edge1:
angle = np.arctan2(
np.float(pt4[1] - pt1[1]), np.float(pt4[0] - pt1[0]))
def norm_angle(angle, range=[-np.pi / 4, np.pi]):
return (angle - range[0]) % range[1] + range[0]
angle = norm_angle(angle)
x_ctr = np.float(pt1[0] + pt3[0]) / 2.0
y_ctr = np.float(pt1[1] + pt3[1]) / 2.0
rotated_box = np.array([x_ctr, y_ctr, width, height, angle])
rotated_boxes.append(rotated_box)
ret_rotated_boxes = np.array(rotated_boxes)
assert ret_rotated_boxes.shape[1] == 5
return ret_rotated_boxes
def __call__(self, samples, context=None):
"""
Args:
samples (list): a batch of sample, each is dict.
"""
coarsest_stride = self.pad_to_stride
max_shape = np.array([data['image'].shape for data in samples]).max(
axis=0)
if coarsest_stride > 0:
max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(
np.ceil(max_shape[2] / coarsest_stride) * coarsest_stride)
for data in samples:
im = data['image']
im_c, im_h, im_w = im.shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
data['image'] = padding_im
if 'semantic' in data and data['semantic'] is not None:
semantic = data['semantic']
padding_sem = np.zeros(
(1, max_shape[1], max_shape[2]), dtype=np.float32)
padding_sem[:, :im_h, :im_w] = semantic
data['semantic'] = padding_sem
if 'gt_segm' in data and data['gt_segm'] is not None:
gt_segm = data['gt_segm']
padding_segm = np.zeros(
(gt_segm.shape[0], max_shape[1], max_shape[2]),
dtype=np.uint8)
padding_segm[:, :im_h, :im_w] = gt_segm
data['gt_segm'] = padding_segm
if self.pad_gt:
gt_num = []
if 'gt_poly' in data and data['gt_poly'] is not None and len(data[
'gt_poly']) > 0:
pad_mask = True
else:
pad_mask = False
if pad_mask:
poly_num = []
poly_part_num = []
point_num = []
for data in samples:
gt_num.append(data['gt_bbox'].shape[0])
if pad_mask:
poly_num.append(len(data['gt_poly']))
for poly in data['gt_poly']:
poly_part_num.append(int(len(poly)))
for p_p in poly:
point_num.append(int(len(p_p) / 2))
gt_num_max = max(gt_num)
for i, sample in enumerate(samples):
assert 'gt_rbox' in sample
assert 'gt_rbox2poly' in sample
gt_box_data = -np.ones([gt_num_max, 4], dtype=np.float32)
gt_class_data = -np.ones([gt_num_max], dtype=np.int32)
is_crowd_data = np.ones([gt_num_max], dtype=np.int32)
if pad_mask:
poly_num_max = max(poly_num)
poly_part_num_max = max(poly_part_num)
point_num_max = max(point_num)
gt_masks_data = -np.ones(
[poly_num_max, poly_part_num_max, point_num_max, 2],
dtype=np.float32)
gt_num = sample['gt_bbox'].shape[0]
gt_box_data[0:gt_num, :] = sample['gt_bbox']
gt_class_data[0:gt_num] = np.squeeze(sample['gt_class'])
is_crowd_data[0:gt_num] = np.squeeze(sample['is_crowd'])
if pad_mask:
for j, poly in enumerate(sample['gt_poly']):
for k, p_p in enumerate(poly):
pp_np = np.array(p_p).reshape(-1, 2)
gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
sample['gt_poly'] = gt_masks_data
sample['gt_bbox'] = gt_box_data
sample['gt_class'] = gt_class_data
sample['is_crowd'] = is_crowd_data
# ploy to rbox
polys = sample['gt_rbox2poly']
rbox = self.poly_to_rbox(polys)
sample['gt_rbox'] = rbox
return samples
......@@ -536,6 +536,17 @@ class RandomFlip(BaseOperator):
bbox[:, 2] = width - oldx1
return bbox
def apply_rbox(self, bbox, width):
oldx1 = bbox[:, 0].copy()
oldx2 = bbox[:, 2].copy()
oldx3 = bbox[:, 4].copy()
oldx4 = bbox[:, 6].copy()
bbox[:, 0] = width - oldx2
bbox[:, 2] = width - oldx1
bbox[:, 4] = width - oldx3
bbox[:, 6] = width - oldx4
return bbox
def apply(self, sample, context=None):
"""Filp the image and bounding box.
Operators:
......@@ -567,6 +578,10 @@ class RandomFlip(BaseOperator):
if 'gt_segm' in sample and sample['gt_segm'].any():
sample['gt_segm'] = sample['gt_segm'][:, :, ::-1]
if 'gt_rbox2poly' in sample and sample['gt_rbox2poly'].any():
sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'],
width)
sample['flipped'] = True
sample['image'] = im
return sample
......@@ -704,6 +719,16 @@ class Resize(BaseOperator):
[im_scale_x, im_scale_y],
[resize_w, resize_h])
# apply rbox
if 'gt_rbox2poly' in sample:
if np.array(sample['gt_rbox2poly']).shape[1] != 8:
logger.warn(
"gt_rbox2poly's length shoule be 8, but actually is {}".
format(len(sample['gt_rbox2poly'])))
sample['gt_rbox2poly'] = self.apply_bbox(sample['gt_rbox2poly'],
[im_scale_x, im_scale_y],
[resize_w, resize_h])
# apply polygon
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2],
......@@ -1933,3 +1958,113 @@ class Poly2Mask(BaseOperator):
]
sample['gt_segm'] = np.asarray(masks).astype(np.uint8)
return sample
@register_op
class Rbox2Poly(BaseOperator):
"""
Convert rbbox format to poly format.
"""
def __init__(self):
super(Rbox2Poly, self).__init__()
def apply(self, sample, context=None):
assert 'gt_rbox' in sample
assert sample['gt_rbox'].shape[1] == 5
rrect = sample['gt_rbox']
bbox_num = rrect.shape[0]
x_ctr = rrect[:, 0]
y_ctr = rrect[:, 1]
width = rrect[:, 2]
height = rrect[:, 3]
angle = rrect[:, 4]
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
# rect 2x4
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
R = np.array([[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]])
poly = []
for i in range(R.shape[2]):
tmp_r = R[:, :, i].reshape(2, 2)
poly.append(tmp_r.dot(rect[:, :, i]))
# poly:[M, 2, 4]
poly = np.array(poly)
coor_x = poly[:, 0, :4] + x_ctr.reshape(bbox_num, 1)
coor_y = poly[:, 1, :4] + y_ctr.reshape(bbox_num, 1)
poly = np.stack(
[
coor_x[:, 0], coor_y[:, 0], coor_x[:, 1], coor_y[:, 1],
coor_x[:, 2], coor_y[:, 2], coor_x[:, 3], coor_y[:, 3]
],
axis=1)
x1 = x_ctr - width / 2.0
y1 = y_ctr - height / 2.0
x2 = x_ctr + width / 2.0
y2 = y_ctr + height / 2.0
sample['gt_bbox'] = np.stack([x1, y1, x2, y2], axis=1)
sample['gt_rbox2poly'] = poly
return sample
@register_op
class Poly2Rbox(BaseOperator):
"""
Convert poly format to rbbox format.
"""
def __init__(self):
super(Poly2Rbox, self).__init__()
def poly_to_rbox(self, polys):
"""
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
to
rotated_boxes:[x_ctr,y_ctr,w,h,angle]
"""
rotated_boxes = []
for poly in polys:
poly = np.array(poly[:8], dtype=np.float32)
pt1 = (poly[0], poly[1])
pt2 = (poly[2], poly[3])
pt3 = (poly[4], poly[5])
pt4 = (poly[6], poly[7])
edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[
1] - pt2[1]) * (pt1[1] - pt2[1]))
edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[
1] - pt3[1]) * (pt2[1] - pt3[1]))
width = max(edge1, edge2)
height = min(edge1, edge2)
angle = 0
if edge1 > edge2:
angle = np.arctan2(
np.float(pt2[1] - pt1[1]), np.float(pt2[0] - pt1[0]))
elif edge2 >= edge1:
angle = np.arctan2(
np.float(pt4[1] - pt1[1]), np.float(pt4[0] - pt1[0]))
def norm_angle(angle, range=[-np.pi / 4, np.pi]):
return (angle - range[0]) % range[1] + range[0]
angle = norm_angle(angle)
x_ctr = np.float(pt1[0] + pt3[0]) / 2
y_ctr = np.float(pt1[1] + pt3[1]) / 2
rotated_box = np.array([x_ctr, y_ctr, width, height, angle])
rotated_boxes.append(rotated_box)
ret_rotated_boxes = np.array(rotated_boxes)
assert ret_rotated_boxes.shape[1] == 5
return ret_rotated_boxes
def apply(self, sample, context=None):
assert 'gt_rbox2poly' in sample
poly = sample['gt_rbox2poly']
rbox = self.poly_to_rbox(poly)
sample['gt_rbox'] = rbox
return sample
......@@ -31,6 +31,7 @@ TRT_MIN_SUBGRAPH = {
'SSD': 60,
'RCNN': 40,
'RetinaNet': 40,
'S2ANet': 40,
'EfficientDet': 40,
'Face': 3,
'TTFNet': 3,
......
......@@ -31,7 +31,7 @@ from paddle.static import InputSpec
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results
from ppdet.utils.visualizer import visualize_results, save_result
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results
from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats
......@@ -333,7 +333,11 @@ class Trainer(object):
def evaluate(self):
self._eval_with_loader(self.loader)
def predict(self, images, draw_threshold=0.5, output_dir='output'):
def predict(self,
images,
draw_threshold=0.5,
output_dir='output',
save_txt=False):
self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0)
......@@ -369,6 +373,7 @@ class Trainer(object):
if 'mask' in batch_res else None
segm_res = batch_res['segm'][start:end] \
if 'segm' in batch_res else None
image = visualize_results(image, bbox_res, mask_res, segm_res,
int(outs['im_id']), catid2name,
draw_threshold)
......@@ -380,6 +385,9 @@ class Trainer(object):
logger.info("Detection bbox results save in {}".format(
save_name))
image.save(save_name, quality=95)
if save_txt:
save_path = os.path.splitext(save_name)[0] + '.txt'
save_result(save_path, bbox_res, catid2name, draw_threshold)
start = end
def _get_save_image_name(self, output_dir, image_path):
......
# 自定义OP编译
旋转框IOU计算OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/new_custom_op.html)
## 1. 环境依赖
- Paddle >= 2.0.1
- gcc 8.2
## 2. 安装
```
python3.7 setup.py install
```
按照如下方式使用
```
# 引入自定义op
from rbox_iou_ops import rbox_iou
paddle.set_device('gpu:0')
paddle.disable_static()
rbox1 = np.random.rand(13000, 5)
rbox2 = np.random.rand(7, 5)
pd_rbox1 = paddle.to_tensor(rbox1)
pd_rbox2 = paddle.to_tensor(rbox2)
iou = rbox_iou(pd_rbox1, pd_rbox2)
print('iou', iou)
```
## 3. 单元测试
单元测试`test.py`文件中,通过对比python实现的结果和测试自定义op结果。
由于python计算细节与cpp计算细节略有区别,误差区间设置为0.02。
```
python3.7 test.py
```
提示`rbox_iou OP compute right!`说明OP测试通过。
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/extension.h"
#include <vector>
std::vector<paddle::Tensor> RboxIouCPUForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2);
std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2);
#define CHECK_INPUT_SAME(x1, x2) PD_CHECK(x1.place() == x2.place(), "input must be smae pacle.")
std::vector<paddle::Tensor> RboxIouForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) {
CHECK_INPUT_SAME(rbox1, rbox2);
if (rbox1.place() == paddle::PlaceType::kCPU) {
return RboxIouCPUForward(rbox1, rbox2);
}
else if (rbox1.place() == paddle::PlaceType::kGPU) {
return RboxIouCUDAForward(rbox1, rbox2);
}
}
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> rbox1_shape, std::vector<int64_t> rbox2_shape) {
return {{rbox1_shape[0], rbox2_shape[0]}};
}
std::vector<paddle::DataType> InferDtype(paddle::DataType t1, paddle::DataType t2) {
return {t1};
}
PD_BUILD_OP(rbox_iou)
.Inputs({"RBOX1", "RBOX2"})
.Outputs({"Output"})
.SetKernelFn(PD_KERNEL(RboxIouForward))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cassert>
#include <cmath>
#ifdef __CUDACC__
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
#include "paddle/extension.h"
#include <vector>
namespace {
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
return Point(x + p.x, y + p.y);
}
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
x += p.x;
y += p.y;
return *this;
}
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
return Point(x - p.x, y - p.y);
}
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
return Point(x * coeff, y * coeff);
}
};
template <typename T>
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}
template <typename T>
HOST_DEVICE_INLINE T cross_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.y - B.x * A.y;
}
template <typename T>
HOST_DEVICE_INLINE void get_rotated_vertices(
const RotatedBox<T>& box,
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251
//double theta = box.a * 0.01745329251;
//MODIFIED
double theta = box.a;
T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f;
// y: top --> down; x: left --> right
pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
pts[2].x = 2 * box.x_ctr - pts[0].x;
pts[2].y = 2 * box.y_ctr - pts[0].y;
pts[3].x = 2 * box.x_ctr - pts[1].x;
pts[3].y = 2 * box.y_ctr - pts[1].y;
}
template <typename T>
HOST_DEVICE_INLINE int get_intersection_points(
const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = cross_2d<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = cross_2d<T>(vec2[j], vec12) / det;
T t2 = cross_2d<T>(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
HOST_DEVICE_INLINE int convex_hull_graham(
const Point<T> (&p)[24],
const int& num_in,
Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto& start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
#ifdef __CUDACC__
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) ||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}
#else
// CPU version
std::sort(
q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else {
return temp > 0;
}
});
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
}
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
HOST_DEVICE_INLINE T rboxes_intersection(
const RotatedBox<T>& box1,
const RotatedBox<T>& box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
get_rotated_vertices<T>(box1, pts1);
get_rotated_vertices<T>(box2, pts2);
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
} // namespace
template <typename T>
HOST_DEVICE_INLINE T
rbox_iou_single(T const* const box1_raw, T const* const box2_raw) {
// shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
box1.x_ctr = box1_raw[0] - center_shift_x;
box1.y_ctr = box1_raw[1] - center_shift_y;
box1.w = box1_raw[2];
box1.h = box1_raw[3];
box1.a = box1_raw[4];
box2.x_ctr = box2_raw[0] - center_shift_x;
box2.y_ctr = box2_raw[1] - center_shift_y;
box2.w = box2_raw[2];
box2.h = box2_raw[3];
box2.a = box2_raw[4];
const T area1 = box1.w * box1.h;
const T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
const T intersection = rboxes_intersection<T>(box1, box2);
const T iou = intersection / (area1 + area2 - intersection);
return iou;
}
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
/**
Computes ceil(a / b)
*/
template <typename T>
__host__ __device__ __forceinline__ T CeilDiv0(T a, T b) {
return (a + b - 1) / b;
}
static inline int CeilDiv(const int a, const int b) {
return (a + b -1) / b;
}
template <typename T>
__global__ void rbox_iou_cuda_kernel(
const int rbox1_num,
const int rbox2_num,
const T* rbox1_data_ptr,
const T* rbox2_data_ptr,
T* output_data_ptr) {
// get row_start and col_start
const int rbox1_block_idx = blockIdx.x * blockDim.x;
const int rbox2_block_idx = blockIdx.y * blockDim.y;
const int rbox1_thread_num = min(rbox1_num - rbox1_block_idx, blockDim.x);
const int rbox2_thread_num = min(rbox2_num - rbox2_block_idx, blockDim.y);
__shared__ T block_boxes1[BLOCK_DIM_X * 5];
__shared__ T block_boxes2[BLOCK_DIM_Y * 5];
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
if (threadIdx.x < rbox1_thread_num && threadIdx.y == 0) {
block_boxes1[threadIdx.x * 5 + 0] =
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 0];
block_boxes1[threadIdx.x * 5 + 1] =
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 1];
block_boxes1[threadIdx.x * 5 + 2] =
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 2];
block_boxes1[threadIdx.x * 5 + 3] =
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 3];
block_boxes1[threadIdx.x * 5 + 4] =
rbox1_data_ptr[(rbox1_block_idx + threadIdx.x) * 5 + 4];
}
// threadIdx.x < BLOCK_DIM_Y=rbox2_thread_num, just use same condition as above: threadIdx.y == 0
if (threadIdx.x < rbox2_thread_num && threadIdx.y == 0) {
block_boxes2[threadIdx.x * 5 + 0] =
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 0];
block_boxes2[threadIdx.x * 5 + 1] =
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 1];
block_boxes2[threadIdx.x * 5 + 2] =
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 2];
block_boxes2[threadIdx.x * 5 + 3] =
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 3];
block_boxes2[threadIdx.x * 5 + 4] =
rbox2_data_ptr[(rbox2_block_idx + threadIdx.x) * 5 + 4];
}
// sync
__syncthreads();
if (threadIdx.x < rbox1_thread_num && threadIdx.y < rbox2_thread_num) {
int offset = (rbox1_block_idx + threadIdx.x) * rbox2_num + rbox2_block_idx + threadIdx.y;
output_data_ptr[offset] = rbox_iou_single<T>(block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
}
}
#define CHECK_INPUT_GPU(x) PD_CHECK(x.place() == paddle::PlaceType::kGPU, #x " must be a GPU Tensor.")
std::vector<paddle::Tensor> RboxIouCUDAForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) {
CHECK_INPUT_GPU(rbox1);
CHECK_INPUT_GPU(rbox2);
auto rbox1_num = rbox1.shape()[0];
auto rbox2_num = rbox2.shape()[0];
auto output = paddle::Tensor(paddle::PlaceType::kGPU);
output.reshape({rbox1_num, rbox2_num});
const int blocks_x = CeilDiv(rbox1_num, BLOCK_DIM_X);
const int blocks_y = CeilDiv(rbox2_num, BLOCK_DIM_Y);
dim3 blocks(blocks_x, blocks_y);
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
PD_DISPATCH_FLOATING_TYPES(
rbox1.type(),
"rbox_iou_cuda_kernel",
([&] {
rbox_iou_cuda_kernel<data_t><<<blocks, threads, 0, rbox1.stream()>>>(
rbox1_num,
rbox2_num,
rbox1.data<data_t>(),
rbox2.data<data_t>(),
output.mutable_data<data_t>());
}));
return {output};
}
template <typename T>
void rbox_iou_cpu_kernel(
const int rbox1_num,
const int rbox2_num,
const T* rbox1_data_ptr,
const T* rbox2_data_ptr,
T* output_data_ptr) {
int i, j;
for (i = 0; i < rbox1_num; i++) {
for (j = 0; j < rbox2_num; j++) {
int offset = i * rbox2_num + j;
output_data_ptr[offset] = rbox_iou_single<T>(rbox1_data_ptr + i * 5, rbox2_data_ptr + j * 5);
}
}
}
#define CHECK_INPUT_CPU(x) PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
std::vector<paddle::Tensor> RboxIouCPUForward(const paddle::Tensor& rbox1, const paddle::Tensor& rbox2) {
CHECK_INPUT_CPU(rbox1);
CHECK_INPUT_CPU(rbox2);
auto rbox1_num = rbox1.shape()[0];
auto rbox2_num = rbox2.shape()[0];
auto output = paddle::Tensor(paddle::PlaceType::kCPU);
output.reshape({rbox1_num, rbox2_num});
PD_DISPATCH_FLOATING_TYPES(
rbox1.type(),
"rbox_iou_cpu_kernel",
([&] {
rbox_iou_cpu_kernel<data_t>(
rbox1_num,
rbox2_num,
rbox1.data<data_t>(),
rbox2.data<data_t>(),
output.mutable_data<data_t>());
}));
return {output};
}
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
if __name__ == "__main__":
setup(
name='rbox_iou_ops',
ext_modules=CUDAExtension(sources=['rbox_iou_op.cc', 'rbox_iou_op.cu']))
import numpy as np
import os
import sys
import cv2
import time
import shapely
from shapely.geometry import Polygon
import paddle
paddle.set_device('gpu:0')
paddle.disable_static()
try:
from rbox_iou_ops import rbox_iou
except Exception as e:
print('import custom_ops error', e)
sys.exit(-1)
# generate random data
rbox1 = np.random.rand(13000, 5)
rbox2 = np.random.rand(7, 5)
# x1 y1 w h [0, 0.5]
rbox1[:, 0:4] = rbox1[:, 0:4] * 0.45 + 0.001
rbox2[:, 0:4] = rbox2[:, 0:4] * 0.45 + 0.001
# generate rbox
rbox1[:, 4] = rbox1[:, 4] - 0.5
rbox2[:, 4] = rbox2[:, 4] - 0.5
print('rbox1', rbox1.shape, 'rbox2', rbox2.shape)
# to paddle tensor
pd_rbox1 = paddle.to_tensor(rbox1)
pd_rbox2 = paddle.to_tensor(rbox2)
iou = rbox_iou(pd_rbox1, pd_rbox2)
start_time = time.time()
print('paddle time:', time.time() - start_time)
print('iou is', iou.cpu().shape)
# get gt
def rbox2poly_single(rrect, get_best_begin_point=False):
"""
rrect:[x_ctr,y_ctr,w,h,angle]
to
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
"""
x_ctr, y_ctr, width, height, angle = rrect[:5]
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
# rect 2x4
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
R = np.array([[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]])
# poly
poly = R.dot(rect)
x0, x1, x2, x3 = poly[0, :4] + x_ctr
y0, y1, y2, y3 = poly[1, :4] + y_ctr
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32)
return poly
def intersection(g, p):
"""
Intersection.
"""
g = g[:8].reshape((4, 2))
p = p[:8].reshape((4, 2))
a = g
b = p
use_filter = True
if use_filter:
# step1:
inter_x1 = np.maximum(np.min(a[:, 0]), np.min(b[:, 0]))
inter_x2 = np.minimum(np.max(a[:, 0]), np.max(b[:, 0]))
inter_y1 = np.maximum(np.min(a[:, 1]), np.min(b[:, 1]))
inter_y2 = np.minimum(np.max(a[:, 1]), np.max(b[:, 1]))
if inter_x1 >= inter_x2 or inter_y1 >= inter_y2:
return 0.
x1 = np.minimum(np.min(a[:, 0]), np.min(b[:, 0]))
x2 = np.maximum(np.max(a[:, 0]), np.max(b[:, 0]))
y1 = np.minimum(np.min(a[:, 1]), np.min(b[:, 1]))
y2 = np.maximum(np.max(a[:, 1]), np.max(b[:, 1]))
if x1 >= x2 or y1 >= y2 or (x2 - x1) < 2 or (y2 - y1) < 2:
return 0.
g = Polygon(g)
p = Polygon(p)
#g = g.buffer(0)
#p = p.buffer(0)
if not g.is_valid or not p.is_valid:
return 0
inter = Polygon(g).intersection(Polygon(p)).area
union = g.area + p.area - inter
if union == 0:
return 0
else:
return inter / union
# rbox_iou by python
def rbox_overlaps(anchors, gt_bboxes, use_cv2=False):
"""
Args:
anchors: [NA, 5] x1,y1,x2,y2,angle
gt_bboxes: [M, 5] x1,y1,x2,y2,angle
Returns:
"""
assert anchors.shape[1] == 5
assert gt_bboxes.shape[1] == 5
gt_bboxes_ploy = [rbox2poly_single(e) for e in gt_bboxes]
anchors_ploy = [rbox2poly_single(e) for e in anchors]
num_gt, num_anchors = len(gt_bboxes_ploy), len(anchors_ploy)
iou = np.zeros((num_gt, num_anchors), dtype=np.float32)
start_time = time.time()
for i in range(num_gt):
for j in range(num_anchors):
try:
iou[i, j] = intersection(gt_bboxes_ploy[i], anchors_ploy[j])
except Exception as e:
print('cur gt_bboxes_ploy[i]', gt_bboxes_ploy[i],
'anchors_ploy[j]', anchors_ploy[j], e)
iou = iou.T
print('intersection all sp_time', time.time() - start_time)
return iou
# make coor as int
ploy_rbox1 = rbox1
ploy_rbox2 = rbox2
ploy_rbox1[:, 0:4] = rbox1[:, 0:4] * 1024
ploy_rbox2[:, 0:4] = rbox2[:, 0:4] * 1024
start_time = time.time()
iou_py = rbox_overlaps(ploy_rbox1, ploy_rbox2, use_cv2=False)
print('rbox time', time.time() - start_time)
print(iou_py.shape)
iou_pd = iou.cpu().numpy()
sum_abs_diff = np.sum(np.abs(iou_pd - iou_py))
print('sum of abs diff', sum_abs_diff)
if sum_abs_diff < 0.02:
print("rbox_iou OP compute right!")
......@@ -21,7 +21,7 @@ import sys
import numpy as np
import itertools
from ppdet.metrics.json_results import get_det_res, get_seg_res, get_solov2_segm_res
from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res
from ppdet.metrics.map_utils import draw_pr_curve
from ppdet.utils.logger import setup_logger
......@@ -45,6 +45,10 @@ def get_infer_results(outs, catid, bias=0):
infer_res = {}
if 'bbox' in outs:
if len(outs['bbox']) > 0 and len(outs['bbox'][0]) > 6:
infer_res['bbox'] = get_det_poly_res(
outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias)
else:
infer_res['bbox'] = get_det_res(
outs['bbox'], outs['bbox_num'], im_id, catid, bias=bias)
......
......@@ -43,6 +43,54 @@ def get_det_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0):
return det_res
def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0):
det_res = []
k = 0
for i in range(len(bbox_nums)):
cur_image_id = int(image_id[i][0])
det_nums = bbox_nums[i]
for j in range(det_nums):
dt = bboxes[k]
k = k + 1
num_id, score, x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist()
if int(num_id) < 0:
continue
category_id = int(num_id)
rbox = [x1, y1, x2, y2, x3, y3, x4, y4]
dt_res = {
'image_id': cur_image_id,
'category_id': category_id,
'bbox': rbox,
'score': score
}
det_res.append(dt_res)
return det_res
def get_det_poly_res(bboxes, bbox_nums, image_id, label_to_cat_id_map, bias=0):
det_res = []
k = 0
for i in range(len(bbox_nums)):
cur_image_id = int(image_id[i][0])
det_nums = bbox_nums[i]
for j in range(det_nums):
dt = bboxes[k]
k = k + 1
num_id, score, x1, y1, x2, y2, x3, y3, x4, y4 = dt.tolist()
if int(num_id) < 0:
continue
category_id = int(num_id)
rbox = [x1, y1, x2, y2, x3, y3, x4, y4]
dt_res = {
'image_id': cur_image_id,
'category_id': category_id,
'bbox': rbox,
'score': score
}
det_res.append(dt_res)
return det_res
def get_seg_res(masks, bboxes, mask_nums, image_id, label_to_cat_id_map):
import pycocotools.mask as mask_util
seg_res = []
......
......@@ -14,6 +14,7 @@ from . import ssd
from . import fcos
from . import solov2
from . import ttfnet
from . import s2anet
from .meta_arch import *
from .faster_rcnn import *
......@@ -24,3 +25,4 @@ from .ssd import *
from .fcos import *
from .solov2 import *
from .ttfnet import *
from .s2anet import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import numpy as np
__all__ = ['S2ANet']
@register
class S2ANet(BaseArch):
__category__ = 'architecture'
__inject__ = [
's2anet_head',
's2anet_bbox_post_process',
]
def __init__(self, backbone, neck, s2anet_head, s2anet_bbox_post_process):
"""
S2ANet, see https://arxiv.org/pdf/2008.09397.pdf
Args:
backbone (object): backbone instance
neck (object): `FPN` instance
s2anet_head (object): `S2ANetHead` instance
s2anet_bbox_post_process (object): `S2ANetBBoxPostProcess` instance
"""
super(S2ANet, self).__init__()
self.backbone = backbone
self.neck = neck
self.s2anet_head = s2anet_head
self.s2anet_bbox_post_process = s2anet_bbox_post_process
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = cfg['neck'] and create(cfg['neck'], **kwargs)
out_shape = neck and neck.out_shape or backbone.out_shape
kwargs = {'input_shape': out_shape}
s2anet_head = create(cfg['s2anet_head'], **kwargs)
s2anet_bbox_post_process = create(cfg['s2anet_bbox_post_process'],
**kwargs)
return {
'backbone': backbone,
'neck': neck,
"s2anet_head": s2anet_head,
"s2anet_bbox_post_process": s2anet_bbox_post_process,
}
def _forward(self):
body_feats = self.backbone(self.inputs)
if self.neck is not None:
body_feats = self.neck(body_feats)
self.s2anet_head(body_feats)
if self.training:
loss = self.s2anet_head.get_loss(self.inputs)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
return loss
else:
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
nms_pre = self.s2anet_bbox_post_process.nms_pre
pred_scores, pred_bboxes = self.s2anet_head.get_prediction(nms_pre)
# post_process
pred_cls_score_bbox, bbox_num, index = self.s2anet_bbox_post_process.get_prediction(
pred_scores, pred_bboxes, im_shape, scale_factor)
# output
output = {'bbox': pred_cls_score_bbox, 'bbox_num': bbox_num}
return output
def get_loss(self, ):
loss = self._forward()
return loss
def get_pred(self):
output = self._forward()
return output
......@@ -16,6 +16,7 @@ import math
import paddle
import paddle.nn.functional as F
import math
import numpy as np
def bbox2delta(src_boxes, tgt_boxes, weights):
......@@ -260,3 +261,147 @@ def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
return iou - (rho2 / c2 + v * alpha)
else:
return iou
def rect2rbox(bboxes):
"""
:param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
:return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
"""
bboxes = bboxes.reshape(-1, 4)
num_boxes = bboxes.shape[0]
x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
edges1 = np.abs(bboxes[:, 2] - bboxes[:, 0])
edges2 = np.abs(bboxes[:, 3] - bboxes[:, 1])
angles = np.zeros([num_boxes], dtype=bboxes.dtype)
inds = edges1 < edges2
rboxes = np.stack((x_ctr, y_ctr, edges1, edges2, angles), axis=1)
rboxes[inds, 2] = edges2[inds]
rboxes[inds, 3] = edges1[inds]
rboxes[inds, 4] = np.pi / 2.0
return rboxes
def delta2rbox(Rrois,
deltas,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1],
wh_ratio_clip=1e-6):
"""
:param Rrois: (cx, cy, w, h, theta)
:param deltas: (dx, dy, dw, dh, dtheta)
:param means:
:param stds:
:param wh_ratio_clip:
:return:
"""
means = paddle.to_tensor(means)
stds = paddle.to_tensor(stds)
deltas = paddle.reshape(deltas, [-1, deltas.shape[-1]])
denorm_deltas = deltas * stds + means
dx = denorm_deltas[:, 0]
dy = denorm_deltas[:, 1]
dw = denorm_deltas[:, 2]
dh = denorm_deltas[:, 3]
dangle = denorm_deltas[:, 4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
Rroi_x = Rrois[:, 0]
Rroi_y = Rrois[:, 1]
Rroi_w = Rrois[:, 2]
Rroi_h = Rrois[:, 3]
Rroi_angle = Rrois[:, 4]
gx = dx * Rroi_w * paddle.cos(Rroi_angle) - dy * Rroi_h * paddle.sin(
Rroi_angle) + Rroi_x
gy = dx * Rroi_w * paddle.sin(Rroi_angle) + dy * Rroi_h * paddle.cos(
Rroi_angle) + Rroi_y
gw = Rroi_w * dw.exp()
gh = Rroi_h * dh.exp()
ga = np.pi * dangle + Rroi_angle
ga = (ga + np.pi / 4) % np.pi - np.pi / 4
ga = paddle.to_tensor(ga)
gw = paddle.to_tensor(gw, dtype='float32')
gh = paddle.to_tensor(gh, dtype='float32')
bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
return bboxes
def rbox2delta(proposals, gt, means=[0, 0, 0, 0, 0], stds=[1, 1, 1, 1, 1]):
"""
Args:
proposals:
gt:
means: 1x5
stds: 1x5
Returns:
"""
proposals = proposals.astype(np.float64)
PI = np.pi
gt_widths = gt[..., 2]
gt_heights = gt[..., 3]
gt_angle = gt[..., 4]
proposals_widths = proposals[..., 2]
proposals_heights = proposals[..., 3]
proposals_angle = proposals[..., 4]
coord = gt[..., 0:2] - proposals[..., 0:2]
dx = (np.cos(proposals[..., 4]) * coord[..., 0] + np.sin(proposals[..., 4])
* coord[..., 1]) / proposals_widths
dy = (-np.sin(proposals[..., 4]) * coord[..., 0] + np.cos(proposals[..., 4])
* coord[..., 1]) / proposals_heights
dw = np.log(gt_widths / proposals_widths)
dh = np.log(gt_heights / proposals_heights)
da = (gt_angle - proposals_angle)
da = (da + PI / 4) % PI - PI / 4
da /= PI
deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
means = np.array(means, dtype=deltas.dtype)
stds = np.array(stds, dtype=deltas.dtype)
deltas = (deltas - means) / stds
deltas = deltas.astype(np.float32)
return deltas
def bbox_decode(bbox_preds,
anchors,
means=[0, 0, 0, 0, 0],
stds=[1, 1, 1, 1, 1]):
"""decode bbox from deltas
Args:
bbox_preds: [N,H,W,5]
anchors: [H*W,5]
return:
bboxes: [N,H,W,5]
"""
means = paddle.to_tensor(means)
stds = paddle.to_tensor(stds)
num_imgs, H, W, _ = bbox_preds.shape
bboxes_list = []
for img_id in range(num_imgs):
bbox_pred = bbox_preds[img_id]
# bbox_pred.shape=[5,H,W]
bbox_delta = bbox_pred
anchors = paddle.to_tensor(anchors)
bboxes = delta2rbox(
anchors, bbox_delta, means, stds, wh_ratio_clip=1e-6)
bboxes = paddle.reshape(bboxes, [H, W, 5])
bboxes_list.append(bboxes)
return paddle.stack(bboxes_list, axis=0)
\ No newline at end of file
......@@ -22,6 +22,7 @@ from . import solov2_head
from . import ttf_head
from . import cascade_head
from . import face_head
from . import s2anet_head
from .bbox_head import *
from .mask_head import *
......@@ -33,3 +34,4 @@ from .solov2_head import *
from .ttf_head import *
from .cascade_head import *
from .face_head import *
from .s2anet_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register
from ppdet.modeling import ops
from ppdet.modeling import bbox_utils
from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
import numpy as np
class S2ANetAnchorGenerator(object):
"""
S2ANetAnchorGenerator by np
"""
def __init__(self,
base_size=8,
scales=1.0,
ratios=1.0,
scale_major=True,
ctr=None):
self.base_size = base_size
self.scales = scales
self.ratios = ratios
self.scale_major = scale_major
self.ctr = ctr
self.base_anchors = self.gen_base_anchors()
@property
def num_base_anchors(self):
return self.base_anchors.shape[0]
def gen_base_anchors(self):
w = self.base_size
h = self.base_size
if self.ctr is None:
x_ctr = 0.5 * (w - 1)
y_ctr = 0.5 * (h - 1)
else:
x_ctr, y_ctr = self.ctr
h_ratios = np.sqrt(self.ratios)
w_ratios = 1 / h_ratios
if self.scale_major:
ws = (w * w_ratios[:] * self.scales[:]).reshape([-1])
hs = (h * h_ratios[:] * self.scales[:]).reshape([-1])
else:
ws = (w * self.scales[:] * w_ratios[:]).reshape([-1])
hs = (h * self.scales[:] * h_ratios[:]).reshape([-1])
# yapf: disable
base_anchors = np.stack(
[
x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1),
x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1)
],
axis=-1)
base_anchors = np.round(base_anchors)
# yapf: enable
return base_anchors
def _meshgrid(self, x, y, row_major=True):
xx, yy = np.meshgrid(x, y)
xx = xx.reshape(-1)
yy = yy.reshape(-1)
if row_major:
return xx, yy
else:
return yy, xx
def grid_anchors(self, featmap_size, stride=16):
# featmap_size*stride project it to original area
base_anchors = self.base_anchors
feat_h, feat_w = featmap_size
shift_x = np.arange(0, feat_w, 1, 'int32') * stride
shift_y = np.arange(0, feat_h, 1, 'int32') * stride
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1)
# shifts = shifts.type_as(base_anchors)
# first feat_w elements correspond to the first row of shifts
# add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
# shifted anchors (K, A, 4), reshape to (K*A, 4)
#all_anchors = base_anchors[:, :] + shifts[:, :]
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
# all_anchors = all_anchors.reshape([-1, 4])
# first A rows correspond to A anchors of (0, 0) in feature map,
# then (0, 1), (0, 2), ...
return all_anchors
def valid_flags(self, featmap_size, valid_size):
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = np.zeros([feat_w], dtype='uint8')
valid_y = np.zeros([feat_h], dtype='uint8')
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
valid = valid.reshape([-1])
# valid = valid[:, None].expand(
# [valid.size(0), self.num_base_anchors]).reshape([-1])
return valid
class AlignConv(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
super(AlignConv, self).__init__()
self.kernel_size = kernel_size
self.align_conv = paddle.vision.ops.DeformConv2D(
in_channels,
out_channels,
kernel_size=self.kernel_size,
padding=(self.kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
bias_attr=None)
@paddle.no_grad()
def get_offset(self, anchors, featmap_size, stride):
"""
Args:
anchors: [M,5] xc,yc,w,h,angle
featmap_size: (feat_h, feat_w)
stride: 8
Returns:
"""
anchors = paddle.reshape(anchors, [-1, 5]) # (NA,5)
dtype = anchors.dtype
feat_h, feat_w = featmap_size
pad = (self.kernel_size - 1) // 2
idx = paddle.arange(-pad, pad + 1, dtype=dtype)
yy, xx = paddle.meshgrid(idx, idx)
xx = paddle.reshape(xx, [-1])
yy = paddle.reshape(yy, [-1])
# get sampling locations of default conv
xc = paddle.arange(0, feat_w, dtype=dtype)
yc = paddle.arange(0, feat_h, dtype=dtype)
yc, xc = paddle.meshgrid(yc, xc)
xc = paddle.reshape(xc, [-1, 1])
yc = paddle.reshape(yc, [-1, 1])
x_conv = xc + xx
y_conv = yc + yy
# get sampling locations of anchors
# x_ctr, y_ctr, w, h, a = np.unbind(anchors, dim=1)
x_ctr = anchors[:, 0]
y_ctr = anchors[:, 1]
w = anchors[:, 2]
h = anchors[:, 3]
a = anchors[:, 4]
x_ctr = paddle.reshape(x_ctr, [x_ctr.shape[0], 1])
y_ctr = paddle.reshape(y_ctr, [y_ctr.shape[0], 1])
w = paddle.reshape(w, [w.shape[0], 1])
h = paddle.reshape(h, [h.shape[0], 1])
a = paddle.reshape(a, [a.shape[0], 1])
x_ctr = x_ctr / stride
y_ctr = y_ctr / stride
w_s = w / stride
h_s = h / stride
cos, sin = paddle.cos(a), paddle.sin(a)
dw, dh = w_s / self.kernel_size, h_s / self.kernel_size
x, y = dw * xx, dh * yy
xr = cos * x - sin * y
yr = sin * x + cos * y
x_anchor, y_anchor = xr + x_ctr, yr + y_ctr
# get offset filed
offset_x = x_anchor - x_conv
offset_y = y_anchor - y_conv
# x, y in anchors is opposite in image coordinates,
# so we stack them with y, x other than x, y
offset = paddle.stack([offset_y, offset_x], axis=-1)
# NA,ks*ks*2
# [NA, ks, ks, 2] --> [NA, ks*ks*2]
offset = paddle.reshape(offset, [offset.shape[0], -1])
# [NA, ks*ks*2] --> [ks*ks*2, NA]
offset = paddle.transpose(offset, [1, 0])
# [NA, ks*ks*2] --> [1, ks*ks*2, H, W]
offset = paddle.reshape(offset, [1, -1, feat_h, feat_w])
return offset
def forward(self, x, refine_anchors, stride):
featmap_size = (x.shape[2], x.shape[3])
offset = self.get_offset(refine_anchors, featmap_size, stride)
x = F.relu(self.align_conv(x, offset))
return x
@register
class S2ANetHead(nn.Layer):
"""
S2Anet head
Args:
stacked_convs (int): number of stacked_convs
feat_in (int): input channels of feat
feat_out (int): output channels of feat
num_classes (int): num_classes
anchor_strides (list): stride of anchors
anchor_scales (list): scale of anchors
anchor_ratios (list): ratios of anchors
target_means (list): target_means
target_stds (list): target_stds
align_conv_type (str): align_conv_type ['Conv', 'AlignConv']
align_conv_size (int): kernel size of align_conv
use_sigmoid_cls (bool): use sigmoid_cls or not
reg_loss_weight (list): reg loss weight
"""
__shared__ = ['num_classes']
__inject__ = ['anchor_assign']
def __init__(self,
stacked_convs=2,
feat_in=256,
feat_out=256,
num_classes=15,
anchor_strides=[8, 16, 32, 64, 128],
anchor_scales=[4],
anchor_ratios=[1.0],
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0),
align_conv_type='AlignConv',
align_conv_size=3,
use_sigmoid_cls=True,
anchor_assign=RBoxAssigner().__dict__,
reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(S2ANetHead, self).__init__()
self.stacked_convs = stacked_convs
self.feat_in = feat_in
self.feat_out = feat_out
self.anchor_list = None
self.anchor_scales = anchor_scales
self.anchor_ratios = anchor_ratios
self.anchor_strides = anchor_strides
self.anchor_base_sizes = list(anchor_strides)
self.target_means = target_means
self.target_stds = target_stds
assert align_conv_type in ['AlignConv', 'Conv']
self.align_conv_type = align_conv_type
self.align_conv_size = align_conv_size
self.use_sigmoid_cls = use_sigmoid_cls
self.cls_out_channels = num_classes if self.use_sigmoid_cls else 1
self.sampling = False
self.anchor_assign = anchor_assign
self.reg_loss_weight = reg_loss_weight
self.s2anet_head_out = None
# anchor
self.anchor_generators = []
for anchor_base in self.anchor_base_sizes:
self.anchor_generators.append(
S2ANetAnchorGenerator(anchor_base, anchor_scales,
anchor_ratios))
# featmap_sizes
self.featmap_sizes = []
self.base_anchors = []
self.rbox_anchors = []
self.refine_anchor_list = []
self.fam_cls_convs = nn.Sequential()
self.fam_reg_convs = nn.Sequential()
for i in range(self.stacked_convs):
chan_in = self.feat_in if i == 0 else self.feat_out
self.fam_cls_convs.add_sublayer(
'fam_cls_conv_{}'.format(i),
nn.Conv2D(
in_channels=chan_in,
out_channels=self.feat_out,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0))))
self.fam_cls_convs.add_sublayer('fam_cls_conv_{}_act'.format(i),
nn.ReLU())
self.fam_reg_convs.add_sublayer(
'fam_reg_conv_{}'.format(i),
nn.Conv2D(
in_channels=chan_in,
out_channels=self.feat_out,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0))))
self.fam_reg_convs.add_sublayer('fam_reg_conv_{}_act'.format(i),
nn.ReLU())
self.fam_reg = nn.Conv2D(
self.feat_out,
5,
1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0)))
prior_prob = 0.01
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
self.fam_cls = nn.Conv2D(
self.feat_out,
self.cls_out_channels,
1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(bias_init)))
if self.align_conv_type == "AlignConv":
self.align_conv = AlignConv(self.feat_out, self.feat_out,
self.align_conv_size)
elif self.align_conv_type == "Conv":
self.align_conv = nn.Conv2D(
self.feat_out,
self.feat_out,
self.align_conv_size,
padding=(self.align_conv_size - 1) // 2,
bias_attr=ParamAttr(initializer=Constant(0)))
self.or_conv = nn.Conv2D(
self.feat_out,
self.feat_out,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0)))
# ODM
self.odm_cls_convs = nn.Sequential()
self.odm_reg_convs = nn.Sequential()
for i in range(self.stacked_convs):
ch_in = self.feat_out
# ch_in = int(self.feat_out / 8) if i == 0 else self.feat_out
self.odm_cls_convs.add_sublayer(
'odm_cls_conv_{}'.format(i),
nn.Conv2D(
in_channels=ch_in,
out_channels=self.feat_out,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0))))
self.odm_cls_convs.add_sublayer('odm_cls_conv_{}_act'.format(i),
nn.ReLU())
self.odm_reg_convs.add_sublayer(
'odm_reg_conv_{}'.format(i),
nn.Conv2D(
in_channels=self.feat_out,
out_channels=self.feat_out,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0))))
self.odm_reg_convs.add_sublayer('odm_reg_conv_{}_act'.format(i),
nn.ReLU())
self.odm_cls = nn.Conv2D(
self.feat_out,
self.cls_out_channels,
3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(bias_init)))
self.odm_reg = nn.Conv2D(
self.feat_out,
5,
3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
bias_attr=ParamAttr(initializer=Constant(0)))
def forward(self, feats):
fam_reg_branch_list = []
fam_cls_branch_list = []
odm_reg_branch_list = []
odm_cls_branch_list = []
self.featmap_sizes = dict()
self.base_anchors = dict()
self.refine_anchor_list = []
for i, feat in enumerate(feats):
fam_cls_feat = self.fam_cls_convs(feat)
fam_cls = self.fam_cls(fam_cls_feat)
# [N, CLS, H, W] --> [N, H, W, CLS]
fam_cls = fam_cls.transpose([0, 2, 3, 1])
fam_cls_reshape = paddle.reshape(
fam_cls, [fam_cls.shape[0], -1, self.cls_out_channels])
fam_cls_branch_list.append(fam_cls_reshape)
fam_reg_feat = self.fam_reg_convs(feat)
fam_reg = self.fam_reg(fam_reg_feat)
# [N, 5, H, W] --> [N, H, W, 5]
fam_reg = fam_reg.transpose([0, 2, 3, 1])
fam_reg_reshape = paddle.reshape(fam_reg, [fam_reg.shape[0], -1, 5])
fam_reg_branch_list.append(fam_reg_reshape)
# prepare anchor
featmap_size = feat.shape[-2:]
self.featmap_sizes[i] = featmap_size
init_anchors = self.anchor_generators[i].grid_anchors(
featmap_size, self.anchor_strides[i])
init_anchors = bbox_utils.rect2rbox(init_anchors)
self.base_anchors[(i, featmap_size[0])] = init_anchors
#fam_reg1 = fam_reg
#fam_reg1.stop_gradient = True
refine_anchor = bbox_utils.bbox_decode(
fam_reg.detach(), init_anchors, self.target_means,
self.target_stds)
self.refine_anchor_list.append(refine_anchor)
if self.align_conv_type == 'AlignConv':
align_feat = self.align_conv(feat,
refine_anchor.clone(),
self.anchor_strides[i])
elif self.align_conv_type == 'DCN':
align_offset = self.align_conv_offset(feat)
align_feat = self.align_conv(feat, align_offset)
elif self.align_conv_type == 'GA_DCN':
align_offset = self.align_conv_offset(feat)
align_feat = self.align_conv(feat, align_offset)
elif self.align_conv_type == 'Conv':
align_feat = self.align_conv(feat)
or_feat = self.or_conv(align_feat)
odm_reg_feat = or_feat
odm_cls_feat = or_feat
odm_reg_feat = self.odm_reg_convs(odm_reg_feat)
odm_cls_feat = self.odm_cls_convs(odm_cls_feat)
odm_cls_score = self.odm_cls(odm_cls_feat)
# [N, CLS, H, W] --> [N, H, W, CLS]
odm_cls_score = odm_cls_score.transpose([0, 2, 3, 1])
odm_cls_score_reshape = paddle.reshape(
odm_cls_score,
[odm_cls_score.shape[0], -1, self.cls_out_channels])
odm_cls_branch_list.append(odm_cls_score_reshape)
odm_bbox_pred = self.odm_reg(odm_reg_feat)
# [N, 5, H, W] --> [N, H, W, 5]
odm_bbox_pred = odm_bbox_pred.transpose([0, 2, 3, 1])
odm_bbox_pred_reshape = paddle.reshape(
odm_bbox_pred, [odm_bbox_pred.shape[0], -1, 5])
odm_reg_branch_list.append(odm_bbox_pred_reshape)
self.s2anet_head_out = (fam_cls_branch_list, fam_reg_branch_list,
odm_cls_branch_list, odm_reg_branch_list)
return self.s2anet_head_out
def get_prediction(self, nms_pre):
refine_anchors = self.refine_anchor_list
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = self.s2anet_head_out
pred_scores, pred_bboxes = self.get_bboxes(
odm_cls_branch_list,
odm_reg_branch_list,
refine_anchors,
nms_pre,
cls_out_channels=self.cls_out_channels,
use_sigmoid_cls=self.use_sigmoid_cls)
return pred_scores, pred_bboxes
def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
"""
Args:
pred: pred score
label: label
delta: delta
Returns: loss
"""
assert pred.shape == label.shape and label.numel() > 0
assert delta > 0
diff = paddle.abs(pred - label)
loss = paddle.where(diff < delta, 0.5 * diff * diff / delta,
diff - 0.5 * delta)
return loss
def get_fam_loss(self, fam_target, s2anet_head_out):
(labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = fam_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
fam_cls_losses = []
fam_bbox_losses = []
st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples)
for idx, feat_size in enumerate(featmap_sizes):
feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data
feat_labels = labels[st_idx:st_idx + feat_anchor_num]
feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
st_idx += feat_anchor_num
# step2: calc cls loss
feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1)
fam_cls_score = fam_cls_branch_list[idx]
fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
fam_cls_score1 = fam_cls_score
# gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = paddle.nn.functional.one_hot(
feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True
num_total_samples = paddle.to_tensor(
num_total_samples, dtype='float32', stop_gradient=True)
fam_cls = F.sigmoid_focal_loss(
fam_cls_score1,
feat_labels_one_hot,
normalizer=num_total_samples,
reduction='none')
feat_label_weights = feat_label_weights.reshape(
feat_label_weights.shape[0], 1)
feat_label_weights = np.repeat(
feat_label_weights, self.cls_out_channels, axis=1)
feat_label_weights = paddle.to_tensor(
feat_label_weights, stop_gradient=True)
fam_cls = fam_cls * feat_label_weights
fam_cls_total = paddle.sum(fam_cls)
fam_cls_losses.append(fam_cls_total)
# step3: regression loss
fam_bbox_pred = fam_reg_branch_list[idx]
feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32', stop_gradient=True)
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
fam_bbox_pred = fam_reg_branch_list[idx]
fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
fam_bbox = fam_bbox * feat_bbox_weights
fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
fam_bbox_losses.append(fam_bbox_total)
fam_cls_loss = paddle.add_n(fam_cls_losses)
fam_reg_loss = paddle.add_n(fam_bbox_losses)
return fam_cls_loss, fam_reg_loss
def get_odm_loss(self, odm_target, s2anet_head_out):
(labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds) = odm_target
fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list = s2anet_head_out
odm_cls_losses = []
odm_bbox_losses = []
st_idx = 0
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
num_total_samples = len(pos_inds) + len(
neg_inds) if self.sampling else len(pos_inds)
num_total_samples = max(1, num_total_samples)
for idx, feat_size in enumerate(featmap_sizes):
feat_anchor_num = feat_size[0] * feat_size[1]
# step1: get data
feat_labels = labels[st_idx:st_idx + feat_anchor_num]
feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
st_idx += feat_anchor_num
# step2: calc cls loss
feat_labels = feat_labels.reshape(-1)
feat_label_weights = feat_label_weights.reshape(-1)
odm_cls_score = odm_cls_branch_list[idx]
odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
odm_cls_score1 = odm_cls_score
# gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = paddle.nn.functional.one_hot(
feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
feat_labels_one_hot.stop_gradient = True
num_total_samples = paddle.to_tensor(
num_total_samples, dtype='float32', stop_gradient=True)
odm_cls = F.sigmoid_focal_loss(
odm_cls_score1,
feat_labels_one_hot,
normalizer=num_total_samples,
reduction='none')
feat_label_weights = feat_label_weights.reshape(
feat_label_weights.shape[0], 1)
feat_label_weights = np.repeat(
feat_label_weights, self.cls_out_channels, axis=1)
feat_label_weights = paddle.to_tensor(feat_label_weights)
feat_label_weights.stop_gradient = True
odm_cls = odm_cls * feat_label_weights
odm_cls_total = paddle.sum(odm_cls)
odm_cls_losses.append(odm_cls_total)
# # step3: regression loss
feat_bbox_targets = paddle.to_tensor(
feat_bbox_targets, dtype='float32')
feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
feat_bbox_targets.stop_gradient = True
odm_bbox_pred = odm_reg_branch_list[idx]
odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
feat_bbox_weights = paddle.to_tensor(
feat_bbox_weights, stop_gradient=True)
odm_bbox = odm_bbox * feat_bbox_weights
odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
odm_bbox_losses.append(odm_bbox_total)
odm_cls_loss = paddle.add_n(odm_cls_losses)
odm_reg_loss = paddle.add_n(odm_bbox_losses)
return odm_cls_loss, odm_reg_loss
def get_loss(self, inputs):
# inputs: im_id image im_shape scale_factor gt_bbox gt_class is_crowd
# compute loss
fam_cls_loss_lst = []
fam_reg_loss_lst = []
odm_cls_loss_lst = []
odm_reg_loss_lst = []
im_shape = inputs['im_shape']
for im_id in range(im_shape.shape[0]):
np_im_shape = inputs['im_shape'][im_id].numpy()
np_scale_factor = inputs['scale_factor'][im_id].numpy()
# data_format: (xc, yc, w, h, theta)
gt_bboxes = inputs['gt_rbox'][im_id].numpy()
gt_labels = inputs['gt_class'][im_id].numpy()
is_crowd = inputs['is_crowd'][im_id].numpy()
gt_labels = gt_labels + 1
# featmap_sizes
featmap_sizes = [self.featmap_sizes[e] for e in self.featmap_sizes]
anchors_list, valid_flag_list = self.get_init_anchors(featmap_sizes,
np_im_shape)
anchors_list_all = []
for ii, anchor in enumerate(anchors_list):
anchor = anchor.reshape(-1, 4)
anchor = bbox_utils.rect2rbox(anchor)
anchors_list_all.extend(anchor)
anchors_list_all = np.array(anchors_list_all)
# get im_feat
fam_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[0]]
fam_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[1]]
odm_cls_feats_list = [e[im_id] for e in self.s2anet_head_out[2]]
odm_reg_feats_list = [e[im_id] for e in self.s2anet_head_out[3]]
im_s2anet_head_out = (fam_cls_feats_list, fam_reg_feats_list,
odm_cls_feats_list, odm_reg_feats_list)
# FAM
im_fam_target = self.anchor_assign(anchors_list_all, gt_bboxes,
gt_labels, is_crowd)
if im_fam_target is not None:
im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
im_fam_target, im_s2anet_head_out)
fam_cls_loss_lst.append(im_fam_cls_loss)
fam_reg_loss_lst.append(im_fam_reg_loss)
# ODM
refine_anchors_list, valid_flag_list = self.get_refine_anchors(
featmap_sizes, image_shape=np_im_shape)
refine_anchors_list = np.array(refine_anchors_list)
im_odm_target = self.anchor_assign(refine_anchors_list, gt_bboxes,
gt_labels, is_crowd)
if im_odm_target is not None:
im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
im_odm_target, im_s2anet_head_out)
odm_cls_loss_lst.append(im_odm_cls_loss)
odm_reg_loss_lst.append(im_odm_reg_loss)
fam_cls_loss = paddle.add_n(fam_cls_loss_lst)
fam_reg_loss = paddle.add_n(fam_reg_loss_lst)
odm_cls_loss = paddle.add_n(odm_cls_loss_lst)
odm_reg_loss = paddle.add_n(odm_reg_loss_lst)
return {
'fam_cls_loss': fam_cls_loss,
'fam_reg_loss': fam_reg_loss,
'odm_cls_loss': odm_cls_loss,
'odm_reg_loss': odm_reg_loss
}
def get_init_anchors(self, featmap_sizes, image_shape):
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
image_shape (list[dict]): Image meta info.
Returns:
tuple: anchors of each image, valid flags of each image
"""
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# anchors for one time
anchor_list = []
for i in range(num_levels):
anchors = self.anchor_generators[i].grid_anchors(
featmap_sizes[i], self.anchor_strides[i])
anchor_list.append(anchors)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return anchor_list, valid_flag_list
def get_refine_anchors(self, featmap_sizes, image_shape):
num_levels = len(featmap_sizes)
refine_anchors_list = []
for i in range(num_levels):
refine_anchor = self.refine_anchor_list[i]
refine_anchor = paddle.squeeze(refine_anchor, axis=0)
refine_anchor = refine_anchor.numpy()
refine_anchor = np.reshape(refine_anchor,
[-1, refine_anchor.shape[-1]])
refine_anchors_list.extend(refine_anchor)
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for i in range(num_levels):
anchor_stride = self.anchor_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = image_shape
valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
flags = self.anchor_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
valid_flag_list.append(flags)
return refine_anchors_list, valid_flag_list
def rbox2poly_single(self, rrect, get_best_begin_point=False):
"""
rrect:[x_ctr,y_ctr,w,h,angle]
to
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
"""
x_ctr, y_ctr, width, height, angle = rrect[:5]
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
# rect 2x4
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
R = np.array([[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]])
# poly
poly = R.dot(rect)
x0, x1, x2, x3 = poly[0, :4] + x_ctr
y0, y1, y2, y3 = poly[1, :4] + y_ctr
poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32)
return poly
def get_bboxes(self, cls_score_list, bbox_pred_list, mlvl_anchors, nms_pre,
cls_out_channels, use_sigmoid_cls):
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
idx = 0
for cls_score, bbox_pred, anchors in zip(cls_score_list, bbox_pred_list,
mlvl_anchors):
cls_score = paddle.reshape(cls_score, [-1, cls_out_channels])
if use_sigmoid_cls:
scores = F.sigmoid(cls_score)
else:
scores = F.softmax(cls_score, axis=-1)
# bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
bbox_pred = paddle.transpose(bbox_pred, [1, 2, 0])
bbox_pred = paddle.reshape(bbox_pred, [-1, 5])
anchors = paddle.reshape(anchors, [-1, 5])
if nms_pre > 0 and scores.shape[0] > nms_pre:
# Get maximum scores for foreground classes.
if use_sigmoid_cls:
max_scores = paddle.max(scores, axis=1)
else:
max_scores = paddle.max(scores[:, 1:], axis=1)
topk_val, topk_inds = paddle.topk(max_scores, nms_pre)
anchors = paddle.gather(anchors, topk_inds)
bbox_pred = paddle.gather(bbox_pred, topk_inds)
scores = paddle.gather(scores, topk_inds)
target_means = (.0, .0, .0, .0, .0)
target_stds = (1.0, 1.0, 1.0, 1.0, 1.0)
bboxes = bbox_utils.delta2rbox(anchors, bbox_pred, target_means,
target_stds)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
idx += 1
mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0)
mlvl_scores = paddle.concat(mlvl_scores)
if use_sigmoid_cls:
# Add a dummy background class to the front when using sigmoid
padding = paddle.zeros(
[mlvl_scores.shape[0], 1], dtype=mlvl_scores.dtype)
mlvl_scores = paddle.concat([padding, mlvl_scores], axis=1)
return mlvl_scores, mlvl_bboxes
......@@ -218,3 +218,128 @@ class FCOSPostProcess(object):
centerness, scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, score)
return bbox_pred, bbox_num
@register
class S2ANetBBoxPostProcess(object):
__inject__ = ['nms']
def __init__(self, nms_pre=2000, min_bbox_size=0, nms=None):
super(S2ANetBBoxPostProcess, self).__init__()
self.nms_pre = nms_pre
self.min_bbox_size = min_bbox_size
self.nms = nms
self.origin_shape_list = []
def rbox2poly(self, rrect, get_best_begin_point=True):
"""
rrect: [N, 5] [x_ctr,y_ctr,w,h,angle]
to
poly:[x0,y0,x1,y1,x2,y2,x3,y3]
"""
bbox_num = rrect.shape[0]
x_ctr = rrect[:, 0]
y_ctr = rrect[:, 1]
width = rrect[:, 2]
height = rrect[:, 3]
angle = rrect[:, 4]
tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
# rect 2x4
rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
R = np.array([[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]])
# R:[2,2,M] rect:[2,4,M]
#poly = R.dot(rect)
poly = []
for i in range(R.shape[2]):
poly.append(R[:, :, i].dot(rect[:, :, i]))
# poly:[M, 2, 4]
poly = np.array(poly)
coor_x = poly[:, 0, :4] + x_ctr.reshape(bbox_num, 1)
coor_y = poly[:, 1, :4] + y_ctr.reshape(bbox_num, 1)
poly = np.stack(
[
coor_x[:, 0], coor_y[:, 0], coor_x[:, 1], coor_y[:, 1],
coor_x[:, 2], coor_y[:, 2], coor_x[:, 3], coor_y[:, 3]
],
axis=1)
if get_best_begin_point:
poly_lst = [get_best_begin_point_single(e) for e in poly]
poly = np.array(poly_lst)
return poly
def get_prediction(self, pred_scores, pred_bboxes, im_shape, scale_factor):
"""
pred_scores : [N, M] score
pred_bboxes : [N, 5] xc, yc, w, h, a
im_shape : [N, 2] im_shape
scale_factor : [N, 2] scale_factor
"""
# TODO: support bs>1
pred_ploys = self.rbox2poly(pred_bboxes.numpy(), False)
pred_ploys = paddle.to_tensor(pred_ploys)
pred_ploys = paddle.reshape(
pred_ploys, [1, pred_ploys.shape[0], pred_ploys.shape[1]])
pred_scores = paddle.to_tensor(pred_scores)
# pred_scores [NA, 16] --> [16, NA]
pred_scores = paddle.transpose(pred_scores, [1, 0])
pred_scores = paddle.reshape(
pred_scores, [1, pred_scores.shape[0], pred_scores.shape[1]])
pred_cls_score_bbox, bbox_num, index = self.nms(pred_ploys, pred_scores)
# post process scale
# result [n, 10]
if bbox_num > 0:
pred_bbox, bbox_num = self.post_process(pred_cls_score_bbox[:, 2:],
bbox_num, im_shape[0],
scale_factor[0])
pred_cls_score_bbox = paddle.concat(
[pred_cls_score_bbox[:, 0:2], pred_bbox], axis=1)
else:
pred_cls_score_bbox = paddle.to_tensor(
np.array(
[[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
dtype='float32'))
bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
return pred_cls_score_bbox, bbox_num, index
def post_process(self, bboxes, bbox_num, im_shape, scale_factor):
"""
Rescale, clip and filter the bbox from the output of NMS to
get final prediction.
Args:
bboxes(Tensor): bboxes [N, 8]
bbox_num(Tensor): bbox_num
im_shape(Tensor): [1 2]
scale_factor(Tensor): [1 2]
Returns:
bbox_pred(Tensor): The output is the prediction with shape [N, 8]
including labels, scores and bboxes. The size of
bboxes are corresponding to the original image.
"""
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
origin_h = origin_shape[0]
origin_w = origin_shape[1]
bboxes[:, 0::2] = bboxes[:, 0::2] / scale_factor[0]
bboxes[:, 1::2] = bboxes[:, 1::2] / scale_factor[1]
zeros = paddle.zeros_like(origin_h)
x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w), zeros)
y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h), zeros)
x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w), zeros)
y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h), zeros)
x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w), zeros)
y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h), zeros)
x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w), zeros)
y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h), zeros)
bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
bboxes = (bbox, bbox_num)
return bboxes
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
from ppdet.core.workspace import register, serializable
from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target
from ppdet.modeling import bbox_utils
import numpy as np
@register
......@@ -176,3 +176,170 @@ class MaskAssigner(object):
# mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
return outs
@register
class RBoxAssigner(object):
"""
assigner of rbox
Args:
pos_iou_thr (float): threshold of pos samples
neg_iou_thr (float): threshold of neg samples
min_iou_thr (float): the min threshold of samples
ignore_iof_thr (int): the ignored threshold
"""
def __init__(self,
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_iou_thr=0.0,
ignore_iof_thr=-2):
super(RBoxAssigner, self).__init__()
self.pos_iou_thr = pos_iou_thr
self.neg_iou_thr = neg_iou_thr
self.min_iou_thr = min_iou_thr
self.ignore_iof_thr = ignore_iof_thr
def anchor_valid(self, anchors):
"""
Args:
anchor: M x 4
Returns:
"""
if anchors.ndim == 3:
anchors = anchors.reshape(-1, anchor.shape[-1])
assert anchors.ndim == 2
anchor_num = anchors.shape[0]
anchor_valid = np.ones((anchor_num), np.uint8)
anchor_inds = np.arange(anchor_num)
return anchor_inds
def assign_anchor(self,
anchors,
gt_bboxes,
gt_lables,
pos_iou_thr,
neg_iou_thr,
min_iou_thr=0.0,
ignore_iof_thr=-2):
"""
Args:
anchors:
gt_bboxes:[M, 5] rc,yc,w,h,angle
gt_lables:
Returns:
"""
assert anchors.shape[1] == 4 or anchors.shape[1] == 5
assert gt_bboxes.shape[1] == 4 or gt_bboxes.shape[1] == 5
anchors_xc_yc = anchors
gt_bboxes_xc_yc = gt_bboxes
# calc rbox iou
anchors_xc_yc = anchors_xc_yc.astype(np.float32)
gt_bboxes_xc_yc = gt_bboxes_xc_yc.astype(np.float32)
anchors_xc_yc = paddle.to_tensor(anchors_xc_yc, place=paddle.CPUPlace())
gt_bboxes_xc_yc = paddle.to_tensor(
gt_bboxes_xc_yc, place=paddle.CPUPlace())
try:
from rbox_iou_ops import rbox_iou
except Exception as e:
print('import custom_ops error', e)
sys.exit(-1)
iou = rbox_iou(gt_bboxes_xc_yc, anchors_xc_yc)
iou = iou.numpy()
iou = iou.T
# every gt's anchor's index
gt_bbox_anchor_inds = iou.argmax(axis=0)
gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])]
gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0]
# every anchor's gt bbox's index
anchor_gt_bbox_inds = iou.argmax(axis=1)
anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds]
# (1) set labels=-2 as default
labels = np.ones((iou.shape[0], ), dtype=np.int32) * ignore_iof_thr
# (2) assign ignore
labels[anchor_gt_bbox_iou < min_iou_thr] = ignore_iof_thr
# (3) assign neg_ids -1
assign_neg_ids1 = anchor_gt_bbox_iou >= min_iou_thr
assign_neg_ids2 = anchor_gt_bbox_iou < neg_iou_thr
assign_neg_ids = np.logical_and(assign_neg_ids1, assign_neg_ids2)
labels[assign_neg_ids] = -1
# anchor_gt_bbox_iou_inds
# (4) assign max_iou as pos_ids >=0
anchor_gt_bbox_iou_inds = anchor_gt_bbox_inds[gt_bbox_anchor_iou_inds]
# gt_bbox_anchor_iou_inds = np.logical_and(gt_bbox_anchor_iou_inds, anchor_gt_bbox_iou >= min_iou_thr)
labels[gt_bbox_anchor_iou_inds] = gt_lables[anchor_gt_bbox_iou_inds]
# (5) assign >= pos_iou_thr as pos_ids
iou_pos_iou_thr_ids = anchor_gt_bbox_iou >= pos_iou_thr
iou_pos_iou_thr_ids_box_inds = anchor_gt_bbox_inds[iou_pos_iou_thr_ids]
labels[iou_pos_iou_thr_ids] = gt_lables[iou_pos_iou_thr_ids_box_inds]
return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels
def __call__(self, anchors, gt_bboxes, gt_labels, is_crowd):
assert anchors.ndim == 2
assert anchors.shape[1] == 5
assert gt_bboxes.ndim == 2
assert gt_bboxes.shape[1] == 5
pos_iou_thr = self.pos_iou_thr
neg_iou_thr = self.neg_iou_thr
min_iou_thr = self.min_iou_thr
ignore_iof_thr = self.ignore_iof_thr
anchor_num = anchors.shape[0]
anchors_inds = self.anchor_valid(anchors)
anchors = anchors[anchors_inds]
gt_bboxes = gt_bboxes
is_crowd_slice = is_crowd
not_crowd_inds = np.where(is_crowd_slice == 0)
# Step1: match anchor and gt_bbox
anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = self.assign_anchor(
anchors, gt_bboxes,
gt_labels.reshape(-1), pos_iou_thr, neg_iou_thr, min_iou_thr,
ignore_iof_thr)
# Step2: sample anchor
pos_inds = np.where(labels >= 0)[0]
neg_inds = np.where(labels == -1)[0]
# Step3: make output
anchors_num = anchors.shape[0]
bbox_targets = np.zeros_like(anchors)
bbox_weights = np.zeros_like(anchors)
pos_labels = np.ones(anchors_num, dtype=np.int32) * -1
pos_labels_weights = np.zeros(anchors_num, dtype=np.float32)
pos_sampled_anchors = anchors[pos_inds]
#print('ancho target pos_inds', pos_inds, len(pos_inds))
pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]]
if len(pos_inds) > 0:
pos_bbox_targets = bbox_utils.rbox2delta(pos_sampled_anchors,
pos_sampled_gt_boxes)
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
pos_labels[pos_inds] = labels[pos_inds]
pos_labels_weights[pos_inds] = 1.0
if len(neg_inds) > 0:
pos_labels_weights[neg_inds] = 1.0
return (pos_labels, pos_labels_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds)
......@@ -20,8 +20,9 @@ from __future__ import unicode_literals
import numpy as np
from PIL import Image, ImageDraw
import cv2
from .colormap import colormap
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['visualize_results']
......@@ -86,21 +87,32 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold):
if score < threshold:
continue
xmin, ymin, w, h = bbox
xmax = xmin + w
ymax = ymin + h
if catid not in catid2color:
idx = np.random.randint(len(color_list))
catid2color[catid] = color_list[idx]
color = tuple(catid2color[catid])
# draw bbox
if len(bbox) == 4:
# draw bbox
xmin, ymin, w, h = bbox
xmax = xmin + w
ymax = ymin + h
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill=color)
elif len(bbox) == 8:
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
draw.line(
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
width=2,
fill=color)
xmin = min(x1, x2, x3, x4)
ymin = min(y1, y2, y3, y4)
else:
logger.error('the shape of bbox must be [M, 4] or [M, 8]!')
# draw label
text = "{} {:.2f}".format(catid2name[catid], score)
......@@ -112,6 +124,23 @@ def draw_bbox(image, im_id, catid2name, bboxes, threshold):
return image
def save_result(save_path, bbox_res, catid2name, threshold):
"""
save result as txt
"""
with open(save_path, 'w') as f:
for dt in bbox_res:
catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
if score < threshold:
continue
# each bbox result as a line
# for rbox: classname score x1 y1 x2 y2 x3 y3 x4 y4
# for bbox: classname score x1 y1 w h
bbox_pred = '{} {} '.format(catid2name[catid], score) + ' '.join(
[str(e) for e in bbox])
f.write(bbox_pred + '\n')
def draw_segm(image,
im_id,
catid2name,
......
......@@ -71,6 +71,11 @@ def parse_args():
type=str,
default="vdl_log_dir/image",
help='VisualDL logging directory for image.')
parser.add_argument(
"--save_txt",
type=bool,
default=False,
help="whether to record the data to VisualDL.")
args = parser.parse_args()
return args
......@@ -120,7 +125,8 @@ def run(FLAGS, cfg):
trainer.predict(
images,
draw_threshold=FLAGS.draw_threshold,
output_dir=FLAGS.output_dir)
output_dir=FLAGS.output_dir,
save_txt=FLAGS.save_txt)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册