diff --git a/configs/clrnet/README.cn.md b/configs/clrnet/README.cn.md new file mode 100644 index 0000000000000000000000000000000000000000..422709e43297e1f2b4746e52d1dbeaa5cef59d3f --- /dev/null +++ b/configs/clrnet/README.cn.md @@ -0,0 +1,68 @@ +简体中文 | [English](README.md) + +# CLRNet (CLRNet: Cross Layer Refinement Network for Lane Detection) + +## 目录 +- [简介](#简介) +- [模型库](#模型库) +- [引用](#引用) + +## 介绍 + +[CLRNet](https://arxiv.org/abs/2203.10350)是一个车道线检测模型。CLRNet模型设计了车道线检测的直线先验轨迹,车道线iou以及nms方法,融合提取车道线轨迹的上下文高层特征与底层特征,利用FPN多尺度进行refine,在车道线检测相关数据集取得了SOTA的性能。 + +## 模型库 + +### CLRNet在CUlane上结果 + +| 骨架网络 | mF1 | F1@50 | F1@75 | 下载链接 | 配置文件 |训练日志| +| :--------------| :------- | :----: | :------: | :----: |:-----: |:-----: | +| ResNet-18 | 54.98 | 79.46 | 62.10 | [下载链接](https://paddledet.bj.bcebos.com/models/clrnet_resnet18_culane.pdparams) | [配置文件](./clrnet_resnet18_culane.yml) |[训练日志](https://bj.bcebos.com/v1/paddledet/logs/train_clrnet_r18_15_culane.log)| + +### 数据集下载 +下载[CULane数据集](https://xingangpan.github.io/projects/CULane.html)并解压到`dataset/culane`目录。 + +您的数据集目录结构如下: +```shell +culane/driver_xx_xxframe # data folders x6 +culane/laneseg_label_w16 # lane segmentation labels +culane/list # data lists +``` +如果您使用百度云链接下载,注意确保`driver_23_30frame_part1.tar.gz`和`driver_23_30frame_part2.tar.gz`解压后的文件都在`driver_23_30frame`目录下。 + +现已将用于测试的小数据集上传到PaddleDetection,可通过运行训练脚本,自动下载并解压数据,如需复现结果请下载链接中的全量数据集训练。 + +### 训练 +- GPU单卡训练 +```shell +python tools/train.py -c configs/clrnet/clr_resnet18_culane.yml +``` +- GPU多卡训练 +```shell +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/clrnet/clr_resnet18_culane.yml +``` + +### 评估 +```shell +python tools/eval.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams +``` + +### 预测 +```shell +python tools/infer_culane.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams --infer_img=demo/lane00000.jpg +``` + +注意:预测功能暂不支持模型静态图推理部署。 + +## 引用 +``` +@InProceedings{Zheng_2022_CVPR, + author = {Zheng, Tu and Huang, Yifei and Liu, Yang and Tang, Wenjian and Yang, Zheng and Cai, Deng and He, Xiaofei}, + title = {CLRNet: Cross Layer Refinement Network for Lane Detection}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2022}, + pages = {898-907} +} +``` diff --git a/configs/clrnet/README.md b/configs/clrnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f61b0c86c0df864d2cb1cd9a09e631638cda100b --- /dev/null +++ b/configs/clrnet/README.md @@ -0,0 +1,68 @@ +English | [简体中文](README_cn.md) + +# CLRNet (CLRNet: Cross Layer Refinement Network for Lane Detection) + +## Table of Contents +- [Introduction](#Introduction) +- [Model Zoo](#Model_Zoo) +- [Citations](#Citations) + +## Introduction + +[CLRNet](https://arxiv.org/abs/2203.10350) is a lane detection model. The CLRNet model is designed with line prior for lane detection, line iou loss as well as nms method, fused to extract contextual high-level features of lane line with low-level features, and refined by FPN multi-scale. Finally, the model achieved SOTA performance in lane detection datasets. + +## Model Zoo + +### CLRNet Results on CULane dataset + +| backbone | mF1 | F1@50 | F1@75 | download | config | +| :--------------| :------- | :----: | :------: | :----: |:-----: | +| ResNet-18 | 54.98 | 79.46 | 62.10 | [model](https://paddledet.bj.bcebos.com/models/clrnet_resnet18_culane.pdparams) | [config](./clrnet_resnet18_culane.yml) | + +### Download +Download [CULane](https://xingangpan.github.io/projects/CULane.html). Then extract them to `dataset/culane`. + +For CULane, you should have structure like this: +```shell +culane/driver_xx_xxframe # data folders x6 +culane/laneseg_label_w16 # lane segmentation labels +culane/list # data lists +``` +If you use Baidu Cloud, make sure that images in `driver_23_30frame_part1.tar.gz` and `driver_23_30frame_part2.tar.gz` are located in one folder `driver_23_30frame` instead of two seperate folders after you decompress them. + +Now we have uploaded a small subset of CULane dataset to PaddleDetection for code checking. You can simply run the training script below to download it automatically. If you want to implement the results, you need to download the full dataset at th link for training. + +### Training +- single GPU +```shell +python tools/train.py -c configs/clrnet/clr_resnet18_culane.yml +``` +- multi GPU +```shell +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/clrnet/clr_resnet18_culane.yml +``` + +### Evaluation +```shell +python tools/eval.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams +``` + +### Inference +```shell +python tools/infer_culane.py -c configs/clrnet/clr_resnet18_culane.yml -o weights=output/clr_resnet18_culane/model_final.pdparams --infer_img=demo/lane00000.jpg +``` + +Notice: The inference phase does not support static model graph deploy at present. + +## Citations +``` +@InProceedings{Zheng_2022_CVPR, + author = {Zheng, Tu and Huang, Yifei and Liu, Yang and Tang, Wenjian and Yang, Zheng and Cai, Deng and He, Xiaofei}, + title = {CLRNet: Cross Layer Refinement Network for Lane Detection}, + booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2022}, + pages = {898-907} +} +``` diff --git a/configs/clrnet/_base_/clrnet_r18_fpn.yml b/configs/clrnet/_base_/clrnet_r18_fpn.yml new file mode 100644 index 0000000000000000000000000000000000000000..5b109814c8a799fec9ad0de58ead196f5f6117b8 --- /dev/null +++ b/configs/clrnet/_base_/clrnet_r18_fpn.yml @@ -0,0 +1,41 @@ +architecture: CLRNet + +CLRNet: + backbone: CLRResNet + neck: CLRFPN + clr_head: CLRHead + +CLRResNet: + resnet: 'resnet18' + pretrained: True + +CLRFPN: + in_channels: [128,256,512] + out_channel: 64 + extra_stage: 0 + +CLRHead: + prior_feat_channels: 64 + fc_hidden_dim: 64 + num_priors: 192 + num_fc: 2 + refine_layers: 3 + sample_points: 36 + loss: CLRNetLoss + conf_threshold: 0.4 + nms_thres: 0.8 + +CLRNetLoss: + cls_loss_weight : 2.0 + xyt_loss_weight : 0.2 + iou_loss_weight : 2.0 + seg_loss_weight : 1.0 + refine_layers : 3 + ignore_label: 255 + bg_weight: 0.4 + +# for visualize lane detection results +sample_y: + start: 589 + end: 230 + step: -20 diff --git a/configs/clrnet/_base_/clrnet_reader.yml b/configs/clrnet/_base_/clrnet_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..b5eb77daed1fdfeb2d637dd9a3060674f919c08c --- /dev/null +++ b/configs/clrnet/_base_/clrnet_reader.yml @@ -0,0 +1,37 @@ +worker_num: 10 + +img_h: &img_h 320 +img_w: &img_w 800 +ori_img_h: &ori_img_h 590 +ori_img_w: &ori_img_w 1640 +num_points: &num_points 72 +max_lanes: &max_lanes 4 + +TrainReader: + batch_size: 24 + batch_transforms: + - CULaneTrainProcess: {img_h: *img_h, img_w: *img_w} + - CULaneDataProcess: {num_points: *num_points, max_lanes: *max_lanes, img_w: *img_w, img_h: *img_h} + shuffle: True + drop_last: False + + + + +EvalReader: + batch_size: 24 + batch_transforms: + - CULaneResize: {prob: 1.0, img_h: *img_h, img_w: *img_w} + - CULaneDataProcess: {num_points: *num_points, max_lanes: *max_lanes, img_w: *img_w, img_h: *img_h} + shuffle: False + drop_last: False + + + +TestReader: + batch_size: 24 + batch_transforms: + - CULaneResize: {prob: 1.0, img_h: *img_h, img_w: *img_w} + - CULaneDataProcess: {num_points: *num_points, max_lanes: *max_lanes, img_w: *img_w, img_h: *img_h} + shuffle: False + drop_last: False diff --git a/configs/clrnet/_base_/optimizer_1x.yml b/configs/clrnet/_base_/optimizer_1x.yml new file mode 100644 index 0000000000000000000000000000000000000000..f35407e1eddfc5cc38c8ba6aea1bdaf5d6df7da1 --- /dev/null +++ b/configs/clrnet/_base_/optimizer_1x.yml @@ -0,0 +1,14 @@ +epoch: 15 +snapshot_epoch: 5 + +LearningRate: + base_lr: 0.6e-3 + schedulers: + - !CosineDecay + max_epochs: 15 + use_warmup: False + +OptimizerBuilder: + regularizer: False + optimizer: + type: AdamW diff --git a/configs/clrnet/clrnet_resnet18_culane.yml b/configs/clrnet/clrnet_resnet18_culane.yml new file mode 100644 index 0000000000000000000000000000000000000000..f7e7acd34f8afa4883a5fa5840f571c001864909 --- /dev/null +++ b/configs/clrnet/clrnet_resnet18_culane.yml @@ -0,0 +1,9 @@ +_BASE_: [ + '../datasets/culane.yml', + '_base_/clrnet_reader.yml', + '_base_/clrnet_r18_fpn.yml', + '_base_/optimizer_1x.yml', + '../runtime.yml' +] + +weights: output/clr_resnet18_culane/model_final diff --git a/configs/datasets/culane.yml b/configs/datasets/culane.yml new file mode 100644 index 0000000000000000000000000000000000000000..79e59e3ebd56134eff1aa76282812496549fd995 --- /dev/null +++ b/configs/datasets/culane.yml @@ -0,0 +1,28 @@ +metric: CULaneMetric +num_classes: 5 # 4 lanes + background + +cut_height: &cut_height 270 +dataset_dir: &dataset_dir dataset/culane + +TrainDataset: + name: CULaneDataSet + dataset_dir: *dataset_dir + list_path: 'list/train_gt.txt' + split: train + cut_height: *cut_height + + +EvalDataset: + name: CULaneDataSet + dataset_dir: *dataset_dir + list_path: 'list/test.txt' + split: test + cut_height: *cut_height + + +TestDataset: + name: CULaneDataSet + dataset_dir: *dataset_dir + list_path: 'list/test.txt' + split: test + cut_height: *cut_height diff --git a/demo/lane00000.jpg b/demo/lane00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..01f3d1db1fee6eeb8ceb9ee084ebd4a666544061 Binary files /dev/null and b/demo/lane00000.jpg differ diff --git a/deploy/python/clrnet_postprocess.py b/deploy/python/clrnet_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..3648f33881349b4e299efa8e8224210da7299c54 --- /dev/null +++ b/deploy/python/clrnet_postprocess.py @@ -0,0 +1,180 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn +from scipy.special import softmax +from ppdet.modeling.lane_utils import Lane +from ppdet.modeling.losses import line_iou + + +class CLRNetPostProcess(object): + """ + Args: + input_shape (int): network input image size + ori_shape (int): ori image shape of before padding + scale_factor (float): scale factor of ori image + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres, + max_lanes, num_points): + self.img_w = img_w + self.conf_threshold = conf_threshold + self.nms_thres = nms_thres + self.max_lanes = max_lanes + self.num_points = num_points + self.n_strips = num_points - 1 + self.n_offsets = num_points + self.ori_img_h = ori_img_h + self.cut_height = cut_height + + self.prior_ys = paddle.linspace( + start=1, stop=0, num=self.n_offsets).astype('float64') + + def predictions_to_pred(self, predictions): + """ + Convert predictions to internal Lane structure for evaluation. + """ + lanes = [] + for lane in predictions: + lane_xs = lane[6:].clone() + start = min( + max(0, int(round(lane[2].item() * self.n_strips))), + self.n_strips) + length = int(round(lane[5].item())) + end = start + length - 1 + end = min(end, len(self.prior_ys) - 1) + if start > 0: + mask = ((lane_xs[:start] >= 0.) & + (lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1] + mask = ~((mask.cumprod()[::-1]).astype(np.bool)) + lane_xs[:start][mask] = -2 + if end < len(self.prior_ys) - 1: + lane_xs[end + 1:] = -2 + + lane_ys = self.prior_ys[lane_xs >= 0].clone() + lane_xs = lane_xs[lane_xs >= 0] + lane_xs = lane_xs.flip(axis=0).astype('float64') + lane_ys = lane_ys.flip(axis=0) + + lane_ys = (lane_ys * + (self.ori_img_h - self.cut_height) + self.cut_height + ) / self.ori_img_h + if len(lane_xs) <= 1: + continue + points = paddle.stack( + x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])), + axis=1).squeeze(axis=2) + lane = Lane( + points=points.cpu().numpy(), + metadata={ + 'start_x': lane[3], + 'start_y': lane[2], + 'conf': lane[1] + }) + lanes.append(lane) + return lanes + + def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k): + """ + NMS for lane detection. + predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77] + scores: paddle.Tensor [num_lanes] + nms_overlap_thresh: float + top_k: int + """ + # sort by scores to get idx + idx = scores.argsort(descending=True) + keep = [] + + condidates = predictions.clone() + condidates = condidates.index_select(idx) + + while len(condidates) > 0: + keep.append(idx[0]) + if len(keep) >= top_k or len(condidates) == 1: + break + + ious = [] + for i in range(1, len(condidates)): + ious.append(1 - line_iou( + condidates[i].unsqueeze(0), + condidates[0].unsqueeze(0), + img_w=self.img_w, + length=15)) + ious = paddle.to_tensor(ious) + + mask = ious <= nms_overlap_thresh + id = paddle.where(mask == False)[0] + + if id.shape[0] == 0: + break + condidates = condidates[1:].index_select(id) + idx = idx[1:].index_select(id) + keep = paddle.stack(keep) + + return keep + + def get_lanes(self, output, as_lanes=True): + """ + Convert model output to lanes. + """ + softmax = nn.Softmax(axis=1) + decoded = [] + + for predictions in output: + if len(predictions) == 0: + decoded.append([]) + continue + threshold = self.conf_threshold + scores = softmax(predictions[:, :2])[:, 1] + keep_inds = scores >= threshold + predictions = predictions[keep_inds] + scores = scores[keep_inds] + + if predictions.shape[0] == 0: + decoded.append([]) + continue + nms_predictions = predictions.detach().clone() + nms_predictions = paddle.concat( + x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1) + + nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips + nms_predictions[..., 5:] = nms_predictions[..., 5:] * ( + self.img_w - 1) + + keep = self.lane_nms( + nms_predictions[..., 5:], + scores, + nms_overlap_thresh=self.nms_thres, + top_k=self.max_lanes) + + predictions = predictions.index_select(keep) + + if predictions.shape[0] == 0: + decoded.append([]) + continue + predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips) + if as_lanes: + pred = self.predictions_to_pred(predictions) + else: + pred = predictions + decoded.append(pred) + return decoded + + def __call__(self, lanes_list): + lanes = self.get_lanes(lanes_list) + return lanes diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 1c7be6373b8afd8a6975fed0f4d46f1fb91aff28..dc0922bb34b38947204d3804fd71efd53cd86031 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -33,9 +33,10 @@ sys.path.insert(0, parent_path) from benchmark_utils import PaddleInferBenchmark from picodet_postprocess import PicoDetPostProcess -from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image +from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop -from visualize import visualize_box_mask +from clrnet_postprocess import CLRNetPostProcess +from visualize import visualize_box_mask, imshow_lanes from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid # Global dictionary @@ -43,7 +44,7 @@ SUPPORT_MODELS = { 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet', - 'PPLCNet', 'DETR', 'CenterTrack' + 'PPLCNet', 'DETR', 'CenterTrack', 'CLRNet' } @@ -713,6 +714,112 @@ class DetectorPicoDet(Detector): return result +class DetectorCLRNet(Detector): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to turn on MKLDNN + enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16 + """ + + def __init__( + self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + output_dir='./', + threshold=0.5, ): + super(DetectorCLRNet, self).__init__( + model_dir=model_dir, + device=device, + run_mode=run_mode, + batch_size=batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, + output_dir=output_dir, + threshold=threshold, ) + + deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + with open(deploy_file) as f: + yml_conf = yaml.safe_load(f) + self.img_w = yml_conf['img_w'] + self.ori_img_h = yml_conf['ori_img_h'] + self.cut_height = yml_conf['cut_height'] + self.max_lanes = yml_conf['max_lanes'] + self.nms_thres = yml_conf['nms_thres'] + self.num_points = yml_conf['num_points'] + self.conf_threshold = yml_conf['conf_threshold'] + + def postprocess(self, inputs, result): + # postprocess output of predictor + lanes_list = result['lanes'] + postprocessor = CLRNetPostProcess( + img_w=self.img_w, + ori_img_h=self.ori_img_h, + cut_height=self.cut_height, + conf_threshold=self.conf_threshold, + nms_thres=self.nms_thres, + max_lanes=self.max_lanes, + num_points=self.num_points) + lanes = postprocessor(lanes_list) + result = dict(lanes=lanes) + return result + + def predict(self, repeats=1, run_benchmark=False): + ''' + Args: + repeats (int): repeat number for prediction + Returns: + result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + ''' + lanes_list = [] + + if run_benchmark: + for i in range(repeats): + self.predictor.run() + paddle.device.cuda.synchronize() + result = dict(lanes=lanes_list) + return result + + for i in range(repeats): + # TODO: check the output of predictor + self.predictor.run() + lanes_list.clear() + output_names = self.predictor.get_output_names() + num_outs = int(len(output_names) / 2) + if num_outs == 0: + lanes_list.append([]) + for out_idx in range(num_outs): + lanes_list.append( + self.predictor.get_output_handle(output_names[out_idx]) + .copy_to_cpu()) + result = dict(lanes=lanes_list) + return result + + def create_inputs(imgs, im_info): """generate input for different model type Args: @@ -965,6 +1072,16 @@ def get_test_images(infer_dir, infer_img): def visualize(image_list, result, labels, output_dir='output/', threshold=0.5): # visualize the predict result + if 'lanes' in result: + print(image_list) + for idx, image_file in enumerate(image_list): + lanes = result['lanes'][idx] + img = cv2.imread(image_file) + out_file = os.path.join(output_dir, os.path.basename(image_file)) + # hard code + lanes = [lane.to_array([], ) for lane in lanes] + imshow_lanes(img, lanes, out_file=out_file) + return start_idx = 0 for idx, image_file in enumerate(image_list): im_bboxes_num = result['boxes_num'][idx] @@ -1013,6 +1130,8 @@ def main(): detector_func = 'DetectorSOLOv2' elif arch == 'PicoDet': detector_func = 'DetectorPicoDet' + elif arch == "CLRNet": + detector_func = 'DetectorCLRNet' detector = eval(detector_func)( FLAGS.model_dir, diff --git a/deploy/python/preprocess.py b/deploy/python/preprocess.py index 6f1a5a2a1a0e38e3edbd9685ad4013b6579ddb87..1936d3e49ef6d9069edc974b4399a5cd71aa9f5d 100644 --- a/deploy/python/preprocess.py +++ b/deploy/python/preprocess.py @@ -14,6 +14,7 @@ import cv2 import numpy as np +import imgaug.augmenters as iaa from keypoint_preprocess import get_affine_transform from PIL import Image @@ -509,6 +510,32 @@ class WarpAffine(object): return inp, im_info +class CULaneResize(object): + def __init__(self, img_h, img_w, cut_height, prob=0.5): + super(CULaneResize, self).__init__() + self.img_h = img_h + self.img_w = img_w + self.cut_height = cut_height + self.prob = prob + + def __call__(self, im, im_info): + # cut + im = im[self.cut_height:, :, :] + # resize + transform = iaa.Sometimes(self.prob, + iaa.Resize({ + "height": self.img_h, + "width": self.img_w + })) + im = transform(image=im.copy().astype(np.uint8)) + + im = im.astype(np.float32) / 255. + # check transpose is need whether the func decode_image is equal to CULaneDataSet cv.imread + im = im.transpose(2, 0, 1) + + return im, im_info + + def preprocess(im, preprocess_ops): # process image by preprocess_ops im_info = { diff --git a/deploy/python/visualize.py b/deploy/python/visualize.py index 5d4ea4de12766dc067ae894def7374604484295d..e964ec05df3ea1e0c6759e080fd4c001c9dc78e5 100644 --- a/deploy/python/visualize.py +++ b/deploy/python/visualize.py @@ -577,3 +577,63 @@ def visualize_vehicle_retrograde(im, mot_res, vehicle_retrograde_res): draw.text((xmax + 1, ymin - th), text, fill=(0, 255, 0)) return im + + +COLORS = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (128, 255, 0), + (255, 128, 0), + (128, 0, 255), + (255, 0, 128), + (0, 128, 255), + (0, 255, 128), + (128, 255, 255), + (255, 128, 255), + (255, 255, 128), + (60, 180, 0), + (180, 60, 0), + (0, 60, 180), + (0, 180, 60), + (60, 0, 180), + (180, 0, 60), + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (128, 255, 0), + (255, 128, 0), + (128, 0, 255), +] + + +def imshow_lanes(img, lanes, show=False, out_file=None, width=4): + lanes_xys = [] + for _, lane in enumerate(lanes): + xys = [] + for x, y in lane: + if x <= 0 or y <= 0: + continue + x, y = int(x), int(y) + xys.append((x, y)) + lanes_xys.append(xys) + lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0) + + for idx, xys in enumerate(lanes_xys): + for i in range(1, len(xys)): + cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width) + + if show: + cv2.imshow('view', img) + cv2.waitKey(0) + + if out_file: + if not os.path.exists(os.path.dirname(out_file)): + os.makedirs(os.path.dirname(out_file)) + cv2.imwrite(out_file, img) \ No newline at end of file diff --git a/ppdet/data/culane_utils.py b/ppdet/data/culane_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8c948096fa9efcb43cf642c15f43492a528501 --- /dev/null +++ b/ppdet/data/culane_utils.py @@ -0,0 +1,130 @@ +import math +import numpy as np +from imgaug.augmentables.lines import LineString +from scipy.interpolate import InterpolatedUnivariateSpline + + +def lane_to_linestrings(lanes): + lines = [] + for lane in lanes: + lines.append(LineString(lane)) + + return lines + + +def linestrings_to_lanes(lines): + lanes = [] + for line in lines: + lanes.append(line.coords) + + return lanes + + +def sample_lane(points, sample_ys, img_w): + # this function expects the points to be sorted + points = np.array(points) + if not np.all(points[1:, 1] < points[:-1, 1]): + raise Exception('Annotaion points have to be sorted') + x, y = points[:, 0], points[:, 1] + + # interpolate points inside domain + assert len(points) > 1 + interp = InterpolatedUnivariateSpline( + y[::-1], x[::-1], k=min(3, len(points) - 1)) + domain_min_y = y.min() + domain_max_y = y.max() + sample_ys_inside_domain = sample_ys[(sample_ys >= domain_min_y) & ( + sample_ys <= domain_max_y)] + assert len(sample_ys_inside_domain) > 0 + interp_xs = interp(sample_ys_inside_domain) + + # extrapolate lane to the bottom of the image with a straight line using the 2 points closest to the bottom + two_closest_points = points[:2] + extrap = np.polyfit( + two_closest_points[:, 1], two_closest_points[:, 0], deg=1) + extrap_ys = sample_ys[sample_ys > domain_max_y] + extrap_xs = np.polyval(extrap, extrap_ys) + all_xs = np.hstack((extrap_xs, interp_xs)) + + # separate between inside and outside points + inside_mask = (all_xs >= 0) & (all_xs < img_w) + xs_inside_image = all_xs[inside_mask] + xs_outside_image = all_xs[~inside_mask] + + return xs_outside_image, xs_inside_image + + +def filter_lane(lane): + assert lane[-1][1] <= lane[0][1] + filtered_lane = [] + used = set() + for p in lane: + if p[1] not in used: + filtered_lane.append(p) + used.add(p[1]) + + return filtered_lane + + +def transform_annotation(img_w, img_h, max_lanes, n_offsets, offsets_ys, + n_strips, strip_size, anno): + old_lanes = anno['lanes'] + + # removing lanes with less than 2 points + old_lanes = filter(lambda x: len(x) > 1, old_lanes) + # sort lane points by Y (bottom to top of the image) + old_lanes = [sorted(lane, key=lambda x: -x[1]) for lane in old_lanes] + # remove points with same Y (keep first occurrence) + old_lanes = [filter_lane(lane) for lane in old_lanes] + # normalize the annotation coordinates + old_lanes = [[[x * img_w / float(img_w), y * img_h / float(img_h)] + for x, y in lane] for lane in old_lanes] + # create tranformed annotations + lanes = np.ones( + (max_lanes, 2 + 1 + 1 + 2 + n_offsets), dtype=np.float32 + ) * -1e5 # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, S+1 coordinates + lanes_endpoints = np.ones((max_lanes, 2)) + # lanes are invalid by default + lanes[:, 0] = 1 + lanes[:, 1] = 0 + for lane_idx, lane in enumerate(old_lanes): + if lane_idx >= max_lanes: + break + + try: + xs_outside_image, xs_inside_image = sample_lane(lane, offsets_ys, + img_w) + except AssertionError: + continue + if len(xs_inside_image) <= 1: + continue + all_xs = np.hstack((xs_outside_image, xs_inside_image)) + lanes[lane_idx, 0] = 0 + lanes[lane_idx, 1] = 1 + lanes[lane_idx, 2] = len(xs_outside_image) / n_strips + lanes[lane_idx, 3] = xs_inside_image[0] + + thetas = [] + for i in range(1, len(xs_inside_image)): + theta = math.atan( + i * strip_size / + (xs_inside_image[i] - xs_inside_image[0] + 1e-5)) / math.pi + theta = theta if theta > 0 else 1 - abs(theta) + thetas.append(theta) + + theta_far = sum(thetas) / len(thetas) + + # lanes[lane_idx, + # 4] = (theta_closest + theta_far) / 2 # averaged angle + lanes[lane_idx, 4] = theta_far + lanes[lane_idx, 5] = len(xs_inside_image) + lanes[lane_idx, 6:6 + len(all_xs)] = all_xs + lanes_endpoints[lane_idx, 0] = (len(all_xs) - 1) / n_strips + lanes_endpoints[lane_idx, 1] = xs_inside_image[-1] + + new_anno = { + 'label': lanes, + 'old_anno': anno, + 'lane_endpoints': lanes_endpoints + } + return new_anno diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index f4fef334ee2a87791a9838dabc19097486cb46ea..be723eab53e34399b4c0b3929c8e6e2d7588bf3c 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -19,6 +19,7 @@ from . import category from . import keypoint_coco from . import mot from . import sniper_coco +from . import culane from .coco import * from .voc import * @@ -29,3 +30,4 @@ from .mot import * from .sniper_coco import SniperCOCODataSet from .dataset import ImageFolder from .pose3d_cmb import * +from .culane import * diff --git a/ppdet/data/source/culane.py b/ppdet/data/source/culane.py new file mode 100644 index 0000000000000000000000000000000000000000..977d608bac73770dc101b35e68ac6734f967e08e --- /dev/null +++ b/ppdet/data/source/culane.py @@ -0,0 +1,206 @@ +from ppdet.core.workspace import register, serializable +import cv2 +import os +import tarfile +import numpy as np +import os.path as osp +from ppdet.data.source.dataset import DetDataset +from imgaug.augmentables.lines import LineStringsOnImage +from imgaug.augmentables.segmaps import SegmentationMapsOnImage +from ppdet.data.culane_utils import lane_to_linestrings +import pickle as pkl +from ppdet.utils.logger import setup_logger +try: + from collections.abc import Sequence +except Exception: + from collections import Sequence +from .dataset import DetDataset, _make_dataset, _is_valid_file +from ppdet.utils.download import download_dataset + +logger = setup_logger(__name__) + + +@register +@serializable +class CULaneDataSet(DetDataset): + def __init__( + self, + dataset_dir, + cut_height, + list_path, + split='train', + data_fields=['image'], + video_file=None, + frame_rate=-1, ): + super(CULaneDataSet, self).__init__( + dataset_dir=dataset_dir, + cut_height=cut_height, + split=split, + data_fields=data_fields) + self.dataset_dir = dataset_dir + self.list_path = osp.join(dataset_dir, list_path) + self.cut_height = cut_height + self.data_fields = data_fields + self.split = split + self.training = 'train' in split + self.data_infos = [] + self.video_file = video_file + self.frame_rate = frame_rate + self._imid2path = {} + self.predict_dir = None + + def __len__(self): + return len(self.data_infos) + + def check_or_download_dataset(self): + if not osp.exists(self.dataset_dir): + download_dataset("dataset", dataset="culane") + # extract .tar files in self.dataset_dir + for fname in os.listdir(self.dataset_dir): + logger.info("Decompressing {}...".format(fname)) + # ignore .* files + if fname.startswith('.'): + continue + if fname.find('.tar.gz') >= 0: + with tarfile.open(osp.join(self.dataset_dir, fname)) as tf: + tf.extractall(path=self.dataset_dir) + logger.info("Dataset files are ready.") + + def parse_dataset(self): + logger.info('Loading CULane annotations...') + if self.predict_dir is not None: + logger.info('switch to predict mode') + return + # Waiting for the dataset to load is tedious, let's cache it + os.makedirs('cache', exist_ok=True) + cache_path = 'cache/culane_paddle_{}.pkl'.format(self.split) + if os.path.exists(cache_path): + with open(cache_path, 'rb') as cache_file: + self.data_infos = pkl.load(cache_file) + self.max_lanes = max( + len(anno['lanes']) for anno in self.data_infos) + return + + with open(self.list_path) as list_file: + for line in list_file: + infos = self.load_annotation(line.split()) + self.data_infos.append(infos) + + # cache data infos to file + with open(cache_path, 'wb') as cache_file: + pkl.dump(self.data_infos, cache_file) + + def load_annotation(self, line): + infos = {} + img_line = line[0] + img_line = img_line[1 if img_line[0] == '/' else 0::] + img_path = os.path.join(self.dataset_dir, img_line) + infos['img_name'] = img_line + infos['img_path'] = img_path + if len(line) > 1: + mask_line = line[1] + mask_line = mask_line[1 if mask_line[0] == '/' else 0::] + mask_path = os.path.join(self.dataset_dir, mask_line) + infos['mask_path'] = mask_path + + if len(line) > 2: + exist_list = [int(l) for l in line[2:]] + infos['lane_exist'] = np.array(exist_list) + + anno_path = img_path[: + -3] + 'lines.txt' # remove sufix jpg and add lines.txt + with open(anno_path, 'r') as anno_file: + data = [ + list(map(float, line.split())) for line in anno_file.readlines() + ] + lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) + if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data] + lanes = [list(set(lane)) for lane in lanes] # remove duplicated points + lanes = [lane for lane in lanes + if len(lane) > 2] # remove lanes with less than 2 points + + lanes = [sorted( + lane, key=lambda x: x[1]) for lane in lanes] # sort by y + infos['lanes'] = lanes + + return infos + + def set_images(self, images): + self.predict_dir = images + self.data_infos = self._load_images() + + def _find_images(self): + predict_dir = self.predict_dir + if not isinstance(predict_dir, Sequence): + predict_dir = [predict_dir] + images = [] + for im_dir in predict_dir: + if os.path.isdir(im_dir): + im_dir = os.path.join(self.predict_dir, im_dir) + images.extend(_make_dataset(im_dir)) + elif os.path.isfile(im_dir) and _is_valid_file(im_dir): + images.append(im_dir) + return images + + def _load_images(self): + images = self._find_images() + ct = 0 + records = [] + for image in images: + assert image != '' and os.path.isfile(image), \ + "Image {} not found".format(image) + if self.sample_num > 0 and ct >= self.sample_num: + break + rec = { + 'im_id': np.array([ct]), + "img_path": os.path.abspath(image), + "img_name": os.path.basename(image), + "lanes": [] + } + self._imid2path[ct] = image + ct += 1 + records.append(rec) + assert len(records) > 0, "No image file found" + return records + + def get_imid2path(self): + return self._imid2path + + def __getitem__(self, idx): + data_info = self.data_infos[idx] + img = cv2.imread(data_info['img_path']) + img = img[self.cut_height:, :, :] + sample = data_info.copy() + sample.update({'image': img}) + img_org = sample['image'] + + if self.training: + label = cv2.imread(sample['mask_path'], cv2.IMREAD_UNCHANGED) + if len(label.shape) > 2: + label = label[:, :, 0] + label = label.squeeze() + label = label[self.cut_height:, :] + sample.update({'mask': label}) + if self.cut_height != 0: + new_lanes = [] + for i in sample['lanes']: + lanes = [] + for p in i: + lanes.append((p[0], p[1] - self.cut_height)) + new_lanes.append(lanes) + sample.update({'lanes': new_lanes}) + + sample['mask'] = SegmentationMapsOnImage( + sample['mask'], shape=img_org.shape) + + sample['full_img_path'] = data_info['img_path'] + sample['img_name'] = data_info['img_name'] + sample['im_id'] = np.array([idx]) + + sample['image'] = sample['image'].copy().astype(np.uint8) + sample['lanes'] = lane_to_linestrings(sample['lanes']) + sample['lanes'] = LineStringsOnImage( + sample['lanes'], shape=img_org.shape) + sample['seg'] = np.zeros(img_org.shape) + + return sample diff --git a/ppdet/data/transform/__init__.py b/ppdet/data/transform/__init__.py index 08d7f64d9e906a4b7aa47496e32cd787244a8c9f..56803581234303b72e1726c2b6a5c6f6dac0cdbc 100644 --- a/ppdet/data/transform/__init__.py +++ b/ppdet/data/transform/__init__.py @@ -18,6 +18,7 @@ from . import keypoint_operators from . import mot_operators from . import rotated_operators from . import keypoints_3d_operators +from . import culane_operators from .operators import * from .batch_operators import * @@ -25,8 +26,10 @@ from .keypoint_operators import * from .mot_operators import * from .rotated_operators import * from .keypoints_3d_operators import * +from .culane_operators import * __all__ = [] __all__ += registered_ops __all__ += keypoint_operators.__all__ __all__ += mot_operators.__all__ +__all__ += culane_operators.__all__ diff --git a/ppdet/data/transform/culane_operators.py b/ppdet/data/transform/culane_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..47904357da2faa07413c9a401896c656d53079d2 --- /dev/null +++ b/ppdet/data/transform/culane_operators.py @@ -0,0 +1,366 @@ +import numpy as np +import imgaug.augmenters as iaa +from .operators import BaseOperator, register_op +from ppdet.utils.logger import setup_logger +from ppdet.data.culane_utils import linestrings_to_lanes, transform_annotation + +logger = setup_logger(__name__) + +__all__ = [ + "CULaneTrainProcess", "CULaneDataProcess", "HorizontalFlip", + "ChannelShuffle", "CULaneAffine", "CULaneResize", "OneOfBlur", + "MultiplyAndAddToBrightness", "AddToHueAndSaturation" +] + + +def trainTransforms(img_h, img_w): + transforms = [{ + 'name': 'Resize', + 'parameters': dict(size=dict( + height=img_h, width=img_w)), + 'p': 1.0 + }, { + 'name': 'HorizontalFlip', + 'parameters': dict(p=1.0), + 'p': 0.5 + }, { + 'name': 'ChannelShuffle', + 'parameters': dict(p=1.0), + 'p': 0.1 + }, { + 'name': 'MultiplyAndAddToBrightness', + 'parameters': dict( + mul=(0.85, 1.15), add=(-10, 10)), + 'p': 0.6 + }, { + 'name': 'AddToHueAndSaturation', + 'parameters': dict(value=(-10, 10)), + 'p': 0.7 + }, { + 'name': 'OneOf', + 'transforms': [ + dict( + name='MotionBlur', parameters=dict(k=(3, 5))), dict( + name='MedianBlur', parameters=dict(k=(3, 5))) + ], + 'p': 0.2 + }, { + 'name': 'Affine', + 'parameters': dict( + translate_percent=dict( + x=(-0.1, 0.1), y=(-0.1, 0.1)), + rotate=(-10, 10), + scale=(0.8, 1.2)), + 'p': 0.7 + }, { + 'name': 'Resize', + 'parameters': dict(size=dict( + height=img_h, width=img_w)), + 'p': 1.0 + }] + return transforms + + +@register_op +class CULaneTrainProcess(BaseOperator): + def __init__(self, img_w, img_h): + super(CULaneTrainProcess, self).__init__() + self.img_w = img_w + self.img_h = img_h + self.transforms = trainTransforms(self.img_h, self.img_w) + + if self.transforms is not None: + img_transforms = [] + for aug in self.transforms: + p = aug['p'] + if aug['name'] != 'OneOf': + img_transforms.append( + iaa.Sometimes( + p=p, + then_list=getattr(iaa, aug['name'])(**aug[ + 'parameters']))) + else: + img_transforms.append( + iaa.Sometimes( + p=p, + then_list=iaa.OneOf([ + getattr(iaa, aug_['name'])(**aug_['parameters']) + for aug_ in aug['transforms'] + ]))) + else: + img_transforms = [] + self.iaa_transform = iaa.Sequential(img_transforms) + + def apply(self, sample, context=None): + img, line_strings, seg = self.iaa_transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + return sample + + +@register_op +class CULaneDataProcess(BaseOperator): + def __init__(self, img_w, img_h, num_points, max_lanes): + super(CULaneDataProcess, self).__init__() + self.img_w = img_w + self.img_h = img_h + self.num_points = num_points + self.n_offsets = num_points + self.n_strips = num_points - 1 + self.strip_size = self.img_h / self.n_strips + + self.max_lanes = max_lanes + self.offsets_ys = np.arange(self.img_h, -1, -self.strip_size) + + def apply(self, sample, context=None): + data = {} + line_strings = sample['lanes'] + line_strings.clip_out_of_image_() + new_anno = {'lanes': linestrings_to_lanes(line_strings)} + + for i in range(30): + try: + annos = transform_annotation( + self.img_w, self.img_h, self.max_lanes, self.n_offsets, + self.offsets_ys, self.n_strips, self.strip_size, new_anno) + label = annos['label'] + lane_endpoints = annos['lane_endpoints'] + break + except: + if (i + 1) == 30: + logger.critical('Transform annotation failed 30 times :(') + exit() + + sample['image'] = sample['image'].astype(np.float32) / 255. + data['image'] = sample['image'].transpose(2, 0, 1) + data['lane_line'] = label + data['seg'] = sample['seg'] + data['full_img_path'] = sample['full_img_path'] + data['img_name'] = sample['img_name'] + data['im_id'] = sample['im_id'] + + if 'mask' in sample.keys(): + data['seg'] = sample['mask'].get_arr() + + data['im_shape'] = np.array([self.img_w, self.img_h], dtype=np.float32) + data['scale_factor'] = np.array([1., 1.], dtype=np.float32) + + return data + + +@register_op +class CULaneResize(BaseOperator): + def __init__(self, img_h, img_w, prob=0.5): + super(CULaneResize, self).__init__() + self.img_h = img_h + self.img_w = img_w + self.prob = prob + + def apply(self, sample, context=None): + transform = iaa.Sometimes(self.prob, + iaa.Resize({ + "height": self.img_h, + "width": self.img_w + })) + if 'mask' in sample.keys(): + img, line_strings, seg = transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + else: + img, line_strings = transform( + image=sample['image'].copy().astype(np.uint8), + line_strings=sample['lanes']) + sample['image'] = img + sample['lanes'] = line_strings + + return sample + + +@register_op +class HorizontalFlip(BaseOperator): + def __init__(self, prob=0.5): + super(HorizontalFlip, self).__init__() + self.prob = prob + + def apply(self, sample, context=None): + transform = iaa.Sometimes(self.prob, iaa.HorizontalFlip(1.0)) + if 'mask' in sample.keys(): + img, line_strings, seg = transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + else: + img, line_strings = transform( + image=sample['image'], line_strings=sample['lanes']) + sample['image'] = img + sample['lanes'] = line_strings + + return sample + + +@register_op +class ChannelShuffle(BaseOperator): + def __init__(self, prob=0.1): + super(ChannelShuffle, self).__init__() + self.prob = prob + + def apply(self, sample, context=None): + transform = iaa.Sometimes(self.prob, iaa.ChannelShuffle(1.0)) + if 'mask' in sample.keys(): + img, line_strings, seg = transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + else: + img, line_strings = transform( + image=sample['image'], line_strings=sample['lanes']) + sample['image'] = img + sample['lanes'] = line_strings + + return sample + + +@register_op +class MultiplyAndAddToBrightness(BaseOperator): + def __init__(self, mul=(0.85, 1.15), add=(-10, 10), prob=0.5): + super(MultiplyAndAddToBrightness, self).__init__() + self.mul = tuple(mul) + self.add = tuple(add) + self.prob = prob + + def apply(self, sample, context=None): + transform = iaa.Sometimes( + self.prob, + iaa.MultiplyAndAddToBrightness( + mul=self.mul, add=self.add)) + if 'mask' in sample.keys(): + img, line_strings, seg = transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + else: + img, line_strings = transform( + image=sample['image'], line_strings=sample['lanes']) + sample['image'] = img + sample['lanes'] = line_strings + + return sample + + +@register_op +class AddToHueAndSaturation(BaseOperator): + def __init__(self, value=(-10, 10), prob=0.5): + super(AddToHueAndSaturation, self).__init__() + self.value = tuple(value) + self.prob = prob + + def apply(self, sample, context=None): + transform = iaa.Sometimes( + self.prob, iaa.AddToHueAndSaturation(value=self.value)) + if 'mask' in sample.keys(): + img, line_strings, seg = transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + else: + img, line_strings = transform( + image=sample['image'], line_strings=sample['lanes']) + sample['image'] = img + sample['lanes'] = line_strings + + return sample + + +@register_op +class OneOfBlur(BaseOperator): + def __init__(self, MotionBlur_k=(3, 5), MedianBlur_k=(3, 5), prob=0.5): + super(OneOfBlur, self).__init__() + self.MotionBlur_k = tuple(MotionBlur_k) + self.MedianBlur_k = tuple(MedianBlur_k) + self.prob = prob + + def apply(self, sample, context=None): + transform = iaa.Sometimes( + self.prob, + iaa.OneOf([ + iaa.MotionBlur(k=self.MotionBlur_k), + iaa.MedianBlur(k=self.MedianBlur_k) + ])) + + if 'mask' in sample.keys(): + img, line_strings, seg = transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + else: + img, line_strings = transform( + image=sample['image'], line_strings=sample['lanes']) + sample['image'] = img + sample['lanes'] = line_strings + + return sample + + +@register_op +class CULaneAffine(BaseOperator): + def __init__(self, + translate_percent_x=(-0.1, 0.1), + translate_percent_y=(-0.1, 0.1), + rotate=(3, 5), + scale=(0.8, 1.2), + prob=0.5): + super(CULaneAffine, self).__init__() + self.translate_percent = { + 'x': tuple(translate_percent_x), + 'y': tuple(translate_percent_y) + } + self.rotate = tuple(rotate) + self.scale = tuple(scale) + self.prob = prob + + def apply(self, sample, context=None): + transform = iaa.Sometimes( + self.prob, + iaa.Affine( + translate_percent=self.translate_percent, + rotate=self.rotate, + scale=self.scale)) + + if 'mask' in sample.keys(): + img, line_strings, seg = transform( + image=sample['image'], + line_strings=sample['lanes'], + segmentation_maps=sample['mask']) + sample['image'] = img + sample['lanes'] = line_strings + sample['mask'] = seg + else: + img, line_strings = transform( + image=sample['image'], line_strings=sample['lanes']) + sample['image'] = img + sample['lanes'] = line_strings + + return sample diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 882dd5af65e3e6fa3fc89bdbd99096bf277478e9..daaa39a6294c79c811b69e3d2b41e0be8152b16c 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -54,10 +54,12 @@ TRT_MIN_SUBGRAPH = { 'YOLOF': 40, 'METRO_Body': 3, 'DETR': 3, + 'CLRNet': 3 } KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack'] +LANE_ARCH = ['CLRNet'] TO_STATIC_SPEC = { 'yolov3_darknet53_270e_coco': [{ @@ -215,12 +217,13 @@ def _prune_input_spec(input_spec, program, targets): def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): preprocess_list = [] + label_list = [] + if arch != "lane_arch": + anno_file = dataset_cfg.get_anno() - anno_file = dataset_cfg.get_anno() + clsid2catid, catid2name = get_categories(metric, anno_file, arch) - clsid2catid, catid2name = get_categories(metric, anno_file, arch) - - label_list = [str(cat) for cat in catid2name.values()] + label_list = [str(cat) for cat in catid2name.values()] fuse_normalize = reader_cfg.get('fuse_normalize', False) sample_transforms = reader_cfg['sample_transforms'] @@ -246,6 +249,13 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): 'stride': value['pad_to_stride'] }) break + elif key == "CULaneResize": + # cut and resize + p = {'type': key} + p.update(value) + p.update({"cut_height": dataset_cfg.cut_height}) + preprocess_list.append(p) + break return preprocess_list, label_list @@ -315,6 +325,20 @@ def _dump_infer_config(config, path, image_shape, model): if infer_arch in KEYPOINT_ARCH: label_arch = 'keypoint_arch' + if infer_arch in LANE_ARCH: + infer_cfg['arch'] = infer_arch + infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch] + infer_cfg['img_w'] = config['img_w'] + infer_cfg['ori_img_h'] = config['ori_img_h'] + infer_cfg['cut_height'] = config['cut_height'] + label_arch = 'lane_arch' + head_name = "CLRHead" + infer_cfg['conf_threshold'] = config[head_name]['conf_threshold'] + infer_cfg['nms_thres'] = config[head_name]['nms_thres'] + infer_cfg['max_lanes'] = config[head_name]['max_lanes'] + infer_cfg['num_points'] = config[head_name]['num_points'] + arch_state = True + if infer_arch in MOT_ARCH: if config['metric'] in ['COCO', 'VOC']: # MOT model run as Detector diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index bfd92fd62fb14dbcdb99f7093a4ef1de310707d3..583712c31f7417e299c3adc5b40253200c09fa34 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -39,13 +39,14 @@ from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.visualizer import visualize_results, save_result from ppdet.metrics import get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownCOCOWholeBadyHandEval, KeyPointTopDownMPIIEval, Pose3DEval -from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, RBoxMetric, JDEDetMetric, SNIPERCOCOMetric +from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, RBoxMetric, JDEDetMetric, SNIPERCOCOMetric, CULaneMetric from ppdet.data.source.sniper_coco import SniperCOCODataSet from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats from ppdet.utils.fuse_utils import fuse_conv_bn from ppdet.utils import profiler from ppdet.modeling.post_process import multiclass_nms +from ppdet.modeling.lane_utils import imshow_lanes from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback from .export_utils import _dump_infer_config, _prune_input_spec, apply_to_static @@ -383,6 +384,15 @@ class Trainer(object): ] elif self.cfg.metric == 'MOTDet': self._metrics = [JDEDetMetric(), ] + elif self.cfg.metric == 'CULaneMetric': + output_eval = self.cfg.get('output_eval', None) + self._metrics = [ + CULaneMetric( + cfg=self.cfg, + output_eval=output_eval, + split=self.dataset.split, + dataset_dir=self.cfg.dataset_dir) + ] else: logger.warning("Metric not support for metric type {}".format( self.cfg.metric)) @@ -1139,6 +1149,12 @@ class Trainer(object): "crops": InputSpec( shape=[None, 3, 192, 64], name='crops') }) + + if self.cfg.architecture == 'CLRNet': + input_spec[0].update({ + "full_img_path": str, + "img_name": str, + }) if prune_input: static_model = paddle.jit.to_static( self.model, input_spec=input_spec) @@ -1277,3 +1293,107 @@ class Trainer(object): logger.info("Found {} inference images in total.".format( len(images))) return all_images + + def predict_culane(self, + images, + output_dir='output', + save_results=False, + visualize=True): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + self.dataset.set_images(images) + loader = create('TestReader')(self.dataset, 0) + + imid2path = self.dataset.get_imid2path() + + def setup_metrics_for_loader(): + # mem + metrics = copy.deepcopy(self._metrics) + mode = self.mode + save_prediction_only = self.cfg[ + 'save_prediction_only'] if 'save_prediction_only' in self.cfg else None + output_eval = self.cfg[ + 'output_eval'] if 'output_eval' in self.cfg else None + + # modify + self.mode = '_test' + self.cfg['save_prediction_only'] = True + self.cfg['output_eval'] = output_dir + self.cfg['imid2path'] = imid2path + self._init_metrics() + + # restore + self.mode = mode + self.cfg.pop('save_prediction_only') + if save_prediction_only is not None: + self.cfg['save_prediction_only'] = save_prediction_only + + self.cfg.pop('output_eval') + if output_eval is not None: + self.cfg['output_eval'] = output_eval + + self.cfg.pop('imid2path') + + _metrics = copy.deepcopy(self._metrics) + self._metrics = metrics + + return _metrics + + if save_results: + metrics = setup_metrics_for_loader() + else: + metrics = [] + + # Run Infer + self.status['mode'] = 'test' + self.model.eval() + if self.cfg.get('print_flops', False): + flops_loader = create('TestReader')(self.dataset, 0) + self._flops(flops_loader) + results = [] + for step_id, data in enumerate(tqdm(loader)): + self.status['step_id'] = step_id + # forward + outs = self.model(data) + + for _m in metrics: + _m.update(data, outs) + + for key in ['im_shape', 'scale_factor', 'im_id']: + if isinstance(data, typing.Sequence): + outs[key] = data[0][key] + else: + outs[key] = data[key] + for key, value in outs.items(): + if hasattr(value, 'numpy'): + outs[key] = value.numpy() + results.append(outs) + + for _m in metrics: + _m.accumulate() + _m.reset() + + if visualize: + import cv2 + + for outs in results: + for i in range(len(outs['img_path'])): + lanes = outs['lanes'][i] + img_path = outs['img_path'][i] + img = cv2.imread(img_path) + out_file = os.path.join(output_dir, + os.path.basename(img_path)) + lanes = [ + lane.to_array( + sample_y_range=[ + self.cfg['sample_y']['start'], + self.cfg['sample_y']['end'], + self.cfg['sample_y']['step'] + ], + img_w=self.cfg.ori_img_w, + img_h=self.cfg.ori_img_h) for lane in lanes + ] + imshow_lanes(img, lanes, out_file=out_file) + + return results diff --git a/ppdet/metrics/__init__.py b/ppdet/metrics/__init__.py index 3e1b83cca4425ae82c8af14096a1e91bd7e0503d..288f1581faf5df8ccb2c83a91e213c95f1ec563c 100644 --- a/ppdet/metrics/__init__.py +++ b/ppdet/metrics/__init__.py @@ -27,4 +27,8 @@ __all__ = metrics.__all__ + mot_metrics.__all__ from . import mcmot_metrics from .mcmot_metrics import * -__all__ = metrics.__all__ + mcmot_metrics.__all__ \ No newline at end of file +__all__ = metrics.__all__ + mcmot_metrics.__all__ + +from . import culane_metrics +from .culane_metrics import * +__all__ = metrics.__all__ + culane_metrics.__all__ \ No newline at end of file diff --git a/ppdet/metrics/culane_metrics.py b/ppdet/metrics/culane_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..848d2c186a8dcc695932a17bba495d7d37ce9968 --- /dev/null +++ b/ppdet/metrics/culane_metrics.py @@ -0,0 +1,327 @@ +import os +import cv2 +import numpy as np +import os.path as osp +from functools import partial +from .metrics import Metric +from scipy.interpolate import splprep, splev +from scipy.optimize import linear_sum_assignment +from shapely.geometry import LineString, Polygon +from ppdet.utils.logger import setup_logger + +logger = setup_logger(__name__) + +__all__ = [ + 'draw_lane', 'discrete_cross_iou', 'continuous_cross_iou', 'interp', + 'culane_metric', 'load_culane_img_data', 'load_culane_data', + 'eval_predictions', "CULaneMetric" +] + +LIST_FILE = { + 'train': 'list/train_gt.txt', + 'val': 'list/val.txt', + 'test': 'list/test.txt', +} + +CATEGORYS = { + 'normal': 'list/test_split/test0_normal.txt', + 'crowd': 'list/test_split/test1_crowd.txt', + 'hlight': 'list/test_split/test2_hlight.txt', + 'shadow': 'list/test_split/test3_shadow.txt', + 'noline': 'list/test_split/test4_noline.txt', + 'arrow': 'list/test_split/test5_arrow.txt', + 'curve': 'list/test_split/test6_curve.txt', + 'cross': 'list/test_split/test7_cross.txt', + 'night': 'list/test_split/test8_night.txt', +} + + +def draw_lane(lane, img=None, img_shape=None, width=30): + if img is None: + img = np.zeros(img_shape, dtype=np.uint8) + lane = lane.astype(np.int32) + for p1, p2 in zip(lane[:-1], lane[1:]): + cv2.line( + img, tuple(p1), tuple(p2), color=(255, 255, 255), thickness=width) + return img + + +def discrete_cross_iou(xs, ys, width=30, img_shape=(590, 1640, 3)): + xs = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in xs] + ys = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in ys] + + ious = np.zeros((len(xs), len(ys))) + for i, x in enumerate(xs): + for j, y in enumerate(ys): + ious[i, j] = (x & y).sum() / (x | y).sum() + return ious + + +def continuous_cross_iou(xs, ys, width=30, img_shape=(590, 1640, 3)): + h, w, _ = img_shape + image = Polygon([(0, 0), (0, h - 1), (w - 1, h - 1), (w - 1, 0)]) + xs = [ + LineString(lane).buffer( + distance=width / 2., cap_style=1, join_style=2).intersection(image) + for lane in xs + ] + ys = [ + LineString(lane).buffer( + distance=width / 2., cap_style=1, join_style=2).intersection(image) + for lane in ys + ] + + ious = np.zeros((len(xs), len(ys))) + for i, x in enumerate(xs): + for j, y in enumerate(ys): + ious[i, j] = x.intersection(y).area / x.union(y).area + + return ious + + +def interp(points, n=50): + x = [x for x, _ in points] + y = [y for _, y in points] + tck, u = splprep([x, y], s=0, t=n, k=min(3, len(points) - 1)) + + u = np.linspace(0., 1., num=(len(u) - 1) * n + 1) + return np.array(splev(u, tck)).T + + +def culane_metric(pred, + anno, + width=30, + iou_thresholds=[0.5], + official=True, + img_shape=(590, 1640, 3)): + _metric = {} + for thr in iou_thresholds: + tp = 0 + fp = 0 if len(anno) != 0 else len(pred) + fn = 0 if len(pred) != 0 else len(anno) + _metric[thr] = [tp, fp, fn] + + interp_pred = np.array( + [interp( + pred_lane, n=5) for pred_lane in pred], dtype=object) # (4, 50, 2) + interp_anno = np.array( + [interp( + anno_lane, n=5) for anno_lane in anno], dtype=object) # (4, 50, 2) + + if official: + ious = discrete_cross_iou( + interp_pred, interp_anno, width=width, img_shape=img_shape) + else: + ious = continuous_cross_iou( + interp_pred, interp_anno, width=width, img_shape=img_shape) + + row_ind, col_ind = linear_sum_assignment(1 - ious) + + _metric = {} + for thr in iou_thresholds: + tp = int((ious[row_ind, col_ind] > thr).sum()) + fp = len(pred) - tp + fn = len(anno) - tp + _metric[thr] = [tp, fp, fn] + return _metric + + +def load_culane_img_data(path): + with open(path, 'r') as data_file: + img_data = data_file.readlines() + img_data = [line.split() for line in img_data] + img_data = [list(map(float, lane)) for lane in img_data] + img_data = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)] + for lane in img_data] + img_data = [lane for lane in img_data if len(lane) >= 2] + + return img_data + + +def load_culane_data(data_dir, file_list_path): + with open(file_list_path, 'r') as file_list: + filepaths = [ + os.path.join(data_dir, + line[1 if line[0] == '/' else 0:].rstrip().replace( + '.jpg', '.lines.txt')) + for line in file_list.readlines() + ] + + data = [] + for path in filepaths: + img_data = load_culane_img_data(path) + data.append(img_data) + + return data + + +def eval_predictions(pred_dir, + anno_dir, + list_path, + iou_thresholds=[0.5], + width=30, + official=True, + sequential=False): + logger.info('Calculating metric for List: {}'.format(list_path)) + predictions = load_culane_data(pred_dir, list_path) + annotations = load_culane_data(anno_dir, list_path) + img_shape = (590, 1640, 3) + if sequential: + results = map(partial( + culane_metric, + width=width, + official=official, + iou_thresholds=iou_thresholds, + img_shape=img_shape), + predictions, + annotations) + else: + from multiprocessing import Pool, cpu_count + from itertools import repeat + with Pool(cpu_count()) as p: + results = p.starmap(culane_metric, + zip(predictions, annotations, + repeat(width), + repeat(iou_thresholds), + repeat(official), repeat(img_shape))) + + mean_f1, mean_prec, mean_recall, total_tp, total_fp, total_fn = 0, 0, 0, 0, 0, 0 + ret = {} + for thr in iou_thresholds: + tp = sum(m[thr][0] for m in results) + fp = sum(m[thr][1] for m in results) + fn = sum(m[thr][2] for m in results) + precision = float(tp) / (tp + fp) if tp != 0 else 0 + recall = float(tp) / (tp + fn) if tp != 0 else 0 + f1 = 2 * precision * recall / (precision + recall) if tp != 0 else 0 + logger.info('iou thr: {:.2f}, tp: {}, fp: {}, fn: {},' + 'precision: {}, recall: {}, f1: {}'.format( + thr, tp, fp, fn, precision, recall, f1)) + mean_f1 += f1 / len(iou_thresholds) + mean_prec += precision / len(iou_thresholds) + mean_recall += recall / len(iou_thresholds) + total_tp += tp + total_fp += fp + total_fn += fn + ret[thr] = { + 'TP': tp, + 'FP': fp, + 'FN': fn, + 'Precision': precision, + 'Recall': recall, + 'F1': f1 + } + if len(iou_thresholds) > 2: + logger.info( + 'mean result, total_tp: {}, total_fp: {}, total_fn: {},' + 'precision: {}, recall: {}, f1: {}'.format( + total_tp, total_fp, total_fn, mean_prec, mean_recall, mean_f1)) + ret['mean'] = { + 'TP': total_tp, + 'FP': total_fp, + 'FN': total_fn, + 'Precision': mean_prec, + 'Recall': mean_recall, + 'F1': mean_f1 + } + return ret + + +class CULaneMetric(Metric): + def __init__(self, + cfg, + output_eval=None, + split="test", + dataset_dir="dataset/CULane/"): + super(CULaneMetric, self).__init__() + self.output_eval = "evaluation" if output_eval is None else output_eval + self.dataset_dir = dataset_dir + self.split = split + self.list_path = osp.join(dataset_dir, LIST_FILE[split]) + self.predictions = [] + self.img_names = [] + self.lanes = [] + self.eval_results = {} + self.cfg = cfg + self.reset() + + def reset(self): + self.predictions = [] + self.img_names = [] + self.lanes = [] + self.eval_results = {} + + def get_prediction_string(self, pred): + ys = np.arange(270, 590, 8) / self.cfg.ori_img_h + out = [] + for lane in pred: + xs = lane(ys) + valid_mask = (xs >= 0) & (xs < 1) + xs = xs * self.cfg.ori_img_w + lane_xs = xs[valid_mask] + lane_ys = ys[valid_mask] * self.cfg.ori_img_h + lane_xs, lane_ys = lane_xs[::-1], lane_ys[::-1] + lane_str = ' '.join([ + '{:.5f} {:.5f}'.format(x, y) for x, y in zip(lane_xs, lane_ys) + ]) + if lane_str != '': + out.append(lane_str) + + return '\n'.join(out) + + def accumulate(self): + loss_lines = [[], [], [], []] + for idx, pred in enumerate(self.predictions): + output_dir = os.path.join(self.output_eval, + os.path.dirname(self.img_names[idx])) + output_filename = os.path.basename(self.img_names[ + idx])[:-3] + 'lines.txt' + os.makedirs(output_dir, exist_ok=True) + output = self.get_prediction_string(pred) + + # store loss lines + lanes = self.lanes[idx] + if len(lanes) - len(pred) in [1, 2, 3, 4]: + loss_lines[len(lanes) - len(pred) - 1].append(self.img_names[ + idx]) + + with open(os.path.join(output_dir, output_filename), + 'w') as out_file: + out_file.write(output) + + for i, names in enumerate(loss_lines): + with open( + os.path.join(output_dir, 'loss_{}_lines.txt'.format(i + 1)), + 'w') as f: + for name in names: + f.write(name + '\n') + + for cate, cate_file in CATEGORYS.items(): + result = eval_predictions( + self.output_eval, + self.dataset_dir, + os.path.join(self.dataset_dir, cate_file), + iou_thresholds=[0.5], + official=True) + + result = eval_predictions( + self.output_eval, + self.dataset_dir, + self.list_path, + iou_thresholds=np.linspace(0.5, 0.95, 10), + official=True) + self.eval_results['F1@50'] = result[0.5]['F1'] + self.eval_results['result'] = result + + def update(self, inputs, outputs): + assert len(inputs['img_name']) == len(outputs['lanes']) + self.predictions.extend(outputs['lanes']) + self.img_names.extend(inputs['img_name']) + self.lanes.extend(inputs['lane_line']) + + def log(self): + logger.info(self.eval_results) + + # abstract method for getting metric results + def get_results(self): + return self.eval_results diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index eb5ff75c2f99860dc178f7d4b25aabf28e5b946d..ad60f0f24f35f43728ccf1c5e9f6b354c25e4453 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -42,6 +42,7 @@ from . import yolof from . import pose3d_metro from . import centertrack from . import queryinst +from . import clrnet from .meta_arch import * from .faster_rcnn import * @@ -74,4 +75,5 @@ from .yolof import * from .pose3d_metro import * from .centertrack import * from .queryinst import * -from .keypoint_petr import * \ No newline at end of file +from .keypoint_petr import * +from .clrnet import * \ No newline at end of file diff --git a/ppdet/modeling/architectures/clrnet.py b/ppdet/modeling/architectures/clrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8336fd8887fe2c283ce1ee210ab38653ecf62079 --- /dev/null +++ b/ppdet/modeling/architectures/clrnet.py @@ -0,0 +1,67 @@ +from .meta_arch import BaseArch +from ppdet.core.workspace import register, create +from paddle import in_dynamic_mode + +__all__ = ['CLRNet'] + + +@register +class CLRNet(BaseArch): + __category__ = 'architecture' + + def __init__(self, + backbone="CLRResNet", + neck="CLRFPN", + clr_head="CLRHead", + post_process=None): + super(CLRNet, self).__init__() + self.backbone = backbone + self.neck = neck + self.heads = clr_head + self.post_process = post_process + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + # fpn + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + # head + kwargs = {'input_shape': neck.out_shape} + clr_head = create(cfg['clr_head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + 'clr_head': clr_head, + } + + def _forward(self): + # Backbone + body_feats = self.backbone(self.inputs['image']) + # neck + neck_feats = self.neck(body_feats) + # CRL Head + + if self.training: + output = self.heads(neck_feats, self.inputs) + else: + output = self.heads(neck_feats) + output = {'lanes': output} + # TODO: hard code fix as_lanes=False problem in clrnet_head.py "get_lanes" function for static mode + if in_dynamic_mode(): + output = self.heads.get_lanes(output['lanes']) + output = { + "lanes": output, + "img_path": self.inputs['full_img_path'], + "img_name": self.inputs['img_name'] + } + + return output + + def get_loss(self): + return self._forward() + + def get_pred(self): + return self._forward() diff --git a/ppdet/modeling/assigners/clrnet_assigner.py b/ppdet/modeling/assigners/clrnet_assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..59c94a0a7e2d895ada3b06530900ebec947d4b53 --- /dev/null +++ b/ppdet/modeling/assigners/clrnet_assigner.py @@ -0,0 +1,147 @@ +import paddle +import paddle.nn.functional as F +from ppdet.modeling.losses.clrnet_line_iou_loss import line_iou + + +def distance_cost(predictions, targets, img_w): + """ + repeat predictions and targets to generate all combinations + use the abs distance as the new distance cost + """ + num_priors = predictions.shape[0] + num_targets = targets.shape[0] + predictions = paddle.repeat_interleave( + predictions, num_targets, axis=0)[..., 6:] + targets = paddle.concat(x=num_priors * [targets])[..., 6:] + invalid_masks = (targets < 0) | (targets >= img_w) + lengths = (~invalid_masks).sum(axis=1) + distances = paddle.abs(x=targets - predictions) + distances[invalid_masks] = 0.0 + distances = distances.sum(axis=1) / (lengths.cast("float32") + 1e-09) + distances = distances.reshape([num_priors, num_targets]) + return distances + + +def focal_cost(cls_pred, gt_labels, alpha=0.25, gamma=2, eps=1e-12): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value + """ + cls_pred = F.sigmoid(cls_pred) + neg_cost = -(1 - cls_pred + eps).log() * (1 - alpha) * cls_pred.pow(gamma) + pos_cost = -(cls_pred + eps).log() * alpha * (1 - cls_pred).pow(gamma) + cls_cost = pos_cost.index_select( + gt_labels, axis=1) - neg_cost.index_select( + gt_labels, axis=1) + return cls_cost + + +def dynamic_k_assign(cost, pair_wise_ious): + """ + Assign grouth truths with priors dynamically. + + Args: + cost: the assign cost. + pair_wise_ious: iou of grouth truth and priors. + + Returns: + prior_idx: the index of assigned prior. + gt_idx: the corresponding ground truth index. + """ + matching_matrix = paddle.zeros_like(cost) + ious_matrix = pair_wise_ious + ious_matrix[ious_matrix < 0] = 0.0 + n_candidate_k = 4 + topk_ious, _ = paddle.topk(ious_matrix, n_candidate_k, axis=0) + dynamic_ks = paddle.clip(x=topk_ious.sum(0).cast("int32"), min=1) + num_gt = cost.shape[1] + + for gt_idx in range(num_gt): + _, pos_idx = paddle.topk( + x=cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False) + matching_matrix[pos_idx, gt_idx] = 1.0 + del topk_ious, dynamic_ks, pos_idx + matched_gt = matching_matrix.sum(axis=1) + + if (matched_gt > 1).sum() > 0: + matched_gt_indices = paddle.nonzero(matched_gt > 1)[:, 0] + cost_argmin = paddle.argmin( + cost.index_select(matched_gt_indices), axis=1) + matching_matrix[matched_gt_indices][0] *= 0.0 + matching_matrix[matched_gt_indices, cost_argmin] = 1.0 + + prior_idx = matching_matrix.sum(axis=1).nonzero() + gt_idx = matching_matrix[prior_idx].argmax(axis=-1) + return prior_idx.flatten(), gt_idx.flatten() + + +def cdist_paddle(x1, x2, p=2): + assert x1.shape[1] == x2.shape[1] + B, M = x1.shape + # if p == np.inf: + # dist = np.max(np.abs(x1[:, np.newaxis, :] - x2[np.newaxis, :, :]), axis=-1) + if p == 1: + dist = paddle.sum( + paddle.abs(x1.unsqueeze(axis=1) - x2.unsqueeze(axis=0)), axis=-1) + else: + dist = paddle.pow(paddle.sum(paddle.pow( + paddle.abs(x1.unsqueeze(axis=1) - x2.unsqueeze(axis=0)), p), + axis=-1), + 1 / p) + return dist + + +def assign(predictions, + targets, + img_w, + img_h, + distance_cost_weight=3.0, + cls_cost_weight=1.0): + """ + computes dynamicly matching based on the cost, including cls cost and lane similarity cost + Args: + predictions (Tensor): predictions predicted by each stage, shape: (num_priors, 78) + targets (Tensor): lane targets, shape: (num_targets, 78) + return: + matched_row_inds (Tensor): matched predictions, shape: (num_targets) + matched_col_inds (Tensor): matched targets, shape: (num_targets) + """ + predictions = predictions.detach().clone() + predictions[:, 3] *= img_w - 1 + predictions[:, 6:] *= img_w - 1 + + targets = targets.detach().clone() + distances_score = distance_cost(predictions, targets, img_w) + distances_score = 1 - distances_score / paddle.max(x=distances_score) + 0.01 + + cls_score = focal_cost(predictions[:, :2], targets[:, 1].cast('int64')) + + num_priors = predictions.shape[0] + num_targets = targets.shape[0] + target_start_xys = targets[:, 2:4] + target_start_xys[..., 0] *= (img_h - 1) + prediction_start_xys = predictions[:, 2:4] + prediction_start_xys[..., 0] *= (img_h - 1) + start_xys_score = cdist_paddle( + prediction_start_xys, target_start_xys, + p=2).reshape([num_priors, num_targets]) + + start_xys_score = 1 - start_xys_score / paddle.max(x=start_xys_score) + 0.01 + + target_thetas = targets[:, 4].unsqueeze(axis=-1) + theta_score = cdist_paddle( + predictions[:, 4].unsqueeze(axis=-1), target_thetas, + p=1).reshape([num_priors, num_targets]) * 180 + theta_score = 1 - theta_score / paddle.max(x=theta_score) + 0.01 + + cost = -(distances_score * start_xys_score * theta_score + )**2 * distance_cost_weight + cls_score * cls_cost_weight + iou = line_iou(predictions[..., 6:], targets[..., 6:], img_w, aligned=False) + + matched_row_inds, matched_col_inds = dynamic_k_assign(cost, iou) + return matched_row_inds, matched_col_inds diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index e61ff711186e3191b14f4e41a6363dbf8886e6c6..bc000c72dadd49f0ca21d70ac10f1fa4ff6da8a6 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -38,6 +38,7 @@ from . import trans_encoder from . import focalnet from . import vit_mae from . import hgnet_v2 +from . import clrnet_resnet from .vgg import * from .resnet import * @@ -66,3 +67,4 @@ from .focalnet import * from .vitpose import * from .vit_mae import * from .hgnet_v2 import * +from .clrnet_resnet import * diff --git a/ppdet/modeling/backbones/clrnet_resnet.py b/ppdet/modeling/backbones/clrnet_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..00758df552dd6826e813aab93cd90573448424f1 --- /dev/null +++ b/ppdet/modeling/backbones/clrnet_resnet.py @@ -0,0 +1,697 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +from paddle.utils.download import get_weights_path_from_url +from ppdet.core.workspace import register, serializable +from ..shape_spec import ShapeSpec + +__all__ = ['CLRResNet'] + +model_urls = { + 'resnet18': + 'https://x2paddle.bj.bcebos.com/vision/models/resnet18-pt.pdparams', + 'resnet34': + 'https://x2paddle.bj.bcebos.com/vision/models/resnet34-pt.pdparams', + 'resnet50': + 'https://x2paddle.bj.bcebos.com/vision/models/resnet50-pt.pdparams', + 'resnet101': + 'https://x2paddle.bj.bcebos.com/vision/models/resnet101-pt.pdparams', + 'resnet152': + 'https://x2paddle.bj.bcebos.com/vision/models/resnet152-pt.pdparams', + 'resnext50_32x4d': + 'https://x2paddle.bj.bcebos.com/vision/models/resnext50_32x4d-pt.pdparams', + 'resnext101_32x8d': + 'https://x2paddle.bj.bcebos.com/vision/models/resnext101_32x8d-pt.pdparams', + 'wide_resnet50_2': + 'https://x2paddle.bj.bcebos.com/vision/models/wide_resnet50_2-pt.pdparams', + 'wide_resnet101_2': + 'https://x2paddle.bj.bcebos.com/vision/models/wide_resnet101_2-pt.pdparams', +} + + +class BasicBlock(nn.Layer): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + + self.conv1 = nn.Conv2D( + inplanes, planes, 3, padding=1, stride=stride, bias_attr=False) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class BottleneckBlock(nn.Layer): + + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None): + super(BottleneckBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2D + width = int(planes * (base_width / 64.)) * groups + + self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False) + self.bn1 = norm_layer(width) + + self.conv2 = nn.Conv2D( + width, + width, + 3, + padding=dilation, + stride=stride, + groups=groups, + dilation=dilation, + bias_attr=False) + self.bn2 = norm_layer(width) + + self.conv3 = nn.Conv2D( + width, planes * self.expansion, 1, bias_attr=False) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Layer): + """ResNet model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + Block (BasicBlock|BottleneckBlock): Block module of model. + depth (int, optional): Layers of ResNet, Default: 50. + width (int, optional): Base width per convolution group for each convolution block, Default: 64. + num_classes (int, optional): Output dim of last fc layer. If num_classes <= 0, last fc layer + will not be defined. Default: 1000. + with_pool (bool, optional): Use pool before the last fc layer or not. Default: True. + groups (int, optional): Number of groups for each convolution block, Default: 1. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNet model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import ResNet + from paddle.vision.models.resnet import BottleneckBlock, BasicBlock + # build ResNet with 18 layers + resnet18 = ResNet(BasicBlock, 18) + # build ResNet with 50 layers + resnet50 = ResNet(BottleneckBlock, 50) + # build Wide ResNet model + wide_resnet50_2 = ResNet(BottleneckBlock, 50, width=64*2) + # build ResNeXt model + resnext50_32x4d = ResNet(BottleneckBlock, 50, width=4, groups=32) + x = paddle.rand([1, 3, 224, 224]) + out = resnet18(x) + print(out.shape) + # [1, 1000] + """ + + def __init__(self, block, depth=50, width=64, with_pool=True, groups=1): + super(ResNet, self).__init__() + layer_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3] + } + + layers = layer_cfg[depth] + self.groups = groups + self.base_width = width + self.with_pool = with_pool + self._norm_layer = nn.BatchNorm2D + + self.inplanes = 64 + self.dilation = 1 + + self.conv1 = nn.Conv2D( + 3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias_attr=False) + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + if with_pool: + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + + ch_out_list = [64, 128, 256, 512] + block = BottleneckBlock if depth >= 50 else BasicBlock + + self._out_channels = [block.expansion * v for v in ch_out_list] + self._out_strides = [4, 8, 16, 32] + self.return_idx = [0, 1, 2, 3] + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D( + self.inplanes, + planes * block.expansion, + 1, + stride=stride, + bias_attr=False), + norm_layer(planes * block.expansion), ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self._out_channels[i], stride=self._out_strides[i]) + for i in self.return_idx + ] + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + out_layers = [] + x = self.layer1(x) + out_layers.append(x) + x = self.layer2(x) + out_layers.append(x) + x = self.layer3(x) + out_layers.append(x) + x = self.layer4(x) + out_layers.append(x) + + if self.with_pool: + x = self.avgpool(x) + + return out_layers + + +@register +@serializable +class CLRResNet(nn.Layer): + def __init__(self, + resnet='resnet18', + pretrained=True, + out_conv=False, + fea_stride=8, + out_channel=128, + in_channels=[64, 128, 256, 512], + cfg=None): + super(CLRResNet, self).__init__() + self.cfg = cfg + self.in_channels = in_channels + + self.model = eval(resnet)(pretrained=pretrained) + self.out = None + if out_conv: + out_channel = 512 + for chan in reversed(self.in_channels): + if chan < 0: continue + out_channel = chan + break + self.out = nn.Conv2D( + out_channel * self.model.expansion, + cfg.featuremap_out_channel, + kernel_size=1, + bias_attr=False) + + @property + def out_shape(self): + return self.model.out_shape + + def forward(self, x): + x = self.model(x) + if self.out: + x[-1] = self.out(x[-1]) + return x + + +def _resnet(arch, Block, depth, pretrained, **kwargs): + model = ResNet(Block, depth, **kwargs) + if pretrained: + assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( + arch) + weight_path = get_weights_path_from_url(model_urls[arch]) + + param = paddle.load(weight_path) + model.set_dict(param) + + return model + + +def resnet18(pretrained=False, **kwargs): + """ResNet 18-layer model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNet 18-layer model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnet18 + # build model + model = resnet18() + # build model and load imagenet pretrained weight + # model = resnet18(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs) + + +def resnet34(pretrained=False, **kwargs): + """ResNet 34-layer model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNet 34-layer model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnet34 + # build model + model = resnet34() + # build model and load imagenet pretrained weight + # model = resnet34(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs) + + +def resnet50(pretrained=False, **kwargs): + """ResNet 50-layer model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNet 50-layer model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnet50 + # build model + model = resnet50() + # build model and load imagenet pretrained weight + # model = resnet50(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnet101(pretrained=False, **kwargs): + """ResNet 101-layer model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNet 101-layer. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnet101 + # build model + model = resnet101() + # build model and load imagenet pretrained weight + # model = resnet101(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs) + + +def resnet152(pretrained=False, **kwargs): + """ResNet 152-layer model from + `"Deep Residual Learning for Image Recognition" `_. + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNet 152-layer model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnet152 + # build model + model = resnet152() + # build model and load imagenet pretrained weight + # model = resnet152(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs) + + +def resnext50_32x4d(pretrained=False, **kwargs): + """ResNeXt-50 32x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_. + + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNeXt-50 32x4d model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnext50_32x4d + # build model + model = resnext50_32x4d() + # build model and load imagenet pretrained weight + # model = resnext50_32x4d(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 32 + kwargs['width'] = 4 + return _resnet('resnext50_32x4d', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnext50_64x4d(pretrained=False, **kwargs): + """ResNeXt-50 64x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_. + + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNeXt-50 64x4d model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnext50_64x4d + # build model + model = resnext50_64x4d() + # build model and load imagenet pretrained weight + # model = resnext50_64x4d(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 64 + kwargs['width'] = 4 + return _resnet('resnext50_64x4d', BottleneckBlock, 50, pretrained, **kwargs) + + +def resnext101_32x4d(pretrained=False, **kwargs): + """ResNeXt-101 32x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_. + + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNeXt-101 32x4d model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnext101_32x4d + # build model + model = resnext101_32x4d() + # build model and load imagenet pretrained weight + # model = resnext101_32x4d(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 32 + kwargs['width'] = 4 + return _resnet('resnext101_32x4d', BottleneckBlock, 101, pretrained, + **kwargs) + + +def resnext101_64x4d(pretrained=False, **kwargs): + """ResNeXt-101 64x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_. + + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNeXt-101 64x4d model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnext101_64x4d + # build model + model = resnext101_64x4d() + # build model and load imagenet pretrained weight + # model = resnext101_64x4d(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 64 + kwargs['width'] = 4 + return _resnet('resnext101_64x4d', BottleneckBlock, 101, pretrained, + **kwargs) + + +def resnext152_32x4d(pretrained=False, **kwargs): + """ResNeXt-152 32x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_. + + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNeXt-152 32x4d model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnext152_32x4d + # build model + model = resnext152_32x4d() + # build model and load imagenet pretrained weight + # model = resnext152_32x4d(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 32 + kwargs['width'] = 4 + return _resnet('resnext152_32x4d', BottleneckBlock, 152, pretrained, + **kwargs) + + +def resnext152_64x4d(pretrained=False, **kwargs): + """ResNeXt-152 64x4d model from + `"Aggregated Residual Transformations for Deep Neural Networks" `_. + + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of ResNeXt-152 64x4d model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import resnext152_64x4d + # build model + model = resnext152_64x4d() + # build model and load imagenet pretrained weight + # model = resnext152_64x4d(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['groups'] = 64 + kwargs['width'] = 4 + return _resnet('resnext152_64x4d', BottleneckBlock, 152, pretrained, + **kwargs) + + +def wide_resnet50_2(pretrained=False, **kwargs): + """Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of Wide ResNet-50-2 model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import wide_resnet50_2 + # build model + model = wide_resnet50_2() + # build model and load imagenet pretrained weight + # model = wide_resnet50_2(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['width'] = 64 * 2 + return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs) + + +def wide_resnet101_2(pretrained=False, **kwargs): + """Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + Args: + pretrained (bool, optional): Whether to load pre-trained weights. If True, returns a model pre-trained + on ImageNet. Default: False. + **kwargs (optional): Additional keyword arguments. For details, please refer to :ref:`ResNet `. + Returns: + :ref:`api_paddle_nn_Layer`. An instance of Wide ResNet-101-2 model. + Examples: + .. code-block:: python + import paddle + from paddle.vision.models import wide_resnet101_2 + # build model + model = wide_resnet101_2() + # build model and load imagenet pretrained weight + # model = wide_resnet101_2(pretrained=True) + x = paddle.rand([1, 3, 224, 224]) + out = model(x) + print(out.shape) + # [1, 1000] + """ + kwargs['width'] = 64 * 2 + return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained, + **kwargs) diff --git a/ppdet/modeling/clrnet_utils.py b/ppdet/modeling/clrnet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..24ece5c2e89d694a7077116c014dff9ab2a79d27 --- /dev/null +++ b/ppdet/modeling/clrnet_utils.py @@ -0,0 +1,309 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.modeling.initializer import constant_ +from paddle.nn.initializer import KaimingNormal + + +class ConvModule(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=False, + norm_type='bn', + wtih_act=True): + super(ConvModule, self).__init__() + assert norm_type in ['bn', 'sync_bn', 'gn', None] + self.with_norm = norm_type is not None + self.wtih_act = wtih_act + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias_attr=bias, + weight_attr=KaimingNormal()) + if self.with_norm: + if norm_type == 'bn': + self.bn = nn.BatchNorm2D(out_channels) + elif norm_type == 'gn': + self.bn = nn.GroupNorm(out_channels, out_channels) + + if self.wtih_act: + self.act = nn.ReLU() + + def forward(self, inputs): + x = self.conv(inputs) + if self.with_norm: + x = self.bn(x) + if self.wtih_act: + x = self.act(x) + return x + + +def LinearModule(hidden_dim): + return nn.LayerList( + [nn.Linear( + hidden_dim, hidden_dim, bias_attr=True), nn.ReLU()]) + + +class FeatureResize(nn.Layer): + def __init__(self, size=(10, 25)): + super(FeatureResize, self).__init__() + self.size = size + + def forward(self, x): + x = F.interpolate(x, self.size) + return x.flatten(2) + + +class ROIGather(nn.Layer): + ''' + ROIGather module for gather global information + Args: + in_channels: prior feature channels + num_priors: prior numbers we predefined + sample_points: the number of sampled points when we extract feature from line + fc_hidden_dim: the fc output channel + refine_layers: the total number of layers to build refine + ''' + + def __init__(self, + in_channels, + num_priors, + sample_points, + fc_hidden_dim, + refine_layers, + mid_channels=48): + super(ROIGather, self).__init__() + self.in_channels = in_channels + self.num_priors = num_priors + self.f_key = ConvModule( + in_channels=self.in_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0, + norm_type='bn') + + self.f_query = nn.Sequential( + nn.Conv1D( + in_channels=num_priors, + out_channels=num_priors, + kernel_size=1, + stride=1, + padding=0, + groups=num_priors), + nn.ReLU(), ) + self.f_value = nn.Conv2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0) + self.W = nn.Conv1D( + in_channels=num_priors, + out_channels=num_priors, + kernel_size=1, + stride=1, + padding=0, + groups=num_priors) + + self.resize = FeatureResize() + constant_(self.W.weight, 0) + constant_(self.W.bias, 0) + + self.convs = nn.LayerList() + self.catconv = nn.LayerList() + for i in range(refine_layers): + self.convs.append( + ConvModule( + in_channels, + mid_channels, (9, 1), + padding=(4, 0), + bias=False, + norm_type='bn')) + + self.catconv.append( + ConvModule( + mid_channels * (i + 1), + in_channels, (9, 1), + padding=(4, 0), + bias=False, + norm_type='bn')) + + self.fc = nn.Linear( + sample_points * fc_hidden_dim, fc_hidden_dim, bias_attr=True) + + self.fc_norm = nn.LayerNorm(fc_hidden_dim) + + def roi_fea(self, x, layer_index): + feats = [] + for i, feature in enumerate(x): + feat_trans = self.convs[i](feature) + feats.append(feat_trans) + cat_feat = paddle.concat(feats, axis=1) + cat_feat = self.catconv[layer_index](cat_feat) + return cat_feat + + def forward(self, roi_features, x, layer_index): + ''' + Args: + roi_features: prior feature, shape: (Batch * num_priors, prior_feat_channel, sample_point, 1) + x: feature map + layer_index: currently on which layer to refine + Return: + roi: prior features with gathered global information, shape: (Batch, num_priors, fc_hidden_dim) + ''' + + roi = self.roi_fea(roi_features, layer_index) + # return roi + # print(roi.shape) + # return roi + bs = x.shape[0] + # print(bs) + #roi = roi.contiguous().view(bs * self.num_priors, -1) + roi = roi.reshape([bs * self.num_priors, -1]) + # roi = paddle.randn([192,2304]) + # return roi + # print(roi) + # print(self.fc) + # print(self.fc.weight) + roi = self.fc(roi) + roi = F.relu(self.fc_norm(roi)) + # return roi + #roi = roi.view(bs, self.num_priors, -1) + roi = roi.reshape([bs, self.num_priors, -1]) + query = roi + + value = self.resize(self.f_value(x)) # (B, C, N) global feature + query = self.f_query( + query) # (B, N, 1) sample context feature from prior roi + key = self.f_key(x) + value = value.transpose(perm=[0, 2, 1]) + key = self.resize(key) # (B, C, N) global feature + sim_map = paddle.matmul(query, key) + sim_map = (self.in_channels**-.5) * sim_map + sim_map = F.softmax(sim_map, axis=-1) + + context = paddle.matmul(sim_map, value) + context = self.W(context) + + roi = roi + F.dropout(context, p=0.1, training=self.training) + + return roi + + +class SegDecoder(nn.Layer): + ''' + Optionaly seg decoder + ''' + + def __init__(self, + image_height, + image_width, + num_class, + prior_feat_channels=64, + refine_layers=3): + super().__init__() + self.dropout = nn.Dropout2D(0.1) + self.conv = nn.Conv2D(prior_feat_channels * refine_layers, num_class, 1) + self.image_height = image_height + self.image_width = image_width + + def forward(self, x): + x = self.dropout(x) + x = self.conv(x) + x = F.interpolate( + x, + size=[self.image_height, self.image_width], + mode='bilinear', + align_corners=False) + return x + + +import paddle.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class) + target (torch.Tensor): The target of each prediction, shape (N, ) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.shape[0] == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == 2 and target.ndim == 1 + assert pred.shape[0] == target.shape[0] + assert maxk <= pred.shape[1], \ + f'maxk {maxk} exceeds pred dimension {pred.shape[1]}' + pred_value, pred_label = pred.topk(maxk, axis=1) + pred_label = pred_label.t() # transpose to shape (maxk, N) + correct = pred_label.equal(target.reshape([1, -1]).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + res = [] + for k in topk: + correct_k = correct[:k].reshape([-1]).cast("float32").sum(0, + keepdim=True) + correct_k = correct_k * (100.0 / pred.shape[0]) + res.append(correct_k) + return res[0] if return_single else res + + +class Accuracy(nn.Layer): + def __init__(self, topk=(1, ), thresh=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh) diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index 44a9fa85d19e512eb978bb9c80ab99857fc3b5ca..0d126c08fc3f66ed9c8ddd91c12b16f115e9da84 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -40,6 +40,7 @@ from . import ppyoloe_contrast_head from . import centertrack_head from . import sparse_roi_head from . import vitpose_head +from . import clrnet_head from .bbox_head import * from .mask_head import * @@ -69,4 +70,5 @@ from .ppyoloe_contrast_head import * from .centertrack_head import * from .sparse_roi_head import * from .petr_head import * -from .vitpose_head import * \ No newline at end of file +from .vitpose_head import * +from .clrnet_head import * \ No newline at end of file diff --git a/ppdet/modeling/heads/clrnet_head.py b/ppdet/modeling/heads/clrnet_head.py new file mode 100644 index 0000000000000000000000000000000000000000..14760b9ef04943925e74929e54c6ca086784f90b --- /dev/null +++ b/ppdet/modeling/heads/clrnet_head.py @@ -0,0 +1,399 @@ +import math +import paddle +import numpy as np +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register + +from ppdet.modeling.initializer import normal_ +from ppdet.modeling.lane_utils import Lane +from ppdet.modeling.losses import line_iou +from ppdet.modeling.clrnet_utils import ROIGather, LinearModule, SegDecoder + +__all__ = ['CLRHead'] + + +@register +class CLRHead(nn.Layer): + __inject__ = ['loss'] + __shared__ = [ + 'img_w', 'img_h', 'ori_img_h', 'num_classes', 'cut_height', + 'num_points', "max_lanes" + ] + + def __init__(self, + num_points=72, + prior_feat_channels=64, + fc_hidden_dim=64, + num_priors=192, + img_w=800, + img_h=320, + ori_img_h=590, + cut_height=270, + num_classes=5, + num_fc=2, + refine_layers=3, + sample_points=36, + conf_threshold=0.4, + nms_thres=0.5, + max_lanes=4, + loss='CLRNetLoss'): + super(CLRHead, self).__init__() + self.img_w = img_w + self.img_h = img_h + self.n_strips = num_points - 1 + self.n_offsets = num_points + self.num_priors = num_priors + self.sample_points = sample_points + self.refine_layers = refine_layers + self.num_classes = num_classes + self.fc_hidden_dim = fc_hidden_dim + self.ori_img_h = ori_img_h + self.cut_height = cut_height + self.conf_threshold = conf_threshold + self.nms_thres = nms_thres + self.max_lanes = max_lanes + self.prior_feat_channels = prior_feat_channels + self.loss = loss + self.register_buffer( + name='sample_x_indexs', + tensor=(paddle.linspace( + start=0, stop=1, num=self.sample_points, + dtype=paddle.float32) * self.n_strips).astype(dtype='int64')) + self.register_buffer( + name='prior_feat_ys', + tensor=paddle.flip( + x=(1 - self.sample_x_indexs.astype('float32') / self.n_strips), + axis=[-1])) + self.register_buffer( + name='prior_ys', + tensor=paddle.linspace( + start=1, stop=0, num=self.n_offsets).astype('float32')) + self.prior_feat_channels = prior_feat_channels + self._init_prior_embeddings() + init_priors, priors_on_featmap = self.generate_priors_from_embeddings() + self.register_buffer(name='priors', tensor=init_priors) + self.register_buffer(name='priors_on_featmap', tensor=priors_on_featmap) + self.seg_decoder = SegDecoder(self.img_h, self.img_w, self.num_classes, + self.prior_feat_channels, + self.refine_layers) + reg_modules = list() + cls_modules = list() + for _ in range(num_fc): + reg_modules += [*LinearModule(self.fc_hidden_dim)] + cls_modules += [*LinearModule(self.fc_hidden_dim)] + self.reg_modules = nn.LayerList(sublayers=reg_modules) + self.cls_modules = nn.LayerList(sublayers=cls_modules) + self.roi_gather = ROIGather(self.prior_feat_channels, self.num_priors, + self.sample_points, self.fc_hidden_dim, + self.refine_layers) + self.reg_layers = nn.Linear( + in_features=self.fc_hidden_dim, + out_features=self.n_offsets + 1 + 2 + 1, + bias_attr=True) + self.cls_layers = nn.Linear( + in_features=self.fc_hidden_dim, out_features=2, bias_attr=True) + self.init_weights() + + def init_weights(self): + for m in self.cls_layers.parameters(): + normal_(m, mean=0.0, std=0.001) + for m in self.reg_layers.parameters(): + normal_(m, mean=0.0, std=0.001) + + def pool_prior_features(self, batch_features, num_priors, prior_xs): + """ + pool prior feature from feature map. + Args: + batch_features (Tensor): Input feature maps, shape: (B, C, H, W) + """ + batch_size = batch_features.shape[0] + prior_xs = prior_xs.reshape([batch_size, num_priors, -1, 1]) + + prior_ys = self.prior_feat_ys.tile(repeat_times=[ + batch_size * num_priors + ]).reshape([batch_size, num_priors, -1, 1]) + prior_xs = prior_xs * 2.0 - 1.0 + prior_ys = prior_ys * 2.0 - 1.0 + grid = paddle.concat(x=(prior_xs, prior_ys), axis=-1) + feature = F.grid_sample( + x=batch_features, grid=grid, + align_corners=True).transpose(perm=[0, 2, 1, 3]) + feature = feature.reshape([ + batch_size * num_priors, self.prior_feat_channels, + self.sample_points, 1 + ]) + return feature + + def generate_priors_from_embeddings(self): + predictions = self.prior_embeddings.weight + # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, 72 coordinates, score[0] = negative prob, score[1] = positive prob + priors = paddle.zeros( + (self.num_priors, 2 + 2 + 2 + self.n_offsets), + dtype=predictions.dtype) + priors[:, 2:5] = predictions.clone() + priors[:, 6:] = ( + priors[:, 3].unsqueeze(1).clone().tile([1, self.n_offsets]) * + (self.img_w - 1) + + ((1 - self.prior_ys.tile([self.num_priors, 1]) - + priors[:, 2].unsqueeze(1).clone().tile([1, self.n_offsets])) * + self.img_h / paddle.tan(x=priors[:, 4].unsqueeze(1).clone().tile( + [1, self.n_offsets]) * math.pi + 1e-05))) / (self.img_w - 1) + priors_on_featmap = paddle.index_select( + priors, 6 + self.sample_x_indexs, axis=-1) + return priors, priors_on_featmap + + def _init_prior_embeddings(self): + self.prior_embeddings = nn.Embedding(self.num_priors, 3) + bottom_priors_nums = self.num_priors * 3 // 4 + left_priors_nums, _ = self.num_priors // 8, self.num_priors // 8 + strip_size = 0.5 / (left_priors_nums // 2 - 1) + bottom_strip_size = 1 / (bottom_priors_nums // 4 + 1) + + with paddle.no_grad(): + for i in range(left_priors_nums): + self.prior_embeddings.weight[i, 0] = i // 2 * strip_size + self.prior_embeddings.weight[i, 1] = 0.0 + self.prior_embeddings.weight[i, + 2] = 0.16 if i % 2 == 0 else 0.32 + + for i in range(left_priors_nums, + left_priors_nums + bottom_priors_nums): + self.prior_embeddings.weight[i, 0] = 0.0 + self.prior_embeddings.weight[i, 1] = ( + (i - left_priors_nums) // 4 + 1) * bottom_strip_size + self.prior_embeddings.weight[i, 2] = 0.2 * (i % 4 + 1) + + for i in range(left_priors_nums + bottom_priors_nums, + self.num_priors): + self.prior_embeddings.weight[i, 0] = ( + i - left_priors_nums - bottom_priors_nums) // 2 * strip_size + self.prior_embeddings.weight[i, 1] = 1.0 + self.prior_embeddings.weight[i, + 2] = 0.68 if i % 2 == 0 else 0.84 + + def forward(self, x, inputs=None): + """ + Take pyramid features as input to perform Cross Layer Refinement and finally output the prediction lanes. + Each feature is a 4D tensor. + Args: + x: input features (list[Tensor]) + Return: + prediction_list: each layer's prediction result + seg: segmentation result for auxiliary loss + """ + batch_features = list(x[len(x) - self.refine_layers:]) + batch_features.reverse() + batch_size = batch_features[-1].shape[0] + + if self.training: + self.priors, self.priors_on_featmap = self.generate_priors_from_embeddings( + ) + priors, priors_on_featmap = self.priors.tile( + [batch_size, 1, + 1]), self.priors_on_featmap.tile([batch_size, 1, 1]) + predictions_lists = [] + prior_features_stages = [] + + for stage in range(self.refine_layers): + num_priors = priors_on_featmap.shape[1] + prior_xs = paddle.flip(x=priors_on_featmap, axis=[2]) + batch_prior_features = self.pool_prior_features( + batch_features[stage], num_priors, prior_xs) + prior_features_stages.append(batch_prior_features) + + fc_features = self.roi_gather(prior_features_stages, + batch_features[stage], stage) + # return fc_features + fc_features = fc_features.reshape( + [num_priors, batch_size, -1]).reshape( + [batch_size * num_priors, self.fc_hidden_dim]) + cls_features = fc_features.clone() + reg_features = fc_features.clone() + + for cls_layer in self.cls_modules: + cls_features = cls_layer(cls_features) + + # return cls_features + for reg_layer in self.reg_modules: + reg_features = reg_layer(reg_features) + cls_logits = self.cls_layers(cls_features) + reg = self.reg_layers(reg_features) + + cls_logits = cls_logits.reshape( + [batch_size, -1, cls_logits.shape[1]]) + reg = reg.reshape([batch_size, -1, reg.shape[1]]) + predictions = priors.clone() + predictions[:, :, :2] = cls_logits + predictions[:, :, 2:5] += reg[:, :, :3] + predictions[:, :, 5] = reg[:, :, 3] + + def tran_tensor(t): + return t.unsqueeze(axis=2).clone().tile([1, 1, self.n_offsets]) + + predictions[..., 6:] = ( + tran_tensor(predictions[..., 3]) * (self.img_w - 1) + + ((1 - self.prior_ys.tile([batch_size, num_priors, 1]) - + tran_tensor(predictions[..., 2])) * self.img_h / paddle.tan( + tran_tensor(predictions[..., 4]) * math.pi + 1e-05))) / ( + self.img_w - 1) + + prediction_lines = predictions.clone() + predictions[..., 6:] += reg[..., 4:] + predictions_lists.append(predictions) + + if stage != self.refine_layers - 1: + priors = prediction_lines.detach().clone() + priors_on_featmap = priors.index_select( + 6 + self.sample_x_indexs, axis=-1) + + if self.training: + seg = None + seg_features = paddle.concat( + [ + F.interpolate( + feature, + size=[ + batch_features[-1].shape[2], + batch_features[-1].shape[3] + ], + mode='bilinear', + align_corners=False) for feature in batch_features + ], + axis=1) + + seg = self.seg_decoder(seg_features) + + output = {'predictions_lists': predictions_lists, 'seg': seg} + return self.loss(output, inputs) + return predictions_lists[-1] + + def predictions_to_pred(self, predictions): + """ + Convert predictions to internal Lane structure for evaluation. + """ + self.prior_ys = paddle.to_tensor(self.prior_ys) + self.prior_ys = self.prior_ys.astype('float64') + lanes = [] + for lane in predictions: + lane_xs = lane[6:].clone() + start = min( + max(0, int(round(lane[2].item() * self.n_strips))), + self.n_strips) + length = int(round(lane[5].item())) + end = start + length - 1 + end = min(end, len(self.prior_ys) - 1) + if start > 0: + mask = ((lane_xs[:start] >= 0.) & + (lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1] + mask = ~((mask.cumprod()[::-1]).astype(np.bool)) + lane_xs[:start][mask] = -2 + if end < len(self.prior_ys) - 1: + lane_xs[end + 1:] = -2 + + lane_ys = self.prior_ys[lane_xs >= 0].clone() + lane_xs = lane_xs[lane_xs >= 0] + lane_xs = lane_xs.flip(axis=0).astype('float64') + lane_ys = lane_ys.flip(axis=0) + + lane_ys = (lane_ys * + (self.ori_img_h - self.cut_height) + self.cut_height + ) / self.ori_img_h + if len(lane_xs) <= 1: + continue + points = paddle.stack( + x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])), + axis=1).squeeze(axis=2) + lane = Lane( + points=points.cpu().numpy(), + metadata={ + 'start_x': lane[3], + 'start_y': lane[2], + 'conf': lane[1] + }) + lanes.append(lane) + return lanes + + def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k): + """ + NMS for lane detection. + predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77] + scores: paddle.Tensor [num_lanes] + nms_overlap_thresh: float + top_k: int + """ + # sort by scores to get idx + idx = scores.argsort(descending=True) + keep = [] + + condidates = predictions.clone() + condidates = condidates.index_select(idx) + + while len(condidates) > 0: + keep.append(idx[0]) + if len(keep) >= top_k or len(condidates) == 1: + break + + ious = [] + for i in range(1, len(condidates)): + ious.append(1 - line_iou( + condidates[i].unsqueeze(0), + condidates[0].unsqueeze(0), + img_w=self.img_w, + length=15)) + ious = paddle.to_tensor(ious) + + mask = ious <= nms_overlap_thresh + id = paddle.where(mask == False)[0] + + if id.shape[0] == 0: + break + condidates = condidates[1:].index_select(id) + idx = idx[1:].index_select(id) + keep = paddle.stack(keep) + + return keep + + def get_lanes(self, output, as_lanes=True): + """ + Convert model output to lanes. + """ + softmax = nn.Softmax(axis=1) + decoded = [] + + for predictions in output: + threshold = self.conf_threshold + scores = softmax(predictions[:, :2])[:, 1] + keep_inds = scores >= threshold + predictions = predictions[keep_inds] + scores = scores[keep_inds] + + if predictions.shape[0] == 0: + decoded.append([]) + continue + nms_predictions = predictions.detach().clone() + nms_predictions = paddle.concat( + x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1) + + nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips + nms_predictions[..., 5:] = nms_predictions[..., 5:] * ( + self.img_w - 1) + + keep = self.lane_nms( + nms_predictions[..., 5:], + scores, + nms_overlap_thresh=self.nms_thres, + top_k=self.max_lanes) + + predictions = predictions.index_select(keep) + + if predictions.shape[0] == 0: + decoded.append([]) + continue + predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips) + if as_lanes: + pred = self.predictions_to_pred(predictions) + else: + pred = predictions + decoded.append(pred) + return decoded diff --git a/ppdet/modeling/lane_utils.py b/ppdet/modeling/lane_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3fb45c66d285664b9209e1a56a009d47a327971 --- /dev/null +++ b/ppdet/modeling/lane_utils.py @@ -0,0 +1,111 @@ +import os +import cv2 +import numpy as np +from scipy.interpolate import InterpolatedUnivariateSpline + + +class Lane: + def __init__(self, points=None, invalid_value=-2., metadata=None): + super(Lane, self).__init__() + self.curr_iter = 0 + self.points = points + self.invalid_value = invalid_value + self.function = InterpolatedUnivariateSpline( + points[:, 1], points[:, 0], k=min(3, len(points) - 1)) + self.min_y = points[:, 1].min() - 0.01 + self.max_y = points[:, 1].max() + 0.01 + self.metadata = metadata or {} + + def __repr__(self): + return '[Lane]\n' + str(self.points) + '\n[/Lane]' + + def __call__(self, lane_ys): + lane_xs = self.function(lane_ys) + + lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y + )] = self.invalid_value + return lane_xs + + def to_array(self, sample_y_range, img_w, img_h): + self.sample_y = range(sample_y_range[0], sample_y_range[1], + sample_y_range[2]) + sample_y = self.sample_y + img_w, img_h = img_w, img_h + ys = np.array(sample_y) / float(img_h) + xs = self(ys) + valid_mask = (xs >= 0) & (xs < 1) + lane_xs = xs[valid_mask] * img_w + lane_ys = ys[valid_mask] * img_h + lane = np.concatenate( + (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1) + return lane + + def __iter__(self): + return self + + def __next__(self): + if self.curr_iter < len(self.points): + self.curr_iter += 1 + return self.points[self.curr_iter - 1] + self.curr_iter = 0 + raise StopIteration + + +COLORS = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (128, 255, 0), + (255, 128, 0), + (128, 0, 255), + (255, 0, 128), + (0, 128, 255), + (0, 255, 128), + (128, 255, 255), + (255, 128, 255), + (255, 255, 128), + (60, 180, 0), + (180, 60, 0), + (0, 60, 180), + (0, 180, 60), + (60, 0, 180), + (180, 0, 60), + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (128, 255, 0), + (255, 128, 0), + (128, 0, 255), +] + + +def imshow_lanes(img, lanes, show=False, out_file=None, width=4): + lanes_xys = [] + for _, lane in enumerate(lanes): + xys = [] + for x, y in lane: + if x <= 0 or y <= 0: + continue + x, y = int(x), int(y) + xys.append((x, y)) + lanes_xys.append(xys) + lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0) + + for idx, xys in enumerate(lanes_xys): + for i in range(1, len(xys)): + cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width) + + if show: + cv2.imshow('view', img) + cv2.waitKey(0) + + if out_file: + if not os.path.exists(os.path.dirname(out_file)): + os.makedirs(os.path.dirname(out_file)) + cv2.imwrite(out_file, img) diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 0e6b31de8a8fcacce2cfa62242f458565540d0b6..41b3ae0f13850f8c44602316579ae08da4aca6ae 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -31,6 +31,8 @@ from . import probiou_loss from . import cot_loss from . import supcontrast from . import queryinst_loss +from . import clrnet_loss +from . import clrnet_line_iou_loss from .yolo_loss import * from .iou_aware_loss import * @@ -52,3 +54,5 @@ from .probiou_loss import * from .cot_loss import * from .supcontrast import * from .queryinst_loss import * +from .clrnet_loss import * +from .clrnet_line_iou_loss import * \ No newline at end of file diff --git a/ppdet/modeling/losses/clrnet_line_iou_loss.py b/ppdet/modeling/losses/clrnet_line_iou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2a1973d788e36350906b75f5e61a4ca5066d1bc5 --- /dev/null +++ b/ppdet/modeling/losses/clrnet_line_iou_loss.py @@ -0,0 +1,41 @@ +import paddle + + +def line_iou(pred, target, img_w, length=15, aligned=True): + ''' + Calculate the line iou value between predictions and targets + Args: + pred: lane predictions, shape: (num_pred, 72) + target: ground truth, shape: (num_target, 72) + img_w: image width + length: extended radius + aligned: True for iou loss calculation, False for pair-wise ious in assign + ''' + px1 = pred - length + px2 = pred + length + tx1 = target - length + tx2 = target + length + + if aligned: + invalid_mask = target + ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1) + union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1) + else: + num_pred = pred.shape[0] + invalid_mask = target.tile([num_pred, 1, 1]) + + ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum( + px1[:, None, :], tx1[None, ...])) + union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) - + paddle.minimum(px1[:, None, :], tx1[None, ...])) + + invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w) + + ovr[invalid_masks] = 0. + union[invalid_masks] = 0. + iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9) + return iou + + +def liou_loss(pred, target, img_w, length=15): + return (1 - line_iou(pred, target, img_w, length)).mean() diff --git a/ppdet/modeling/losses/clrnet_loss.py b/ppdet/modeling/losses/clrnet_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ad39e58b82ed7495e8a2e295e29d897aff919e --- /dev/null +++ b/ppdet/modeling/losses/clrnet_loss.py @@ -0,0 +1,283 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register +from ppdet.modeling.clrnet_utils import accuracy +from ppdet.modeling.assigners.clrnet_assigner import assign +from ppdet.modeling.losses.clrnet_line_iou_loss import liou_loss + +__all__ = ['CLRNetLoss'] + + +class SoftmaxFocalLoss(nn.Layer): + def __init__(self, gamma, ignore_lb=255, *args, **kwargs): + super(SoftmaxFocalLoss, self).__init__() + self.gamma = gamma + self.nll = nn.NLLLoss(ignore_index=ignore_lb) + + def forward(self, logits, labels): + scores = F.softmax(logits, dim=1) + factor = paddle.pow(1. - scores, self.gamma) + log_score = F.log_softmax(logits, dim=1) + log_score = factor * log_score + loss = self.nll(log_score, labels) + return loss + + +def focal_loss(input: paddle.Tensor, + target: paddle.Tensor, + alpha: float, + gamma: float=2.0, + reduction: str='none', + eps: float=1e-8) -> paddle.Tensor: + r"""Function that computes Focal loss. + + See :class:`~kornia.losses.FocalLoss` for details. + """ + if not paddle.is_tensor(input): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(input))) + + if not len(input.shape) >= 2: + raise ValueError("Invalid input shape, we expect BxCx*. Got: {}".format( + input.shape)) + + if input.shape[0] != target.shape[0]: + raise ValueError( + 'Expected input batch_size ({}) to match target batch_size ({}).'. + format(input.shape[0], target.shape[0])) + + n = input.shape[0] + out_size = (n, ) + tuple(input.shape[2:]) + if target.shape[1:] != input.shape[2:]: + raise ValueError('Expected target size {}, got {}'.format(out_size, + target.shape)) + if (isinstance(input.place, paddle.CUDAPlace) and + isinstance(target.place, paddle.CPUPlace)) | (isinstance( + input.place, paddle.CPUPlace) and isinstance(target.place, + paddle.CUDAPlace)): + raise ValueError( + "input and target must be in the same device. Got: {} and {}". + format(input.place, target.place)) + + # compute softmax over the classes axis + input_soft: paddle.Tensor = F.softmax(input, axis=1) + eps + + # create the labels one hot tensor + target_one_hot: paddle.Tensor = paddle.to_tensor( + F.one_hot( + target, num_classes=input.shape[1]).cast(input.dtype), + place=input.place) + + # compute the actual focal loss + weight = paddle.pow(-input_soft + 1., gamma) + + focal = -alpha * weight * paddle.log(input_soft) + loss_tmp = paddle.sum(target_one_hot * focal, axis=1) + + if reduction == 'none': + loss = loss_tmp + elif reduction == 'mean': + loss = paddle.mean(loss_tmp) + elif reduction == 'sum': + loss = paddle.sum(loss_tmp) + else: + raise NotImplementedError("Invalid reduction mode: {}".format( + reduction)) + return loss + + +class FocalLoss(nn.Layer): + r"""Criterion that computes Focal loss. + + According to [1], the Focal loss is computed as follows: + + .. math:: + + \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) + + where: + - :math:`p_t` is the model's estimated probability for each class. + + + Arguments: + alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. + gamma (float): Focusing parameter :math:`\gamma >= 0`. + reduction (str, optional): Specifies the reduction to apply to the + output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, + ‘mean’: the sum of the output will be divided by the number of elements + in the output, ‘sum’: the output will be summed. Default: ‘none’. + + Shape: + - Input: :math:`(N, C, *)` where C = number of classes. + - Target: :math:`(N, *)` where each value is + :math:`0 ≤ targets[i] ≤ C−1`. + + Examples: + >>> N = 5 # num_classes + >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'} + >>> loss = kornia.losses.FocalLoss(**kwargs) + >>> input = torch.randn(1, N, 3, 5, requires_grad=True) + >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) + >>> output = loss(input, target) + >>> output.backward() + + References: + [1] https://arxiv.org/abs/1708.02002 + """ + + def __init__(self, alpha: float, gamma: float=2.0, + reduction: str='none') -> None: + super(FocalLoss, self).__init__() + self.alpha: float = alpha + self.gamma: float = gamma + self.reduction: str = reduction + self.eps: float = 1e-6 + + def forward( # type: ignore + self, input: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor: + return focal_loss(input, target, self.alpha, self.gamma, self.reduction, + self.eps) + + +@register +class CLRNetLoss(nn.Layer): + __shared__ = ['img_w', 'img_h', 'num_classes', 'num_points'] + + def __init__(self, + cls_loss_weight=2.0, + xyt_loss_weight=0.2, + iou_loss_weight=2.0, + seg_loss_weight=1.0, + refine_layers=3, + num_points=72, + img_w=800, + img_h=320, + num_classes=5, + ignore_label=255, + bg_weight=0.4): + super(CLRNetLoss, self).__init__() + self.cls_loss_weight = cls_loss_weight + self.xyt_loss_weight = xyt_loss_weight + self.iou_loss_weight = iou_loss_weight + self.seg_loss_weight = seg_loss_weight + self.refine_layers = refine_layers + self.img_w = img_w + self.img_h = img_h + self.n_strips = num_points - 1 + self.num_classes = num_classes + self.ignore_label = ignore_label + weights = paddle.ones(shape=[self.num_classes]) + weights[0] = bg_weight + self.criterion = nn.NLLLoss( + ignore_index=self.ignore_label, weight=weights) + + def forward(self, output, batch): + predictions_lists = output['predictions_lists'] + targets = batch['lane_line'].clone() + cls_criterion = FocalLoss(alpha=0.25, gamma=2.0) + cls_loss = paddle.to_tensor(0.0) + reg_xytl_loss = paddle.to_tensor(0.0) + iou_loss = paddle.to_tensor(0.0) + cls_acc = [] + cls_acc_stage = [] + for stage in range(self.refine_layers): + predictions_list = predictions_lists[stage] + for predictions, target in zip(predictions_list, targets): + target = target[target[:, 1] == 1] + + if len(target) == 0: + # If there are no targets, all predictions have to be negatives (i.e., 0 confidence) + cls_target = paddle.zeros( + [predictions.shape[0]], dtype='int64') + cls_pred = predictions[:, :2] + cls_loss = cls_loss + cls_criterion(cls_pred, + cls_target).sum() + continue + + with paddle.no_grad(): + matched_row_inds, matched_col_inds = assign( + predictions, target, self.img_w, self.img_h) + + # classification targets + cls_target = paddle.zeros([predictions.shape[0]], dtype='int64') + cls_target[matched_row_inds] = 1 + cls_pred = predictions[:, :2] + + # regression targets -> [start_y, start_x, theta] (all transformed to absolute values), only on matched pairs + reg_yxtl = predictions.index_select(matched_row_inds)[..., 2:6] + + reg_yxtl[:, 0] *= self.n_strips + reg_yxtl[:, 1] *= (self.img_w - 1) + reg_yxtl[:, 2] *= 180 + reg_yxtl[:, 3] *= self.n_strips + + target_yxtl = target.index_select(matched_col_inds)[..., 2: + 6].clone() + + # regression targets -> S coordinates (all transformed to absolute values) + reg_pred = predictions.index_select(matched_row_inds)[..., 6:] + reg_pred *= (self.img_w - 1) + reg_targets = target.index_select(matched_col_inds)[..., + 6:].clone() + + with paddle.no_grad(): + predictions_starts = paddle.clip( + (predictions.index_select(matched_row_inds)[..., 2] * + self.n_strips).round().cast("int64"), + min=0, + max=self. + n_strips) # ensure the predictions starts is valid + + target_starts = ( + target.index_select(matched_col_inds)[..., 2] * + self.n_strips).round().cast("int64") + target_yxtl[:, -1] -= ( + predictions_starts - target_starts) # reg length + + # Loss calculation + cls_loss = cls_loss + cls_criterion( + cls_pred, cls_target).sum() / target.shape[0] + + target_yxtl[:, 0] *= self.n_strips + target_yxtl[:, 2] *= 180 + + reg_xytl_loss = reg_xytl_loss + F.smooth_l1_loss( + input=reg_yxtl, label=target_yxtl, reduction='none').mean() + + iou_loss = iou_loss + liou_loss( + reg_pred, reg_targets, self.img_w, length=15) + + cls_accuracy = accuracy(cls_pred, cls_target) + cls_acc_stage.append(cls_accuracy) + + cls_acc.append(sum(cls_acc_stage) / (len(cls_acc_stage) + 1e-5)) + + # extra segmentation loss + seg_loss = self.criterion( + F.log_softmax( + output['seg'], axis=1), batch['seg'].cast('int64')) + + cls_loss /= (len(targets) * self.refine_layers) + reg_xytl_loss /= (len(targets) * self.refine_layers) + iou_loss /= (len(targets) * self.refine_layers) + + loss = cls_loss * self.cls_loss_weight \ + + reg_xytl_loss * self.xyt_loss_weight \ + + seg_loss * self.seg_loss_weight \ + + iou_loss * self.iou_loss_weight + + return_value = { + 'loss': loss, + 'cls_loss': cls_loss * self.cls_loss_weight, + 'reg_xytl_loss': reg_xytl_loss * self.xyt_loss_weight, + 'seg_loss': seg_loss * self.seg_loss_weight, + 'iou_loss': iou_loss * self.iou_loss_weight + } + + for i in range(self.refine_layers): + if not isinstance(cls_acc[i], paddle.Tensor): + cls_acc[i] = paddle.to_tensor(cls_acc[i]) + return_value['stage_{}_acc'.format(i)] = cls_acc[i] + + return return_value diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index 478efec98e324b213ad3f822b551f92265d91e25..afd2a954545af722a56c172be8e8890761a28811 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -23,6 +23,7 @@ from . import es_pan from . import lc_pan from . import custom_pan from . import dilated_encoder +from . import clrnet_fpn from .fpn import * from .yolo_fpn import * @@ -37,3 +38,4 @@ from .lc_pan import * from .custom_pan import * from .dilated_encoder import * from .channel_mapper import * +from .clrnet_fpn import * diff --git a/ppdet/modeling/necks/clrnet_fpn.py b/ppdet/modeling/necks/clrnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..936c7e7c7bc2d629c33d568834978e9b32dc8fbe --- /dev/null +++ b/ppdet/modeling/necks/clrnet_fpn.py @@ -0,0 +1,254 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import XavierUniform +from ppdet.modeling.initializer import kaiming_normal_, constant_ +from ppdet.core.workspace import register, serializable +from ppdet.modeling.layers import ConvNormLayer +from ppdet.modeling.shape_spec import ShapeSpec + +__all__ = ['CLRFPN'] + + +@register +@serializable +class CLRFPN(nn.Layer): + """ + Feature Pyramid Network, see https://arxiv.org/abs/1612.03144 + Args: + in_channels (list[int]): input channels of each level which can be + derived from the output shape of backbone by from_config + out_channel (int): output channel of each level + spatial_scales (list[float]): the spatial scales between input feature + maps and original input image which can be derived from the output + shape of backbone by from_config + has_extra_convs (bool): whether to add extra conv to the last level. + default False + extra_stage (int): the number of extra stages added to the last level. + default 1 + use_c5 (bool): Whether to use c5 as the input of extra stage, + otherwise p5 is used. default True + norm_type (string|None): The normalization type in FPN module. If + norm_type is None, norm will not be used after conv and if + norm_type is string, bn, gn, sync_bn are available. default None + norm_decay (float): weight decay for normalization layer weights. + default 0. + freeze_norm (bool): whether to freeze normalization layer. + default False + relu_before_extra_convs (bool): whether to add relu before extra convs. + default False + + """ + + def __init__(self, + in_channels, + out_channel, + spatial_scales=[0.25, 0.125, 0.0625, 0.03125], + has_extra_convs=False, + extra_stage=1, + use_c5=True, + norm_type=None, + norm_decay=0., + freeze_norm=False, + relu_before_extra_convs=True): + super(CLRFPN, self).__init__() + self.out_channel = out_channel + for s in range(extra_stage): + spatial_scales = spatial_scales + [spatial_scales[-1] / 2.] + self.spatial_scales = spatial_scales + self.has_extra_convs = has_extra_convs + self.extra_stage = extra_stage + self.use_c5 = use_c5 + self.relu_before_extra_convs = relu_before_extra_convs + self.norm_type = norm_type + self.norm_decay = norm_decay + self.freeze_norm = freeze_norm + self.in_channels = in_channels + self.lateral_convs = [] + self.fpn_convs = [] + fan = out_channel * 3 * 3 + + # stage index 0,1,2,3 stands for res2,res3,res4,res5 on ResNet Backbone + # 0 <= st_stage < ed_stage <= 3 + st_stage = 4 - len(in_channels) + ed_stage = st_stage + len(in_channels) - 1 + + for i in range(st_stage, ed_stage + 1): + # if i == 3: + # lateral_name = 'fpn_inner_res5_sum' + # else: + # lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2) + lateral_name = "lateral_convs.{}.conv".format(i - 1) + in_c = in_channels[i - st_stage] + if self.norm_type is not None: + lateral = self.add_sublayer( + lateral_name, + ConvNormLayer( + ch_in=in_c, + ch_out=out_channel, + filter_size=1, + stride=1, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + initializer=XavierUniform(fan_out=in_c))) + else: + lateral = self.add_sublayer( + lateral_name, + nn.Conv2D( + in_channels=in_c, + out_channels=out_channel, + kernel_size=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=in_c)))) + self.lateral_convs.append(lateral) + + fpn_name = "fpn_convs.{}.conv".format(i - 1) + if self.norm_type is not None: + fpn_conv = self.add_sublayer( + fpn_name, + ConvNormLayer( + ch_in=out_channel, + ch_out=out_channel, + filter_size=3, + stride=1, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + initializer=XavierUniform(fan_out=fan))) + else: + fpn_conv = self.add_sublayer( + fpn_name, + nn.Conv2D( + in_channels=out_channel, + out_channels=out_channel, + kernel_size=3, + padding=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=fan)))) + self.fpn_convs.append(fpn_conv) + + # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5) + if self.has_extra_convs: + for i in range(self.extra_stage): + lvl = ed_stage + 1 + i + if i == 0 and self.use_c5: + in_c = in_channels[-1] + else: + in_c = out_channel + extra_fpn_name = 'fpn_{}'.format(lvl + 2) + if self.norm_type is not None: + extra_fpn_conv = self.add_sublayer( + extra_fpn_name, + ConvNormLayer( + ch_in=in_c, + ch_out=out_channel, + filter_size=3, + stride=2, + norm_type=self.norm_type, + norm_decay=self.norm_decay, + freeze_norm=self.freeze_norm, + initializer=XavierUniform(fan_out=fan))) + else: + extra_fpn_conv = self.add_sublayer( + extra_fpn_name, + nn.Conv2D( + in_channels=in_c, + out_channels=out_channel, + kernel_size=3, + stride=2, + padding=1, + weight_attr=ParamAttr( + initializer=XavierUniform(fan_out=fan)))) + self.fpn_convs.append(extra_fpn_conv) + self.init_weights() + + def init_weights(self): + for m in self.lateral_convs: + if isinstance(m, (nn.Conv1D, nn.Conv2D)): + kaiming_normal_( + m.weight, a=0, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + constant_(m.bias, value=0.) + elif isinstance(m, (nn.BatchNorm1D, nn.BatchNorm2D)): + constant_(m.weight, value=1) + constant_(m.bias, value=0) + for m in self.fpn_convs: + if isinstance(m, (nn.Conv1D, nn.Conv2D)): + kaiming_normal_( + m.weight, a=0, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + constant_(m.bias, value=0.) + elif isinstance(m, (nn.BatchNorm1D, nn.BatchNorm2D)): + constant_(m.weight, value=1) + constant_(m.bias, value=0) + + @classmethod + def from_config(cls, cfg, input_shape): + return {} + + def forward(self, body_feats): + laterals = [] + if len(body_feats) > len(self.in_channels): + for _ in range(len(body_feats) - len(self.in_channels)): + del body_feats[0] + num_levels = len(body_feats) + # print("body_feats",num_levels) + for i in range(num_levels): + laterals.append(self.lateral_convs[i](body_feats[i])) + + for i in range(1, num_levels): + lvl = num_levels - i + upsample = F.interpolate( + laterals[lvl], + scale_factor=2., + mode='nearest', ) + laterals[lvl - 1] += upsample + + fpn_output = [] + for lvl in range(num_levels): + fpn_output.append(self.fpn_convs[lvl](laterals[lvl])) + + if self.extra_stage > 0: + # use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN) + if not self.has_extra_convs: + assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs' + fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2)) + # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5) + else: + if self.use_c5: + extra_source = body_feats[-1] + else: + extra_source = fpn_output[-1] + fpn_output.append(self.fpn_convs[num_levels](extra_source)) + + for i in range(1, self.extra_stage): + if self.relu_before_extra_convs: + fpn_output.append(self.fpn_convs[num_levels + i](F.relu( + fpn_output[-1]))) + else: + fpn_output.append(self.fpn_convs[num_levels + i]( + fpn_output[-1])) + return fpn_output + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.out_channel, stride=1. / s) + for s in self.spatial_scales + ] diff --git a/ppdet/utils/download.py b/ppdet/utils/download.py index 8fb95afa36602ce9c6964ff05190216d01ffb235..5f3665fc8d3ea41b9d7605e917d628cc71439177 100644 --- a/ppdet/utils/download.py +++ b/ppdet/utils/download.py @@ -101,7 +101,8 @@ DATASETS = { '8a3a353c2c54a2284ad7d2780b65f6a6', ), ], ['annotations', 'images']), 'coco_ce': ([( 'https://paddledet.bj.bcebos.com/data/coco_ce.tar', - 'eadd1b79bc2f069f2744b1dd4e0c0329', ), ], []) + 'eadd1b79bc2f069f2744b1dd4e0c0329', ), ], []), + 'culane': ([('https://bj.bcebos.com/v1/paddledet/data/culane.tar', None, ), ], []) } DOWNLOAD_DATASETS_LIST = DATASETS.keys() diff --git a/requirements.txt b/requirements.txt index f6297b6f7849b4f8c33be4d4386bcddce4b26665..a094c54eaff7092a302c68c5bae3b76249b79a6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,6 @@ sklearn==0.0 # for vehicleplate in deploy/pipeline/ppvehicle pyclipper + +# for culane data augumetation +imgaug>=0.4.0 \ No newline at end of file diff --git a/tools/infer_culane.py b/tools/infer_culane.py new file mode 100644 index 0000000000000000000000000000000000000000..5f629467eacbca09f6fc98cfe174cf043632b27e --- /dev/null +++ b/tools/infer_culane.py @@ -0,0 +1,165 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +# add python path of PaddleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +sys.path.insert(0, parent_path) + +# ignore warning log +import warnings +warnings.filterwarnings('ignore') +import glob +import ast + +import paddle +from ppdet.core.workspace import load_config, merge_config +from ppdet.engine import Trainer +from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config +from ppdet.utils.cli import ArgsParser, merge_args +from ppdet.slim import build_slim_model + +from ppdet.utils.logger import setup_logger +logger = setup_logger('train') + + +def parse_args(): + parser = ArgsParser() + parser.add_argument( + "--infer_dir", + type=str, + default=None, + help="Directory for images to perform inference on.") + parser.add_argument( + "--infer_img", + type=str, + default=None, + help="Image path, has higher priority over --infer_dir") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output visualization files.") + parser.add_argument( + "--save_results", + type=bool, + default=False, + help="Whether to save inference results to output_dir.") + parser.add_argument( + "--visualize", + type=ast.literal_eval, + default=True, + help="Whether to save visualize results to output_dir.") + args = parser.parse_args() + return args + + +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--infer_img or --infer_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + return [infer_img] + + images = set() + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + logger.info("Found {} inference images in total.".format(len(images))) + + return images + + +def run(FLAGS, cfg): + # build trainer + trainer = Trainer(cfg, mode='test') + + # load weights + trainer.load_weights(cfg.weights) + + # get inference images + images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) + + trainer.predict_culane( + images, + output_dir=FLAGS.output_dir, + save_results=FLAGS.save_results, + visualize=FLAGS.visualize) + + +def main(): + FLAGS = parse_args() + cfg = load_config(FLAGS.config) + merge_args(cfg, FLAGS) + merge_config(FLAGS.opt) + + # disable npu in config by default + if 'use_npu' not in cfg: + cfg.use_npu = False + + # disable xpu in config by default + if 'use_xpu' not in cfg: + cfg.use_xpu = False + + if 'use_gpu' not in cfg: + cfg.use_gpu = False + + # disable mlu in config by default + if 'use_mlu' not in cfg: + cfg.use_mlu = False + + if cfg.use_gpu: + place = paddle.set_device('gpu') + elif cfg.use_npu: + place = paddle.set_device('npu') + elif cfg.use_xpu: + place = paddle.set_device('xpu') + elif cfg.use_mlu: + place = paddle.set_device('mlu') + else: + place = paddle.set_device('cpu') + + check_config(cfg) + check_gpu(cfg.use_gpu) + check_npu(cfg.use_npu) + check_xpu(cfg.use_xpu) + check_mlu(cfg.use_mlu) + check_version() + + run(FLAGS, cfg) + + +if __name__ == '__main__': + main() \ No newline at end of file