未验证 提交 129ddbb2 编写于 作者: K Kehan Yin 提交者: GitHub

【Hackathon + No.161】论文复现:CLRNet: Cross Layer Refinement Network for Lane Detection (#8278)

* Feat add CLRNet

* Update CLRNet README.md

* Update requirements: add imgaug>=0.4.0 for CLRNet

* Update CLRNet README.md

* Update README.md

* Update Rename clrnet_utils.py

* Update CLRNet demo & delete demo result

* Update README.md add weight for culane

* Update README.cn.md add training logs

* Feat add dataset download

* Fix bugs when lanes is empty

* Update README

* Update README for dataset info

* Fix export model

* Update configs & README

* style: update codestyle

* Style update op codestyple

* Fix eval process

* Fix eval process

* Update README&configs

* Fix deploy infer

* Fix mkdir in lane visualize

* Docs Update README

* Docs Rename configs

* Docs update weights

---------
Co-authored-by: NLokeZhou <aishenghuoaiqq@163.com>
上级 a694be1e
简体中文 | [English](README.md)
# CLRNet (CLRNet: Cross Layer Refinement Network for Lane Detection)
## 目录
- [简介](#简介)
- [模型库](#模型库)
- [引用](#引用)
## 介绍
[CLRNet](https://arxiv.org/abs/2203.10350)是一个车道线检测模型。CLRNet模型设计了车道线检测的直线先验轨迹,车道线iou以及nms方法,融合提取车道线轨迹的上下文高层特征与底层特征,利用FPN多尺度进行refine,在车道线检测相关数据集取得了SOTA的性能。
## 模型库
### CLRNet在CUlane上结果
| 骨架网络 | mF1 | F1@50 | F1@75 | 下载链接 | 配置文件 |训练日志|
| :--------------| :------- | :----: | :------: | :----: |:-----: |:-----: |
| ResNet-18 | 54.98 | 79.46 | 62.10 | [下载链接](https://paddledet.bj.bcebos.com/models/clrnet_resnet18_culane.pdparams) | [配置文件](./clrnet_resnet18_culane.yml) |[训练日志](https://bj.bcebos.com/v1/paddledet/logs/train_clrnet_r18_15_culane.log)|
### 数据集下载
下载[CULane数据集](https://xingangpan.github.io/projects/CULane.html)并解压到`dataset/culane`目录。
您的数据集目录结构如下:
```shell
culane/driver_xx_xxframe # data folders x6
culane/laneseg_label_w16 # lane segmentation labels
culane/list # data lists
```
如果您使用百度云链接下载,注意确保`driver_23_30frame_part1.tar.gz``driver_23_30frame_part2.tar.gz`解压后的文件都在`driver_23_30frame`目录下。
现已将用于测试的小数据集上传到PaddleDetection,可通过运行训练脚本,自动下载并解压数据,如需复现结果请下载链接中的全量数据集训练。
### 训练
- GPU单卡训练
```shell
python tools/train.py -c configs/clrnet/clr_resnet18_culane.yml
```
- GPU多卡训练
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/clrnet/clr_resnet18_culane.yml
```
### 评估
```shell
python tools/eval.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams
```
### 预测
```shell
python tools/infer_culane.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams --infer_img=demo/lane00000.jpg
```
注意:预测功能暂不支持模型静态图推理部署。
## 引用
```
@InProceedings{Zheng_2022_CVPR,
author = {Zheng, Tu and Huang, Yifei and Liu, Yang and Tang, Wenjian and Yang, Zheng and Cai, Deng and He, Xiaofei},
title = {CLRNet: Cross Layer Refinement Network for Lane Detection},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {898-907}
}
```
English | [简体中文](README_cn.md)
# CLRNet (CLRNet: Cross Layer Refinement Network for Lane Detection)
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Citations](#Citations)
## Introduction
[CLRNet](https://arxiv.org/abs/2203.10350) is a lane detection model. The CLRNet model is designed with line prior for lane detection, line iou loss as well as nms method, fused to extract contextual high-level features of lane line with low-level features, and refined by FPN multi-scale. Finally, the model achieved SOTA performance in lane detection datasets.
## Model Zoo
### CLRNet Results on CULane dataset
| backbone | mF1 | F1@50 | F1@75 | download | config |
| :--------------| :------- | :----: | :------: | :----: |:-----: |
| ResNet-18 | 54.98 | 79.46 | 62.10 | [model](https://paddledet.bj.bcebos.com/models/clrnet_resnet18_culane.pdparams) | [config](./clrnet_resnet18_culane.yml) |
### Download
Download [CULane](https://xingangpan.github.io/projects/CULane.html). Then extract them to `dataset/culane`.
For CULane, you should have structure like this:
```shell
culane/driver_xx_xxframe # data folders x6
culane/laneseg_label_w16 # lane segmentation labels
culane/list # data lists
```
If you use Baidu Cloud, make sure that images in `driver_23_30frame_part1.tar.gz` and `driver_23_30frame_part2.tar.gz` are located in one folder `driver_23_30frame` instead of two seperate folders after you decompress them.
Now we have uploaded a small subset of CULane dataset to PaddleDetection for code checking. You can simply run the training script below to download it automatically. If you want to implement the results, you need to download the full dataset at th link for training.
### Training
- single GPU
```shell
python tools/train.py -c configs/clrnet/clr_resnet18_culane.yml
```
- multi GPU
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/clrnet/clr_resnet18_culane.yml
```
### Evaluation
```shell
python tools/eval.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams
```
### Inference
```shell
python tools/infer_culane.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams --infer_img=demo/lane00000.jpg
```
Notice: The inference phase does not support static model graph deploy at present.
## Citations
```
@InProceedings{Zheng_2022_CVPR,
author = {Zheng, Tu and Huang, Yifei and Liu, Yang and Tang, Wenjian and Yang, Zheng and Cai, Deng and He, Xiaofei},
title = {CLRNet: Cross Layer Refinement Network for Lane Detection},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {898-907}
}
```
architecture: CLRNet
CLRNet:
backbone: CLRResNet
neck: CLRFPN
clr_head: CLRHead
CLRResNet:
resnet: 'resnet18'
pretrained: True
CLRFPN:
in_channels: [128,256,512]
out_channel: 64
extra_stage: 0
CLRHead:
prior_feat_channels: 64
fc_hidden_dim: 64
num_priors: 192
num_fc: 2
refine_layers: 3
sample_points: 36
loss: CLRNetLoss
conf_threshold: 0.4
nms_thres: 0.8
CLRNetLoss:
cls_loss_weight : 2.0
xyt_loss_weight : 0.2
iou_loss_weight : 2.0
seg_loss_weight : 1.0
refine_layers : 3
ignore_label: 255
bg_weight: 0.4
# for visualize lane detection results
sample_y:
start: 589
end: 230
step: -20
worker_num: 10
img_h: &img_h 320
img_w: &img_w 800
ori_img_h: &ori_img_h 590
ori_img_w: &ori_img_w 1640
num_points: &num_points 72
max_lanes: &max_lanes 4
TrainReader:
batch_size: 24
batch_transforms:
- CULaneTrainProcess: {img_h: *img_h, img_w: *img_w}
- CULaneDataProcess: {num_points: *num_points, max_lanes: *max_lanes, img_w: *img_w, img_h: *img_h}
shuffle: True
drop_last: False
EvalReader:
batch_size: 24
batch_transforms:
- CULaneResize: {prob: 1.0, img_h: *img_h, img_w: *img_w}
- CULaneDataProcess: {num_points: *num_points, max_lanes: *max_lanes, img_w: *img_w, img_h: *img_h}
shuffle: False
drop_last: False
TestReader:
batch_size: 24
batch_transforms:
- CULaneResize: {prob: 1.0, img_h: *img_h, img_w: *img_w}
- CULaneDataProcess: {num_points: *num_points, max_lanes: *max_lanes, img_w: *img_w, img_h: *img_h}
shuffle: False
drop_last: False
epoch: 15
snapshot_epoch: 5
LearningRate:
base_lr: 0.6e-3
schedulers:
- !CosineDecay
max_epochs: 15
use_warmup: False
OptimizerBuilder:
regularizer: False
optimizer:
type: AdamW
_BASE_: [
'../datasets/culane.yml',
'_base_/clrnet_reader.yml',
'_base_/clrnet_r18_fpn.yml',
'_base_/optimizer_1x.yml',
'../runtime.yml'
]
weights: output/clr_resnet18_culane/model_final
metric: CULaneMetric
num_classes: 5 # 4 lanes + background
cut_height: &cut_height 270
dataset_dir: &dataset_dir dataset/culane
TrainDataset:
name: CULaneDataSet
dataset_dir: *dataset_dir
list_path: 'list/train_gt.txt'
split: train
cut_height: *cut_height
EvalDataset:
name: CULaneDataSet
dataset_dir: *dataset_dir
list_path: 'list/test.txt'
split: test
cut_height: *cut_height
TestDataset:
name: CULaneDataSet
dataset_dir: *dataset_dir
list_path: 'list/test.txt'
split: test
cut_height: *cut_height
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from scipy.special import softmax
from ppdet.modeling.lane_utils import Lane
from ppdet.modeling.losses import line_iou
class CLRNetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres,
max_lanes, num_points):
self.img_w = img_w
self.conf_threshold = conf_threshold
self.nms_thres = nms_thres
self.max_lanes = max_lanes
self.num_points = num_points
self.n_strips = num_points - 1
self.n_offsets = num_points
self.ori_img_h = ori_img_h
self.cut_height = cut_height
self.prior_ys = paddle.linspace(
start=1, stop=0, num=self.n_offsets).astype('float64')
def predictions_to_pred(self, predictions):
"""
Convert predictions to internal Lane structure for evaluation.
"""
lanes = []
for lane in predictions:
lane_xs = lane[6:].clone()
start = min(
max(0, int(round(lane[2].item() * self.n_strips))),
self.n_strips)
length = int(round(lane[5].item()))
end = start + length - 1
end = min(end, len(self.prior_ys) - 1)
if start > 0:
mask = ((lane_xs[:start] >= 0.) &
(lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1]
mask = ~((mask.cumprod()[::-1]).astype(np.bool))
lane_xs[:start][mask] = -2
if end < len(self.prior_ys) - 1:
lane_xs[end + 1:] = -2
lane_ys = self.prior_ys[lane_xs >= 0].clone()
lane_xs = lane_xs[lane_xs >= 0]
lane_xs = lane_xs.flip(axis=0).astype('float64')
lane_ys = lane_ys.flip(axis=0)
lane_ys = (lane_ys *
(self.ori_img_h - self.cut_height) + self.cut_height
) / self.ori_img_h
if len(lane_xs) <= 1:
continue
points = paddle.stack(
x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])),
axis=1).squeeze(axis=2)
lane = Lane(
points=points.cpu().numpy(),
metadata={
'start_x': lane[3],
'start_y': lane[2],
'conf': lane[1]
})
lanes.append(lane)
return lanes
def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k):
"""
NMS for lane detection.
predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77]
scores: paddle.Tensor [num_lanes]
nms_overlap_thresh: float
top_k: int
"""
# sort by scores to get idx
idx = scores.argsort(descending=True)
keep = []
condidates = predictions.clone()
condidates = condidates.index_select(idx)
while len(condidates) > 0:
keep.append(idx[0])
if len(keep) >= top_k or len(condidates) == 1:
break
ious = []
for i in range(1, len(condidates)):
ious.append(1 - line_iou(
condidates[i].unsqueeze(0),
condidates[0].unsqueeze(0),
img_w=self.img_w,
length=15))
ious = paddle.to_tensor(ious)
mask = ious <= nms_overlap_thresh
id = paddle.where(mask == False)[0]
if id.shape[0] == 0:
break
condidates = condidates[1:].index_select(id)
idx = idx[1:].index_select(id)
keep = paddle.stack(keep)
return keep
def get_lanes(self, output, as_lanes=True):
"""
Convert model output to lanes.
"""
softmax = nn.Softmax(axis=1)
decoded = []
for predictions in output:
if len(predictions) == 0:
decoded.append([])
continue
threshold = self.conf_threshold
scores = softmax(predictions[:, :2])[:, 1]
keep_inds = scores >= threshold
predictions = predictions[keep_inds]
scores = scores[keep_inds]
if predictions.shape[0] == 0:
decoded.append([])
continue
nms_predictions = predictions.detach().clone()
nms_predictions = paddle.concat(
x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
nms_predictions[..., 5:] = nms_predictions[..., 5:] * (
self.img_w - 1)
keep = self.lane_nms(
nms_predictions[..., 5:],
scores,
nms_overlap_thresh=self.nms_thres,
top_k=self.max_lanes)
predictions = predictions.index_select(keep)
if predictions.shape[0] == 0:
decoded.append([])
continue
predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips)
if as_lanes:
pred = self.predictions_to_pred(predictions)
else:
pred = predictions
decoded.append(pred)
return decoded
def __call__(self, lanes_list):
lanes = self.get_lanes(lanes_list)
return lanes
......@@ -33,9 +33,10 @@ sys.path.insert(0, parent_path)
from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from visualize import visualize_box_mask
from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid
# Global dictionary
......@@ -43,7 +44,7 @@ SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet',
'PPLCNet', 'DETR', 'CenterTrack'
'PPLCNet', 'DETR', 'CenterTrack', 'CLRNet'
}
......@@ -713,6 +714,112 @@ class DetectorPicoDet(Detector):
return result
class DetectorCLRNet(Detector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to turn on MKLDNN
enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
"""
def __init__(
self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5, ):
super(DetectorCLRNet, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold, )
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.img_w = yml_conf['img_w']
self.ori_img_h = yml_conf['ori_img_h']
self.cut_height = yml_conf['cut_height']
self.max_lanes = yml_conf['max_lanes']
self.nms_thres = yml_conf['nms_thres']
self.num_points = yml_conf['num_points']
self.conf_threshold = yml_conf['conf_threshold']
def postprocess(self, inputs, result):
# postprocess output of predictor
lanes_list = result['lanes']
postprocessor = CLRNetPostProcess(
img_w=self.img_w,
ori_img_h=self.ori_img_h,
cut_height=self.cut_height,
conf_threshold=self.conf_threshold,
nms_thres=self.nms_thres,
max_lanes=self.max_lanes,
num_points=self.num_points)
lanes = postprocessor(lanes_list)
result = dict(lanes=lanes)
return result
def predict(self, repeats=1, run_benchmark=False):
'''
Args:
repeats (int): repeat number for prediction
Returns:
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
'''
lanes_list = []
if run_benchmark:
for i in range(repeats):
self.predictor.run()
paddle.device.cuda.synchronize()
result = dict(lanes=lanes_list)
return result
for i in range(repeats):
# TODO: check the output of predictor
self.predictor.run()
lanes_list.clear()
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
if num_outs == 0:
lanes_list.append([])
for out_idx in range(num_outs):
lanes_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
result = dict(lanes=lanes_list)
return result
def create_inputs(imgs, im_info):
"""generate input for different model type
Args:
......@@ -965,6 +1072,16 @@ def get_test_images(infer_dir, infer_img):
def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
# visualize the predict result
if 'lanes' in result:
print(image_list)
for idx, image_file in enumerate(image_list):
lanes = result['lanes'][idx]
img = cv2.imread(image_file)
out_file = os.path.join(output_dir, os.path.basename(image_file))
# hard code
lanes = [lane.to_array([], ) for lane in lanes]
imshow_lanes(img, lanes, out_file=out_file)
return
start_idx = 0
for idx, image_file in enumerate(image_list):
im_bboxes_num = result['boxes_num'][idx]
......@@ -1013,6 +1130,8 @@ def main():
detector_func = 'DetectorSOLOv2'
elif arch == 'PicoDet':
detector_func = 'DetectorPicoDet'
elif arch == "CLRNet":
detector_func = 'DetectorCLRNet'
detector = eval(detector_func)(
FLAGS.model_dir,
......
......@@ -14,6 +14,7 @@
import cv2
import numpy as np
import imgaug.augmenters as iaa
from keypoint_preprocess import get_affine_transform
from PIL import Image
......@@ -509,6 +510,32 @@ class WarpAffine(object):
return inp, im_info
class CULaneResize(object):
def __init__(self, img_h, img_w, cut_height, prob=0.5):
super(CULaneResize, self).__init__()
self.img_h = img_h
self.img_w = img_w
self.cut_height = cut_height
self.prob = prob
def __call__(self, im, im_info):
# cut
im = im[self.cut_height:, :, :]
# resize
transform = iaa.Sometimes(self.prob,
iaa.Resize({
"height": self.img_h,
"width": self.img_w
}))
im = transform(image=im.copy().astype(np.uint8))
im = im.astype(np.float32) / 255.
# check transpose is need whether the func decode_image is equal to CULaneDataSet cv.imread
im = im.transpose(2, 0, 1)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
......
......@@ -577,3 +577,63 @@ def visualize_vehicle_retrograde(im, mot_res, vehicle_retrograde_res):
draw.text((xmax + 1, ymin - th), text, fill=(0, 255, 0))
return im
COLORS = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(128, 255, 0),
(255, 128, 0),
(128, 0, 255),
(255, 0, 128),
(0, 128, 255),
(0, 255, 128),
(128, 255, 255),
(255, 128, 255),
(255, 255, 128),
(60, 180, 0),
(180, 60, 0),
(0, 60, 180),
(0, 180, 60),
(60, 0, 180),
(180, 0, 60),
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(128, 255, 0),
(255, 128, 0),
(128, 0, 255),
]
def imshow_lanes(img, lanes, show=False, out_file=None, width=4):
lanes_xys = []
for _, lane in enumerate(lanes):
xys = []
for x, y in lane:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
xys.append((x, y))
lanes_xys.append(xys)
lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0)
for idx, xys in enumerate(lanes_xys):
for i in range(1, len(xys)):
cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
if show:
cv2.imshow('view', img)
cv2.waitKey(0)
if out_file:
if not os.path.exists(os.path.dirname(out_file)):
os.makedirs(os.path.dirname(out_file))
cv2.imwrite(out_file, img)
\ No newline at end of file
import math
import numpy as np
from imgaug.augmentables.lines import LineString
from scipy.interpolate import InterpolatedUnivariateSpline
def lane_to_linestrings(lanes):
lines = []
for lane in lanes:
lines.append(LineString(lane))
return lines
def linestrings_to_lanes(lines):
lanes = []
for line in lines:
lanes.append(line.coords)
return lanes
def sample_lane(points, sample_ys, img_w):
# this function expects the points to be sorted
points = np.array(points)
if not np.all(points[1:, 1] < points[:-1, 1]):
raise Exception('Annotaion points have to be sorted')
x, y = points[:, 0], points[:, 1]
# interpolate points inside domain
assert len(points) > 1
interp = InterpolatedUnivariateSpline(
y[::-1], x[::-1], k=min(3, len(points) - 1))
domain_min_y = y.min()
domain_max_y = y.max()
sample_ys_inside_domain = sample_ys[(sample_ys >= domain_min_y) & (
sample_ys <= domain_max_y)]
assert len(sample_ys_inside_domain) > 0
interp_xs = interp(sample_ys_inside_domain)
# extrapolate lane to the bottom of the image with a straight line using the 2 points closest to the bottom
two_closest_points = points[:2]
extrap = np.polyfit(
two_closest_points[:, 1], two_closest_points[:, 0], deg=1)
extrap_ys = sample_ys[sample_ys > domain_max_y]
extrap_xs = np.polyval(extrap, extrap_ys)
all_xs = np.hstack((extrap_xs, interp_xs))
# separate between inside and outside points
inside_mask = (all_xs >= 0) & (all_xs < img_w)
xs_inside_image = all_xs[inside_mask]
xs_outside_image = all_xs[~inside_mask]
return xs_outside_image, xs_inside_image
def filter_lane(lane):
assert lane[-1][1] <= lane[0][1]
filtered_lane = []
used = set()
for p in lane:
if p[1] not in used:
filtered_lane.append(p)
used.add(p[1])
return filtered_lane
def transform_annotation(img_w, img_h, max_lanes, n_offsets, offsets_ys,
n_strips, strip_size, anno):
old_lanes = anno['lanes']
# removing lanes with less than 2 points
old_lanes = filter(lambda x: len(x) > 1, old_lanes)
# sort lane points by Y (bottom to top of the image)
old_lanes = [sorted(lane, key=lambda x: -x[1]) for lane in old_lanes]
# remove points with same Y (keep first occurrence)
old_lanes = [filter_lane(lane) for lane in old_lanes]
# normalize the annotation coordinates
old_lanes = [[[x * img_w / float(img_w), y * img_h / float(img_h)]
for x, y in lane] for lane in old_lanes]
# create tranformed annotations
lanes = np.ones(
(max_lanes, 2 + 1 + 1 + 2 + n_offsets), dtype=np.float32
) * -1e5 # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, S+1 coordinates
lanes_endpoints = np.ones((max_lanes, 2))
# lanes are invalid by default
lanes[:, 0] = 1
lanes[:, 1] = 0
for lane_idx, lane in enumerate(old_lanes):
if lane_idx >= max_lanes:
break
try:
xs_outside_image, xs_inside_image = sample_lane(lane, offsets_ys,
img_w)
except AssertionError:
continue
if len(xs_inside_image) <= 1:
continue
all_xs = np.hstack((xs_outside_image, xs_inside_image))
lanes[lane_idx, 0] = 0
lanes[lane_idx, 1] = 1
lanes[lane_idx, 2] = len(xs_outside_image) / n_strips
lanes[lane_idx, 3] = xs_inside_image[0]
thetas = []
for i in range(1, len(xs_inside_image)):
theta = math.atan(
i * strip_size /
(xs_inside_image[i] - xs_inside_image[0] + 1e-5)) / math.pi
theta = theta if theta > 0 else 1 - abs(theta)
thetas.append(theta)
theta_far = sum(thetas) / len(thetas)
# lanes[lane_idx,
# 4] = (theta_closest + theta_far) / 2 # averaged angle
lanes[lane_idx, 4] = theta_far
lanes[lane_idx, 5] = len(xs_inside_image)
lanes[lane_idx, 6:6 + len(all_xs)] = all_xs
lanes_endpoints[lane_idx, 0] = (len(all_xs) - 1) / n_strips
lanes_endpoints[lane_idx, 1] = xs_inside_image[-1]
new_anno = {
'label': lanes,
'old_anno': anno,
'lane_endpoints': lanes_endpoints
}
return new_anno
......@@ -19,6 +19,7 @@ from . import category
from . import keypoint_coco
from . import mot
from . import sniper_coco
from . import culane
from .coco import *
from .voc import *
......@@ -29,3 +30,4 @@ from .mot import *
from .sniper_coco import SniperCOCODataSet
from .dataset import ImageFolder
from .pose3d_cmb import *
from .culane import *
from ppdet.core.workspace import register, serializable
import cv2
import os
import tarfile
import numpy as np
import os.path as osp
from ppdet.data.source.dataset import DetDataset
from imgaug.augmentables.lines import LineStringsOnImage
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from ppdet.data.culane_utils import lane_to_linestrings
import pickle as pkl
from ppdet.utils.logger import setup_logger
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from .dataset import DetDataset, _make_dataset, _is_valid_file
from ppdet.utils.download import download_dataset
logger = setup_logger(__name__)
@register
@serializable
class CULaneDataSet(DetDataset):
def __init__(
self,
dataset_dir,
cut_height,
list_path,
split='train',
data_fields=['image'],
video_file=None,
frame_rate=-1, ):
super(CULaneDataSet, self).__init__(
dataset_dir=dataset_dir,
cut_height=cut_height,
split=split,
data_fields=data_fields)
self.dataset_dir = dataset_dir
self.list_path = osp.join(dataset_dir, list_path)
self.cut_height = cut_height
self.data_fields = data_fields
self.split = split
self.training = 'train' in split
self.data_infos = []
self.video_file = video_file
self.frame_rate = frame_rate
self._imid2path = {}
self.predict_dir = None
def __len__(self):
return len(self.data_infos)
def check_or_download_dataset(self):
if not osp.exists(self.dataset_dir):
download_dataset("dataset", dataset="culane")
# extract .tar files in self.dataset_dir
for fname in os.listdir(self.dataset_dir):
logger.info("Decompressing {}...".format(fname))
# ignore .* files
if fname.startswith('.'):
continue
if fname.find('.tar.gz') >= 0:
with tarfile.open(osp.join(self.dataset_dir, fname)) as tf:
tf.extractall(path=self.dataset_dir)
logger.info("Dataset files are ready.")
def parse_dataset(self):
logger.info('Loading CULane annotations...')
if self.predict_dir is not None:
logger.info('switch to predict mode')
return
# Waiting for the dataset to load is tedious, let's cache it
os.makedirs('cache', exist_ok=True)
cache_path = 'cache/culane_paddle_{}.pkl'.format(self.split)
if os.path.exists(cache_path):
with open(cache_path, 'rb') as cache_file:
self.data_infos = pkl.load(cache_file)
self.max_lanes = max(
len(anno['lanes']) for anno in self.data_infos)
return
with open(self.list_path) as list_file:
for line in list_file:
infos = self.load_annotation(line.split())
self.data_infos.append(infos)
# cache data infos to file
with open(cache_path, 'wb') as cache_file:
pkl.dump(self.data_infos, cache_file)
def load_annotation(self, line):
infos = {}
img_line = line[0]
img_line = img_line[1 if img_line[0] == '/' else 0::]
img_path = os.path.join(self.dataset_dir, img_line)
infos['img_name'] = img_line
infos['img_path'] = img_path
if len(line) > 1:
mask_line = line[1]
mask_line = mask_line[1 if mask_line[0] == '/' else 0::]
mask_path = os.path.join(self.dataset_dir, mask_line)
infos['mask_path'] = mask_path
if len(line) > 2:
exist_list = [int(l) for l in line[2:]]
infos['lane_exist'] = np.array(exist_list)
anno_path = img_path[:
-3] + 'lines.txt' # remove sufix jpg and add lines.txt
with open(anno_path, 'r') as anno_file:
data = [
list(map(float, line.split())) for line in anno_file.readlines()
]
lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)
if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data]
lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
lanes = [lane for lane in lanes
if len(lane) > 2] # remove lanes with less than 2 points
lanes = [sorted(
lane, key=lambda x: x[1]) for lane in lanes] # sort by y
infos['lanes'] = lanes
return infos
def set_images(self, images):
self.predict_dir = images
self.data_infos = self._load_images()
def _find_images(self):
predict_dir = self.predict_dir
if not isinstance(predict_dir, Sequence):
predict_dir = [predict_dir]
images = []
for im_dir in predict_dir:
if os.path.isdir(im_dir):
im_dir = os.path.join(self.predict_dir, im_dir)
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
return images
def _load_images(self):
images = self._find_images()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {
'im_id': np.array([ct]),
"img_path": os.path.abspath(image),
"img_name": os.path.basename(image),
"lanes": []
}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def __getitem__(self, idx):
data_info = self.data_infos[idx]
img = cv2.imread(data_info['img_path'])
img = img[self.cut_height:, :, :]
sample = data_info.copy()
sample.update({'image': img})
img_org = sample['image']
if self.training:
label = cv2.imread(sample['mask_path'], cv2.IMREAD_UNCHANGED)
if len(label.shape) > 2:
label = label[:, :, 0]
label = label.squeeze()
label = label[self.cut_height:, :]
sample.update({'mask': label})
if self.cut_height != 0:
new_lanes = []
for i in sample['lanes']:
lanes = []
for p in i:
lanes.append((p[0], p[1] - self.cut_height))
new_lanes.append(lanes)
sample.update({'lanes': new_lanes})
sample['mask'] = SegmentationMapsOnImage(
sample['mask'], shape=img_org.shape)
sample['full_img_path'] = data_info['img_path']
sample['img_name'] = data_info['img_name']
sample['im_id'] = np.array([idx])
sample['image'] = sample['image'].copy().astype(np.uint8)
sample['lanes'] = lane_to_linestrings(sample['lanes'])
sample['lanes'] = LineStringsOnImage(
sample['lanes'], shape=img_org.shape)
sample['seg'] = np.zeros(img_org.shape)
return sample
......@@ -18,6 +18,7 @@ from . import keypoint_operators
from . import mot_operators
from . import rotated_operators
from . import keypoints_3d_operators
from . import culane_operators
from .operators import *
from .batch_operators import *
......@@ -25,8 +26,10 @@ from .keypoint_operators import *
from .mot_operators import *
from .rotated_operators import *
from .keypoints_3d_operators import *
from .culane_operators import *
__all__ = []
__all__ += registered_ops
__all__ += keypoint_operators.__all__
__all__ += mot_operators.__all__
__all__ += culane_operators.__all__
import numpy as np
import imgaug.augmenters as iaa
from .operators import BaseOperator, register_op
from ppdet.utils.logger import setup_logger
from ppdet.data.culane_utils import linestrings_to_lanes, transform_annotation
logger = setup_logger(__name__)
__all__ = [
"CULaneTrainProcess", "CULaneDataProcess", "HorizontalFlip",
"ChannelShuffle", "CULaneAffine", "CULaneResize", "OneOfBlur",
"MultiplyAndAddToBrightness", "AddToHueAndSaturation"
]
def trainTransforms(img_h, img_w):
transforms = [{
'name': 'Resize',
'parameters': dict(size=dict(
height=img_h, width=img_w)),
'p': 1.0
}, {
'name': 'HorizontalFlip',
'parameters': dict(p=1.0),
'p': 0.5
}, {
'name': 'ChannelShuffle',
'parameters': dict(p=1.0),
'p': 0.1
}, {
'name': 'MultiplyAndAddToBrightness',
'parameters': dict(
mul=(0.85, 1.15), add=(-10, 10)),
'p': 0.6
}, {
'name': 'AddToHueAndSaturation',
'parameters': dict(value=(-10, 10)),
'p': 0.7
}, {
'name': 'OneOf',
'transforms': [
dict(
name='MotionBlur', parameters=dict(k=(3, 5))), dict(
name='MedianBlur', parameters=dict(k=(3, 5)))
],
'p': 0.2
}, {
'name': 'Affine',
'parameters': dict(
translate_percent=dict(
x=(-0.1, 0.1), y=(-0.1, 0.1)),
rotate=(-10, 10),
scale=(0.8, 1.2)),
'p': 0.7
}, {
'name': 'Resize',
'parameters': dict(size=dict(
height=img_h, width=img_w)),
'p': 1.0
}]
return transforms
@register_op
class CULaneTrainProcess(BaseOperator):
def __init__(self, img_w, img_h):
super(CULaneTrainProcess, self).__init__()
self.img_w = img_w
self.img_h = img_h
self.transforms = trainTransforms(self.img_h, self.img_w)
if self.transforms is not None:
img_transforms = []
for aug in self.transforms:
p = aug['p']
if aug['name'] != 'OneOf':
img_transforms.append(
iaa.Sometimes(
p=p,
then_list=getattr(iaa, aug['name'])(**aug[
'parameters'])))
else:
img_transforms.append(
iaa.Sometimes(
p=p,
then_list=iaa.OneOf([
getattr(iaa, aug_['name'])(**aug_['parameters'])
for aug_ in aug['transforms']
])))
else:
img_transforms = []
self.iaa_transform = iaa.Sequential(img_transforms)
def apply(self, sample, context=None):
img, line_strings, seg = self.iaa_transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
return sample
@register_op
class CULaneDataProcess(BaseOperator):
def __init__(self, img_w, img_h, num_points, max_lanes):
super(CULaneDataProcess, self).__init__()
self.img_w = img_w
self.img_h = img_h
self.num_points = num_points
self.n_offsets = num_points
self.n_strips = num_points - 1
self.strip_size = self.img_h / self.n_strips
self.max_lanes = max_lanes
self.offsets_ys = np.arange(self.img_h, -1, -self.strip_size)
def apply(self, sample, context=None):
data = {}
line_strings = sample['lanes']
line_strings.clip_out_of_image_()
new_anno = {'lanes': linestrings_to_lanes(line_strings)}
for i in range(30):
try:
annos = transform_annotation(
self.img_w, self.img_h, self.max_lanes, self.n_offsets,
self.offsets_ys, self.n_strips, self.strip_size, new_anno)
label = annos['label']
lane_endpoints = annos['lane_endpoints']
break
except:
if (i + 1) == 30:
logger.critical('Transform annotation failed 30 times :(')
exit()
sample['image'] = sample['image'].astype(np.float32) / 255.
data['image'] = sample['image'].transpose(2, 0, 1)
data['lane_line'] = label
data['seg'] = sample['seg']
data['full_img_path'] = sample['full_img_path']
data['img_name'] = sample['img_name']
data['im_id'] = sample['im_id']
if 'mask' in sample.keys():
data['seg'] = sample['mask'].get_arr()
data['im_shape'] = np.array([self.img_w, self.img_h], dtype=np.float32)
data['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return data
@register_op
class CULaneResize(BaseOperator):
def __init__(self, img_h, img_w, prob=0.5):
super(CULaneResize, self).__init__()
self.img_h = img_h
self.img_w = img_w
self.prob = prob
def apply(self, sample, context=None):
transform = iaa.Sometimes(self.prob,
iaa.Resize({
"height": self.img_h,
"width": self.img_w
}))
if 'mask' in sample.keys():
img, line_strings, seg = transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
else:
img, line_strings = transform(
image=sample['image'].copy().astype(np.uint8),
line_strings=sample['lanes'])
sample['image'] = img
sample['lanes'] = line_strings
return sample
@register_op
class HorizontalFlip(BaseOperator):
def __init__(self, prob=0.5):
super(HorizontalFlip, self).__init__()
self.prob = prob
def apply(self, sample, context=None):
transform = iaa.Sometimes(self.prob, iaa.HorizontalFlip(1.0))
if 'mask' in sample.keys():
img, line_strings, seg = transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
else:
img, line_strings = transform(
image=sample['image'], line_strings=sample['lanes'])
sample['image'] = img
sample['lanes'] = line_strings
return sample
@register_op
class ChannelShuffle(BaseOperator):
def __init__(self, prob=0.1):
super(ChannelShuffle, self).__init__()
self.prob = prob
def apply(self, sample, context=None):
transform = iaa.Sometimes(self.prob, iaa.ChannelShuffle(1.0))
if 'mask' in sample.keys():
img, line_strings, seg = transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
else:
img, line_strings = transform(
image=sample['image'], line_strings=sample['lanes'])
sample['image'] = img
sample['lanes'] = line_strings
return sample
@register_op
class MultiplyAndAddToBrightness(BaseOperator):
def __init__(self, mul=(0.85, 1.15), add=(-10, 10), prob=0.5):
super(MultiplyAndAddToBrightness, self).__init__()
self.mul = tuple(mul)
self.add = tuple(add)
self.prob = prob
def apply(self, sample, context=None):
transform = iaa.Sometimes(
self.prob,
iaa.MultiplyAndAddToBrightness(
mul=self.mul, add=self.add))
if 'mask' in sample.keys():
img, line_strings, seg = transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
else:
img, line_strings = transform(
image=sample['image'], line_strings=sample['lanes'])
sample['image'] = img
sample['lanes'] = line_strings
return sample
@register_op
class AddToHueAndSaturation(BaseOperator):
def __init__(self, value=(-10, 10), prob=0.5):
super(AddToHueAndSaturation, self).__init__()
self.value = tuple(value)
self.prob = prob
def apply(self, sample, context=None):
transform = iaa.Sometimes(
self.prob, iaa.AddToHueAndSaturation(value=self.value))
if 'mask' in sample.keys():
img, line_strings, seg = transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
else:
img, line_strings = transform(
image=sample['image'], line_strings=sample['lanes'])
sample['image'] = img
sample['lanes'] = line_strings
return sample
@register_op
class OneOfBlur(BaseOperator):
def __init__(self, MotionBlur_k=(3, 5), MedianBlur_k=(3, 5), prob=0.5):
super(OneOfBlur, self).__init__()
self.MotionBlur_k = tuple(MotionBlur_k)
self.MedianBlur_k = tuple(MedianBlur_k)
self.prob = prob
def apply(self, sample, context=None):
transform = iaa.Sometimes(
self.prob,
iaa.OneOf([
iaa.MotionBlur(k=self.MotionBlur_k),
iaa.MedianBlur(k=self.MedianBlur_k)
]))
if 'mask' in sample.keys():
img, line_strings, seg = transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
else:
img, line_strings = transform(
image=sample['image'], line_strings=sample['lanes'])
sample['image'] = img
sample['lanes'] = line_strings
return sample
@register_op
class CULaneAffine(BaseOperator):
def __init__(self,
translate_percent_x=(-0.1, 0.1),
translate_percent_y=(-0.1, 0.1),
rotate=(3, 5),
scale=(0.8, 1.2),
prob=0.5):
super(CULaneAffine, self).__init__()
self.translate_percent = {
'x': tuple(translate_percent_x),
'y': tuple(translate_percent_y)
}
self.rotate = tuple(rotate)
self.scale = tuple(scale)
self.prob = prob
def apply(self, sample, context=None):
transform = iaa.Sometimes(
self.prob,
iaa.Affine(
translate_percent=self.translate_percent,
rotate=self.rotate,
scale=self.scale))
if 'mask' in sample.keys():
img, line_strings, seg = transform(
image=sample['image'],
line_strings=sample['lanes'],
segmentation_maps=sample['mask'])
sample['image'] = img
sample['lanes'] = line_strings
sample['mask'] = seg
else:
img, line_strings = transform(
image=sample['image'], line_strings=sample['lanes'])
sample['image'] = img
sample['lanes'] = line_strings
return sample
......@@ -54,10 +54,12 @@ TRT_MIN_SUBGRAPH = {
'YOLOF': 40,
'METRO_Body': 3,
'DETR': 3,
'CLRNet': 3
}
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
LANE_ARCH = ['CLRNet']
TO_STATIC_SPEC = {
'yolov3_darknet53_270e_coco': [{
......@@ -215,7 +217,8 @@ def _prune_input_spec(input_spec, program, targets):
def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
preprocess_list = []
label_list = []
if arch != "lane_arch":
anno_file = dataset_cfg.get_anno()
clsid2catid, catid2name = get_categories(metric, anno_file, arch)
......@@ -246,6 +249,13 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
'stride': value['pad_to_stride']
})
break
elif key == "CULaneResize":
# cut and resize
p = {'type': key}
p.update(value)
p.update({"cut_height": dataset_cfg.cut_height})
preprocess_list.append(p)
break
return preprocess_list, label_list
......@@ -315,6 +325,20 @@ def _dump_infer_config(config, path, image_shape, model):
if infer_arch in KEYPOINT_ARCH:
label_arch = 'keypoint_arch'
if infer_arch in LANE_ARCH:
infer_cfg['arch'] = infer_arch
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
infer_cfg['img_w'] = config['img_w']
infer_cfg['ori_img_h'] = config['ori_img_h']
infer_cfg['cut_height'] = config['cut_height']
label_arch = 'lane_arch'
head_name = "CLRHead"
infer_cfg['conf_threshold'] = config[head_name]['conf_threshold']
infer_cfg['nms_thres'] = config[head_name]['nms_thres']
infer_cfg['max_lanes'] = config[head_name]['max_lanes']
infer_cfg['num_points'] = config[head_name]['num_points']
arch_state = True
if infer_arch in MOT_ARCH:
if config['metric'] in ['COCO', 'VOC']:
# MOT model run as Detector
......
......@@ -39,13 +39,14 @@ from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results, save_result
from ppdet.metrics import get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownCOCOWholeBadyHandEval, KeyPointTopDownMPIIEval, Pose3DEval
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, RBoxMetric, JDEDetMetric, SNIPERCOCOMetric, CULaneMetric
from ppdet.data.source.sniper_coco import SniperCOCODataSet
from ppdet.data.source.category import get_categories
import ppdet.utils.stats as stats
from ppdet.utils.fuse_utils import fuse_conv_bn
from ppdet.utils import profiler
from ppdet.modeling.post_process import multiclass_nms
from ppdet.modeling.lane_utils import imshow_lanes
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback
from .export_utils import _dump_infer_config, _prune_input_spec, apply_to_static
......@@ -383,6 +384,15 @@ class Trainer(object):
]
elif self.cfg.metric == 'MOTDet':
self._metrics = [JDEDetMetric(), ]
elif self.cfg.metric == 'CULaneMetric':
output_eval = self.cfg.get('output_eval', None)
self._metrics = [
CULaneMetric(
cfg=self.cfg,
output_eval=output_eval,
split=self.dataset.split,
dataset_dir=self.cfg.dataset_dir)
]
else:
logger.warning("Metric not support for metric type {}".format(
self.cfg.metric))
......@@ -1139,6 +1149,12 @@ class Trainer(object):
"crops": InputSpec(
shape=[None, 3, 192, 64], name='crops')
})
if self.cfg.architecture == 'CLRNet':
input_spec[0].update({
"full_img_path": str,
"img_name": str,
})
if prune_input:
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
......@@ -1277,3 +1293,107 @@ class Trainer(object):
logger.info("Found {} inference images in total.".format(
len(images)))
return all_images
def predict_culane(self,
images,
output_dir='output',
save_results=False,
visualize=True):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0)
imid2path = self.dataset.get_imid2path()
def setup_metrics_for_loader():
# mem
metrics = copy.deepcopy(self._metrics)
mode = self.mode
save_prediction_only = self.cfg[
'save_prediction_only'] if 'save_prediction_only' in self.cfg else None
output_eval = self.cfg[
'output_eval'] if 'output_eval' in self.cfg else None
# modify
self.mode = '_test'
self.cfg['save_prediction_only'] = True
self.cfg['output_eval'] = output_dir
self.cfg['imid2path'] = imid2path
self._init_metrics()
# restore
self.mode = mode
self.cfg.pop('save_prediction_only')
if save_prediction_only is not None:
self.cfg['save_prediction_only'] = save_prediction_only
self.cfg.pop('output_eval')
if output_eval is not None:
self.cfg['output_eval'] = output_eval
self.cfg.pop('imid2path')
_metrics = copy.deepcopy(self._metrics)
self._metrics = metrics
return _metrics
if save_results:
metrics = setup_metrics_for_loader()
else:
metrics = []
# Run Infer
self.status['mode'] = 'test'
self.model.eval()
if self.cfg.get('print_flops', False):
flops_loader = create('TestReader')(self.dataset, 0)
self._flops(flops_loader)
results = []
for step_id, data in enumerate(tqdm(loader)):
self.status['step_id'] = step_id
# forward
outs = self.model(data)
for _m in metrics:
_m.update(data, outs)
for key in ['im_shape', 'scale_factor', 'im_id']:
if isinstance(data, typing.Sequence):
outs[key] = data[0][key]
else:
outs[key] = data[key]
for key, value in outs.items():
if hasattr(value, 'numpy'):
outs[key] = value.numpy()
results.append(outs)
for _m in metrics:
_m.accumulate()
_m.reset()
if visualize:
import cv2
for outs in results:
for i in range(len(outs['img_path'])):
lanes = outs['lanes'][i]
img_path = outs['img_path'][i]
img = cv2.imread(img_path)
out_file = os.path.join(output_dir,
os.path.basename(img_path))
lanes = [
lane.to_array(
sample_y_range=[
self.cfg['sample_y']['start'],
self.cfg['sample_y']['end'],
self.cfg['sample_y']['step']
],
img_w=self.cfg.ori_img_w,
img_h=self.cfg.ori_img_h) for lane in lanes
]
imshow_lanes(img, lanes, out_file=out_file)
return results
......@@ -28,3 +28,7 @@ __all__ = metrics.__all__ + mot_metrics.__all__
from . import mcmot_metrics
from .mcmot_metrics import *
__all__ = metrics.__all__ + mcmot_metrics.__all__
from . import culane_metrics
from .culane_metrics import *
__all__ = metrics.__all__ + culane_metrics.__all__
\ No newline at end of file
import os
import cv2
import numpy as np
import os.path as osp
from functools import partial
from .metrics import Metric
from scipy.interpolate import splprep, splev
from scipy.optimize import linear_sum_assignment
from shapely.geometry import LineString, Polygon
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'draw_lane', 'discrete_cross_iou', 'continuous_cross_iou', 'interp',
'culane_metric', 'load_culane_img_data', 'load_culane_data',
'eval_predictions', "CULaneMetric"
]
LIST_FILE = {
'train': 'list/train_gt.txt',
'val': 'list/val.txt',
'test': 'list/test.txt',
}
CATEGORYS = {
'normal': 'list/test_split/test0_normal.txt',
'crowd': 'list/test_split/test1_crowd.txt',
'hlight': 'list/test_split/test2_hlight.txt',
'shadow': 'list/test_split/test3_shadow.txt',
'noline': 'list/test_split/test4_noline.txt',
'arrow': 'list/test_split/test5_arrow.txt',
'curve': 'list/test_split/test6_curve.txt',
'cross': 'list/test_split/test7_cross.txt',
'night': 'list/test_split/test8_night.txt',
}
def draw_lane(lane, img=None, img_shape=None, width=30):
if img is None:
img = np.zeros(img_shape, dtype=np.uint8)
lane = lane.astype(np.int32)
for p1, p2 in zip(lane[:-1], lane[1:]):
cv2.line(
img, tuple(p1), tuple(p2), color=(255, 255, 255), thickness=width)
return img
def discrete_cross_iou(xs, ys, width=30, img_shape=(590, 1640, 3)):
xs = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in xs]
ys = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in ys]
ious = np.zeros((len(xs), len(ys)))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
ious[i, j] = (x & y).sum() / (x | y).sum()
return ious
def continuous_cross_iou(xs, ys, width=30, img_shape=(590, 1640, 3)):
h, w, _ = img_shape
image = Polygon([(0, 0), (0, h - 1), (w - 1, h - 1), (w - 1, 0)])
xs = [
LineString(lane).buffer(
distance=width / 2., cap_style=1, join_style=2).intersection(image)
for lane in xs
]
ys = [
LineString(lane).buffer(
distance=width / 2., cap_style=1, join_style=2).intersection(image)
for lane in ys
]
ious = np.zeros((len(xs), len(ys)))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
ious[i, j] = x.intersection(y).area / x.union(y).area
return ious
def interp(points, n=50):
x = [x for x, _ in points]
y = [y for _, y in points]
tck, u = splprep([x, y], s=0, t=n, k=min(3, len(points) - 1))
u = np.linspace(0., 1., num=(len(u) - 1) * n + 1)
return np.array(splev(u, tck)).T
def culane_metric(pred,
anno,
width=30,
iou_thresholds=[0.5],
official=True,
img_shape=(590, 1640, 3)):
_metric = {}
for thr in iou_thresholds:
tp = 0
fp = 0 if len(anno) != 0 else len(pred)
fn = 0 if len(pred) != 0 else len(anno)
_metric[thr] = [tp, fp, fn]
interp_pred = np.array(
[interp(
pred_lane, n=5) for pred_lane in pred], dtype=object) # (4, 50, 2)
interp_anno = np.array(
[interp(
anno_lane, n=5) for anno_lane in anno], dtype=object) # (4, 50, 2)
if official:
ious = discrete_cross_iou(
interp_pred, interp_anno, width=width, img_shape=img_shape)
else:
ious = continuous_cross_iou(
interp_pred, interp_anno, width=width, img_shape=img_shape)
row_ind, col_ind = linear_sum_assignment(1 - ious)
_metric = {}
for thr in iou_thresholds:
tp = int((ious[row_ind, col_ind] > thr).sum())
fp = len(pred) - tp
fn = len(anno) - tp
_metric[thr] = [tp, fp, fn]
return _metric
def load_culane_img_data(path):
with open(path, 'r') as data_file:
img_data = data_file.readlines()
img_data = [line.split() for line in img_data]
img_data = [list(map(float, lane)) for lane in img_data]
img_data = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)]
for lane in img_data]
img_data = [lane for lane in img_data if len(lane) >= 2]
return img_data
def load_culane_data(data_dir, file_list_path):
with open(file_list_path, 'r') as file_list:
filepaths = [
os.path.join(data_dir,
line[1 if line[0] == '/' else 0:].rstrip().replace(
'.jpg', '.lines.txt'))
for line in file_list.readlines()
]
data = []
for path in filepaths:
img_data = load_culane_img_data(path)
data.append(img_data)
return data
def eval_predictions(pred_dir,
anno_dir,
list_path,
iou_thresholds=[0.5],
width=30,
official=True,
sequential=False):
logger.info('Calculating metric for List: {}'.format(list_path))
predictions = load_culane_data(pred_dir, list_path)
annotations = load_culane_data(anno_dir, list_path)
img_shape = (590, 1640, 3)
if sequential:
results = map(partial(
culane_metric,
width=width,
official=official,
iou_thresholds=iou_thresholds,
img_shape=img_shape),
predictions,
annotations)
else:
from multiprocessing import Pool, cpu_count
from itertools import repeat
with Pool(cpu_count()) as p:
results = p.starmap(culane_metric,
zip(predictions, annotations,
repeat(width),
repeat(iou_thresholds),
repeat(official), repeat(img_shape)))
mean_f1, mean_prec, mean_recall, total_tp, total_fp, total_fn = 0, 0, 0, 0, 0, 0
ret = {}
for thr in iou_thresholds:
tp = sum(m[thr][0] for m in results)
fp = sum(m[thr][1] for m in results)
fn = sum(m[thr][2] for m in results)
precision = float(tp) / (tp + fp) if tp != 0 else 0
recall = float(tp) / (tp + fn) if tp != 0 else 0
f1 = 2 * precision * recall / (precision + recall) if tp != 0 else 0
logger.info('iou thr: {:.2f}, tp: {}, fp: {}, fn: {},'
'precision: {}, recall: {}, f1: {}'.format(
thr, tp, fp, fn, precision, recall, f1))
mean_f1 += f1 / len(iou_thresholds)
mean_prec += precision / len(iou_thresholds)
mean_recall += recall / len(iou_thresholds)
total_tp += tp
total_fp += fp
total_fn += fn
ret[thr] = {
'TP': tp,
'FP': fp,
'FN': fn,
'Precision': precision,
'Recall': recall,
'F1': f1
}
if len(iou_thresholds) > 2:
logger.info(
'mean result, total_tp: {}, total_fp: {}, total_fn: {},'
'precision: {}, recall: {}, f1: {}'.format(
total_tp, total_fp, total_fn, mean_prec, mean_recall, mean_f1))
ret['mean'] = {
'TP': total_tp,
'FP': total_fp,
'FN': total_fn,
'Precision': mean_prec,
'Recall': mean_recall,
'F1': mean_f1
}
return ret
class CULaneMetric(Metric):
def __init__(self,
cfg,
output_eval=None,
split="test",
dataset_dir="dataset/CULane/"):
super(CULaneMetric, self).__init__()
self.output_eval = "evaluation" if output_eval is None else output_eval
self.dataset_dir = dataset_dir
self.split = split
self.list_path = osp.join(dataset_dir, LIST_FILE[split])
self.predictions = []
self.img_names = []
self.lanes = []
self.eval_results = {}
self.cfg = cfg
self.reset()
def reset(self):
self.predictions = []
self.img_names = []
self.lanes = []
self.eval_results = {}
def get_prediction_string(self, pred):
ys = np.arange(270, 590, 8) / self.cfg.ori_img_h
out = []
for lane in pred:
xs = lane(ys)
valid_mask = (xs >= 0) & (xs < 1)
xs = xs * self.cfg.ori_img_w
lane_xs = xs[valid_mask]
lane_ys = ys[valid_mask] * self.cfg.ori_img_h
lane_xs, lane_ys = lane_xs[::-1], lane_ys[::-1]
lane_str = ' '.join([
'{:.5f} {:.5f}'.format(x, y) for x, y in zip(lane_xs, lane_ys)
])
if lane_str != '':
out.append(lane_str)
return '\n'.join(out)
def accumulate(self):
loss_lines = [[], [], [], []]
for idx, pred in enumerate(self.predictions):
output_dir = os.path.join(self.output_eval,
os.path.dirname(self.img_names[idx]))
output_filename = os.path.basename(self.img_names[
idx])[:-3] + 'lines.txt'
os.makedirs(output_dir, exist_ok=True)
output = self.get_prediction_string(pred)
# store loss lines
lanes = self.lanes[idx]
if len(lanes) - len(pred) in [1, 2, 3, 4]:
loss_lines[len(lanes) - len(pred) - 1].append(self.img_names[
idx])
with open(os.path.join(output_dir, output_filename),
'w') as out_file:
out_file.write(output)
for i, names in enumerate(loss_lines):
with open(
os.path.join(output_dir, 'loss_{}_lines.txt'.format(i + 1)),
'w') as f:
for name in names:
f.write(name + '\n')
for cate, cate_file in CATEGORYS.items():
result = eval_predictions(
self.output_eval,
self.dataset_dir,
os.path.join(self.dataset_dir, cate_file),
iou_thresholds=[0.5],
official=True)
result = eval_predictions(
self.output_eval,
self.dataset_dir,
self.list_path,
iou_thresholds=np.linspace(0.5, 0.95, 10),
official=True)
self.eval_results['F1@50'] = result[0.5]['F1']
self.eval_results['result'] = result
def update(self, inputs, outputs):
assert len(inputs['img_name']) == len(outputs['lanes'])
self.predictions.extend(outputs['lanes'])
self.img_names.extend(inputs['img_name'])
self.lanes.extend(inputs['lane_line'])
def log(self):
logger.info(self.eval_results)
# abstract method for getting metric results
def get_results(self):
return self.eval_results
......@@ -42,6 +42,7 @@ from . import yolof
from . import pose3d_metro
from . import centertrack
from . import queryinst
from . import clrnet
from .meta_arch import *
from .faster_rcnn import *
......@@ -75,3 +76,4 @@ from .pose3d_metro import *
from .centertrack import *
from .queryinst import *
from .keypoint_petr import *
from .clrnet import *
\ No newline at end of file
from .meta_arch import BaseArch
from ppdet.core.workspace import register, create
from paddle import in_dynamic_mode
__all__ = ['CLRNet']
@register
class CLRNet(BaseArch):
__category__ = 'architecture'
def __init__(self,
backbone="CLRResNet",
neck="CLRFPN",
clr_head="CLRHead",
post_process=None):
super(CLRNet, self).__init__()
self.backbone = backbone
self.neck = neck
self.heads = clr_head
self.post_process = post_process
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# fpn
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
# head
kwargs = {'input_shape': neck.out_shape}
clr_head = create(cfg['clr_head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
'clr_head': clr_head,
}
def _forward(self):
# Backbone
body_feats = self.backbone(self.inputs['image'])
# neck
neck_feats = self.neck(body_feats)
# CRL Head
if self.training:
output = self.heads(neck_feats, self.inputs)
else:
output = self.heads(neck_feats)
output = {'lanes': output}
# TODO: hard code fix as_lanes=False problem in clrnet_head.py "get_lanes" function for static mode
if in_dynamic_mode():
output = self.heads.get_lanes(output['lanes'])
output = {
"lanes": output,
"img_path": self.inputs['full_img_path'],
"img_name": self.inputs['img_name']
}
return output
def get_loss(self):
return self._forward()
def get_pred(self):
return self._forward()
import paddle
import paddle.nn.functional as F
from ppdet.modeling.losses.clrnet_line_iou_loss import line_iou
def distance_cost(predictions, targets, img_w):
"""
repeat predictions and targets to generate all combinations
use the abs distance as the new distance cost
"""
num_priors = predictions.shape[0]
num_targets = targets.shape[0]
predictions = paddle.repeat_interleave(
predictions, num_targets, axis=0)[..., 6:]
targets = paddle.concat(x=num_priors * [targets])[..., 6:]
invalid_masks = (targets < 0) | (targets >= img_w)
lengths = (~invalid_masks).sum(axis=1)
distances = paddle.abs(x=targets - predictions)
distances[invalid_masks] = 0.0
distances = distances.sum(axis=1) / (lengths.cast("float32") + 1e-09)
distances = distances.reshape([num_priors, num_targets])
return distances
def focal_cost(cls_pred, gt_labels, alpha=0.25, gamma=2, eps=1e-12):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value
"""
cls_pred = F.sigmoid(cls_pred)
neg_cost = -(1 - cls_pred + eps).log() * (1 - alpha) * cls_pred.pow(gamma)
pos_cost = -(cls_pred + eps).log() * alpha * (1 - cls_pred).pow(gamma)
cls_cost = pos_cost.index_select(
gt_labels, axis=1) - neg_cost.index_select(
gt_labels, axis=1)
return cls_cost
def dynamic_k_assign(cost, pair_wise_ious):
"""
Assign grouth truths with priors dynamically.
Args:
cost: the assign cost.
pair_wise_ious: iou of grouth truth and priors.
Returns:
prior_idx: the index of assigned prior.
gt_idx: the corresponding ground truth index.
"""
matching_matrix = paddle.zeros_like(cost)
ious_matrix = pair_wise_ious
ious_matrix[ious_matrix < 0] = 0.0
n_candidate_k = 4
topk_ious, _ = paddle.topk(ious_matrix, n_candidate_k, axis=0)
dynamic_ks = paddle.clip(x=topk_ious.sum(0).cast("int32"), min=1)
num_gt = cost.shape[1]
for gt_idx in range(num_gt):
_, pos_idx = paddle.topk(
x=cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[pos_idx, gt_idx] = 1.0
del topk_ious, dynamic_ks, pos_idx
matched_gt = matching_matrix.sum(axis=1)
if (matched_gt > 1).sum() > 0:
matched_gt_indices = paddle.nonzero(matched_gt > 1)[:, 0]
cost_argmin = paddle.argmin(
cost.index_select(matched_gt_indices), axis=1)
matching_matrix[matched_gt_indices][0] *= 0.0
matching_matrix[matched_gt_indices, cost_argmin] = 1.0
prior_idx = matching_matrix.sum(axis=1).nonzero()
gt_idx = matching_matrix[prior_idx].argmax(axis=-1)
return prior_idx.flatten(), gt_idx.flatten()
def cdist_paddle(x1, x2, p=2):
assert x1.shape[1] == x2.shape[1]
B, M = x1.shape
# if p == np.inf:
# dist = np.max(np.abs(x1[:, np.newaxis, :] - x2[np.newaxis, :, :]), axis=-1)
if p == 1:
dist = paddle.sum(
paddle.abs(x1.unsqueeze(axis=1) - x2.unsqueeze(axis=0)), axis=-1)
else:
dist = paddle.pow(paddle.sum(paddle.pow(
paddle.abs(x1.unsqueeze(axis=1) - x2.unsqueeze(axis=0)), p),
axis=-1),
1 / p)
return dist
def assign(predictions,
targets,
img_w,
img_h,
distance_cost_weight=3.0,
cls_cost_weight=1.0):
"""
computes dynamicly matching based on the cost, including cls cost and lane similarity cost
Args:
predictions (Tensor): predictions predicted by each stage, shape: (num_priors, 78)
targets (Tensor): lane targets, shape: (num_targets, 78)
return:
matched_row_inds (Tensor): matched predictions, shape: (num_targets)
matched_col_inds (Tensor): matched targets, shape: (num_targets)
"""
predictions = predictions.detach().clone()
predictions[:, 3] *= img_w - 1
predictions[:, 6:] *= img_w - 1
targets = targets.detach().clone()
distances_score = distance_cost(predictions, targets, img_w)
distances_score = 1 - distances_score / paddle.max(x=distances_score) + 0.01
cls_score = focal_cost(predictions[:, :2], targets[:, 1].cast('int64'))
num_priors = predictions.shape[0]
num_targets = targets.shape[0]
target_start_xys = targets[:, 2:4]
target_start_xys[..., 0] *= (img_h - 1)
prediction_start_xys = predictions[:, 2:4]
prediction_start_xys[..., 0] *= (img_h - 1)
start_xys_score = cdist_paddle(
prediction_start_xys, target_start_xys,
p=2).reshape([num_priors, num_targets])
start_xys_score = 1 - start_xys_score / paddle.max(x=start_xys_score) + 0.01
target_thetas = targets[:, 4].unsqueeze(axis=-1)
theta_score = cdist_paddle(
predictions[:, 4].unsqueeze(axis=-1), target_thetas,
p=1).reshape([num_priors, num_targets]) * 180
theta_score = 1 - theta_score / paddle.max(x=theta_score) + 0.01
cost = -(distances_score * start_xys_score * theta_score
)**2 * distance_cost_weight + cls_score * cls_cost_weight
iou = line_iou(predictions[..., 6:], targets[..., 6:], img_w, aligned=False)
matched_row_inds, matched_col_inds = dynamic_k_assign(cost, iou)
return matched_row_inds, matched_col_inds
......@@ -38,6 +38,7 @@ from . import trans_encoder
from . import focalnet
from . import vit_mae
from . import hgnet_v2
from . import clrnet_resnet
from .vgg import *
from .resnet import *
......@@ -66,3 +67,4 @@ from .focalnet import *
from .vitpose import *
from .vit_mae import *
from .hgnet_v2 import *
from .clrnet_resnet import *
此差异已折叠。
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.modeling.initializer import constant_
from paddle.nn.initializer import KaimingNormal
class ConvModule(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
norm_type='bn',
wtih_act=True):
super(ConvModule, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn', None]
self.with_norm = norm_type is not None
self.wtih_act = wtih_act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=bias,
weight_attr=KaimingNormal())
if self.with_norm:
if norm_type == 'bn':
self.bn = nn.BatchNorm2D(out_channels)
elif norm_type == 'gn':
self.bn = nn.GroupNorm(out_channels, out_channels)
if self.wtih_act:
self.act = nn.ReLU()
def forward(self, inputs):
x = self.conv(inputs)
if self.with_norm:
x = self.bn(x)
if self.wtih_act:
x = self.act(x)
return x
def LinearModule(hidden_dim):
return nn.LayerList(
[nn.Linear(
hidden_dim, hidden_dim, bias_attr=True), nn.ReLU()])
class FeatureResize(nn.Layer):
def __init__(self, size=(10, 25)):
super(FeatureResize, self).__init__()
self.size = size
def forward(self, x):
x = F.interpolate(x, self.size)
return x.flatten(2)
class ROIGather(nn.Layer):
'''
ROIGather module for gather global information
Args:
in_channels: prior feature channels
num_priors: prior numbers we predefined
sample_points: the number of sampled points when we extract feature from line
fc_hidden_dim: the fc output channel
refine_layers: the total number of layers to build refine
'''
def __init__(self,
in_channels,
num_priors,
sample_points,
fc_hidden_dim,
refine_layers,
mid_channels=48):
super(ROIGather, self).__init__()
self.in_channels = in_channels
self.num_priors = num_priors
self.f_key = ConvModule(
in_channels=self.in_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
norm_type='bn')
self.f_query = nn.Sequential(
nn.Conv1D(
in_channels=num_priors,
out_channels=num_priors,
kernel_size=1,
stride=1,
padding=0,
groups=num_priors),
nn.ReLU(), )
self.f_value = nn.Conv2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0)
self.W = nn.Conv1D(
in_channels=num_priors,
out_channels=num_priors,
kernel_size=1,
stride=1,
padding=0,
groups=num_priors)
self.resize = FeatureResize()
constant_(self.W.weight, 0)
constant_(self.W.bias, 0)
self.convs = nn.LayerList()
self.catconv = nn.LayerList()
for i in range(refine_layers):
self.convs.append(
ConvModule(
in_channels,
mid_channels, (9, 1),
padding=(4, 0),
bias=False,
norm_type='bn'))
self.catconv.append(
ConvModule(
mid_channels * (i + 1),
in_channels, (9, 1),
padding=(4, 0),
bias=False,
norm_type='bn'))
self.fc = nn.Linear(
sample_points * fc_hidden_dim, fc_hidden_dim, bias_attr=True)
self.fc_norm = nn.LayerNorm(fc_hidden_dim)
def roi_fea(self, x, layer_index):
feats = []
for i, feature in enumerate(x):
feat_trans = self.convs[i](feature)
feats.append(feat_trans)
cat_feat = paddle.concat(feats, axis=1)
cat_feat = self.catconv[layer_index](cat_feat)
return cat_feat
def forward(self, roi_features, x, layer_index):
'''
Args:
roi_features: prior feature, shape: (Batch * num_priors, prior_feat_channel, sample_point, 1)
x: feature map
layer_index: currently on which layer to refine
Return:
roi: prior features with gathered global information, shape: (Batch, num_priors, fc_hidden_dim)
'''
roi = self.roi_fea(roi_features, layer_index)
# return roi
# print(roi.shape)
# return roi
bs = x.shape[0]
# print(bs)
#roi = roi.contiguous().view(bs * self.num_priors, -1)
roi = roi.reshape([bs * self.num_priors, -1])
# roi = paddle.randn([192,2304])
# return roi
# print(roi)
# print(self.fc)
# print(self.fc.weight)
roi = self.fc(roi)
roi = F.relu(self.fc_norm(roi))
# return roi
#roi = roi.view(bs, self.num_priors, -1)
roi = roi.reshape([bs, self.num_priors, -1])
query = roi
value = self.resize(self.f_value(x)) # (B, C, N) global feature
query = self.f_query(
query) # (B, N, 1) sample context feature from prior roi
key = self.f_key(x)
value = value.transpose(perm=[0, 2, 1])
key = self.resize(key) # (B, C, N) global feature
sim_map = paddle.matmul(query, key)
sim_map = (self.in_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, axis=-1)
context = paddle.matmul(sim_map, value)
context = self.W(context)
roi = roi + F.dropout(context, p=0.1, training=self.training)
return roi
class SegDecoder(nn.Layer):
'''
Optionaly seg decoder
'''
def __init__(self,
image_height,
image_width,
num_class,
prior_feat_channels=64,
refine_layers=3):
super().__init__()
self.dropout = nn.Dropout2D(0.1)
self.conv = nn.Conv2D(prior_feat_channels * refine_layers, num_class, 1)
self.image_height = image_height
self.image_width = image_width
def forward(self, x):
x = self.dropout(x)
x = self.conv(x)
x = F.interpolate(
x,
size=[self.image_height, self.image_width],
mode='bilinear',
align_corners=False)
return x
import paddle.nn as nn
def accuracy(pred, target, topk=1, thresh=None):
"""Calculate accuracy according to the prediction and target.
Args:
pred (torch.Tensor): The model prediction, shape (N, num_class)
target (torch.Tensor): The target of each prediction, shape (N, )
topk (int | tuple[int], optional): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thresh (float, optional): If not None, predictions with scores under
this threshold are considered incorrect. Default to None.
Returns:
float | tuple[float]: If the input ``topk`` is a single integer,
the function will return a single float as accuracy. If
``topk`` is a tuple containing multiple integers, the
function will return a tuple containing accuracies of
each ``topk`` number.
"""
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
return_single = True
else:
return_single = False
maxk = max(topk)
if pred.shape[0] == 0:
accu = [pred.new_tensor(0.) for i in range(len(topk))]
return accu[0] if return_single else accu
assert pred.ndim == 2 and target.ndim == 1
assert pred.shape[0] == target.shape[0]
assert maxk <= pred.shape[1], \
f'maxk {maxk} exceeds pred dimension {pred.shape[1]}'
pred_value, pred_label = pred.topk(maxk, axis=1)
pred_label = pred_label.t() # transpose to shape (maxk, N)
correct = pred_label.equal(target.reshape([1, -1]).expand_as(pred_label))
if thresh is not None:
# Only prediction values larger than thresh are counted as correct
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].reshape([-1]).cast("float32").sum(0,
keepdim=True)
correct_k = correct_k * (100.0 / pred.shape[0])
res.append(correct_k)
return res[0] if return_single else res
class Accuracy(nn.Layer):
def __init__(self, topk=(1, ), thresh=None):
"""Module to calculate the accuracy.
Args:
topk (tuple, optional): The criterion used to calculate the
accuracy. Defaults to (1,).
thresh (float, optional): If not None, predictions with scores
under this threshold are considered incorrect. Default to None.
"""
super().__init__()
self.topk = topk
self.thresh = thresh
def forward(self, pred, target):
"""Forward function to calculate accuracy.
Args:
pred (torch.Tensor): Prediction of models.
target (torch.Tensor): Target for each prediction.
Returns:
tuple[float]: The accuracies under different topk criterions.
"""
return accuracy(pred, target, self.topk, self.thresh)
......@@ -40,6 +40,7 @@ from . import ppyoloe_contrast_head
from . import centertrack_head
from . import sparse_roi_head
from . import vitpose_head
from . import clrnet_head
from .bbox_head import *
from .mask_head import *
......@@ -70,3 +71,4 @@ from .centertrack_head import *
from .sparse_roi_head import *
from .petr_head import *
from .vitpose_head import *
from .clrnet_head import *
\ No newline at end of file
import math
import paddle
import numpy as np
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.initializer import normal_
from ppdet.modeling.lane_utils import Lane
from ppdet.modeling.losses import line_iou
from ppdet.modeling.clrnet_utils import ROIGather, LinearModule, SegDecoder
__all__ = ['CLRHead']
@register
class CLRHead(nn.Layer):
__inject__ = ['loss']
__shared__ = [
'img_w', 'img_h', 'ori_img_h', 'num_classes', 'cut_height',
'num_points', "max_lanes"
]
def __init__(self,
num_points=72,
prior_feat_channels=64,
fc_hidden_dim=64,
num_priors=192,
img_w=800,
img_h=320,
ori_img_h=590,
cut_height=270,
num_classes=5,
num_fc=2,
refine_layers=3,
sample_points=36,
conf_threshold=0.4,
nms_thres=0.5,
max_lanes=4,
loss='CLRNetLoss'):
super(CLRHead, self).__init__()
self.img_w = img_w
self.img_h = img_h
self.n_strips = num_points - 1
self.n_offsets = num_points
self.num_priors = num_priors
self.sample_points = sample_points
self.refine_layers = refine_layers
self.num_classes = num_classes
self.fc_hidden_dim = fc_hidden_dim
self.ori_img_h = ori_img_h
self.cut_height = cut_height
self.conf_threshold = conf_threshold
self.nms_thres = nms_thres
self.max_lanes = max_lanes
self.prior_feat_channels = prior_feat_channels
self.loss = loss
self.register_buffer(
name='sample_x_indexs',
tensor=(paddle.linspace(
start=0, stop=1, num=self.sample_points,
dtype=paddle.float32) * self.n_strips).astype(dtype='int64'))
self.register_buffer(
name='prior_feat_ys',
tensor=paddle.flip(
x=(1 - self.sample_x_indexs.astype('float32') / self.n_strips),
axis=[-1]))
self.register_buffer(
name='prior_ys',
tensor=paddle.linspace(
start=1, stop=0, num=self.n_offsets).astype('float32'))
self.prior_feat_channels = prior_feat_channels
self._init_prior_embeddings()
init_priors, priors_on_featmap = self.generate_priors_from_embeddings()
self.register_buffer(name='priors', tensor=init_priors)
self.register_buffer(name='priors_on_featmap', tensor=priors_on_featmap)
self.seg_decoder = SegDecoder(self.img_h, self.img_w, self.num_classes,
self.prior_feat_channels,
self.refine_layers)
reg_modules = list()
cls_modules = list()
for _ in range(num_fc):
reg_modules += [*LinearModule(self.fc_hidden_dim)]
cls_modules += [*LinearModule(self.fc_hidden_dim)]
self.reg_modules = nn.LayerList(sublayers=reg_modules)
self.cls_modules = nn.LayerList(sublayers=cls_modules)
self.roi_gather = ROIGather(self.prior_feat_channels, self.num_priors,
self.sample_points, self.fc_hidden_dim,
self.refine_layers)
self.reg_layers = nn.Linear(
in_features=self.fc_hidden_dim,
out_features=self.n_offsets + 1 + 2 + 1,
bias_attr=True)
self.cls_layers = nn.Linear(
in_features=self.fc_hidden_dim, out_features=2, bias_attr=True)
self.init_weights()
def init_weights(self):
for m in self.cls_layers.parameters():
normal_(m, mean=0.0, std=0.001)
for m in self.reg_layers.parameters():
normal_(m, mean=0.0, std=0.001)
def pool_prior_features(self, batch_features, num_priors, prior_xs):
"""
pool prior feature from feature map.
Args:
batch_features (Tensor): Input feature maps, shape: (B, C, H, W)
"""
batch_size = batch_features.shape[0]
prior_xs = prior_xs.reshape([batch_size, num_priors, -1, 1])
prior_ys = self.prior_feat_ys.tile(repeat_times=[
batch_size * num_priors
]).reshape([batch_size, num_priors, -1, 1])
prior_xs = prior_xs * 2.0 - 1.0
prior_ys = prior_ys * 2.0 - 1.0
grid = paddle.concat(x=(prior_xs, prior_ys), axis=-1)
feature = F.grid_sample(
x=batch_features, grid=grid,
align_corners=True).transpose(perm=[0, 2, 1, 3])
feature = feature.reshape([
batch_size * num_priors, self.prior_feat_channels,
self.sample_points, 1
])
return feature
def generate_priors_from_embeddings(self):
predictions = self.prior_embeddings.weight
# 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, 72 coordinates, score[0] = negative prob, score[1] = positive prob
priors = paddle.zeros(
(self.num_priors, 2 + 2 + 2 + self.n_offsets),
dtype=predictions.dtype)
priors[:, 2:5] = predictions.clone()
priors[:, 6:] = (
priors[:, 3].unsqueeze(1).clone().tile([1, self.n_offsets]) *
(self.img_w - 1) +
((1 - self.prior_ys.tile([self.num_priors, 1]) -
priors[:, 2].unsqueeze(1).clone().tile([1, self.n_offsets])) *
self.img_h / paddle.tan(x=priors[:, 4].unsqueeze(1).clone().tile(
[1, self.n_offsets]) * math.pi + 1e-05))) / (self.img_w - 1)
priors_on_featmap = paddle.index_select(
priors, 6 + self.sample_x_indexs, axis=-1)
return priors, priors_on_featmap
def _init_prior_embeddings(self):
self.prior_embeddings = nn.Embedding(self.num_priors, 3)
bottom_priors_nums = self.num_priors * 3 // 4
left_priors_nums, _ = self.num_priors // 8, self.num_priors // 8
strip_size = 0.5 / (left_priors_nums // 2 - 1)
bottom_strip_size = 1 / (bottom_priors_nums // 4 + 1)
with paddle.no_grad():
for i in range(left_priors_nums):
self.prior_embeddings.weight[i, 0] = i // 2 * strip_size
self.prior_embeddings.weight[i, 1] = 0.0
self.prior_embeddings.weight[i,
2] = 0.16 if i % 2 == 0 else 0.32
for i in range(left_priors_nums,
left_priors_nums + bottom_priors_nums):
self.prior_embeddings.weight[i, 0] = 0.0
self.prior_embeddings.weight[i, 1] = (
(i - left_priors_nums) // 4 + 1) * bottom_strip_size
self.prior_embeddings.weight[i, 2] = 0.2 * (i % 4 + 1)
for i in range(left_priors_nums + bottom_priors_nums,
self.num_priors):
self.prior_embeddings.weight[i, 0] = (
i - left_priors_nums - bottom_priors_nums) // 2 * strip_size
self.prior_embeddings.weight[i, 1] = 1.0
self.prior_embeddings.weight[i,
2] = 0.68 if i % 2 == 0 else 0.84
def forward(self, x, inputs=None):
"""
Take pyramid features as input to perform Cross Layer Refinement and finally output the prediction lanes.
Each feature is a 4D tensor.
Args:
x: input features (list[Tensor])
Return:
prediction_list: each layer's prediction result
seg: segmentation result for auxiliary loss
"""
batch_features = list(x[len(x) - self.refine_layers:])
batch_features.reverse()
batch_size = batch_features[-1].shape[0]
if self.training:
self.priors, self.priors_on_featmap = self.generate_priors_from_embeddings(
)
priors, priors_on_featmap = self.priors.tile(
[batch_size, 1,
1]), self.priors_on_featmap.tile([batch_size, 1, 1])
predictions_lists = []
prior_features_stages = []
for stage in range(self.refine_layers):
num_priors = priors_on_featmap.shape[1]
prior_xs = paddle.flip(x=priors_on_featmap, axis=[2])
batch_prior_features = self.pool_prior_features(
batch_features[stage], num_priors, prior_xs)
prior_features_stages.append(batch_prior_features)
fc_features = self.roi_gather(prior_features_stages,
batch_features[stage], stage)
# return fc_features
fc_features = fc_features.reshape(
[num_priors, batch_size, -1]).reshape(
[batch_size * num_priors, self.fc_hidden_dim])
cls_features = fc_features.clone()
reg_features = fc_features.clone()
for cls_layer in self.cls_modules:
cls_features = cls_layer(cls_features)
# return cls_features
for reg_layer in self.reg_modules:
reg_features = reg_layer(reg_features)
cls_logits = self.cls_layers(cls_features)
reg = self.reg_layers(reg_features)
cls_logits = cls_logits.reshape(
[batch_size, -1, cls_logits.shape[1]])
reg = reg.reshape([batch_size, -1, reg.shape[1]])
predictions = priors.clone()
predictions[:, :, :2] = cls_logits
predictions[:, :, 2:5] += reg[:, :, :3]
predictions[:, :, 5] = reg[:, :, 3]
def tran_tensor(t):
return t.unsqueeze(axis=2).clone().tile([1, 1, self.n_offsets])
predictions[..., 6:] = (
tran_tensor(predictions[..., 3]) * (self.img_w - 1) +
((1 - self.prior_ys.tile([batch_size, num_priors, 1]) -
tran_tensor(predictions[..., 2])) * self.img_h / paddle.tan(
tran_tensor(predictions[..., 4]) * math.pi + 1e-05))) / (
self.img_w - 1)
prediction_lines = predictions.clone()
predictions[..., 6:] += reg[..., 4:]
predictions_lists.append(predictions)
if stage != self.refine_layers - 1:
priors = prediction_lines.detach().clone()
priors_on_featmap = priors.index_select(
6 + self.sample_x_indexs, axis=-1)
if self.training:
seg = None
seg_features = paddle.concat(
[
F.interpolate(
feature,
size=[
batch_features[-1].shape[2],
batch_features[-1].shape[3]
],
mode='bilinear',
align_corners=False) for feature in batch_features
],
axis=1)
seg = self.seg_decoder(seg_features)
output = {'predictions_lists': predictions_lists, 'seg': seg}
return self.loss(output, inputs)
return predictions_lists[-1]
def predictions_to_pred(self, predictions):
"""
Convert predictions to internal Lane structure for evaluation.
"""
self.prior_ys = paddle.to_tensor(self.prior_ys)
self.prior_ys = self.prior_ys.astype('float64')
lanes = []
for lane in predictions:
lane_xs = lane[6:].clone()
start = min(
max(0, int(round(lane[2].item() * self.n_strips))),
self.n_strips)
length = int(round(lane[5].item()))
end = start + length - 1
end = min(end, len(self.prior_ys) - 1)
if start > 0:
mask = ((lane_xs[:start] >= 0.) &
(lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1]
mask = ~((mask.cumprod()[::-1]).astype(np.bool))
lane_xs[:start][mask] = -2
if end < len(self.prior_ys) - 1:
lane_xs[end + 1:] = -2
lane_ys = self.prior_ys[lane_xs >= 0].clone()
lane_xs = lane_xs[lane_xs >= 0]
lane_xs = lane_xs.flip(axis=0).astype('float64')
lane_ys = lane_ys.flip(axis=0)
lane_ys = (lane_ys *
(self.ori_img_h - self.cut_height) + self.cut_height
) / self.ori_img_h
if len(lane_xs) <= 1:
continue
points = paddle.stack(
x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])),
axis=1).squeeze(axis=2)
lane = Lane(
points=points.cpu().numpy(),
metadata={
'start_x': lane[3],
'start_y': lane[2],
'conf': lane[1]
})
lanes.append(lane)
return lanes
def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k):
"""
NMS for lane detection.
predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77]
scores: paddle.Tensor [num_lanes]
nms_overlap_thresh: float
top_k: int
"""
# sort by scores to get idx
idx = scores.argsort(descending=True)
keep = []
condidates = predictions.clone()
condidates = condidates.index_select(idx)
while len(condidates) > 0:
keep.append(idx[0])
if len(keep) >= top_k or len(condidates) == 1:
break
ious = []
for i in range(1, len(condidates)):
ious.append(1 - line_iou(
condidates[i].unsqueeze(0),
condidates[0].unsqueeze(0),
img_w=self.img_w,
length=15))
ious = paddle.to_tensor(ious)
mask = ious <= nms_overlap_thresh
id = paddle.where(mask == False)[0]
if id.shape[0] == 0:
break
condidates = condidates[1:].index_select(id)
idx = idx[1:].index_select(id)
keep = paddle.stack(keep)
return keep
def get_lanes(self, output, as_lanes=True):
"""
Convert model output to lanes.
"""
softmax = nn.Softmax(axis=1)
decoded = []
for predictions in output:
threshold = self.conf_threshold
scores = softmax(predictions[:, :2])[:, 1]
keep_inds = scores >= threshold
predictions = predictions[keep_inds]
scores = scores[keep_inds]
if predictions.shape[0] == 0:
decoded.append([])
continue
nms_predictions = predictions.detach().clone()
nms_predictions = paddle.concat(
x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
nms_predictions[..., 5:] = nms_predictions[..., 5:] * (
self.img_w - 1)
keep = self.lane_nms(
nms_predictions[..., 5:],
scores,
nms_overlap_thresh=self.nms_thres,
top_k=self.max_lanes)
predictions = predictions.index_select(keep)
if predictions.shape[0] == 0:
decoded.append([])
continue
predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips)
if as_lanes:
pred = self.predictions_to_pred(predictions)
else:
pred = predictions
decoded.append(pred)
return decoded
import os
import cv2
import numpy as np
from scipy.interpolate import InterpolatedUnivariateSpline
class Lane:
def __init__(self, points=None, invalid_value=-2., metadata=None):
super(Lane, self).__init__()
self.curr_iter = 0
self.points = points
self.invalid_value = invalid_value
self.function = InterpolatedUnivariateSpline(
points[:, 1], points[:, 0], k=min(3, len(points) - 1))
self.min_y = points[:, 1].min() - 0.01
self.max_y = points[:, 1].max() + 0.01
self.metadata = metadata or {}
def __repr__(self):
return '[Lane]\n' + str(self.points) + '\n[/Lane]'
def __call__(self, lane_ys):
lane_xs = self.function(lane_ys)
lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y
)] = self.invalid_value
return lane_xs
def to_array(self, sample_y_range, img_w, img_h):
self.sample_y = range(sample_y_range[0], sample_y_range[1],
sample_y_range[2])
sample_y = self.sample_y
img_w, img_h = img_w, img_h
ys = np.array(sample_y) / float(img_h)
xs = self(ys)
valid_mask = (xs >= 0) & (xs < 1)
lane_xs = xs[valid_mask] * img_w
lane_ys = ys[valid_mask] * img_h
lane = np.concatenate(
(lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1)
return lane
def __iter__(self):
return self
def __next__(self):
if self.curr_iter < len(self.points):
self.curr_iter += 1
return self.points[self.curr_iter - 1]
self.curr_iter = 0
raise StopIteration
COLORS = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(128, 255, 0),
(255, 128, 0),
(128, 0, 255),
(255, 0, 128),
(0, 128, 255),
(0, 255, 128),
(128, 255, 255),
(255, 128, 255),
(255, 255, 128),
(60, 180, 0),
(180, 60, 0),
(0, 60, 180),
(0, 180, 60),
(60, 0, 180),
(180, 0, 60),
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(128, 255, 0),
(255, 128, 0),
(128, 0, 255),
]
def imshow_lanes(img, lanes, show=False, out_file=None, width=4):
lanes_xys = []
for _, lane in enumerate(lanes):
xys = []
for x, y in lane:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
xys.append((x, y))
lanes_xys.append(xys)
lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0)
for idx, xys in enumerate(lanes_xys):
for i in range(1, len(xys)):
cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
if show:
cv2.imshow('view', img)
cv2.waitKey(0)
if out_file:
if not os.path.exists(os.path.dirname(out_file)):
os.makedirs(os.path.dirname(out_file))
cv2.imwrite(out_file, img)
......@@ -31,6 +31,8 @@ from . import probiou_loss
from . import cot_loss
from . import supcontrast
from . import queryinst_loss
from . import clrnet_loss
from . import clrnet_line_iou_loss
from .yolo_loss import *
from .iou_aware_loss import *
......@@ -52,3 +54,5 @@ from .probiou_loss import *
from .cot_loss import *
from .supcontrast import *
from .queryinst_loss import *
from .clrnet_loss import *
from .clrnet_line_iou_loss import *
\ No newline at end of file
import paddle
def line_iou(pred, target, img_w, length=15, aligned=True):
'''
Calculate the line iou value between predictions and targets
Args:
pred: lane predictions, shape: (num_pred, 72)
target: ground truth, shape: (num_target, 72)
img_w: image width
length: extended radius
aligned: True for iou loss calculation, False for pair-wise ious in assign
'''
px1 = pred - length
px2 = pred + length
tx1 = target - length
tx2 = target + length
if aligned:
invalid_mask = target
ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1)
union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1)
else:
num_pred = pred.shape[0]
invalid_mask = target.tile([num_pred, 1, 1])
ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum(
px1[:, None, :], tx1[None, ...]))
union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) -
paddle.minimum(px1[:, None, :], tx1[None, ...]))
invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w)
ovr[invalid_masks] = 0.
union[invalid_masks] = 0.
iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9)
return iou
def liou_loss(pred, target, img_w, length=15):
return (1 - line_iou(pred, target, img_w, length)).mean()
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling.clrnet_utils import accuracy
from ppdet.modeling.assigners.clrnet_assigner import assign
from ppdet.modeling.losses.clrnet_line_iou_loss import liou_loss
__all__ = ['CLRNetLoss']
class SoftmaxFocalLoss(nn.Layer):
def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
super(SoftmaxFocalLoss, self).__init__()
self.gamma = gamma
self.nll = nn.NLLLoss(ignore_index=ignore_lb)
def forward(self, logits, labels):
scores = F.softmax(logits, dim=1)
factor = paddle.pow(1. - scores, self.gamma)
log_score = F.log_softmax(logits, dim=1)
log_score = factor * log_score
loss = self.nll(log_score, labels)
return loss
def focal_loss(input: paddle.Tensor,
target: paddle.Tensor,
alpha: float,
gamma: float=2.0,
reduction: str='none',
eps: float=1e-8) -> paddle.Tensor:
r"""Function that computes Focal loss.
See :class:`~kornia.losses.FocalLoss` for details.
"""
if not paddle.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(input)))
if not len(input.shape) >= 2:
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}".format(
input.shape))
if input.shape[0] != target.shape[0]:
raise ValueError(
'Expected input batch_size ({}) to match target batch_size ({}).'.
format(input.shape[0], target.shape[0]))
n = input.shape[0]
out_size = (n, ) + tuple(input.shape[2:])
if target.shape[1:] != input.shape[2:]:
raise ValueError('Expected target size {}, got {}'.format(out_size,
target.shape))
if (isinstance(input.place, paddle.CUDAPlace) and
isinstance(target.place, paddle.CPUPlace)) | (isinstance(
input.place, paddle.CPUPlace) and isinstance(target.place,
paddle.CUDAPlace)):
raise ValueError(
"input and target must be in the same device. Got: {} and {}".
format(input.place, target.place))
# compute softmax over the classes axis
input_soft: paddle.Tensor = F.softmax(input, axis=1) + eps
# create the labels one hot tensor
target_one_hot: paddle.Tensor = paddle.to_tensor(
F.one_hot(
target, num_classes=input.shape[1]).cast(input.dtype),
place=input.place)
# compute the actual focal loss
weight = paddle.pow(-input_soft + 1., gamma)
focal = -alpha * weight * paddle.log(input_soft)
loss_tmp = paddle.sum(target_one_hot * focal, axis=1)
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = paddle.mean(loss_tmp)
elif reduction == 'sum':
loss = paddle.sum(loss_tmp)
else:
raise NotImplementedError("Invalid reduction mode: {}".format(
reduction))
return loss
class FocalLoss(nn.Layer):
r"""Criterion that computes Focal loss.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
>>> N = 5 # num_classes
>>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
>>> loss = kornia.losses.FocalLoss(**kwargs)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
def __init__(self, alpha: float, gamma: float=2.0,
reduction: str='none') -> None:
super(FocalLoss, self).__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: str = reduction
self.eps: float = 1e-6
def forward( # type: ignore
self, input: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:
return focal_loss(input, target, self.alpha, self.gamma, self.reduction,
self.eps)
@register
class CLRNetLoss(nn.Layer):
__shared__ = ['img_w', 'img_h', 'num_classes', 'num_points']
def __init__(self,
cls_loss_weight=2.0,
xyt_loss_weight=0.2,
iou_loss_weight=2.0,
seg_loss_weight=1.0,
refine_layers=3,
num_points=72,
img_w=800,
img_h=320,
num_classes=5,
ignore_label=255,
bg_weight=0.4):
super(CLRNetLoss, self).__init__()
self.cls_loss_weight = cls_loss_weight
self.xyt_loss_weight = xyt_loss_weight
self.iou_loss_weight = iou_loss_weight
self.seg_loss_weight = seg_loss_weight
self.refine_layers = refine_layers
self.img_w = img_w
self.img_h = img_h
self.n_strips = num_points - 1
self.num_classes = num_classes
self.ignore_label = ignore_label
weights = paddle.ones(shape=[self.num_classes])
weights[0] = bg_weight
self.criterion = nn.NLLLoss(
ignore_index=self.ignore_label, weight=weights)
def forward(self, output, batch):
predictions_lists = output['predictions_lists']
targets = batch['lane_line'].clone()
cls_criterion = FocalLoss(alpha=0.25, gamma=2.0)
cls_loss = paddle.to_tensor(0.0)
reg_xytl_loss = paddle.to_tensor(0.0)
iou_loss = paddle.to_tensor(0.0)
cls_acc = []
cls_acc_stage = []
for stage in range(self.refine_layers):
predictions_list = predictions_lists[stage]
for predictions, target in zip(predictions_list, targets):
target = target[target[:, 1] == 1]
if len(target) == 0:
# If there are no targets, all predictions have to be negatives (i.e., 0 confidence)
cls_target = paddle.zeros(
[predictions.shape[0]], dtype='int64')
cls_pred = predictions[:, :2]
cls_loss = cls_loss + cls_criterion(cls_pred,
cls_target).sum()
continue
with paddle.no_grad():
matched_row_inds, matched_col_inds = assign(
predictions, target, self.img_w, self.img_h)
# classification targets
cls_target = paddle.zeros([predictions.shape[0]], dtype='int64')
cls_target[matched_row_inds] = 1
cls_pred = predictions[:, :2]
# regression targets -> [start_y, start_x, theta] (all transformed to absolute values), only on matched pairs
reg_yxtl = predictions.index_select(matched_row_inds)[..., 2:6]
reg_yxtl[:, 0] *= self.n_strips
reg_yxtl[:, 1] *= (self.img_w - 1)
reg_yxtl[:, 2] *= 180
reg_yxtl[:, 3] *= self.n_strips
target_yxtl = target.index_select(matched_col_inds)[..., 2:
6].clone()
# regression targets -> S coordinates (all transformed to absolute values)
reg_pred = predictions.index_select(matched_row_inds)[..., 6:]
reg_pred *= (self.img_w - 1)
reg_targets = target.index_select(matched_col_inds)[...,
6:].clone()
with paddle.no_grad():
predictions_starts = paddle.clip(
(predictions.index_select(matched_row_inds)[..., 2] *
self.n_strips).round().cast("int64"),
min=0,
max=self.
n_strips) # ensure the predictions starts is valid
target_starts = (
target.index_select(matched_col_inds)[..., 2] *
self.n_strips).round().cast("int64")
target_yxtl[:, -1] -= (
predictions_starts - target_starts) # reg length
# Loss calculation
cls_loss = cls_loss + cls_criterion(
cls_pred, cls_target).sum() / target.shape[0]
target_yxtl[:, 0] *= self.n_strips
target_yxtl[:, 2] *= 180
reg_xytl_loss = reg_xytl_loss + F.smooth_l1_loss(
input=reg_yxtl, label=target_yxtl, reduction='none').mean()
iou_loss = iou_loss + liou_loss(
reg_pred, reg_targets, self.img_w, length=15)
cls_accuracy = accuracy(cls_pred, cls_target)
cls_acc_stage.append(cls_accuracy)
cls_acc.append(sum(cls_acc_stage) / (len(cls_acc_stage) + 1e-5))
# extra segmentation loss
seg_loss = self.criterion(
F.log_softmax(
output['seg'], axis=1), batch['seg'].cast('int64'))
cls_loss /= (len(targets) * self.refine_layers)
reg_xytl_loss /= (len(targets) * self.refine_layers)
iou_loss /= (len(targets) * self.refine_layers)
loss = cls_loss * self.cls_loss_weight \
+ reg_xytl_loss * self.xyt_loss_weight \
+ seg_loss * self.seg_loss_weight \
+ iou_loss * self.iou_loss_weight
return_value = {
'loss': loss,
'cls_loss': cls_loss * self.cls_loss_weight,
'reg_xytl_loss': reg_xytl_loss * self.xyt_loss_weight,
'seg_loss': seg_loss * self.seg_loss_weight,
'iou_loss': iou_loss * self.iou_loss_weight
}
for i in range(self.refine_layers):
if not isinstance(cls_acc[i], paddle.Tensor):
cls_acc[i] = paddle.to_tensor(cls_acc[i])
return_value['stage_{}_acc'.format(i)] = cls_acc[i]
return return_value
......@@ -23,6 +23,7 @@ from . import es_pan
from . import lc_pan
from . import custom_pan
from . import dilated_encoder
from . import clrnet_fpn
from .fpn import *
from .yolo_fpn import *
......@@ -37,3 +38,4 @@ from .lc_pan import *
from .custom_pan import *
from .dilated_encoder import *
from .channel_mapper import *
from .clrnet_fpn import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import XavierUniform
from ppdet.modeling.initializer import kaiming_normal_, constant_
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.shape_spec import ShapeSpec
__all__ = ['CLRFPN']
@register
@serializable
class CLRFPN(nn.Layer):
"""
Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
Args:
in_channels (list[int]): input channels of each level which can be
derived from the output shape of backbone by from_config
out_channel (int): output channel of each level
spatial_scales (list[float]): the spatial scales between input feature
maps and original input image which can be derived from the output
shape of backbone by from_config
has_extra_convs (bool): whether to add extra conv to the last level.
default False
extra_stage (int): the number of extra stages added to the last level.
default 1
use_c5 (bool): Whether to use c5 as the input of extra stage,
otherwise p5 is used. default True
norm_type (string|None): The normalization type in FPN module. If
norm_type is None, norm will not be used after conv and if
norm_type is string, bn, gn, sync_bn are available. default None
norm_decay (float): weight decay for normalization layer weights.
default 0.
freeze_norm (bool): whether to freeze normalization layer.
default False
relu_before_extra_convs (bool): whether to add relu before extra convs.
default False
"""
def __init__(self,
in_channels,
out_channel,
spatial_scales=[0.25, 0.125, 0.0625, 0.03125],
has_extra_convs=False,
extra_stage=1,
use_c5=True,
norm_type=None,
norm_decay=0.,
freeze_norm=False,
relu_before_extra_convs=True):
super(CLRFPN, self).__init__()
self.out_channel = out_channel
for s in range(extra_stage):
spatial_scales = spatial_scales + [spatial_scales[-1] / 2.]
self.spatial_scales = spatial_scales
self.has_extra_convs = has_extra_convs
self.extra_stage = extra_stage
self.use_c5 = use_c5
self.relu_before_extra_convs = relu_before_extra_convs
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.in_channels = in_channels
self.lateral_convs = []
self.fpn_convs = []
fan = out_channel * 3 * 3
# stage index 0,1,2,3 stands for res2,res3,res4,res5 on ResNet Backbone
# 0 <= st_stage < ed_stage <= 3
st_stage = 4 - len(in_channels)
ed_stage = st_stage + len(in_channels) - 1
for i in range(st_stage, ed_stage + 1):
# if i == 3:
# lateral_name = 'fpn_inner_res5_sum'
# else:
# lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2)
lateral_name = "lateral_convs.{}.conv".format(i - 1)
in_c = in_channels[i - st_stage]
if self.norm_type is not None:
lateral = self.add_sublayer(
lateral_name,
ConvNormLayer(
ch_in=in_c,
ch_out=out_channel,
filter_size=1,
stride=1,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
initializer=XavierUniform(fan_out=in_c)))
else:
lateral = self.add_sublayer(
lateral_name,
nn.Conv2D(
in_channels=in_c,
out_channels=out_channel,
kernel_size=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=in_c))))
self.lateral_convs.append(lateral)
fpn_name = "fpn_convs.{}.conv".format(i - 1)
if self.norm_type is not None:
fpn_conv = self.add_sublayer(
fpn_name,
ConvNormLayer(
ch_in=out_channel,
ch_out=out_channel,
filter_size=3,
stride=1,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
initializer=XavierUniform(fan_out=fan)))
else:
fpn_conv = self.add_sublayer(
fpn_name,
nn.Conv2D(
in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan))))
self.fpn_convs.append(fpn_conv)
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
if self.has_extra_convs:
for i in range(self.extra_stage):
lvl = ed_stage + 1 + i
if i == 0 and self.use_c5:
in_c = in_channels[-1]
else:
in_c = out_channel
extra_fpn_name = 'fpn_{}'.format(lvl + 2)
if self.norm_type is not None:
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
ConvNormLayer(
ch_in=in_c,
ch_out=out_channel,
filter_size=3,
stride=2,
norm_type=self.norm_type,
norm_decay=self.norm_decay,
freeze_norm=self.freeze_norm,
initializer=XavierUniform(fan_out=fan)))
else:
extra_fpn_conv = self.add_sublayer(
extra_fpn_name,
nn.Conv2D(
in_channels=in_c,
out_channels=out_channel,
kernel_size=3,
stride=2,
padding=1,
weight_attr=ParamAttr(
initializer=XavierUniform(fan_out=fan))))
self.fpn_convs.append(extra_fpn_conv)
self.init_weights()
def init_weights(self):
for m in self.lateral_convs:
if isinstance(m, (nn.Conv1D, nn.Conv2D)):
kaiming_normal_(
m.weight, a=0, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
constant_(m.bias, value=0.)
elif isinstance(m, (nn.BatchNorm1D, nn.BatchNorm2D)):
constant_(m.weight, value=1)
constant_(m.bias, value=0)
for m in self.fpn_convs:
if isinstance(m, (nn.Conv1D, nn.Conv2D)):
kaiming_normal_(
m.weight, a=0, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
constant_(m.bias, value=0.)
elif isinstance(m, (nn.BatchNorm1D, nn.BatchNorm2D)):
constant_(m.weight, value=1)
constant_(m.bias, value=0)
@classmethod
def from_config(cls, cfg, input_shape):
return {}
def forward(self, body_feats):
laterals = []
if len(body_feats) > len(self.in_channels):
for _ in range(len(body_feats) - len(self.in_channels)):
del body_feats[0]
num_levels = len(body_feats)
# print("body_feats",num_levels)
for i in range(num_levels):
laterals.append(self.lateral_convs[i](body_feats[i]))
for i in range(1, num_levels):
lvl = num_levels - i
upsample = F.interpolate(
laterals[lvl],
scale_factor=2.,
mode='nearest', )
laterals[lvl - 1] += upsample
fpn_output = []
for lvl in range(num_levels):
fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
if self.extra_stage > 0:
# use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
if not self.has_extra_convs:
assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs'
fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2))
# add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
else:
if self.use_c5:
extra_source = body_feats[-1]
else:
extra_source = fpn_output[-1]
fpn_output.append(self.fpn_convs[num_levels](extra_source))
for i in range(1, self.extra_stage):
if self.relu_before_extra_convs:
fpn_output.append(self.fpn_convs[num_levels + i](F.relu(
fpn_output[-1])))
else:
fpn_output.append(self.fpn_convs[num_levels + i](
fpn_output[-1]))
return fpn_output
@property
def out_shape(self):
return [
ShapeSpec(
channels=self.out_channel, stride=1. / s)
for s in self.spatial_scales
]
......@@ -101,7 +101,8 @@ DATASETS = {
'8a3a353c2c54a2284ad7d2780b65f6a6', ), ], ['annotations', 'images']),
'coco_ce': ([(
'https://paddledet.bj.bcebos.com/data/coco_ce.tar',
'eadd1b79bc2f069f2744b1dd4e0c0329', ), ], [])
'eadd1b79bc2f069f2744b1dd4e0c0329', ), ], []),
'culane': ([('https://bj.bcebos.com/v1/paddledet/data/culane.tar', None, ), ], [])
}
DOWNLOAD_DATASETS_LIST = DATASETS.keys()
......
......@@ -18,3 +18,6 @@ sklearn==0.0
# for vehicleplate in deploy/pipeline/ppvehicle
pyclipper
# for culane data augumetation
imgaug>=0.4.0
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# add python path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
# ignore warning log
import warnings
warnings.filterwarnings('ignore')
import glob
import ast
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Trainer
from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config
from ppdet.utils.cli import ArgsParser, merge_args
from ppdet.slim import build_slim_model
from ppdet.utils.logger import setup_logger
logger = setup_logger('train')
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--save_results",
type=bool,
default=False,
help="Whether to save inference results to output_dir.")
parser.add_argument(
"--visualize",
type=ast.literal_eval,
default=True,
help="Whether to save visualize results to output_dir.")
args = parser.parse_args()
return args
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def run(FLAGS, cfg):
# build trainer
trainer = Trainer(cfg, mode='test')
# load weights
trainer.load_weights(cfg.weights)
# get inference images
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
trainer.predict_culane(
images,
output_dir=FLAGS.output_dir,
save_results=FLAGS.save_results,
visualize=FLAGS.visualize)
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_args(cfg, FLAGS)
merge_config(FLAGS.opt)
# disable npu in config by default
if 'use_npu' not in cfg:
cfg.use_npu = False
# disable xpu in config by default
if 'use_xpu' not in cfg:
cfg.use_xpu = False
if 'use_gpu' not in cfg:
cfg.use_gpu = False
# disable mlu in config by default
if 'use_mlu' not in cfg:
cfg.use_mlu = False
if cfg.use_gpu:
place = paddle.set_device('gpu')
elif cfg.use_npu:
place = paddle.set_device('npu')
elif cfg.use_xpu:
place = paddle.set_device('xpu')
elif cfg.use_mlu:
place = paddle.set_device('mlu')
else:
place = paddle.set_device('cpu')
check_config(cfg)
check_gpu(cfg.use_gpu)
check_npu(cfg.use_npu)
check_xpu(cfg.use_xpu)
check_mlu(cfg.use_mlu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册