未验证 提交 35bd189b 编写于 作者: C cnn 提交者: GitHub

[dev] add small dataset(spine) for s2anet, and add eval (#3401)

* support --eval for s2anet

* add spine dataset and config yml for s2anet

* add doc for s2anet

* update doc

* fix typo, test=document_fix

* add some comments, and update import
上级 d214d9ec
metric: COCO
metric: RBOX
num_classes: 15
TrainDataset:
......
metric: RBOX
num_classes: 9
TrainDataset:
!COCODataSet
image_dir: images
anno_path: annotations/train.json
dataset_dir: dataset/spine_coco
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_rbox']
EvalDataset:
!COCODataSet
image_dir: images
anno_path: annotations/valid.json
dataset_dir: dataset/spine_coco
TestDataset:
!ImageFolder
anno_path: annotations/valid.json
dataset_dir: dataset/spine_coco
......@@ -2,16 +2,19 @@
## 内容
- [简介](#简介)
- [DOTA数据集](#DOTA数据集)
- [准备数据](#准备数据)
- [开始训练](#开始训练)
- [模型库](#模型库)
- [训练说明](#训练说明)
- [预测部署](#预测部署)
## 简介
[S2ANet](https://arxiv.org/pdf/2008.09397.pdf)是用于检测旋转框的模型,要求使用PaddlePaddle 2.0.1(可使用pip安装) 或适当的[develop版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/install/Tables.html#whl-release)
## DOTA数据集
## 准备数据
### DOTA数据
[DOTA Dataset]是航空影像中物体检测的数据集,包含2806张图像,每张图像4000*4000分辨率。
| 数据版本 | 类别数 | 图像数 | 图像尺寸 | 实例数 | 标注方式 |
......@@ -27,19 +30,22 @@ DOTA数据集中总共有2806张图像,其中1411张图像作为训练集,45
设置`crop_size=1024, stride=824, gap=200`参数切割数据后,训练集15749张图像,评估集5297张图像,测试集10833张图像。
## 模型库
### 自定义数据
### S2ANet模型
数据标注有两种方式:
| 模型 | GPU个数 | Conv类型 | mAP | 模型下载 | 配置文件 |
|:-----------:|:-------:|:----------:|:--------:| :----------:| :---------: |
| S2ANet | 8 | Conv | 71.42 | [model](https://paddledet.bj.bcebos.com/models/s2anet_conv_1x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/s2anet_conv_1x_dota.yml) |
- 第一种是标注旋转矩形,可以通过旋转矩形标注工具[roLabelImg](https://github.com/cgvict/roLabelImg) 来标注旋转矩形框。
**注意:**这里使用`multiclass_nms`,与原作者使用nms略有不同,精度相比原始论文中高0.15 (71.27-->71.42)。
- 第二种是标注四边形,通过脚本转成外接旋转矩形,这样得到的标注可能跟真实的物体框有一定误差。
然后将标注结果转换成coco标注格式,其中每个`bbox`的格式为 `[x_center, y_center, width, height, angle]`,这里角度以弧度表示。
参考[脊椎间盘数据集](https://aistudio.baidu.com/aistudio/datasetdetail/85885) ,我们将数据集划分为训练集(230)、测试集(57),数据地址为:[spine_coco](https://paddledet.bj.bcebos.com/data/spine_coco.tar) 。该数据集图像数量比较少,使用这个数据集可以快速训练S2ANet模型。
## 训练说明
### 1. 旋转框IOU计算OP
## 开始训练
### 1. 安装旋转框IOU计算OP
旋转框IOU计算OP[ext_op](../../ppdet/ext_op)是参考Paddle[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/07_new_op/new_custom_op.html) 的方式开发。
......@@ -82,27 +88,59 @@ cd PaddleDetecetion/ppdet/ext_op
python3.7 test.py
```
### 2. 数据格式
DOTA 数据集中实例是按照任意四边形标注,在进行训练模型前,需要参考[DOTA2COCO](https://github.com/CAPTAIN-WHU/DOTA_devkit/blob/master/DOTA2COCO.py) 转换成`[xc, yc, bow_w, bow_h, angle]`格式,并以coco数据格式存储。
### 2. 训练
**注意:**
配置文件中学习率是按照8卡GPU训练设置的,如果使用单卡GPU训练,请将学习率设置为原来的1/8。
GPU单卡训练
```bash
export CUDA_VISIBLE_DEVICES=0
python3.7 tools/train.py -c configs/dota/s2anet_1x_spine.yml
```
GPU多卡训练
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python3.7 -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/dota/s2anet_1x_spine.yml
```
可以通过`--eval`开启边训练边测试。
### 2. 评估
```bash
python3.7 tools/eval.py -c configs/dota/s2anet_1x_spine.yml -o weitghts=output/s2anet_1x_spine/model_final.pdparams
```
## 评估
### 3. 预测
执行如下命令,会将图像预测结果保存到`output_dir`文件夹下。
```bash
python3.7 tools/infer.py -c configs/dota/s2anet_1x_spine.yml -o weitghts=output/s2anet_1x_spine/model_final.pdparams --infer_img=demo/39006.jpg
```
### 4. DOTA数据评估
执行如下命令,会在`output_dir`文件夹下将每个图像预测结果保存到同文件夹名的txt文本中。
```
python3.7 tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=./weights/s2anet_1x_dota.pdparams --infer_dir=dota_test_images --draw_threshold=0.05 --save_txt=True --output_dir=output
```
请参考[DOTA_devkit](https://github.com/CAPTAIN-WHU/DOTA_devkit) 生成评估文件,评估文件格式请参考[DOTA Test](http://captain.whu.edu.cn/DOTAweb/tasks.html) ,生成zip文件,每个类一个txt文件,txt文件中每行格式为:`image_id score x1 y1 x2 y2 x3 y3 x4 y4`,提交服务器进行评估。
## 模型库
### S2ANet模型
| 模型 | GPU个数 | Conv类型 | mAP | 模型下载 | 配置文件 |
|:-----------:|:-------:|:----------:|:--------:| :----------:| :---------: |
| S2ANet | 8 | Conv | 71.42 | [model](https://paddledet.bj.bcebos.com/models/s2anet_conv_1x_dota.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/dota/s2anet_conv_1x_dota.yml) |
**注意:**这里使用`multiclass_nms`,与原作者使用nms略有不同,精度相比原始论文中高0.15 (71.27-->71.42)。
## 预测部署
Paddle中`multiclass_nms`算子的输入支持四边形输入,因此部署时可以不需要依赖旋转框IOU计算算子。
```bash
# 预测
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/dota/s2anet_1x_dota.yml -o weights=model.pdparams --infer_img=demo/P0072__1.0__0___0.png
```
部署教程请参考[预测部署](../../deploy/README.md)
## Citations
......
_BASE_: [
'../datasets/spine_coco.yml',
'../runtime.yml',
'_base_/s2anet_optimizer_1x.yml',
'_base_/s2anet.yml',
'_base_/s2anet_reader.yml',
]
weights: output/s2anet_1x_spine/model_final
# for 8 card
LearningRate:
base_lr: 0.01
S2ANetHead:
anchor_strides: [8, 16, 32, 64, 128]
anchor_scales: [4]
anchor_ratios: [1.0]
anchor_assign: RBoxAssigner
stacked_convs: 2
feat_in: 256
feat_out: 256
num_classes: 9
align_conv_type: 'DCN' # AlignConv Conv
align_conv_size: 3
use_sigmoid_cls: True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os.path as osp
import logging
# add python path of PadleDetection to sys.path
parent_path = osp.abspath(osp.join(__file__, *(['..'] * 3)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from ppdet.utils.download import download_dataset
logging.basicConfig(level=logging.INFO)
download_path = osp.split(osp.realpath(sys.argv[0]))[0]
download_dataset(download_path, 'spine_coco')
......@@ -39,7 +39,7 @@ def get_categories(metric_type, anno_file=None, arch=None):
if arch == 'keypoint_arch':
return (None, {'id': 'keypoint'})
if metric_type.lower() == 'coco':
if metric_type.lower() == 'coco' or metric_type.lower() == 'rbox':
if anno_file and os.path.isfile(anno_file):
# lazy import pycocotools here
from pycocotools.coco import COCO
......
......@@ -36,6 +36,7 @@ from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results, save_result
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval
from ppdet.metrics import RBoxMetric
from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats
......@@ -178,6 +179,35 @@ class Trainer(object):
IouType=IouType,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'RBOX':
# TODO: bias should be unified
bias = self.cfg['bias'] if 'bias' in self.cfg else 0
output_eval = self.cfg['output_eval'] \
if 'output_eval' in self.cfg else None
save_prediction_only = self.cfg.get('save_prediction_only', False)
# pass clsid2catid info to metric instance to avoid multiple loading
# annotation file
clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
if self.mode == 'eval' else None
# when do validation in train, annotation file should be get from
# EvalReader instead of self.dataset(which is TrainReader)
anno_file = self.dataset.get_anno()
if self.mode == 'train' and validate:
eval_dataset = self.cfg['EvalDataset']
eval_dataset.check_or_download_dataset()
anno_file = eval_dataset.get_anno()
self._metrics = [
RBoxMetric(
anno_file=anno_file,
clsid2catid=clsid2catid,
classwise=classwise,
output_eval=output_eval,
bias=bias,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'VOC':
self._metrics = [
VOCMetric(
......
......@@ -21,6 +21,8 @@ import os
import sys
import numpy as np
import itertools
import paddle
from ppdet.modeling.bbox_utils import poly2rbox, rbox2poly_np
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
......@@ -89,6 +91,50 @@ def jaccard_overlap(pred, gt, is_bbox_normalized=False):
return overlap
def calc_rbox_iou(pred, gt_rbox):
"""
calc iou between rotated bbox
"""
# calc iou of bounding box for speedup
pred = np.array(pred, np.float32).reshape(-1, 8)
pred = pred.reshape(-1, 2)
gt_poly = rbox2poly_np(np.array(gt_rbox).reshape(-1, 5))[0]
gt_poly = gt_poly.reshape(-1, 2)
pred_rect = [
np.min(pred[:, 0]), np.min(pred[:, 1]), np.max(pred[:, 0]),
np.max(pred[:, 1])
]
gt_rect = [
np.min(gt_poly[:, 0]), np.min(gt_poly[:, 1]), np.max(gt_poly[:, 0]),
np.max(gt_poly[:, 1])
]
iou = jaccard_overlap(pred_rect, gt_rect, False)
if iou <= 0:
return iou
# calc rbox iou
pred = pred.reshape(-1, 8)
pred = np.array(pred, np.float32).reshape(-1, 8)
pred_rbox = poly2rbox(pred)
pred_rbox = pred_rbox.reshape(-1, 5)
pred_rbox = pred_rbox.reshape(-1, 5)
try:
from rbox_iou_ops import rbox_iou
except Exception as e:
print("import custom_ops error, try install rbox_iou_ops " \
"following ppdet/ext_op/README.md", e)
sys.stdout.flush()
sys.exit(-1)
gt_rbox = np.array(gt_rbox, np.float32).reshape(-1, 5)
pd_gt_rbox = paddle.to_tensor(gt_rbox, dtype='float32')
pd_pred_rbox = paddle.to_tensor(pred_rbox, dtype='float32')
iou = rbox_iou(pd_gt_rbox, pd_pred_rbox)
iou = iou.numpy()
return iou[0][0]
def prune_zero_padding(gt_box, gt_label, difficult=None):
valid_cnt = 0
for i in range(len(gt_box)):
......@@ -161,14 +207,16 @@ class DetectionMAP(object):
# record class score positive
visited = [False] * len(gt_label)
for b, s, l in zip(bbox, score, label):
xmin, ymin, xmax, ymax = b.tolist()
pred = [xmin, ymin, xmax, ymax]
pred = b.tolist() if isinstance(b, np.ndarray) else b
max_idx = -1
max_overlap = -1.0
for i, gl in enumerate(gt_label):
if int(gl) == int(l):
overlap = jaccard_overlap(pred, gt_box[i],
self.is_bbox_normalized)
if len(gt_box[i]) == 5:
overlap = calc_rbox_iou(pred, gt_box[i])
else:
overlap = jaccard_overlap(pred, gt_box[i],
self.is_bbox_normalized)
if overlap > max_overlap:
max_overlap = overlap
max_idx = i
......
......@@ -31,7 +31,12 @@ from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results'
'Metric',
'COCOMetric',
'VOCMetric',
'WiderFaceMetric',
'get_infer_results',
'RBoxMetric',
]
COCO_SIGMAS = np.array([
......@@ -299,3 +304,94 @@ class WiderFaceMetric(Metric):
pred_dir='output/pred',
eval_mode='widerface',
multi_scale=self.multi_scale)
class RBoxMetric(Metric):
def __init__(self, anno_file, **kwargs):
assert os.path.isfile(anno_file), \
"anno_file {} not a file".format(anno_file)
assert os.path.exists(anno_file), "anno_file {} not exists".format(
anno_file)
self.anno_file = anno_file
self.gt_anno = json.load(open(self.anno_file))
cats = self.gt_anno['categories']
self.clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
self.catid2clsid = {cat['id']: i for i, cat in enumerate(cats)}
self.catid2name = {cat['id']: cat['name'] for cat in cats}
self.classwise = kwargs.get('classwise', False)
self.output_eval = kwargs.get('output_eval', None)
# TODO: bias should be unified
self.bias = kwargs.get('bias', 0)
self.save_prediction_only = kwargs.get('save_prediction_only', False)
self.iou_type = kwargs.get('IouType', 'bbox')
self.overlap_thresh = kwargs.get('overlap_thresh', 0.5)
self.map_type = kwargs.get('map_type', '11point')
self.evaluate_difficult = kwargs.get('evaluate_difficult', False)
class_num = len(self.catid2name)
self.detection_map = DetectionMAP(
class_num=class_num,
overlap_thresh=self.overlap_thresh,
map_type=self.map_type,
is_bbox_normalized=False,
evaluate_difficult=self.evaluate_difficult,
catid2name=self.catid2name,
classwise=self.classwise)
self.reset()
def reset(self):
self.result_bbox = []
self.detection_map.reset()
def update(self, inputs, outputs):
outs = {}
# outputs Tensor -> numpy.ndarray
for k, v in outputs.items():
outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v
im_id = inputs['im_id']
outs['im_id'] = im_id.numpy() if isinstance(im_id,
paddle.Tensor) else im_id
infer_results = get_infer_results(
outs, self.clsid2catid, bias=self.bias)
self.result_bbox += infer_results[
'bbox'] if 'bbox' in infer_results else []
bbox = [b['bbox'] for b in self.result_bbox]
score = [b['score'] for b in self.result_bbox]
label = [b['category_id'] for b in self.result_bbox]
label = [self.catid2clsid[e] for e in label]
gt_box = [
e['bbox'] for e in self.gt_anno['annotations']
if e['image_id'] == outs['im_id']
]
gt_label = [
e['category_id'] for e in self.gt_anno['annotations']
if e['image_id'] == outs['im_id']
]
gt_label = [self.catid2clsid[e] for e in gt_label]
self.detection_map.update(bbox, score, label, gt_box, gt_label)
def accumulate(self):
if len(self.result_bbox) > 0:
output = "bbox.json"
if self.output_eval:
output = os.path.join(self.output_eval, output)
with open(output, 'w') as f:
json.dump(self.result_bbox, f)
logger.info('The bbox result is saved to bbox.json.')
if self.save_prediction_only:
logger.info('The bbox result is saved to {} and do not '
'evaluate the mAP.'.format(output))
else:
logger.info("Accumulating evaluatation results...")
self.detection_map.accumulate()
def log(self):
map_stat = 100. * self.detection_map.get_map()
logger.info("mAP({:.2f}, {}) = {:.2f}%".format(self.overlap_thresh,
self.map_type, map_stat))
def get_results(self):
return {'bbox': [self.detection_map.get_map()]}
......@@ -618,8 +618,6 @@ class S2ANetHead(nn.Layer):
fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
fam_cls_score1 = fam_cls_score
# gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
feat_labels = feat_labels + 1
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
......@@ -681,9 +679,6 @@ class S2ANetHead(nn.Layer):
odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
odm_cls_score1 = odm_cls_score
# gt_classes 0~14(data), feat_labels 0~14, sigmoid_focal_loss need class>=1
# for debug 0426
feat_labels = feat_labels + 1
feat_labels = paddle.to_tensor(feat_labels)
feat_labels_one_hot = F.one_hot(feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:]
......@@ -833,10 +828,5 @@ class S2ANetHead(nn.Layer):
mlvl_bboxes = paddle.concat(mlvl_bboxes, axis=0)
mlvl_scores = paddle.concat(mlvl_scores)
if use_sigmoid_cls:
# Add a dummy background class to the front when using sigmoid
padding = paddle.zeros(
[mlvl_scores.shape[0], 1], dtype=mlvl_scores.dtype)
mlvl_scores = paddle.concat([padding, mlvl_scores], axis=1)
return mlvl_scores, mlvl_bboxes
......@@ -93,6 +93,9 @@ DATASETS = {
'roadsign_coco': ([(
'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_coco.tar',
'49ce5a9b5ad0d6266163cd01de4b018e', ), ], ['annotations', 'images']),
'spine_coco': ([(
'https://paddledet.bj.bcebos.com/data/spine_coco.tar',
'03030f42d9b6202a6e425d4becefda0d', ), ], ['annotations', 'images']),
'mot': (),
'objects365': ()
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册