未验证 提交 9b67ef53 编写于 作者: K Kaipeng Deng 提交者: GitHub

Add PointRCNN model (#3967)

* add PointRCNN model
   by heavengate, FDInSky, tink2123
上级 3584aeec
*log*
checkpoints*
build
output
result_dir
pp_pointrcnn*
data/gt_database
utils/pts_utils/dist
utils/pts_utils/build
utils/pts_utils/pts_utils.egg-info
utils/cyops/*.c
utils/cyops/*.so
ext_op/src/*.o
ext_op/src/*.so
# PointRCNN 3D目标检测模型
---
## 内容
- [简介](#简介)
- [快速开始](#快速开始)
- [参考文献](#参考文献)
- [版本更新](#版本更新)
## 简介
[PointRCNN](https://arxiv.org/abs/1812.04244) 是 Shaoshuai Shi, Xiaogang Wang, Hongsheng Li. 等人提出的,是第一个仅使用原始点云的2-stage(两阶段)3D目标检测器,第一阶段将 Pointnet++ with MSG(Multi-scale Grouping)作为backbone,直接将原始点云数据分割为前景点和背景点,并利用前景点生成bounding box。第二阶段在标准坐标系中对生成对bounding box进一步筛选和优化。该模型还提出了基于bin的方式,把回归问题转化为分类问题,验证了在三维边界框回归中的有效性。PointRCNN在KITTI数据集上进行评估,论文发布时在KITTI 3D目标检测排行榜上获得了最佳性能。
网络结构如下所示:
<p align="center">
<img src="images/teaser.png" height=300 width=800 hspace='10'/> <br />
用于点云的目标检测器 PointNet++
</p>
**注意:** PointRCNN 模型构建依赖于自定义的 C++ 算子,目前仅支持GPU设备在Linux/Unix系统上进行编译,本模型**不能运行在Windows系统或CPU设备上**
## 快速开始
### 安装
**安装 [PaddlePaddle](https://github.com/PaddlePaddle/Paddle):**
在当前目录下运行样例代码需要 PaddelPaddle Fluid [develop每日版本](https://www.paddlepaddle.org.cn/install/doc/tables#多版本whl包列表-dev-11)或使用PaddlePaddle [develop分支](https://github.com/PaddlePaddle/Paddle/tree/develop)源码编译安装.
为了使自定义算子与paddle版本兼容,建议您**优先使用源码编译paddle**,源码编译方式请参考[编译安装](https://www.paddlepaddle.org.cn/install/doc/source/ubuntu)
**安装PointRCNN:**
1. 下载[PaddlePaddle/models](https://github.com/PaddlePaddle/models)模型库
通过如下命令下载Paddle models模型库:
```
git clone https://github.com/PaddlePaddle/models
```
2.`PaddleCV/Paddle3D/PointRCNN`目录下下载[pybind11](https://github.com/pybind/pybind11)
`pts_utils`依赖`pybind11`编译,须在`PaddleCV/Paddle3D/PointRCNN`目录下下载`pybind11`子库,可使用如下命令下载:
```
cd PaddleCV/Paddle3D/PointRCNN
git clone https://github.com/pybind/pybind11
```
3. 编译安装`pts_utils`, `kitti_utils`, `roipool3d_utils`, `iou_utils` 等模块
使用如下命令编译安装`pts_utils`, `kitti_utils`, `roipool3d_utils`, `iou_utils` 等模块:
```
sh build_and_install.sh
```
4. 安装python依赖库
使用如下命令安装python依赖库:
```
pip install -r requirement.txt
```
**注意:** KITTI mAP评估工具只能在python 3.6及以上版本中使用,且python3环境中需要安装`scikit-image`,`Numba`,`fire`等子库。
`requirement.txt`中的`scikit-image`,`Numba`,`fire`即为KITTI mAP评估工具所需依赖库。
### 编译自定义OP
请确认Paddle版本为PaddelPaddle Fluid develop每日版本或基于Paddle develop分支源码编译安装,**推荐使用源码编译安装的方式**
自定义OP编译方式如下:
进入 `ext_op/src` 目录,执行编译脚本
```
cd ext_op/src
sh make.sh
```
成功编译后,`ext_op/src` 目录下将会生成 `pointnet2_lib.so`
执行下列操作,确保自定义算子编译正确:
```
# 设置动态库的路径到 LD_LIBRARY_PATH 中
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'`
# 回到 ext_op 目录,添加 PYTHONPATH
cd ..
export PYTHONPATH=$PYTHONPATH:`pwd`
# 运行单测
python tests/test_farthest_point_sampling_op.py
python tests/test_gather_point_op.py
python tests/test_group_points_op.py
python tests/test_query_ball_op.py
python tests/test_three_interp_op.py
python tests/test_three_nn_op.py
```
单测运行成功会输出提示信息,如下所示:
```
.
----------------------------------------------------------------------
Ran 1 test in 13.205s
OK
```
**说明:** 自定义OP编译与[PointNet++](../PointNet++)下一致,更多关于自定义OP的编译说明,请参考[自定义OP编译](../PointNet++/ext_op/README.md)
### 数据准备
**KITTI 3D object detection 数据集:**
PointRCNN使用数据集[KITTI 3D object detection](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d)
上进行训练。
可通过如下方式下载数据集:
```
cd data/KITTI/object
sh download.sh
```
此处的images只用做可视化,训练过程中使用[road planes](https://drive.google.com/file/d/1d5mq0RXRnvHPVeKx6Q612z0YRO1t2wAp/view?usp=sharing)数据来做训练时的数据增强,
请下载并解压至`./data/KITTI/object/training`目录下。
数据目录结构如下所示:
```
PointRCNN
├── data
│ ├── KITTI
│ │ ├── ImageSets
│ │ ├── object
│ │ │ ├──training
│ │ │ │ ├──calib & velodyne & label_2 & image_2 & planes
│ │ │ ├──testing
│ │ │ │ ├──calib & velodyne & image_2
```
### 训练
**PointRCNN模型:**
可通过如下方式启动 PointRCNN模型的训练:
1. 指定单卡训练并设置动态库路径
```
# 指定单卡GPU训练
export CUDA_VISIBLE_DEVICES=0
# 设置动态库的路径到 LD_LIBRARY_PATH 中
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'`
```
2. 生成Groud Truth采样数据,命令如下:
```
python tools/generate_gt_database.py --class_name 'Car' --split train
```
3. 训练 RPN 模型
```
python train.py --cfg=./cfgs/default.yml \
--train_mode=rpn \
--batch_size=16 \
--epoch=200 \
--save_dir=checkpoints
```
RPN训练checkpoints默认保存在`checkpoints/rpn`目录,也可以通过`--save_dir`来指定。
4. 生成增强离线场景数据并保存RPN模型的输出特征和ROI,用于离线训练 RCNN 模型
生成增强的离线场景数据命令如下:
```
python tools/generate_aug_scene.py --class_name 'Car' --split train --aug_times 4
```
保存RPN模型对离线增强数据的输出特征和ROI,可以通过参数`--ckpt_dir`来指定RPN训练最终权重保存路径,RPN权重默认保存在`checkpoints/rpn`目录。
保存输出特征和ROI时须指定`TEST.SPLIT``train_aug`,指定`TEST.RPN_POST_NMS_TOP_N``300`, `TEST.RPN_NMS_THRESH``0.85`
通过`--output_dir`指定保存输出特征和ROI的路径,默认保存到`./output`目录。
```
python eval.py --cfg=cfgs/default.yml \
--eval_mode=rpn \
--ckpt_dir=./checkpoints/rpn/199 \
--save_rpn_feature \
--output_dir=output \
--set TEST.SPLIT train_aug TEST.RPN_POST_NMS_TOP_N 300 TEST.RPN_NMS_THRESH 0.85
```
`--output_dir`下保存的数据目录结构如下:
```
output
├── detections
│ ├── data # 保存ROI数据
│ │ ├── 000000.txt
│ │ ├── 000003.txt
│ │ ├── ...
├── features # 保存输出特征
│ ├── 000000_intensity.npy
│ ├── 000000.npy
│ ├── 000000_rawscore.npy
│ ├── 000000_seg.npy
│ ├── 000000_xyz.npy
│ ├── ...
├── seg_result # 保存语义分割结果
│ ├── 000000.npy
│ ├── 000003.npy
│ ├── ...
```
5. 离线训练RCNN,并且通过参数`--rcnn_training_roi_dir` and `--rcnn_training_feature_dir` 来指定 RPN 模型保存的输出特征和ROI路径。
```
python train.py --cfg=./cfgs/default.yml \
--train_mode=rcnn_offline \
--batch_size=4 \
--epoch=30 \
--save_dir=checkpoints \
--rcnn_training_roi_dir=output/detections/data \
--rcnn_training_feature_dir=output/features
```
RCNN模型训练权重默认保存在`checkpoints/rcnn`目录下,可通过`--save_dir`参数指定。
**注意**: 最好的模型是通过保存RPN模型输出特征和ROI并离线数据增强的方式训练RCNN模型得出的,目前默认仅支持这种方式。
### 模型评估
**PointRCNN模型:**
可通过如下方式启动 PointRCNN 模型的评估:
1. 指定单卡训练并设置动态库路径
```
# 指定单卡GPU训练
export CUDA_VISIBLE_DEVICES=0
# 设置动态库的路径到 LD_LIBRARY_PATH 中
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'`
```
2. 保存RPN模型对评估数据的输出特征和ROI
保存RPN模型对评估数据的输出特征和ROI命令如下,可以通过参数`--ckpt_dir`来指定RPN训练最终权重保存路径,RPN权重默认保存在`checkpoints/rpn`目录。
通过`--output_dir`指定保存输出特征和ROI的路径,默认保存到`./output`目录。
```
python eval.py --cfg=cfgs/default.yml \
--eval_mode=rpn \
--ckpt_dir=./checkpoints/rpn/199 \
--save_rpn_feature \
--output_dir=output/val
```
保存RPN模型对评估数据的输出特征和ROI保存的目录结构与上述保存离线增强数据保存目录结构一致。
3. 评估离线RCNN模型
评估离线RCNN模型命令如下:
```
python eval.py --cfg=cfgs/default.yml \
--eval_mode=rcnn_offline \
--ckpt_dir=./checkpoints/rcnn_offline/29 \
--rcnn_eval_roi_dir=output/val/detections/data \
--rcnn_eval_feature_dir=output/val/features \
--save_result
```
最终目标检测结果文件保存在`./result_dir`目录下`final_result`文件夹下,同时可通过`--save_result`开启保存`roi_output``refine_output`结果文件。
`result_dir`目录结构如下:
```
result_dir
├── final_result
│ ├── data # 最终检测结果
│ │ ├── 000001.txt
│ │ ├── 000002.txt
│ │ ├── ...
├── roi_output
│ ├── data # RCNN模型输出检测ROI结果
│ │ ├── 000001.txt
│ │ ├── 000002.txt
│ │ ├── ...
├── refine_output
│ ├── data # 解码后的检测结果
│ │ ├── 000001.txt
│ │ ├── 000002.txt
│ │ ├── ...
```
4. 使用KITTI mAP工具获得评估结果
若在评估过程中使用的python版本为3.6及以上版本,则程序会自动运行KITTI mAP评估,若使用python版本低于3.6,
由于KITTI mAP仅支持python 3.6及以上版本,须使用对应python版本通过如下命令进行评估:
```
python3 kitti_map.py
```
使用训练最终权重[RPN模型](https://paddlemodels.bj.bcebos.com/Paddle3D/pointrcnn_rpn.tar)[RCNN模型](https://paddlemodels.bj.bcebos.com/Paddle3D/pointrcnn_rcnn_offline.tar)评估结果如下所示:
| Car AP@ | 0.70(easy) | 0.70(moderate) | 0.70(hard) |
| :------- | :--------: | :------------: | :--------: |
| bbox AP: | 90.20 | 88.85 | 88.59 |
| bev AP: | 89.50 | 86.97 | 85.58 |
| 3d AP: | 86.66 | 76.65 | 75.90 |
| aos AP: | 90.10 | 88.64 | 88.26 |
## 参考文献
- [PointRCNN: 3D Object Proposal Generation and Detection From Point Cloud](https://arxiv.org/abs/1812.04244), Shaoshuai Shi, Xiaogang Wang, Hongsheng Li.
- [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413), Charles R. Qi, Li Yi, Hao Su, Leonidas J. Guibas.
- [PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation](https://www.semanticscholar.org/paper/PointNet%3A-Deep-Learning-on-Point-Sets-for-3D-and-Qi-Su/d997beefc0922d97202789d2ac307c55c2c52fba), Charles Ruizhongtai Qi, Hao Su, Kaichun Mo, Leonidas J. Guibas.
## 版本更新
- 11/2019, 新增 PointRCNN模型。
# compile cyops
python utils/cyops/setup.py develop
# compile and install pts_utils
cd utils/pts_utils
python setup.py install
cd ../..
# This config is based on https://github.com/sshaoshuai/PointRCNN/blob/master/tools/cfgs/default.yaml
CLASSES: Car
INCLUDE_SIMILAR_TYPE: True
# config of augmentation
AUG_DATA: True
AUG_METHOD_LIST: ['rotation', 'scaling', 'flip']
AUG_METHOD_PROB: [1.0, 1.0, 0.5]
AUG_ROT_RANGE: 18
GT_AUG_ENABLED: True
GT_EXTRA_NUM: 15
GT_AUG_RAND_NUM: True
GT_AUG_APPLY_PROB: 1.0
GT_AUG_HARD_RATIO: 0.6
PC_REDUCE_BY_RANGE: True
PC_AREA_SCOPE: [[-40, 40], [-1, 3], [0, 70.4]] # x, y, z scope in rect camera coords
CLS_MEAN_SIZE: [[1.52563191462, 1.62856739989, 3.88311640418]]
# 1. config of rpn network
RPN:
ENABLED: True
FIXED: False
# config of input
USE_INTENSITY: False
# config of bin-based loss
LOC_XZ_FINE: True
LOC_SCOPE: 3.0
LOC_BIN_SIZE: 0.5
NUM_HEAD_BIN: 12
# config of network structure
BACKBONE: pointnet2_msg
USE_BN: True
NUM_POINTS: 16384
SA_CONFIG:
NPOINTS: [4096, 1024, 256, 64]
RADIUS: [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]
NSAMPLE: [[16, 32], [16, 32], [16, 32], [16, 32]]
MLPS: [[[16, 16, 32], [32, 32, 64]],
[[64, 64, 128], [64, 96, 128]],
[[128, 196, 256], [128, 196, 256]],
[[256, 256, 512], [256, 384, 512]]]
FP_MLPS: [[128, 128], [256, 256], [512, 512], [512, 512]]
CLS_FC: [128]
REG_FC: [128]
DP_RATIO: 0.5
# config of training
LOSS_CLS: SigmoidFocalLoss
FG_WEIGHT: 15
FOCAL_ALPHA: [0.25, 0.75]
FOCAL_GAMMA: 2.0
REG_LOSS_WEIGHT: [1.0, 1.0, 1.0, 1.0]
LOSS_WEIGHT: [1.0, 1.0]
NMS_TYPE: normal
# config of testing
SCORE_THRESH: 0.3
# 2. config of rcnn network
RCNN:
ENABLED: True
# config of input
ROI_SAMPLE_JIT: False
REG_AUG_METHOD: multiple # multiple, single, normal
ROI_FG_AUG_TIMES: 10
USE_RPN_FEATURES: True
USE_MASK: True
MASK_TYPE: seg
USE_INTENSITY: False
USE_DEPTH: True
USE_SEG_SCORE: False
POOL_EXTRA_WIDTH: 1.0
# config of bin-based loss
LOC_SCOPE: 1.5
LOC_BIN_SIZE: 0.5
NUM_HEAD_BIN: 9
LOC_Y_BY_BIN: False
LOC_Y_SCOPE: 0.5
LOC_Y_BIN_SIZE: 0.25
SIZE_RES_ON_ROI: False
# config of network structure
USE_BN: False
DP_RATIO: 0.0
BACKBONE: pointnet # pointnet
XYZ_UP_LAYER: [128, 128]
NUM_POINTS: 512
SA_CONFIG:
NPOINTS: [128, 32, -1]
RADIUS: [0.2, 0.4, 100]
NSAMPLE: [64, 64, 64]
MLPS: [[128, 128, 128],
[128, 128, 256],
[256, 256, 512]]
CLS_FC: [256, 256]
REG_FC: [256, 256]
# config of training
LOSS_CLS: BinaryCrossEntropy
FOCAL_ALPHA: [0.25, 0.75]
FOCAL_GAMMA: 2.0
CLS_WEIGHT: [1.0, 1.0, 1.0]
CLS_FG_THRESH: 0.6
CLS_BG_THRESH: 0.45
CLS_BG_THRESH_LO: 0.05
REG_FG_THRESH: 0.55
FG_RATIO: 0.5
ROI_PER_IMAGE: 64
HARD_BG_RATIO: 0.8
# config of testing
SCORE_THRESH: 0.3
NMS_THRESH: 0.1
# general training config
TRAIN:
SPLIT: train
VAL_SPLIT: smallval
LR: 0.002
LR_CLIP: 0.00001
LR_DECAY: 0.5
DECAY_STEP_LIST: [100, 150, 180, 200]
LR_WARMUP: True
WARMUP_MIN: 0.0002
WARMUP_EPOCH: 1
BN_MOMENTUM: 0.1
BN_DECAY: 0.5
BNM_CLIP: 0.01
BN_DECAY_STEP_LIST: [1000]
OPTIMIZER: adam # adam, adam_onecycle
WEIGHT_DECAY: 0.001 # L2 regularization
MOMENTUM: 0.9
MOMS: [0.95, 0.85]
DIV_FACTOR: 10.0
PCT_START: 0.4
GRAD_NORM_CLIP: 1.0
RPN_PRE_NMS_TOP_N: 9000
RPN_POST_NMS_TOP_N: 512
RPN_NMS_THRESH: 0.85
RPN_DISTANCE_BASED_PROPOSE: True
TEST:
SPLIT: val
RPN_PRE_NMS_TOP_N: 9000
RPN_POST_NMS_TOP_N: 100
RPN_NMS_THRESH: 0.8
RPN_DISTANCE_BASED_PROPOSE: True
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
echo "Downloading https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_velodyne.zip"
wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_velodyne.zip
echo "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_image_2.zip"
wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_image_2.zip
echo "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_calib.zip"
wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_calib.zip
echo "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_label_2.zip"
wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_object_label_2.zip
echo "Decompressing data_object_velodyne.zip"
unzip data_object_velodyne.zip
echo "Decompressing data_object_image_2.zip"
unzip "data_object_image_2.zip"
echo "Decompressing data_object_calib.zip"
unzip data_object_calib.zip
echo "Decompressing data_object_label_2.zip"
unzip data_object_label_2.zip
echo "Download KITTI ImageSets"
wget https://paddlemodels.bj.bcebos.com/Paddle3D/pointrcnn_kitti_imagesets.tar
tar xf pointrcnn_kitti_imagesets.tar
mv ImageSets ..
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
This code is based on https://github.com/sshaoshuai/PointRCNN/blob/master/lib/datasets/kitti_dataset.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import cv2
import numpy as np
import utils.calibration as calibration
from utils.object3d import get_objects_from_label
from PIL import Image
__all__ = ["KittiDataset"]
class KittiDataset(object):
def __init__(self, data_dir, split='train'):
assert split in ['train', 'train_aug', 'val', 'test'], "unknown split {}".format(split)
self.split = split
self.is_test = self.split == 'test'
self.imageset_dir = os.path.join(data_dir, 'KITTI', 'object', 'testing' if self.is_test else 'training')
split_dir = os.path.join(data_dir, 'KITTI', 'ImageSets', split + '.txt')
self.image_idx_list = [x.strip() for x in open(split_dir).readlines()]
self.num_sample = self.image_idx_list.__len__()
self.image_dir = os.path.join(self.imageset_dir, 'image_2')
self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne')
self.calib_dir = os.path.join(self.imageset_dir, 'calib')
self.label_dir = os.path.join(self.imageset_dir, 'label_2')
self.plane_dir = os.path.join(self.imageset_dir, 'planes')
def get_image(self, idx):
img_file = os.path.join(self.image_dir, '%06d.png' % idx)
assert os.path.exists(img_file)
return cv2.imread(img_file) # (H, W, 3) BGR mode
def get_image_shape(self, idx):
img_file = os.path.join(self.image_dir, '%06d.png' % idx)
assert os.path.exists(img_file)
im = Image.open(img_file)
width, height = im.size
return height, width, 3
def get_lidar(self, idx):
lidar_file = os.path.join(self.lidar_dir, '%06d.bin' % idx)
assert os.path.exists(lidar_file)
return np.fromfile(lidar_file, dtype=np.float32).reshape(-1, 4)
def get_calib(self, idx):
calib_file = os.path.join(self.calib_dir, '%06d.txt' % idx)
assert os.path.exists(calib_file)
return calibration.Calibration(calib_file)
def get_label(self, idx):
label_file = os.path.join(self.label_dir, '%06d.txt' % idx)
assert os.path.exists(label_file)
# return kitti_utils.get_objects_from_label(label_file)
return get_objects_from_label(label_file)
def get_road_plane(self, idx):
plane_file = os.path.join(self.plane_dir, '%06d.txt' % idx)
with open(plane_file, 'r') as f:
lines = f.readlines()
lines = [float(i) for i in lines[3].split()]
plane = np.asarray(lines)
# Ensure normal is always facing up, this is in the rectified camera coordinate
if plane[1] > 0:
plane = -plane
norm = np.linalg.norm(plane[0:3])
plane = plane / norm
return plane
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import shutil
import argparse
import logging
import multiprocessing
import numpy as np
from collections import OrderedDict
import paddle
import paddle.fluid as fluid
from models.point_rcnn import PointRCNN
from data.kitti_rcnn_reader import KittiRCNNReader
from utils.run_utils import *
from utils.config import cfg, load_config, set_config_from_list
from utils.metric_utils import calc_iou_recall, rpn_metric, rcnn_metric
logging.root.handlers = []
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
np.random.seed(1024) # use same seed
METRIC_PROC_NUM = 4
def parse_args():
parser = argparse.ArgumentParser(
"PointRCNN semantic segmentation train script")
parser.add_argument(
'--cfg',
type=str,
default='cfgs/default.yml',
help='specify the config for training')
parser.add_argument(
'--eval_mode',
type=str,
default='rpn',
required=True,
help='specify the training mode')
parser.add_argument(
'--batch_size',
type=int,
default=1,
help='evaluation batch size, default 1')
parser.add_argument(
'--ckpt_dir',
type=str,
default='checkpoints/199',
help='specify a ckpt directory to be evaluated if needed')
parser.add_argument(
'--data_dir',
type=str,
default='./data',
help='KITTI dataset root directory')
parser.add_argument(
'--output_dir',
type=str,
default='output',
help='output directory')
parser.add_argument(
'--save_rpn_feature',
action='store_true',
default=False,
help='save features for separately rcnn training and evaluation')
parser.add_argument(
'--save_result',
action='store_true',
default=False,
help='save roi and refine result of evaluation')
parser.add_argument(
'--rcnn_eval_roi_dir',
type=str,
default=None,
help='specify the saved rois for rcnn evaluation when using rcnn_offline mode')
parser.add_argument(
'--rcnn_eval_feature_dir',
type=str,
default=None,
help='specify the saved features for rcnn evaluation when using rcnn_offline mode')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
parser.add_argument(
'--set',
dest='set_cfgs',
default=None,
nargs=argparse.REMAINDER,
help='set extra config keys if needed.')
args = parser.parse_args()
return args
def eval():
args = parse_args()
print_arguments(args)
# check whether the installed paddle is compiled with GPU
# PointRCNN model can only run on GPU
check_gpu(True)
load_config(args.cfg)
if args.set_cfgs is not None:
set_config_from_list(args.set_cfgs)
if not os.path.isdir(args.output_dir):
os.makedirs(args.output_dir)
if args.eval_mode == 'rpn':
cfg.RPN.ENABLED = True
cfg.RCNN.ENABLED = False
elif args.eval_mode == 'rcnn':
cfg.RCNN.ENABLED = True
cfg.RPN.ENABLED = cfg.RPN.FIXED = True
assert args.batch_size, "batch size must be 1 in rcnn evaluation"
elif args.eval_mode == 'rcnn_offline':
cfg.RCNN.ENABLED = True
cfg.RPN.ENABLED = False
assert args.batch_size, "batch size must be 1 in rcnn_offline evaluation"
else:
raise NotImplementedError("unkown eval mode: {}".format(args.eval_mode))
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
# build model
startup = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup):
with fluid.unique_name.guard():
eval_model = PointRCNN(cfg, args.batch_size, True, 'TEST')
eval_model.build()
eval_pyreader = eval_model.get_pyreader()
eval_feeds = eval_model.get_feeds()
eval_outputs = eval_model.get_outputs()
eval_prog = eval_prog.clone(True)
extra_keys = []
if args.eval_mode == 'rpn':
extra_keys.extend(['sample_id', 'rpn_cls_label', 'gt_boxes3d'])
if args.save_rpn_feature:
extra_keys.extend(['pts_rect', 'pts_features', 'pts_input',])
eval_keys, eval_values = parse_outputs(
eval_outputs, prog=eval_prog, extra_keys=extra_keys)
eval_compile_prog = fluid.compiler.CompiledProgram(
eval_prog).with_data_parallel()
exe.run(startup)
# load checkpoint
assert os.path.isdir(
args.ckpt_dir), "ckpt_dir {} not a directory".format(args.ckpt_dir)
def if_exist(var):
return os.path.exists(os.path.join(args.ckpt_dir, var.name))
fluid.io.load_vars(exe, args.ckpt_dir, eval_prog, predicate=if_exist)
kitti_feature_dir = os.path.join(args.output_dir, 'features')
kitti_output_dir = os.path.join(args.output_dir, 'detections', 'data')
seg_output_dir = os.path.join(args.output_dir, 'seg_result')
if args.save_rpn_feature:
if os.path.exists(kitti_feature_dir):
shutil.rmtree(kitti_feature_dir)
os.makedirs(kitti_feature_dir)
if os.path.exists(kitti_output_dir):
shutil.rmtree(kitti_output_dir)
os.makedirs(kitti_output_dir)
if os.path.exists(seg_output_dir):
shutil.rmtree(seg_output_dir)
os.makedirs(seg_output_dir)
# must make sure these dirs existing
roi_output_dir = os.path.join('./result_dir', 'roi_result', 'data')
refine_output_dir = os.path.join('./result_dir', 'refine_result', 'data')
final_output_dir = os.path.join("./result_dir", 'final_result', 'data')
if not os.path.exists(final_output_dir):
os.makedirs(final_output_dir)
if args.save_result:
if not os.path.exists(roi_output_dir):
os.makedirs(roi_output_dir)
if not os.path.exists(refine_output_dir):
os.makedirs(refine_output_dir)
# get reader
kitti_rcnn_reader = KittiRCNNReader(data_dir=args.data_dir,
npoints=cfg.RPN.NUM_POINTS,
split=cfg.TEST.SPLIT,
mode='EVAL',
classes=cfg.CLASSES,
rcnn_eval_roi_dir=args.rcnn_eval_roi_dir,
rcnn_eval_feature_dir=args.rcnn_eval_feature_dir)
eval_reader = kitti_rcnn_reader.get_multiprocess_reader(args.batch_size, eval_feeds)
eval_pyreader.decorate_sample_list_generator(eval_reader, place)
thresh_list = [0.1, 0.3, 0.5, 0.7, 0.9]
queue = multiprocessing.Queue(128)
mgr = multiprocessing.Manager()
lock = multiprocessing.Lock()
mdict = mgr.dict()
if cfg.RPN.ENABLED:
mdict['exit_proc'] = 0
mdict['total_gt_bbox'] = 0
mdict['total_cnt'] = 0
mdict['total_rpn_iou'] = 0
for i in range(len(thresh_list)):
mdict['total_recalled_bbox_list_{}'.format(i)] = 0
p_list = []
for i in range(METRIC_PROC_NUM):
p_list.append(multiprocessing.Process(
target=rpn_metric,
args=(queue, mdict, lock, thresh_list, args.save_rpn_feature, kitti_feature_dir,
seg_output_dir, kitti_output_dir, kitti_rcnn_reader, cfg.CLASSES)))
p_list[-1].start()
if cfg.RCNN.ENABLED:
for i in range(len(thresh_list)):
mdict['total_recalled_bbox_list_{}'.format(i)] = 0
mdict['total_roi_recalled_bbox_list_{}'.format(i)] = 0
mdict['exit_proc'] = 0
mdict['total_cls_acc'] = 0
mdict['total_cls_acc_refined'] = 0
mdict['total_det_num'] = 0
mdict['total_gt_bbox'] = 0
p_list = []
for i in range(METRIC_PROC_NUM):
p_list.append(multiprocessing.Process(
target=rcnn_metric,
args=(queue, mdict, lock, thresh_list, kitti_rcnn_reader, roi_output_dir,
refine_output_dir, final_output_dir, args.save_result)
))
p_list[-1].start()
try:
eval_pyreader.start()
eval_iter = 0
start_time = time.time()
cur_time = time.time()
while True:
eval_outs = exe.run(eval_compile_prog, fetch_list=eval_values, return_numpy=False)
rets_dict = {k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(eval_keys, eval_outs)}
run_time = time.time() - cur_time
cur_time = time.time()
queue.put(rets_dict)
eval_iter += 1
logger.info("[EVAL] iter {}, time: {:.2f}".format(
eval_iter, run_time))
except fluid.core.EOFException:
# terminate metric process
for i in range(METRIC_PROC_NUM):
queue.put(None)
while mdict['exit_proc'] < METRIC_PROC_NUM:
time.sleep(1)
for p in p_list:
if p.is_alive():
p.join()
end_time = time.time()
logger.info("[EVAL] total {} iter finished, average time: {:.2f}".format(
eval_iter, (end_time - start_time) / float(eval_iter)))
if cfg.RPN.ENABLED:
avg_rpn_iou = mdict['total_rpn_iou'] / max(len(kitti_rcnn_reader), 1.)
logger.info("average rpn iou: {:.3f}".format(avg_rpn_iou))
total_gt_bbox = float(max(mdict['total_gt_bbox'], 1.0))
for idx, thresh in enumerate(thresh_list):
recall = mdict['total_recalled_bbox_list_{}'.format(idx)] / total_gt_bbox
logger.info("total bbox recall(thresh={:.3f}): {} / {} = {:.3f}".format(
thresh, mdict['total_recalled_bbox_list_{}'.format(idx)], mdict['total_gt_bbox'], recall))
if cfg.RCNN.ENABLED:
cnt = float(max(eval_iter, 1.0))
avg_cls_acc = mdict['total_cls_acc'] / cnt
avg_cls_acc_refined = mdict['total_cls_acc_refined'] / cnt
avg_det_num = mdict['total_det_num'] / cnt
logger.info("avg_cls_acc: {}".format(avg_cls_acc))
logger.info("avg_cls_acc_refined: {}".format(avg_cls_acc_refined))
logger.info("avg_det_num: {}".format(avg_det_num))
total_gt_bbox = float(max(mdict['total_gt_bbox'], 1.0))
for idx, thresh in enumerate(thresh_list):
cur_roi_recall = mdict['total_roi_recalled_bbox_list_{}'.format(idx)] / total_gt_bbox
logger.info('total roi bbox recall(thresh=%.3f): %d / %d = %f' % (
thresh, mdict['total_roi_recalled_bbox_list_{}'.format(idx)], total_gt_bbox, cur_roi_recall))
for idx, thresh in enumerate(thresh_list):
cur_recall = mdict['total_recalled_bbox_list_{}'.format(idx)] / total_gt_bbox
logger.info('total bbox recall(thresh=%.2f) %d / %.2f = %.4f' % (
thresh, mdict['total_recalled_bbox_list_{}'.format(idx)], total_gt_bbox, cur_recall))
split_file = os.path.join('./data/KITTI', 'ImageSets', 'val.txt')
image_idx_list = [x.strip() for x in open(split_file).readlines()]
for k in range(image_idx_list.__len__()):
cur_file = os.path.join(final_output_dir, '%s.txt' % image_idx_list[k])
if not os.path.exists(cur_file):
with open(cur_file, 'w') as temp_f:
pass
if float(sys.version[:3]) >= 3.6:
label_dir = os.path.join('./data/KITTI/object/training', 'label_2')
split_file = os.path.join('./data/KITTI', 'ImageSets', 'val.txt')
final_output_dir = os.path.join("./result_dir", 'final_result', 'data')
name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
from tools.kitti_object_eval_python.evaluate import evaluate as kitti_evaluate
ap_result_str, ap_dict = kitti_evaluate(
label_dir, final_output_dir, label_split_file=split_file,
current_class=name_to_class["Car"])
logger.info("KITTI evaluate: {}, {}".format(ap_result_str, ap_dict))
else:
logger.info("KITTI mAP only support python version >= 3.6, users can "
"run 'python3 tools/kitti_eval.py' to evaluate KITTI mAP.")
finally:
eval_pyreader.reset()
if __name__ == "__main__":
eval()
../PointNet++/ext_op
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
__all__ = ["get_reg_loss"]
def sigmoid_focal_loss(logits, labels, weights, gamma=2.0, alpha=0.25):
sce_loss = fluid.layers.sigmoid_cross_entropy_with_logits(logits, labels)
prob = fluid.layers.sigmoid(logits)
p_t = labels * prob + (1.0 - labels) * (1.0 - prob)
modulating_factor = fluid.layers.pow(1.0 - p_t, gamma)
alpha_weight_factor = labels * alpha + (1.0 - labels) * (1.0 - alpha)
return modulating_factor * alpha_weight_factor * sce_loss * weights
def get_reg_loss(pred_reg, reg_label, fg_mask, point_num, loc_scope,
loc_bin_size, num_head_bin, anchor_size,
get_xz_fine=True, get_y_by_bin=False, loc_y_scope=0.5,
loc_y_bin_size=0.25, get_ry_fine=False):
"""
Bin-based 3D bounding boxes regression loss. See https://arxiv.org/abs/1812.04244 for more details.
:param pred_reg: (N, C)
:param reg_label: (N, 7) [dx, dy, dz, h, w, l, ry]
:param loc_scope: constant
:param loc_bin_size: constant
:param num_head_bin: constant
:param anchor_size: (N, 3) or (3)
:param get_xz_fine:
:param get_y_by_bin:
:param loc_y_scope:
:param loc_y_bin_size:
:param get_ry_fine:
:return:
"""
fg_num = fluid.layers.cast(fluid.layers.reduce_sum(fg_mask), dtype=pred_reg.dtype)
fg_num = fluid.layers.clip(fg_num, min=1.0, max=point_num)
fg_scale = float(point_num) / fg_num
per_loc_bin_num = int(loc_scope / loc_bin_size) * 2
loc_y_bin_num = int(loc_y_scope / loc_y_bin_size) * 2
reg_loss_dict = {}
# xz localization loss
x_offset_label, y_offset_label, z_offset_label = reg_label[:, 0:1], reg_label[:, 1:2], reg_label[:, 2:3]
x_shift = fluid.layers.clip(x_offset_label + loc_scope, 0., loc_scope * 2 - 1e-3)
z_shift = fluid.layers.clip(z_offset_label + loc_scope, 0., loc_scope * 2 - 1e-3)
x_bin_label = fluid.layers.cast(x_shift / loc_bin_size, dtype='int64')
z_bin_label = fluid.layers.cast(z_shift / loc_bin_size, dtype='int64')
x_bin_l, x_bin_r = 0, per_loc_bin_num
z_bin_l, z_bin_r = per_loc_bin_num, per_loc_bin_num * 2
start_offset = z_bin_r
loss_x_bin = fluid.layers.softmax_with_cross_entropy(pred_reg[:, x_bin_l: x_bin_r], x_bin_label)
loss_x_bin = fluid.layers.reduce_mean(loss_x_bin * fg_mask) * fg_scale
loss_z_bin = fluid.layers.softmax_with_cross_entropy(pred_reg[:, z_bin_l: z_bin_r], z_bin_label)
loss_z_bin = fluid.layers.reduce_mean(loss_z_bin * fg_mask) * fg_scale
reg_loss_dict['loss_x_bin'] = loss_x_bin
reg_loss_dict['loss_z_bin'] = loss_z_bin
loc_loss = loss_x_bin + loss_z_bin
if get_xz_fine:
x_res_l, x_res_r = per_loc_bin_num * 2, per_loc_bin_num * 3
z_res_l, z_res_r = per_loc_bin_num * 3, per_loc_bin_num * 4
start_offset = z_res_r
x_res_label = x_shift - (fluid.layers.cast(x_bin_label, dtype=x_shift.dtype) * loc_bin_size + loc_bin_size / 2.)
z_res_label = z_shift - (fluid.layers.cast(z_bin_label, dtype=z_shift.dtype) * loc_bin_size + loc_bin_size / 2.)
x_res_norm_label = x_res_label / loc_bin_size
z_res_norm_label = z_res_label / loc_bin_size
x_bin_onehot = fluid.layers.one_hot(x_bin_label, depth=per_loc_bin_num)
z_bin_onehot = fluid.layers.one_hot(z_bin_label, depth=per_loc_bin_num)
loss_x_res = fluid.layers.smooth_l1(fluid.layers.reduce_sum(pred_reg[:, x_res_l: x_res_r] * x_bin_onehot, dim=1, keep_dim=True), x_res_norm_label)
loss_x_res = fluid.layers.reduce_mean(loss_x_res * fg_mask) * fg_scale
loss_z_res = fluid.layers.smooth_l1(fluid.layers.reduce_sum(pred_reg[:, z_res_l: z_res_r] * z_bin_onehot, dim=1, keep_dim=True), z_res_norm_label)
loss_z_res = fluid.layers.reduce_mean(loss_z_res * fg_mask) * fg_scale
reg_loss_dict['loss_x_res'] = loss_x_res
reg_loss_dict['loss_z_res'] = loss_z_res
loc_loss += loss_x_res + loss_z_res
# y localization loss
if get_y_by_bin:
y_bin_l, y_bin_r = start_offset, start_offset + loc_y_bin_num
y_res_l, y_res_r = y_bin_r, y_bin_r + loc_y_bin_num
start_offset = y_res_r
y_shift = fluid.layers.clip(y_offset_label + loc_y_scope, 0., loc_y_scope * 2 - 1e-3)
y_bin_label = fluid.layers.cast(y_shift / loc_y_bin_size, dtype='int64')
y_res_label = y_shift - (fluid.layers.cast(y_bin_label, dtype=y_shift.dtype) * loc_y_bin_size + loc_y_bin_size / 2.)
y_res_norm_label = y_res_label / loc_y_bin_size
y_bin_onehot = fluid.layers.one_hot(y_bin_label, depth=per_loc_bin_num)
loss_y_bin = fluid.layers.cross_entropy(pred_reg[:, y_bin_l: y_bin_r], y_bin_label)
loss_y_bin = fluid.layers.reduce_mean(loss_y_bin * fg_mask) * fg_scale
loss_y_res = fluid.layers.smooth_l1(fluid.layers.reduce_sum(pred_reg[:, y_res_l: y_res_r] * y_bin_onehot, dim=1, keep_dim=True), y_res_norm_label)
loss_y_res = fluid.layers.reduce_mean(loss_y_res * fg_mask) * fg_scale
reg_loss_dict['loss_y_bin'] = loss_y_bin
reg_loss_dict['loss_y_res'] = loss_y_res
loc_loss += loss_y_bin + loss_y_res
else:
y_offset_l, y_offset_r = start_offset, start_offset + 1
start_offset = y_offset_r
loss_y_offset = fluid.layers.smooth_l1(fluid.layers.reduce_sum(pred_reg[:, y_offset_l: y_offset_r], dim=1, keep_dim=True), y_offset_label)
loss_y_offset = fluid.layers.reduce_mean(loss_y_offset * fg_mask) * fg_scale
reg_loss_dict['loss_y_offset'] = loss_y_offset
loc_loss += loss_y_offset
# angle loss
ry_bin_l, ry_bin_r = start_offset, start_offset + num_head_bin
ry_res_l, ry_res_r = ry_bin_r, ry_bin_r + num_head_bin
ry_label = reg_label[:, 6:7]
if get_ry_fine:
# divide pi/2 into several bins
angle_per_class = (np.pi / 2) / num_head_bin
ry_label = ry_label % (2 * np.pi) # 0 ~ 2pi
opposite_flag = fluid.layers.logical_and(ry_label > np.pi * 0.5, ry_label < np.pi * 1.5)
opposite_flag = fluid.layers.cast(opposite_flag, dtype=ry_label.dtype)
shift_angle = (ry_label + opposite_flag * np.pi + np.pi * 0.5) % (2 * np.pi) # (0 ~ pi)
shift_angle.stop_gradient = True
shift_angle = fluid.layers.clip(shift_angle - np.pi * 0.25, min=1e-3, max=np.pi * 0.5 - 1e-3) # (0, pi/2)
# bin center is (5, 10, 15, ..., 85)
ry_bin_label = fluid.layers.cast(shift_angle / angle_per_class, dtype='int64')
ry_res_label = shift_angle - (fluid.layers.cast(ry_bin_label, dtype=shift_angle.dtype) * angle_per_class + angle_per_class / 2)
ry_res_norm_label = ry_res_label / (angle_per_class / 2)
else:
# divide 2pi into several bins
angle_per_class = (2 * np.pi) / num_head_bin
heading_angle = ry_label % (2 * np.pi) # 0 ~ 2pi
shift_angle = (heading_angle + angle_per_class / 2) % (2 * np.pi)
shift_angle.stop_gradient = True
ry_bin_label = fluid.layers.cast(shift_angle / angle_per_class, dtype='int64')
ry_res_label = shift_angle - (fluid.layers.cast(ry_bin_label, dtype=shift_angle.dtype) * angle_per_class + angle_per_class / 2)
ry_res_norm_label = ry_res_label / (angle_per_class / 2)
ry_bin_onehot = fluid.layers.one_hot(ry_bin_label, depth=num_head_bin)
loss_ry_bin = fluid.layers.softmax_with_cross_entropy(pred_reg[:, ry_bin_l:ry_bin_r], ry_bin_label)
loss_ry_bin = fluid.layers.reduce_mean(loss_ry_bin * fg_mask) * fg_scale
loss_ry_res = fluid.layers.smooth_l1(fluid.layers.reduce_sum(pred_reg[:, ry_res_l: ry_res_r] * ry_bin_onehot, dim=1, keep_dim=True), ry_res_norm_label)
loss_ry_res = fluid.layers.reduce_mean(loss_ry_res * fg_mask) * fg_scale
reg_loss_dict['loss_ry_bin'] = loss_ry_bin
reg_loss_dict['loss_ry_res'] = loss_ry_res
angle_loss = loss_ry_bin + loss_ry_res
# size loss
size_res_l, size_res_r = ry_res_r, ry_res_r + 3
assert pred_reg.shape[1] == size_res_r, '%d vs %d' % (pred_reg.shape[1], size_res_r)
anchor_size_var = fluid.layers.zeros(shape=[3], dtype=reg_label.dtype)
fluid.layers.assign(np.array(anchor_size).astype('float32'), anchor_size_var)
size_res_norm_label = (reg_label[:, 3:6] - anchor_size_var) / anchor_size_var
size_res_norm_label = fluid.layers.reshape(size_res_norm_label, shape=[-1, 1], inplace=True)
size_res_norm = pred_reg[:, size_res_l:size_res_r]
size_res_norm = fluid.layers.reshape(size_res_norm, shape=[-1, 1], inplace=True)
size_loss = fluid.layers.smooth_l1(size_res_norm, size_res_norm_label)
size_loss = fluid.layers.reduce_mean(fluid.layers.reshape(size_loss, [-1, 3]) * fg_mask) * fg_scale
# Total regression loss
reg_loss_dict['loss_loc'] = loc_loss
reg_loss_dict['loss_angle'] = angle_loss
reg_loss_dict['loss_size'] = size_loss
return loc_loss, angle_loss, size_loss, reg_loss_dict
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from models.rpn import RPN
from models.rcnn import RCNN
__all__ = ["PointRCNN"]
class PointRCNN(object):
def __init__(self, cfg, batch_size, use_xyz=True, mode='TRAIN', prog=None):
self.cfg = cfg
self.batch_size = batch_size
self.use_xyz = use_xyz
self.mode = mode
self.is_train = mode == 'TRAIN'
self.num_points = self.cfg.RPN.NUM_POINTS
self.prog = prog
self.inputs = None
self.pyreader = None
def build_inputs(self):
self.inputs = OrderedDict()
if self.cfg.RPN.ENABLED:
self.inputs['sample_id'] = fluid.layers.data(name='sample_id', shape=[1], dtype='int32')
self.inputs['pts_input'] = fluid.layers.data(name='pts_input', shape=[self.num_points, 3], dtype='float32')
self.inputs['pts_rect'] = fluid.layers.data(name='pts_rect', shape=[self.num_points, 3], dtype='float32')
self.inputs['pts_features'] = fluid.layers.data(name='pts_features', shape=[self.num_points, 1], dtype='float32')
self.inputs['rpn_cls_label'] = fluid.layers.data(name='rpn_cls_label', shape=[self.num_points], dtype='int32')
self.inputs['rpn_reg_label'] = fluid.layers.data(name='rpn_reg_label', shape=[self.num_points, 7], dtype='float32')
self.inputs['gt_boxes3d'] = fluid.layers.data(name='gt_boxes3d', shape=[7], lod_level=1, dtype='float32')
if self.cfg.RCNN.ENABLED:
if self.cfg.RCNN.ROI_SAMPLE_JIT:
self.inputs['sample_id'] = fluid.layers.data(name='sample_id', shape=[1], dtype='int32', append_batch_size=False)
self.inputs['rpn_xyz'] = fluid.layers.data(name='rpn_xyz', shape=[self.num_points, 3], dtype='float32', append_batch_size=False)
self.inputs['rpn_features'] = fluid.layers.data(name='rpn_features', shape=[self.num_points,128], dtype='float32', append_batch_size=False)
self.inputs['rpn_intensity'] = fluid.layers.data(name='rpn_intensity', shape=[self.num_points], dtype='float32', append_batch_size=False)
self.inputs['seg_mask'] = fluid.layers.data(name='seg_mask', shape=[self.num_points], dtype='float32', append_batch_size=False)
self.inputs['roi_boxes3d'] = fluid.layers.data(name='roi_boxes3d', shape=[-1, -1, 7], dtype='float32', append_batch_size=False, lod_level=0)
self.inputs['pts_depth'] = fluid.layers.data(name='pts_depth', shape=[self.num_points], dtype='float32', append_batch_size=False)
self.inputs['gt_boxes3d'] = fluid.layers.data(name='gt_boxes3d', shape=[-1, -1, 7], dtype='float32', append_batch_size=False, lod_level=0)
else:
self.inputs['sample_id'] = fluid.layers.data(name='sample_id', shape=[-1], dtype='int32', append_batch_size=False)
self.inputs['pts_input'] = fluid.layers.data(name='pts_input', shape=[-1,512,133], dtype='float32', append_batch_size=False)
self.inputs['pts_feature'] = fluid.layers.data(name='pts_feature', shape=[-1,512,128], dtype='float32', append_batch_size=False)
self.inputs['roi_boxes3d'] = fluid.layers.data(name='roi_boxes3d', shape=[-1,7], dtype='float32', append_batch_size=False)
if self.is_train:
self.inputs['cls_label'] = fluid.layers.data(name='cls_label', shape=[-1], dtype='float32', append_batch_size=False)
self.inputs['reg_valid_mask'] = fluid.layers.data(name='reg_valid_mask', shape=[-1], dtype='float32', append_batch_size=False)
self.inputs['gt_boxes3d_ct'] = fluid.layers.data(name='gt_boxes3d_ct', shape=[-1,7], dtype='float32', append_batch_size=False)
self.inputs['gt_of_rois'] = fluid.layers.data(name='gt_of_rois', shape=[-1,7], dtype='float32', append_batch_size=False)
else:
self.inputs['roi_scores'] = fluid.layers.data(name='roi_scores', shape=[-1,], dtype='float32', append_batch_size=False)
self.inputs['gt_iou'] = fluid.layers.data(name='gt_iou', shape=[-1], dtype='float32', append_batch_size=False)
self.inputs['gt_boxes3d'] = fluid.layers.data(name='gt_boxes3d', shape=[-1,-1,7], dtype='float32', append_batch_size=False, lod_level=0)
self.pyreader = fluid.io.PyReader(
feed_list=list(self.inputs.values()),
capacity=64,
use_double_buffer=True,
iterable=False)
def build(self):
self.build_inputs()
if self.cfg.RPN.ENABLED:
self.rpn = RPN(self.cfg, self.batch_size, self.use_xyz,
self.mode, self.prog)
self.rpn.build(self.inputs)
self.rpn_outputs = self.rpn.get_outputs()
self.outputs = self.rpn_outputs
if self.cfg.RCNN.ENABLED:
self.rcnn = RCNN(self.cfg, 1, self.batch_size, self.mode)
self.rcnn.build_model(self.inputs)
self.outputs = self.rcnn.get_outputs()
if self.mode == 'TRAIN':
if self.cfg.RPN.ENABLED:
self.outputs['rpn_loss'], self.outputs['rpn_loss_cls'], \
self.outputs['rpn_loss_reg'] = self.rpn.get_loss()
if self.cfg.RCNN.ENABLED:
self.outputs['rcnn_loss'], self.outputs['rcnn_loss_cls'], \
self.outputs['rcnn_loss_reg'] = self.rcnn.get_loss()
self.outputs['loss'] = self.outputs.get('rpn_loss', 0.) \
+ self.outputs.get('rcnn_loss', 0.)
def get_feeds(self):
return list(self.inputs.keys())
def get_outputs(self):
return self.outputs
def get_loss(self):
rpn_loss, _, _ = self.rpn.get_loss()
rcnn_loss, _, _ = self.rcnn.get_loss()
return rpn_loss + rcnn_loss
def get_pyreader(self):
return self.pyreader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Contains PointNet++ utility functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from ext_op import *
__all__ = ["conv_bn", "pointnet_sa_module", "pointnet_fp_module", "MLP"]
def query_and_group(xyz, new_xyz, radius, nsample, features=None, use_xyz=True):
"""
Perform query_ball and group_points
Args:
xyz (Variable): xyz coordiantes features with shape [B, N, 3]
new_xyz (Variable): centriods features with shape [B, npoint, 3]
radius (float32): radius of ball
nsample (int32): maximum number of gather features
features (Variable): features with shape [B, N, C]
use_xyz (bool): whether use xyz coordiantes features
Returns:
out (Variable): features with shape [B, npoint, nsample, C + 3]
"""
idx = query_ball(xyz, new_xyz, radius, nsample)
idx.stop_gradient = True
xyz = fluid.layers.transpose(xyz,perm=[0, 2, 1])
grouped_xyz = group_points(xyz, idx)
expand_new_xyz = fluid.layers.unsqueeze(fluid.layers.transpose(new_xyz, perm=[0, 2, 1]), axes=[-1])
expand_new_xyz = fluid.layers.expand(expand_new_xyz, [1, 1, 1, grouped_xyz.shape[3]])
grouped_xyz -= expand_new_xyz
if features is not None:
grouped_features = group_points(features, idx)
return fluid.layers.concat([grouped_xyz, grouped_features], axis=1) \
if use_xyz else grouped_features
else:
assert use_xyz, "use_xyz should be True when features is None"
return grouped_xyz
def group_all(xyz, features=None, use_xyz=True):
"""
Group all xyz and features when npoint is None
See query_and_group
"""
xyz = fluid.layers.transpose(xyz,perm=[0, 2, 1])
grouped_xyz = fluid.layers.unsqueeze(xyz, axes=[2])
if features is not None:
grouped_features = fluid.layers.unsqueeze(features, axes=[2])
return fluid.layers.concat([grouped_xyz, grouped_features], axis=1) if use_xyz else grouped_features
else:
return grouped_xyz
def conv_bn(input, out_channels, bn=True, bn_momentum=0.95, act='relu', name=None):
param_attr = ParamAttr(name='{}_conv_weight'.format(name),)
bias_attr = ParamAttr(name='{}_conv_bias'.format(name)) \
if not bn else False
out = fluid.layers.conv2d(input,
num_filters=out_channels,
filter_size=1,
stride=1,
padding=0,
dilation=1,
param_attr=param_attr,
bias_attr=bias_attr,
act=act if not bn else None)
if bn:
bn_name = name + "_bn"
out = fluid.layers.batch_norm(out,
act=act,
momentum=bn_momentum,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_var')
return out
def MLP(features, out_channels_list, bn=True, bn_momentum=0.95, act='relu', name=None):
out = features
for i, out_channels in enumerate(out_channels_list):
out = conv_bn(out, out_channels, bn=bn, act=act, bn_momentum=bn_momentum, name=name + "_{}".format(i))
return out
def pointnet_sa_module(xyz,
npoint=None,
radiuss=[],
nsamples=[],
mlps=[],
feature=None,
bn=True,
bn_momentum=0.95,
use_xyz=True,
name=None):
"""
PointNet MSG(Multi-Scale Group) Set Abstraction Module.
Call with radiuss, nsamples, mlps as single element list for
SSG(Single-Scale Group).
Args:
xyz (Variable): xyz coordiantes features with shape [B, N, 3]
radiuss ([float32]): list of radius of ball
nsamples ([int32]): list of maximum number of gather features
mlps ([[int32]]): list of out_channels_list
feature (Variable): features with shape [B, C, N]
bn (bool): whether perform batch norm after conv2d
bn_momentum (float): momentum of batch norm
use_xyz (bool): whether use xyz coordiantes features
Returns:
new_xyz (Variable): centriods features with shape [B, npoint, 3]
out (Variable): features with shape [B, npoint, \sum_i{mlps[i][-1]}]
"""
assert len(radiuss) == len(nsamples) == len(mlps), \
"radiuss, nsamples, mlps length should be same"
farthest_idx = farthest_point_sampling(xyz, npoint)
farthest_idx.stop_gradient = True
new_xyz = gather_point(xyz, farthest_idx) if npoint is not None else None
outs = []
for i, (radius, nsample, mlp) in enumerate(zip(radiuss, nsamples, mlps)):
out = query_and_group(xyz, new_xyz, radius, nsample, feature, use_xyz) if npoint is not None else group_all(xyz, feature, use_xyz)
out = MLP(out, mlp, bn=bn, bn_momentum=bn_momentum, name=name + '_mlp{}'.format(i))
out = fluid.layers.pool2d(out, pool_size=[1, out.shape[3]], pool_type='max')
out = fluid.layers.squeeze(out, axes=[-1])
outs.append(out)
out = fluid.layers.concat(outs, axis=1)
return (new_xyz, out)
def pointnet_fp_module(unknown, known, unknown_feats, known_feats, mlp, bn=True, bn_momentum=0.95, name=None):
"""
PointNet Feature Propagation Module
Args:
unknown (Variable): unknown xyz coordiantes features with shape [B, N, 3]
known (Variable): known xyz coordiantes features with shape [B, M, 3]
unknown_feats (Variable): unknown features with shape [B, N, C1] to be propagated to
known_feats (Variable): known features with shape [B, M, C2] to be propagated from
mlp ([int32]): out_channels_list
bn (bool): whether perform batch norm after conv2d
Returns:
new_features (Variable): new features with shape [B, N, mlp[-1]]
"""
if known is None:
raise NotImplementedError("Not implement known as None currently.")
else:
dist, idx = three_nn(unknown, known, eps=0.)
dist.stop_gradient = True
idx.stop_gradient = True
dist = fluid.layers.sqrt(dist)
ones = fluid.layers.fill_constant_batch_size_like(dist, dist.shape, dist.dtype, 1)
dist_recip = ones / (dist + 1e-8); # 1.0 / dist
norm = fluid.layers.reduce_sum(dist_recip, dim=-1, keep_dim=True)
weight = dist_recip / norm
weight.stop_gradient = True
interp_feats = three_interp(known_feats, weight, idx)
new_features = interp_feats if unknown_feats is None else \
fluid.layers.concat([interp_feats, unknown_feats], axis=-1)
new_features = fluid.layers.transpose(new_features, perm=[0, 2, 1])
new_features = fluid.layers.unsqueeze(new_features, axes=[-1])
new_features = MLP(new_features, mlp, bn=bn, bn_momentum=bn_momentum, name=name + '_mlp')
new_features = fluid.layers.squeeze(new_features, axes=[-1])
new_features = fluid.layers.transpose(new_features, perm=[0, 2, 1])
return new_features
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Contains PointNet++ SSG/MSG semantic segmentation models
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from models.pointnet2_modules import *
__all__ = ["PointNet2MSG"]
class PointNet2MSG(object):
def __init__(self, cfg, xyz, feature=None, use_xyz=True):
self.cfg = cfg
self.xyz = xyz
self.feature = feature
self.use_xyz = use_xyz
self.model_config()
def model_config(self):
self.SA_confs = []
for i in range(self.cfg.RPN.SA_CONFIG.NPOINTS.__len__()):
self.SA_confs.append({
"npoint": self.cfg.RPN.SA_CONFIG.NPOINTS[i],
"radiuss": self.cfg.RPN.SA_CONFIG.RADIUS[i],
"nsamples": self.cfg.RPN.SA_CONFIG.NSAMPLE[i],
"mlps": self.cfg.RPN.SA_CONFIG.MLPS[i],
})
self.FP_confs = []
for i in range(self.cfg.RPN.FP_MLPS.__len__()):
self.FP_confs.append({"mlp": self.cfg.RPN.FP_MLPS[i]})
def build(self, bn_momentum=0.95):
xyzs, features = [self.xyz], [self.feature]
xyzi, featurei = self.xyz, self.feature
for i, SA_conf in enumerate(self.SA_confs):
xyzi, featurei = pointnet_sa_module(
xyz=xyzi,
feature=featurei,
bn_momentum=bn_momentum,
use_xyz=self.use_xyz,
name="sa_{}".format(i),
**SA_conf)
xyzs.append(xyzi)
features.append(fluid.layers.transpose(featurei, perm=[0, 2, 1]))
for i in range(-1, -(len(self.FP_confs) + 1), -1):
features[i - 1] = pointnet_fp_module(
unknown=xyzs[i - 1],
known=xyzs[i],
unknown_feats=features[i - 1],
known_feats=features[i],
bn_momentum=bn_momentum,
name="fp_{}".format(i + len(self.FP_confs)),
**self.FP_confs[i])
return xyzs[0], features[0]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sys
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from models.pointnet2_modules import MLP, pointnet_sa_module, conv_bn
from models.loss_utils import sigmoid_focal_loss , get_reg_loss
from utils.proposal_target import get_proposal_target_func
from utils.cyops.kitti_utils import rotate_pc_along_y
__all__ = ['RCNN']
class RCNN(object):
def __init__(self, cfg, num_classes, batch_size, mode='TRAIN', use_xyz=True, input_channels=0):
self.cfg = cfg
self.use_xyz = use_xyz
self.num_classes = num_classes
self.input_channels = input_channels
self.inputs = None
self.training = mode == 'TRAIN'
self.batch_size = batch_size
def create_tmp_var(self, name, dtype, shape):
return fluid.default_main_program().current_block().create_var(
name=name, dtype=dtype, shape=shape
)
def build_model(self, inputs):
self.inputs = inputs
if self.cfg.RCNN.ROI_SAMPLE_JIT:
if self.training:
proposal_target = get_proposal_target_func(self.cfg)
tmp_list = [
self.inputs['seg_mask'],
self.inputs['rpn_features'],
self.inputs['gt_boxes3d'],
self.inputs['rpn_xyz'],
self.inputs['pts_depth'],
self.inputs['roi_boxes3d'],
self.inputs['rpn_intensity'],
]
out_name = ['reg_valid_mask' ,'sampled_pts' ,'roi_boxes3d', 'gt_of_rois', 'pts_feature' ,'cls_label','gt_iou']
reg_valid_mask = self.create_tmp_var(name="reg_valid_mask",dtype='float32',shape=[-1,])
sampled_pts = self.create_tmp_var(name="sampled_pts",dtype='float32',shape=[-1, self.cfg.RCNN.NUM_POINTS, 3])
new_roi_boxes3d = self.create_tmp_var(name="new_roi_boxes3d",dtype='float32',shape=[-1, 7])
gt_of_rois = self.create_tmp_var(name="gt_of_rois", dtype='float32', shape=[-1,7])
pts_feature = self.create_tmp_var(name="pts_feature", dtype='float32',shape=[-1,512,130])
cls_label = self.create_tmp_var(name="cls_label",dtype='int64',shape=[-1])
gt_iou = self.create_tmp_var(name="gt_iou",dtype='float32',shape=[-1])
out_list = [reg_valid_mask, sampled_pts, new_roi_boxes3d, gt_of_rois, pts_feature, cls_label, gt_iou]
out = fluid.layers.py_func(func=proposal_target,x=tmp_list,out=out_list)
self.target_dict = {}
for i,item in enumerate(out):
self.target_dict[out_name[i]] = item
pts = fluid.layers.concat(input=[self.target_dict['sampled_pts'],self.target_dict['pts_feature']], axis=2)
self.debug = pts
self.target_dict['pts_input'] = pts
else:
rpn_xyz, rpn_features = inputs['rpn_xyz'], inputs['rpn_features']
batch_rois = inputs['roi_boxes3d']
rpn_intensity = inputs['rpn_intensity']
rpn_intensity = fluid.layers.unsqueeze(rpn_intensity,axes=[2])
seg_mask = fluid.layers.unsqueeze(inputs['seg_mask'],axes=[2])
if self.cfg.RCNN.USE_INTENSITY:
pts_extra_input_list = [rpn_intensity, seg_mask]
else:
pts_extra_input_list = [seg_mask]
if self.cfg.RCNN.USE_DEPTH:
pts_depth = inputs['pts_depth'] / 70.0 -0.5
pts_depth = fluid.layers.unsqueeze(pts_depth,axes=[2])
pts_extra_input_list.append(pts_depth)
pts_extra_input = fluid.layers.concat(pts_extra_input_list, axis=2)
pts_feature = fluid.layers.concat([pts_extra_input, rpn_features],axis=2)
pooled_features, pooled_empty_flag = fluid.layers.roi_pool_3d(rpn_xyz,pts_feature,batch_rois,
self.cfg.RCNN.POOL_EXTRA_WIDTH,
sampled_pt_num=self.cfg.RCNN.NUM_POINTS)
# canonical transformation
batch_size = batch_rois.shape[0]
roi_center = batch_rois[:, :, 0:3]
tmp = pooled_features[:, :, :, 0:3] - fluid.layers.unsqueeze(roi_center,axes=[2])
pooled_features = fluid.layers.concat(input=[tmp,pooled_features[:,:,:,3:]],axis=3)
concat_list = []
for i in range(batch_size):
tmp = rotate_pc_along_y(pooled_features[i, :, :, 0:3],
batch_rois[i, :, 6])
concat = fluid.layers.concat([tmp,pooled_features[i,:,:,3:]],axis=-1)
concat = fluid.layers.unsqueeze(concat,axes=[0])
concat_list.append(concat)
pooled_features = fluid.layers.concat(concat_list,axis=0)
pts = fluid.layers.reshape(pooled_features,shape=[-1,pooled_features.shape[2],pooled_features.shape[3]])
else:
pts = inputs['pts_input']
self.target_dict = {}
self.target_dict['pts_input'] = inputs['pts_input']
self.target_dict['roi_boxes3d'] = inputs['roi_boxes3d']
if self.training:
self.target_dict['cls_label'] = inputs['cls_label']
self.target_dict['reg_valid_mask'] = inputs['reg_valid_mask']
self.target_dict['gt_of_rois'] = inputs['gt_boxes3d_ct']
xyz = pts[:,:,0:3]
feature = fluid.layers.transpose(pts[:,:,3:], [0,2,1]) if pts.shape[-1]>3 else None
if self.cfg.RCNN.USE_RPN_FEATURES:
self.rcnn_input_channel = 3 + int(self.cfg.RCNN.USE_INTENSITY) + \
int(self.cfg.RCNN.USE_MASK) + int(self.cfg.RCNN.USE_DEPTH)
c_out = self.cfg.RCNN.XYZ_UP_LAYER[-1]
xyz_input = pts[:,:,:self.rcnn_input_channel]
xyz_input = fluid.layers.transpose(xyz_input, [0,2,1])
xyz_input = fluid.layers.unsqueeze(xyz_input, axes=[3])
rpn_feature = pts[:,:,self.rcnn_input_channel:]
rpn_feature = fluid.layers.transpose(rpn_feature, [0,2,1])
rpn_feature = fluid.layers.unsqueeze(rpn_feature,axes=[3])
xyz_feature = MLP(
xyz_input,
out_channels_list=self.cfg.RCNN.XYZ_UP_LAYER,
bn=self.cfg.RCNN.USE_BN,
name="xyz_up_layer")
merged_feature = fluid.layers.concat([xyz_feature, rpn_feature],axis=1)
merged_feature = MLP(
merged_feature,
out_channels_list=[c_out],
bn=self.cfg.RCNN.USE_BN,
name="xyz_down_layer")
xyzs = [xyz]
features = [fluid.layers.squeeze(merged_feature,axes=[3])]
else:
xyzs = [xyz]
features = [feature]
# forward
xyzi, featurei = xyzs[-1], features[-1]
for k in range(len(self.cfg.RCNN.SA_CONFIG.NPOINTS)):
mlps = self.cfg.RCNN.SA_CONFIG.MLPS[k]
npoint = self.cfg.RCNN.SA_CONFIG.NPOINTS[k] if self.cfg.RCNN.SA_CONFIG.NPOINTS[k] != -1 else None
xyzi, featurei = pointnet_sa_module(
xyz=xyzi,
feature = featurei,
bn = self.cfg.RCNN.USE_BN,
use_xyz = self.use_xyz,
name = "sa_{}".format(k),
npoint = npoint,
mlps = [mlps],
radiuss = [self.cfg.RCNN.SA_CONFIG.RADIUS[k]],
nsamples = [self.cfg.RCNN.SA_CONFIG.NSAMPLE[k]]
)
xyzs.append(xyzi)
features.append(featurei)
head_in = features[-1]
head_in = fluid.layers.unsqueeze(head_in, axes=[2])
cls_out = head_in
reg_out = cls_out
for i in range(0, self.cfg.RCNN.CLS_FC.__len__()):
cls_out = conv_bn(cls_out, self.cfg.RCNN.CLS_FC[i], bn=self.cfg.RCNN.USE_BN, name='rcnn_cls_{}'.format(i))
if i == 0 and self.cfg.RCNN.DP_RATIO >= 0:
cls_out = fluid.layers.dropout(cls_out, self.cfg.RCNN.DP_RATIO, dropout_implementation="upscale_in_train")
cls_channel = 1 if self.num_classes == 2 else self.num_classes
cls_out = conv_bn(cls_out, cls_channel, act=None, name="cls_out", bn=self.cfg.RCNN.USE_BN)
self.cls_out = fluid.layers.squeeze(cls_out,axes=[1,3])
per_loc_bin_num = int(self.cfg.RCNN.LOC_SCOPE / self.cfg.RCNN.LOC_BIN_SIZE) * 2
loc_y_bin_num = int(self.cfg.RCNN.LOC_Y_SCOPE / self.cfg.RCNN.LOC_Y_BIN_SIZE) * 2
reg_channel = per_loc_bin_num * 4 + self.cfg.RCNN.NUM_HEAD_BIN * 2 + 3
reg_channel += (1 if not self.cfg.RCNN.LOC_Y_BY_BIN else loc_y_bin_num * 2)
for i in range(0, self.cfg.RCNN.REG_FC.__len__()):
reg_out = conv_bn(reg_out, self.cfg.RCNN.REG_FC[i], bn=self.cfg.RCNN.USE_BN, name='rcnn_reg_{}'.format(i))
if i == 0 and self.cfg.RCNN.DP_RATIO >= 0:
reg_out = fluid.layers.dropout(reg_out, self.cfg.RCNN.DP_RATIO, dropout_implementation="upscale_in_train")
reg_out = conv_bn(reg_out, reg_channel, act=None, name="reg_out", bn=self.cfg.RCNN.USE_BN)
self.reg_out = fluid.layers.squeeze(reg_out, axes=[2,3])
self.outputs = {
'rcnn_cls':self.cls_out,
'rcnn_reg':self.reg_out,
}
if self.training:
self.outputs.update(self.target_dict)
elif not self.training:
self.outputs['sample_id'] = inputs['sample_id']
self.outputs['pts_input'] = inputs['pts_input']
self.outputs['roi_boxes3d'] = inputs['roi_boxes3d']
self.outputs['roi_scores'] = inputs['roi_scores']
self.outputs['gt_iou'] = inputs['gt_iou']
self.outputs['gt_boxes3d'] = inputs['gt_boxes3d']
if self.cls_out.shape[1] == 1:
raw_scores = fluid.layers.reshape(self.cls_out, shape=[-1])
norm_scores = fluid.layers.sigmoid(raw_scores)
else:
norm_scores = fluid.layers.softmax(self.cls_out, axis=1)
self.outputs['norm_scores'] = norm_scores
def get_outputs(self):
return self.outputs
def get_loss(self):
assert self.inputs is not None, \
"please call build() first"
rcnn_cls_label = self.outputs['cls_label']
reg_valid_mask = self.outputs['reg_valid_mask']
roi_boxes3d = self.outputs['roi_boxes3d']
roi_size = roi_boxes3d[:, 3:6]
gt_boxes3d_ct = self.outputs['gt_of_rois']
pts_input = self.outputs['pts_input']
rcnn_cls = self.cls_out
rcnn_reg = self.reg_out
# RCNN classification loss
assert self.cfg.RCNN.LOSS_CLS in ["SigmoidFocalLoss", "BinaryCrossEntropy"], \
"unsupported RCNN cls loss type {}".format(self.cfg.RCNN.LOSS_CLS)
if self.cfg.RCNN.LOSS_CLS == "SigmoidFocalLoss":
cls_flat = fluid.layers.reshape(self.cls_out, shape=[-1])
cls_label_flat = fluid.layers.reshape(rcnn_cls_label, shape=[-1])
cls_label_flat = fluid.layers.cast(cls_label_flat, dtype=cls_flat.dtype)
cls_target = fluid.layers.cast(cls_label_flat>0, dtype=cls_flat.dtype)
cls_label_flat.stop_gradient = True
pos = fluid.layers.cast(cls_label_flat > 0, dtype=cls_flat.dtype)
pos.stop_gradient = True
pos_normalizer = fluid.layers.reduce_sum(pos)
cls_weights = fluid.layers.cast(cls_label_flat >= 0, dtype=cls_flat.dtype)
cls_weights = cls_weights / fluid.layers.clip(pos_normalizer, min=1.0, max=1e10)
cls_weights.stop_gradient = True
rcnn_loss_cls = sigmoid_focal_loss(cls_flat, cls_target, cls_weights)
rcnn_loss_cls = fluid.layers.reduce_sum(rcnn_loss_cls)
else: # BinaryCrossEntropy
cls_label = fluid.layers.reshape(rcnn_cls_label, shape=self.cls_out.shape)
cls_valid_mask = fluid.layers.cast(cls_label >= 0, dtype=self.cls_out.dtype)
cls_label = fluid.layers.cast(cls_label, dtype=self.cls_out.dtype)
cls_label.stop_gradient = True
rcnn_loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_out, cls_label)
cls_mask_normalzer = fluid.layers.reduce_sum(cls_valid_mask)
rcnn_loss_cls = fluid.layers.reduce_sum(rcnn_loss_cls * cls_valid_mask) \
/ fluid.layers.clip(cls_mask_normalzer, min=1.0, max=1e10)
# RCNN regression loss
reg_out = self.reg_out
fg_mask = fluid.layers.cast(reg_valid_mask > 0, dtype=reg_out.dtype)
fg_mask.stop_gradient = True
gt_boxes3d_ct = fluid.layers.reshape(gt_boxes3d_ct, [-1,7])
all_anchor_size = roi_size
anchor_size = all_anchor_size[fg_mask] if self.cfg.RCNN.SIZE_RES_ON_ROI else self.cfg.CLS_MEAN_SIZE[0]
loc_loss, angle_loss, size_loss, loss_dict = get_reg_loss(
reg_out * fg_mask,
gt_boxes3d_ct,
fg_mask,
point_num=float(self.batch_size*64),
loc_scope=self.cfg.RCNN.LOC_SCOPE,
loc_bin_size=self.cfg.RCNN.LOC_BIN_SIZE,
num_head_bin=self.cfg.RCNN.NUM_HEAD_BIN,
anchor_size=anchor_size,
get_xz_fine=True,
get_y_by_bin=self.cfg.RCNN.LOC_Y_BY_BIN,
loc_y_scope=self.cfg.RCNN.LOC_Y_SCOPE,
loc_y_bin_size=self.cfg.RCNN.LOC_Y_BIN_SIZE,
get_ry_fine=True
)
rcnn_loss_reg = loc_loss + angle_loss + size_loss * 3
rcnn_loss = rcnn_loss_cls + rcnn_loss_reg
return rcnn_loss, rcnn_loss_cls, rcnn_loss_reg
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant
from utils.proposal_utils import get_proposal_func
from models.pointnet2_msg import PointNet2MSG
from models.pointnet2_modules import conv_bn
from models.loss_utils import sigmoid_focal_loss, get_reg_loss
__all__ = ["RPN"]
class RPN(object):
def __init__(self, cfg, batch_size, use_xyz=True, mode='TRAIN', prog=None):
self.cfg = cfg
self.batch_size = batch_size
self.use_xyz = use_xyz
self.mode = mode
self.is_train = mode == 'TRAIN'
self.inputs = None
self.prog = fluid.default_main_program() if prog is None else prog
def build(self, inputs):
assert self.cfg.RPN.BACKBONE == 'pointnet2_msg', \
"RPN backbone only support pointnet2_msg"
self.inputs = inputs
self.outputs = {}
xyz = inputs["pts_input"]
assert not self.cfg.RPN.USE_INTENSITY, \
"RPN.USE_INTENSITY not support now"
feature = None
msg = PointNet2MSG(self.cfg, xyz, feature, self.use_xyz)
backbone_xyz, backbone_feature = msg.build()
self.outputs['backbone_xyz'] = backbone_xyz
self.outputs['backbone_feature'] = backbone_feature
backbone_feature = fluid.layers.transpose(backbone_feature, perm=[0, 2, 1])
cls_out = fluid.layers.unsqueeze(backbone_feature, axes=[-1])
reg_out = cls_out
# classification branch
for i in range(self.cfg.RPN.CLS_FC.__len__()):
cls_out = conv_bn(cls_out, self.cfg.RPN.CLS_FC[i], bn=self.cfg.RPN.USE_BN, name='rpn_cls_{}'.format(i))
if i == 0 and self.cfg.RPN.DP_RATIO > 0:
cls_out = fluid.layers.dropout(cls_out, self.cfg.RPN.DP_RATIO, dropout_implementation="upscale_in_train")
cls_out = fluid.layers.conv2d(cls_out,
num_filters=1,
filter_size=1,
stride=1,
padding=0,
dilation=1,
param_attr=ParamAttr(name='rpn_cls_out_conv_weight'),
bias_attr=ParamAttr(name='rpn_cls_out_conv_bias',
initializer=Constant(-np.log(99))))
cls_out = fluid.layers.squeeze(cls_out, axes=[1, 3])
self.outputs['rpn_cls'] = cls_out
# regression branch
per_loc_bin_num = int(self.cfg.RPN.LOC_SCOPE / self.cfg.RPN.LOC_BIN_SIZE) * 2
if self.cfg.RPN.LOC_XZ_FINE:
reg_channel = per_loc_bin_num * 4 + self.cfg.RPN.NUM_HEAD_BIN * 2 + 3
else:
reg_channel = per_loc_bin_num * 2 + self.cfg.RPN.NUM_HEAD_BIN * 2 + 3
reg_channel += 1 # reg y
for i in range(self.cfg.RPN.REG_FC.__len__()):
reg_out = conv_bn(reg_out, self.cfg.RPN.REG_FC[i], bn=self.cfg.RPN.USE_BN, name='rpn_reg_{}'.format(i))
if i == 0 and self.cfg.RPN.DP_RATIO > 0:
reg_out = fluid.layers.dropout(reg_out, self.cfg.RPN.DP_RATIO, dropout_implementation="upscale_in_train")
reg_out = fluid.layers.conv2d(reg_out,
num_filters=reg_channel,
filter_size=1,
stride=1,
padding=0,
dilation=1,
param_attr=ParamAttr(name='rpn_reg_out_conv_weight',
initializer=Normal(0., 0.001),),
bias_attr=ParamAttr(name='rpn_reg_out_conv_bias'))
reg_out = fluid.layers.squeeze(reg_out, axes=[3])
reg_out = fluid.layers.transpose(reg_out, [0, 2, 1])
self.outputs['rpn_reg'] = reg_out
if self.mode != 'TRAIN' or self.cfg.RCNN.ENABLED:
rpn_scores_row = cls_out
rpn_scores_norm = fluid.layers.sigmoid(rpn_scores_row)
seg_mask = fluid.layers.cast(rpn_scores_norm > self.cfg.RPN.SCORE_THRESH, dtype='float32')
pts_depth = fluid.layers.sqrt(fluid.layers.reduce_sum(backbone_xyz * backbone_xyz, dim=2))
proposal_func = get_proposal_func(self.cfg, self.mode)
proposal_input = fluid.layers.concat([fluid.layers.unsqueeze(rpn_scores_row, axes=[-1]),
backbone_xyz, reg_out], axis=-1)
proposal = self.prog.current_block().create_var(name='proposal',
shape=[-1, proposal_input.shape[1], 8],
dtype='float32')
fluid.layers.py_func(proposal_func, proposal_input, proposal)
rois, roi_scores_row = proposal[:, :, :7], proposal[:, :, -1]
self.outputs['rois'] = rois
self.outputs['roi_scores_row'] = roi_scores_row
self.outputs['seg_mask'] = seg_mask
self.outputs['pts_depth'] = pts_depth
def get_outputs(self):
return self.outputs
def get_loss(self):
assert self.inputs is not None, \
"please call build() first"
rpn_cls_label = self.inputs['rpn_cls_label']
rpn_reg_label = self.inputs['rpn_reg_label']
rpn_cls = self.outputs['rpn_cls']
rpn_reg = self.outputs['rpn_reg']
# RPN classification loss
assert self.cfg.RPN.LOSS_CLS == "SigmoidFocalLoss", \
"unsupported RPN cls loss type {}".format(self.cfg.RPN.LOSS_CLS)
cls_flat = fluid.layers.reshape(rpn_cls, shape=[-1])
cls_label_flat = fluid.layers.reshape(rpn_cls_label, shape=[-1])
cls_label_pos = fluid.layers.cast(cls_label_flat > 0, dtype=cls_flat.dtype)
pos_normalizer = fluid.layers.reduce_sum(cls_label_pos)
cls_weights = fluid.layers.cast(cls_label_flat >= 0, dtype=cls_flat.dtype)
cls_weights = cls_weights / fluid.layers.clip(pos_normalizer, min=1.0, max=1e10)
cls_weights.stop_gradient = True
cls_label_flat = fluid.layers.cast(cls_label_flat, dtype=cls_flat.dtype)
cls_label_flat.stop_gradient = True
rpn_loss_cls = sigmoid_focal_loss(cls_flat, cls_label_pos, cls_weights)
rpn_loss_cls = fluid.layers.reduce_sum(rpn_loss_cls)
# RPN regression loss
rpn_reg = fluid.layers.reshape(rpn_reg, [-1, rpn_reg.shape[-1]])
reg_label = fluid.layers.reshape(rpn_reg_label, [-1, rpn_reg_label.shape[-1]])
fg_mask = fluid.layers.cast(cls_label_flat > 0, dtype=rpn_reg.dtype)
fg_mask.stop_gradient = True
loc_loss, angle_loss, size_loss, loss_dict = get_reg_loss(
rpn_reg * fg_mask, reg_label, fg_mask,
float(self.batch_size * self.cfg.RPN.NUM_POINTS),
loc_scope=self.cfg.RPN.LOC_SCOPE,
loc_bin_size=self.cfg.RPN.LOC_BIN_SIZE,
num_head_bin=self.cfg.RPN.NUM_HEAD_BIN,
anchor_size=self.cfg.CLS_MEAN_SIZE[0],
get_xz_fine=self.cfg.RPN.LOC_XZ_FINE,
get_y_by_bin=False,
get_ry_fine=False)
rpn_loss_reg = loc_loss + angle_loss + size_loss * 3
self.rpn_loss = rpn_loss_cls * self.cfg.RPN.LOSS_WEIGHT[0] + rpn_loss_reg * self.cfg.RPN.LOSS_WEIGHT[1]
return self.rpn_loss, rpn_loss_cls, rpn_loss_reg
Cython
opencv-python
shapely
scikit-image
Numba
fire
"""
Generate GT database
This code is based on https://github.com/sshaoshuai/PointRCNN/blob/master/tools/generate_aug_scene.py
"""
import os
import numpy as np
import pickle
import pts_utils
import utils.cyops.kitti_utils as kitti_utils
from utils.box_utils import boxes_iou3d
from utils import calibration as calib
from data.kitti_dataset import KittiDataset
import argparse
np.random.seed(1024)
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='generator')
parser.add_argument('--class_name', type=str, default='Car')
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--save_dir', type=str, default='./data/KITTI/aug_scene/training')
parser.add_argument('--split', type=str, default='train')
parser.add_argument('--gt_database_dir', type=str, default='./data/gt_database/train_gt_database_3level_Car.pkl')
parser.add_argument('--include_similar', action='store_true', default=False)
parser.add_argument('--aug_times', type=int, default=4)
args = parser.parse_args()
PC_REDUCE_BY_RANGE = True
if args.class_name == 'Car':
PC_AREA_SCOPE = np.array([[-40, 40], [-1, 3], [0, 70.4]]) # x, y, z scope in rect camera coords
else:
PC_AREA_SCOPE = np.array([[-30, 30], [-1, 3], [0, 50]])
def log_print(info, fp=None):
print(info)
if fp is not None:
# print(info, file=fp)
fp.write(info+"\n")
def save_kitti_format(calib, bbox3d, obj_list, img_shape, save_fp):
corners3d = kitti_utils.boxes3d_to_corners3d(bbox3d)
img_boxes, _ = calib.corners3d_to_img_boxes(corners3d)
img_boxes[:, 0] = np.clip(img_boxes[:, 0], 0, img_shape[1] - 1)
img_boxes[:, 1] = np.clip(img_boxes[:, 1], 0, img_shape[0] - 1)
img_boxes[:, 2] = np.clip(img_boxes[:, 2], 0, img_shape[1] - 1)
img_boxes[:, 3] = np.clip(img_boxes[:, 3], 0, img_shape[0] - 1)
# Discard boxes that are larger than 80% of the image width OR height
img_boxes_w = img_boxes[:, 2] - img_boxes[:, 0]
img_boxes_h = img_boxes[:, 3] - img_boxes[:, 1]
box_valid_mask = np.logical_and(img_boxes_w < img_shape[1] * 0.8, img_boxes_h < img_shape[0] * 0.8)
for k in range(bbox3d.shape[0]):
if box_valid_mask[k] == 0:
continue
x, z, ry = bbox3d[k, 0], bbox3d[k, 2], bbox3d[k, 6]
beta = np.arctan2(z, x)
alpha = -np.sign(beta) * np.pi / 2 + beta + ry
save_fp.write('%s %.2f %d %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f\n' %
(args.class_name, obj_list[k].trucation, int(obj_list[k].occlusion), alpha, img_boxes[k, 0], img_boxes[k, 1],
img_boxes[k, 2], img_boxes[k, 3],
bbox3d[k, 3], bbox3d[k, 4], bbox3d[k, 5], bbox3d[k, 0], bbox3d[k, 1], bbox3d[k, 2],
bbox3d[k, 6]))
class AugSceneGenerator(KittiDataset):
def __init__(self, root_dir, gt_database=None, split='train', classes=args.class_name):
super(AugSceneGenerator, self).__init__(root_dir, split=split)
self.gt_database = None
if classes == 'Car':
self.classes = ('Background', 'Car')
elif classes == 'People':
self.classes = ('Background', 'Pedestrian', 'Cyclist')
elif classes == 'Pedestrian':
self.classes = ('Background', 'Pedestrian')
elif classes == 'Cyclist':
self.classes = ('Background', 'Cyclist')
else:
assert False, "Invalid classes: %s" % classes
self.gt_database = gt_database
def __len__(self):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError
def filtrate_dc_objects(self, obj_list):
valid_obj_list = []
for obj in obj_list:
if obj.cls_type in ['DontCare']:
continue
valid_obj_list.append(obj)
return valid_obj_list
def filtrate_objects(self, obj_list):
valid_obj_list = []
type_whitelist = self.classes
if args.include_similar:
type_whitelist = list(self.classes)
if 'Car' in self.classes:
type_whitelist.append('Van')
if 'Pedestrian' in self.classes or 'Cyclist' in self.classes:
type_whitelist.append('Person_sitting')
for obj in obj_list:
if obj.cls_type in type_whitelist:
valid_obj_list.append(obj)
return valid_obj_list
@staticmethod
def get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape):
"""
Valid point should be in the image (and in the PC_AREA_SCOPE)
:param pts_rect:
:param pts_img:
:param pts_rect_depth:
:param img_shape:
:return:
"""
val_flag_1 = np.logical_and(pts_img[:, 0] >= 0, pts_img[:, 0] < img_shape[1])
val_flag_2 = np.logical_and(pts_img[:, 1] >= 0, pts_img[:, 1] < img_shape[0])
val_flag_merge = np.logical_and(val_flag_1, val_flag_2)
pts_valid_flag = np.logical_and(val_flag_merge, pts_rect_depth >= 0)
if PC_REDUCE_BY_RANGE:
x_range, y_range, z_range = PC_AREA_SCOPE
pts_x, pts_y, pts_z = pts_rect[:, 0], pts_rect[:, 1], pts_rect[:, 2]
range_flag = (pts_x >= x_range[0]) & (pts_x <= x_range[1]) \
& (pts_y >= y_range[0]) & (pts_y <= y_range[1]) \
& (pts_z >= z_range[0]) & (pts_z <= z_range[1])
pts_valid_flag = pts_valid_flag & range_flag
return pts_valid_flag
@staticmethod
def check_pc_range(xyz):
"""
:param xyz: [x, y, z]
:return:
"""
x_range, y_range, z_range = PC_AREA_SCOPE
if (x_range[0] <= xyz[0] <= x_range[1]) and (y_range[0] <= xyz[1] <= y_range[1]) and \
(z_range[0] <= xyz[2] <= z_range[1]):
return True
return False
def aug_one_scene(self, sample_id, pts_rect, pts_intensity, all_gt_boxes3d):
"""
:param pts_rect: (N, 3)
:param gt_boxes3d: (M1, 7)
:param all_gt_boxex3d: (M2, 7)
:return:
"""
assert self.gt_database is not None
extra_gt_num = np.random.randint(10, 15)
try_times = 50
cnt = 0
cur_gt_boxes3d = all_gt_boxes3d.copy()
cur_gt_boxes3d[:, 4] += 0.5
cur_gt_boxes3d[:, 5] += 0.5 # enlarge new added box to avoid too nearby boxes
extra_gt_obj_list = []
extra_gt_boxes3d_list = []
new_pts_list, new_pts_intensity_list = [], []
src_pts_flag = np.ones(pts_rect.shape[0], dtype=np.int32)
road_plane = self.get_road_plane(sample_id)
a, b, c, d = road_plane
while try_times > 0:
try_times -= 1
rand_idx = np.random.randint(0, self.gt_database.__len__() - 1)
new_gt_dict = self.gt_database[rand_idx]
new_gt_box3d = new_gt_dict['gt_box3d'].copy()
new_gt_points = new_gt_dict['points'].copy()
new_gt_intensity = new_gt_dict['intensity'].copy()
new_gt_obj = new_gt_dict['obj']
center = new_gt_box3d[0:3]
if PC_REDUCE_BY_RANGE and (self.check_pc_range(center) is False):
continue
if cnt > extra_gt_num:
break
if new_gt_points.__len__() < 5: # too few points
continue
# put it on the road plane
cur_height = (-d - a * center[0] - c * center[2]) / b
move_height = new_gt_box3d[1] - cur_height
new_gt_box3d[1] -= move_height
new_gt_points[:, 1] -= move_height
cnt += 1
iou3d = boxes_iou3d(new_gt_box3d.reshape(1, 7), cur_gt_boxes3d)
valid_flag = iou3d.max() < 1e-8
if not valid_flag:
continue
enlarged_box3d = new_gt_box3d.copy()
enlarged_box3d[3] += 2 # remove the points above and below the object
boxes_pts_mask_list = pts_utils.pts_in_boxes3d(pts_rect, enlarged_box3d.reshape(1, 7))
pt_mask_flag = (boxes_pts_mask_list[0] == 1)
src_pts_flag[pt_mask_flag] = 0 # remove the original points which are inside the new box
new_pts_list.append(new_gt_points)
new_pts_intensity_list.append(new_gt_intensity)
enlarged_box3d = new_gt_box3d.copy()
enlarged_box3d[4] += 0.5
enlarged_box3d[5] += 0.5 # enlarge new added box to avoid too nearby boxes
cur_gt_boxes3d = np.concatenate((cur_gt_boxes3d, enlarged_box3d.reshape(1, 7)), axis=0)
extra_gt_boxes3d_list.append(new_gt_box3d.reshape(1, 7))
extra_gt_obj_list.append(new_gt_obj)
if new_pts_list.__len__() == 0:
return False, pts_rect, pts_intensity, None, None
extra_gt_boxes3d = np.concatenate(extra_gt_boxes3d_list, axis=0)
# remove original points and add new points
pts_rect = pts_rect[src_pts_flag == 1]
pts_intensity = pts_intensity[src_pts_flag == 1]
new_pts_rect = np.concatenate(new_pts_list, axis=0)
new_pts_intensity = np.concatenate(new_pts_intensity_list, axis=0)
pts_rect = np.concatenate((pts_rect, new_pts_rect), axis=0)
pts_intensity = np.concatenate((pts_intensity, new_pts_intensity), axis=0)
return True, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list
def aug_one_epoch_scene(self, base_id, data_save_dir, label_save_dir, split_list, log_fp=None):
for idx, sample_id in enumerate(self.image_idx_list):
sample_id = int(sample_id)
print('process gt sample (%s, id=%06d)' % (args.split, sample_id))
pts_lidar = self.get_lidar(sample_id)
calib = self.get_calib(sample_id)
pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
pts_img, pts_rect_depth = calib.rect_to_img(pts_rect)
img_shape = self.get_image_shape(sample_id)
pts_valid_flag = self.get_valid_flag(pts_rect, pts_img, pts_rect_depth, img_shape)
pts_rect = pts_rect[pts_valid_flag][:, 0:3]
pts_intensity = pts_lidar[pts_valid_flag][:, 3]
# all labels for checking overlapping
all_obj_list = self.filtrate_dc_objects(self.get_label(sample_id))
all_gt_boxes3d = np.zeros((all_obj_list.__len__(), 7), dtype=np.float32)
for k, obj in enumerate(all_obj_list):
all_gt_boxes3d[k, 0:3], all_gt_boxes3d[k, 3], all_gt_boxes3d[k, 4], all_gt_boxes3d[k, 5], \
all_gt_boxes3d[k, 6] = obj.pos, obj.h, obj.w, obj.l, obj.ry
# gt_boxes3d of current label
obj_list = self.filtrate_objects(self.get_label(sample_id))
if args.class_name != 'Car' and obj_list.__len__() == 0:
continue
# augment one scene
aug_flag, pts_rect, pts_intensity, extra_gt_boxes3d, extra_gt_obj_list = \
self.aug_one_scene(sample_id, pts_rect, pts_intensity, all_gt_boxes3d)
# save augment result to file
pts_info = np.concatenate((pts_rect, pts_intensity.reshape(-1, 1)), axis=1)
bin_file = os.path.join(data_save_dir, '%06d.bin' % (base_id + sample_id))
pts_info.astype(np.float32).tofile(bin_file)
# save filtered original gt_boxes3d
label_save_file = os.path.join(label_save_dir, '%06d.txt' % (base_id + sample_id))
with open(label_save_file, 'w') as f:
for obj in obj_list:
f.write(obj.to_kitti_format() + '\n')
if aug_flag:
# augment successfully
save_kitti_format(calib, extra_gt_boxes3d, extra_gt_obj_list, img_shape=img_shape, save_fp=f)
else:
extra_gt_boxes3d = np.zeros((0, 7), dtype=np.float32)
log_print('Save to file (new_obj: %s): %s' % (extra_gt_boxes3d.__len__(), label_save_file), fp=log_fp)
split_list.append('%06d' % (base_id + sample_id))
def generate_aug_scene(self, aug_times, log_fp=None):
data_save_dir = os.path.join(args.save_dir, 'rectified_data')
label_save_dir = os.path.join(args.save_dir, 'aug_label')
if not os.path.isdir(data_save_dir):
os.makedirs(data_save_dir)
if not os.path.isdir(label_save_dir):
os.makedirs(label_save_dir)
split_file = os.path.join(args.save_dir, '%s_aug.txt' % args.split)
split_list = self.image_idx_list[:]
for epoch in range(aug_times):
base_id = (epoch + 1) * 10000
self.aug_one_epoch_scene(base_id, data_save_dir, label_save_dir, split_list, log_fp=log_fp)
with open(split_file, 'w') as f:
for idx, sample_id in enumerate(split_list):
f.write(str(sample_id) + '\n')
log_print('Save split file to %s' % split_file, fp=log_fp)
target_dir = os.path.join(args.data_dir, 'KITTI/ImageSets/')
os.system('cp %s %s' % (split_file, target_dir))
log_print('Copy split file from %s to %s' % (split_file, target_dir), fp=log_fp)
if __name__ == '__main__':
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
info_file = os.path.join(args.save_dir, 'log_info.txt')
if args.mode == 'generator':
log_fp = open(info_file, 'w')
gt_database = pickle.load(open(args.gt_database_dir, 'rb'))
log_print('Loading gt_database(%d) from %s' % (gt_database.__len__(), args.gt_database_dir), fp=log_fp)
dataset = AugSceneGenerator(root_dir=args.data_dir, gt_database=gt_database, split=args.split)
dataset.generate_aug_scene(aug_times=args.aug_times, log_fp=log_fp)
log_fp.close()
else:
pass
"""
Generate GT database
This code is based on https://github.com/sshaoshuai/PointRCNN/blob/master/tools/generate_gt_database.py
"""
import os
import numpy as np
import pickle
from data.kitti_dataset import KittiDataset
import pts_utils
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--save_dir', type=str, default='./data/gt_database')
parser.add_argument('--class_name', type=str, default='Car')
parser.add_argument('--split', type=str, default='train')
args = parser.parse_args()
class GTDatabaseGenerator(KittiDataset):
def __init__(self, root_dir, split='train', classes=args.class_name):
super(GTDatabaseGenerator, self).__init__(root_dir, split=split)
self.gt_database = None
if classes == 'Car':
self.classes = ('Background', 'Car')
elif classes == 'People':
self.classes = ('Background', 'Pedestrian', 'Cyclist')
elif classes == 'Pedestrian':
self.classes = ('Background', 'Pedestrian')
elif classes == 'Cyclist':
self.classes = ('Background', 'Cyclist')
else:
assert False, "Invalid classes: %s" % classes
def __len__(self):
raise NotImplementedError
def __getitem__(self, item):
raise NotImplementedError
def filtrate_objects(self, obj_list):
valid_obj_list = []
for obj in obj_list:
if obj.cls_type not in self.classes:
continue
if obj.level_str not in ['Easy', 'Moderate', 'Hard']:
continue
valid_obj_list.append(obj)
return valid_obj_list
def generate_gt_database(self):
gt_database = []
for idx, sample_id in enumerate(self.image_idx_list):
sample_id = int(sample_id)
print('process gt sample (id=%06d)' % sample_id)
pts_lidar = self.get_lidar(sample_id)
calib = self.get_calib(sample_id)
pts_rect = calib.lidar_to_rect(pts_lidar[:, 0:3])
pts_intensity = pts_lidar[:, 3]
obj_list = self.filtrate_objects(self.get_label(sample_id))
gt_boxes3d = np.zeros((obj_list.__len__(), 7), dtype=np.float32)
for k, obj in enumerate(obj_list):
gt_boxes3d[k, 0:3], gt_boxes3d[k, 3], gt_boxes3d[k, 4], gt_boxes3d[k, 5], gt_boxes3d[k, 6] \
= obj.pos, obj.h, obj.w, obj.l, obj.ry
if gt_boxes3d.__len__() == 0:
print('No gt object')
continue
boxes_pts_mask_list = pts_utils.pts_in_boxes3d(pts_rect, gt_boxes3d)
for k in range(boxes_pts_mask_list.shape[0]):
pt_mask_flag = (boxes_pts_mask_list[k] == 1)
cur_pts = pts_rect[pt_mask_flag].astype(np.float32)
cur_pts_intensity = pts_intensity[pt_mask_flag].astype(np.float32)
sample_dict = {'sample_id': sample_id,
'cls_type': obj_list[k].cls_type,
'gt_box3d': gt_boxes3d[k],
'points': cur_pts,
'intensity': cur_pts_intensity,
'obj': obj_list[k]}
gt_database.append(sample_dict)
save_file_name = os.path.join(args.save_dir, '%s_gt_database_3level_%s.pkl' % (args.split, self.classes[-1]))
with open(save_file_name, 'wb') as f:
pickle.dump(gt_database, f)
self.gt_database = gt_database
print('Save refine training sample info file to %s' % save_file_name)
if __name__ == '__main__':
dataset = GTDatabaseGenerator(root_dir=args.data_dir, split=args.split)
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
dataset.generate_gt_database()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
def parse_args():
parser = argparse.ArgumentParser(
"KITTI mAP evaluation script")
parser.add_argument(
'--result_dir',
type=str,
default='./result_dir',
help='detection result directory to evaluate')
parser.add_argument(
'--data_dir',
type=str,
default='./data',
help='KITTI dataset root directory')
parser.add_argument(
'--split',
type=str,
default='val',
help='evaluation split, default val')
parser.add_argument(
'--class_name',
type=str,
default='Car',
help='evaluation class name, default Car')
args = parser.parse_args()
return args
def kitti_eval():
if float(sys.version[:3]) < 3.6:
print("KITTI mAP evaluation can only run with python3.6+")
sys.exit(1)
args = parse_args()
label_dir = os.path.join(args.data_dir, 'KITTI/object/training', 'label_2')
split_file = os.path.join(args.data_dir, 'KITTI/ImageSets',
'{}.txt'.format(args.split))
final_output_dir = os.path.join(args.result_dir, 'final_result', 'data')
name_to_class = {'Car': 0, 'Pedestrian': 1, 'Cyclist': 2}
from tools.kitti_object_eval_python.evaluate import evaluate as kitti_evaluate
ap_result_str, ap_dict = kitti_evaluate(
label_dir, final_output_dir, label_split_file=split_file,
current_class=name_to_class[args.class_name])
print("KITTI evaluate: ", ap_result_str, ap_dict)
if __name__ == "__main__":
kitti_eval()
MIT License
Copyright (c) 2018
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# kitti-object-eval-python
**NOTE**: This is borrowed from [traveller59/kitti-object-eval-python](https://github.com/traveller59/kitti-object-eval-python)
Fast kitti object detection eval in python(finish eval in less than 10 second), support 2d/bev/3d/aos. , support coco-style AP. If you use command line interface, numba need some time to compile jit functions.
## Dependencies
Only support python 3.6+, need `numpy`, `skimage`, `numba`, `fire`. If you have Anaconda, just install `cudatoolkit` in anaconda. Otherwise, please reference to this [page](https://github.com/numba/numba#custom-python-environments) to set up llvm and cuda for numba.
* Install by conda:
```
conda install -c numba cudatoolkit=x.x (8.0, 9.0, 9.1, depend on your environment)
```
## Usage
* commandline interface:
```
python evaluate.py evaluate --label_path=/path/to/your_gt_label_folder --result_path=/path/to/your_result_folder --label_split_file=/path/to/val.txt --current_class=0 --coco=False
```
* python interface:
```Python
import kitti_common as kitti
from eval import get_official_eval_result, get_coco_eval_result
def _read_imageset_file(path):
with open(path, 'r') as f:
lines = f.readlines()
return [int(line) for line in lines]
det_path = "/path/to/your_result_folder"
dt_annos = kitti.get_label_annos(det_path)
gt_path = "/path/to/your_gt_label_folder"
gt_split_file = "/path/to/val.txt" # from https://xiaozhichen.github.io/files/mv3d/imagesets.tar.gz
val_image_ids = _read_imageset_file(gt_split_file)
gt_annos = kitti.get_label_annos(gt_path, val_image_ids)
print(get_official_eval_result(gt_annos, dt_annos, 0)) # 6s in my computer
print(get_coco_eval_result(gt_annos, dt_annos, 0)) # 18s in my computer
```
import time
import fire
import tools.kitti_object_eval_python.kitti_common as kitti
from tools.kitti_object_eval_python.eval import get_official_eval_result, get_coco_eval_result
def _read_imageset_file(path):
with open(path, 'r') as f:
lines = f.readlines()
return [int(line) for line in lines]
def evaluate(label_path,
result_path,
label_split_file,
current_class=0,
coco=False,
score_thresh=-1):
dt_annos = kitti.get_label_annos(result_path)
if score_thresh > 0:
dt_annos = kitti.filter_annos_low_score(dt_annos, score_thresh)
val_image_ids = _read_imageset_file(label_split_file)
gt_annos = kitti.get_label_annos(label_path, val_image_ids)
if coco:
return get_coco_eval_result(gt_annos, dt_annos, current_class)
else:
return get_official_eval_result(gt_annos, dt_annos, current_class)
if __name__ == '__main__':
fire.Fire()
#####################
# Based on https://github.com/hongzhenwang/RRPN-revise
# Licensed under The MIT License
# Author: yanyan, scrin@foxmail.com
#####################
import math
import numba
import numpy as np
from numba import cuda
@numba.jit(nopython=True)
def div_up(m, n):
return m // n + (m % n > 0)
@cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True)
def trangle_area(a, b, c):
return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) *
(b[0] - c[0])) / 2.0
@cuda.jit('(float32[:], int32)', device=True, inline=True)
def area(int_pts, num_of_inter):
area_val = 0.0
for i in range(num_of_inter - 2):
area_val += abs(
trangle_area(int_pts[:2], int_pts[2 * i + 2:2 * i + 4],
int_pts[2 * i + 4:2 * i + 6]))
return area_val
@cuda.jit('(float32[:], int32)', device=True, inline=True)
def sort_vertex_in_convex_polygon(int_pts, num_of_inter):
if num_of_inter > 0:
center = cuda.local.array((2, ), dtype=numba.float32)
center[:] = 0.0
for i in range(num_of_inter):
center[0] += int_pts[2 * i]
center[1] += int_pts[2 * i + 1]
center[0] /= num_of_inter
center[1] /= num_of_inter
v = cuda.local.array((2, ), dtype=numba.float32)
vs = cuda.local.array((16, ), dtype=numba.float32)
for i in range(num_of_inter):
v[0] = int_pts[2 * i] - center[0]
v[1] = int_pts[2 * i + 1] - center[1]
d = math.sqrt(v[0] * v[0] + v[1] * v[1])
v[0] = v[0] / d
v[1] = v[1] / d
if v[1] < 0:
v[0] = -2 - v[0]
vs[i] = v[0]
j = 0
temp = 0
for i in range(1, num_of_inter):
if vs[i - 1] > vs[i]:
temp = vs[i]
tx = int_pts[2 * i]
ty = int_pts[2 * i + 1]
j = i
while j > 0 and vs[j - 1] > temp:
vs[j] = vs[j - 1]
int_pts[j * 2] = int_pts[j * 2 - 2]
int_pts[j * 2 + 1] = int_pts[j * 2 - 1]
j -= 1
vs[j] = temp
int_pts[j * 2] = tx
int_pts[j * 2 + 1] = ty
@cuda.jit(
'(float32[:], float32[:], int32, int32, float32[:])',
device=True,
inline=True)
def line_segment_intersection(pts1, pts2, i, j, temp_pts):
A = cuda.local.array((2, ), dtype=numba.float32)
B = cuda.local.array((2, ), dtype=numba.float32)
C = cuda.local.array((2, ), dtype=numba.float32)
D = cuda.local.array((2, ), dtype=numba.float32)
A[0] = pts1[2 * i]
A[1] = pts1[2 * i + 1]
B[0] = pts1[2 * ((i + 1) % 4)]
B[1] = pts1[2 * ((i + 1) % 4) + 1]
C[0] = pts2[2 * j]
C[1] = pts2[2 * j + 1]
D[0] = pts2[2 * ((j + 1) % 4)]
D[1] = pts2[2 * ((j + 1) % 4) + 1]
BA0 = B[0] - A[0]
BA1 = B[1] - A[1]
DA0 = D[0] - A[0]
CA0 = C[0] - A[0]
DA1 = D[1] - A[1]
CA1 = C[1] - A[1]
acd = DA1 * CA0 > CA1 * DA0
bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0])
if acd != bcd:
abc = CA1 * BA0 > BA1 * CA0
abd = DA1 * BA0 > BA1 * DA0
if abc != abd:
DC0 = D[0] - C[0]
DC1 = D[1] - C[1]
ABBA = A[0] * B[1] - B[0] * A[1]
CDDC = C[0] * D[1] - D[0] * C[1]
DH = BA1 * DC0 - BA0 * DC1
Dx = ABBA * DC0 - BA0 * CDDC
Dy = ABBA * DC1 - BA1 * CDDC
temp_pts[0] = Dx / DH
temp_pts[1] = Dy / DH
return True
return False
@cuda.jit(
'(float32[:], float32[:], int32, int32, float32[:])',
device=True,
inline=True)
def line_segment_intersection_v1(pts1, pts2, i, j, temp_pts):
a = cuda.local.array((2, ), dtype=numba.float32)
b = cuda.local.array((2, ), dtype=numba.float32)
c = cuda.local.array((2, ), dtype=numba.float32)
d = cuda.local.array((2, ), dtype=numba.float32)
a[0] = pts1[2 * i]
a[1] = pts1[2 * i + 1]
b[0] = pts1[2 * ((i + 1) % 4)]
b[1] = pts1[2 * ((i + 1) % 4) + 1]
c[0] = pts2[2 * j]
c[1] = pts2[2 * j + 1]
d[0] = pts2[2 * ((j + 1) % 4)]
d[1] = pts2[2 * ((j + 1) % 4) + 1]
area_abc = trangle_area(a, b, c)
area_abd = trangle_area(a, b, d)
if area_abc * area_abd >= 0:
return False
area_cda = trangle_area(c, d, a)
area_cdb = area_cda + area_abc - area_abd
if area_cda * area_cdb >= 0:
return False
t = area_cda / (area_abd - area_abc)
dx = t * (b[0] - a[0])
dy = t * (b[1] - a[1])
temp_pts[0] = a[0] + dx
temp_pts[1] = a[1] + dy
return True
@cuda.jit('(float32, float32, float32[:])', device=True, inline=True)
def point_in_quadrilateral(pt_x, pt_y, corners):
ab0 = corners[2] - corners[0]
ab1 = corners[3] - corners[1]
ad0 = corners[6] - corners[0]
ad1 = corners[7] - corners[1]
ap0 = pt_x - corners[0]
ap1 = pt_y - corners[1]
abab = ab0 * ab0 + ab1 * ab1
abap = ab0 * ap0 + ab1 * ap1
adad = ad0 * ad0 + ad1 * ad1
adap = ad0 * ap0 + ad1 * ap1
return abab >= abap and abap >= 0 and adad >= adap and adap >= 0
@cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True)
def quadrilateral_intersection(pts1, pts2, int_pts):
num_of_inter = 0
for i in range(4):
if point_in_quadrilateral(pts1[2 * i], pts1[2 * i + 1], pts2):
int_pts[num_of_inter * 2] = pts1[2 * i]
int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1]
num_of_inter += 1
if point_in_quadrilateral(pts2[2 * i], pts2[2 * i + 1], pts1):
int_pts[num_of_inter * 2] = pts2[2 * i]
int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1]
num_of_inter += 1
temp_pts = cuda.local.array((2, ), dtype=numba.float32)
for i in range(4):
for j in range(4):
has_pts = line_segment_intersection(pts1, pts2, i, j, temp_pts)
if has_pts:
int_pts[num_of_inter * 2] = temp_pts[0]
int_pts[num_of_inter * 2 + 1] = temp_pts[1]
num_of_inter += 1
return num_of_inter
@cuda.jit('(float32[:], float32[:])', device=True, inline=True)
def rbbox_to_corners(corners, rbbox):
# generate clockwise corners and rotate it clockwise
angle = rbbox[4]
a_cos = math.cos(angle)
a_sin = math.sin(angle)
center_x = rbbox[0]
center_y = rbbox[1]
x_d = rbbox[2]
y_d = rbbox[3]
corners_x = cuda.local.array((4, ), dtype=numba.float32)
corners_y = cuda.local.array((4, ), dtype=numba.float32)
corners_x[0] = -x_d / 2
corners_x[1] = -x_d / 2
corners_x[2] = x_d / 2
corners_x[3] = x_d / 2
corners_y[0] = -y_d / 2
corners_y[1] = y_d / 2
corners_y[2] = y_d / 2
corners_y[3] = -y_d / 2
for i in range(4):
corners[2 *
i] = a_cos * corners_x[i] + a_sin * corners_y[i] + center_x
corners[2 * i
+ 1] = -a_sin * corners_x[i] + a_cos * corners_y[i] + center_y
@cuda.jit('(float32[:], float32[:])', device=True, inline=True)
def inter(rbbox1, rbbox2):
corners1 = cuda.local.array((8, ), dtype=numba.float32)
corners2 = cuda.local.array((8, ), dtype=numba.float32)
intersection_corners = cuda.local.array((16, ), dtype=numba.float32)
rbbox_to_corners(corners1, rbbox1)
rbbox_to_corners(corners2, rbbox2)
num_intersection = quadrilateral_intersection(corners1, corners2,
intersection_corners)
sort_vertex_in_convex_polygon(intersection_corners, num_intersection)
# print(intersection_corners.reshape([-1, 2])[:num_intersection])
return area(intersection_corners, num_intersection)
@cuda.jit('(float32[:], float32[:], int32)', device=True, inline=True)
def devRotateIoUEval(rbox1, rbox2, criterion=-1):
area1 = rbox1[2] * rbox1[3]
area2 = rbox2[2] * rbox2[3]
area_inter = inter(rbox1, rbox2)
if criterion == -1:
return area_inter / (area1 + area2 - area_inter)
elif criterion == 0:
return area_inter / area1
elif criterion == 1:
return area_inter / area2
else:
return area_inter
@cuda.jit('(int64, int64, float32[:], float32[:], float32[:], int32)', fastmath=False)
def rotate_iou_kernel_eval(N, K, dev_boxes, dev_query_boxes, dev_iou, criterion=-1):
threadsPerBlock = 8 * 8
row_start = cuda.blockIdx.x
col_start = cuda.blockIdx.y
tx = cuda.threadIdx.x
row_size = min(N - row_start * threadsPerBlock, threadsPerBlock)
col_size = min(K - col_start * threadsPerBlock, threadsPerBlock)
block_boxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32)
block_qboxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32)
dev_query_box_idx = threadsPerBlock * col_start + tx
dev_box_idx = threadsPerBlock * row_start + tx
if (tx < col_size):
block_qboxes[tx * 5 + 0] = dev_query_boxes[dev_query_box_idx * 5 + 0]
block_qboxes[tx * 5 + 1] = dev_query_boxes[dev_query_box_idx * 5 + 1]
block_qboxes[tx * 5 + 2] = dev_query_boxes[dev_query_box_idx * 5 + 2]
block_qboxes[tx * 5 + 3] = dev_query_boxes[dev_query_box_idx * 5 + 3]
block_qboxes[tx * 5 + 4] = dev_query_boxes[dev_query_box_idx * 5 + 4]
if (tx < row_size):
block_boxes[tx * 5 + 0] = dev_boxes[dev_box_idx * 5 + 0]
block_boxes[tx * 5 + 1] = dev_boxes[dev_box_idx * 5 + 1]
block_boxes[tx * 5 + 2] = dev_boxes[dev_box_idx * 5 + 2]
block_boxes[tx * 5 + 3] = dev_boxes[dev_box_idx * 5 + 3]
block_boxes[tx * 5 + 4] = dev_boxes[dev_box_idx * 5 + 4]
cuda.syncthreads()
if tx < row_size:
for i in range(col_size):
offset = row_start * threadsPerBlock * K + col_start * threadsPerBlock + tx * K + i
dev_iou[offset] = devRotateIoUEval(block_qboxes[i * 5:i * 5 + 5],
block_boxes[tx * 5:tx * 5 + 5], criterion)
def rotate_iou_gpu_eval(boxes, query_boxes, criterion=-1, device_id=0):
"""rotated box iou running in gpu. 500x faster than cpu version
(take 5ms in one example with numba.cuda code).
convert from [this project](
https://github.com/hongzhenwang/RRPN-revise/tree/master/lib/rotation).
Args:
boxes (float tensor: [N, 5]): rbboxes. format: centers, dims,
angles(clockwise when positive)
query_boxes (float tensor: [K, 5]): [description]
device_id (int, optional): Defaults to 0. [description]
Returns:
[type]: [description]
"""
box_dtype = boxes.dtype
boxes = boxes.astype(np.float32)
query_boxes = query_boxes.astype(np.float32)
N = boxes.shape[0]
K = query_boxes.shape[0]
iou = np.zeros((N, K), dtype=np.float32)
if N == 0 or K == 0:
return iou
threadsPerBlock = 8 * 8
cuda.select_device(device_id)
blockspergrid = (div_up(N, threadsPerBlock), div_up(K, threadsPerBlock))
stream = cuda.stream()
with stream.auto_synchronize():
boxes_dev = cuda.to_device(boxes.reshape([-1]), stream)
query_boxes_dev = cuda.to_device(query_boxes.reshape([-1]), stream)
iou_dev = cuda.to_device(iou.reshape([-1]), stream)
rotate_iou_kernel_eval[blockspergrid, threadsPerBlock, stream](
N, K, boxes_dev, query_boxes_dev, iou_dev, criterion)
iou_dev.copy_to_host(iou.reshape([-1]), stream=stream)
return iou.astype(boxes.dtype)
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
cmake_minimum_required(VERSION 2.8.12)
project(pts_utils)
add_subdirectory(pybind11)
pybind11_add_module(pts_utils pts_utils.cpp)
from setuptools import setup
from setuptools import Extension
setup(
name='pts_utils',
ext_modules = [Extension(
name='pts_utils',
sources=['pts_utils.cpp'],
include_dirs=[r'../../pybind11/include'],
extra_compile_args=['-std=c++11']
)],
)
import numpy as np
import pts_utils
a = np.random.random((16384, 3)).astype('float32')
b = np.random.random((64, 7)).astype('float32')
c = pts_utils.pts_in_boxes3d(a, b)
print(a, b, c, c.shape, np.sum(c))
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册