From f9447e9892e60994775b886aae95fcf5d18a55af Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Mon, 20 Feb 2023 20:59:06 +0800 Subject: [PATCH] Cherry pick pipeline (#7798) * ppvehicle add customized docs, test=document_fix * fix pipeline readme.md test=document_fix * fix pipeline readme.md link test=document_fix * add licensed to lane_to_mask.py * ppvehicle_violation.md add args list test=document_fix --- deploy/pipeline/README.md | 4 +- deploy/pipeline/README_en.md | 4 +- deploy/pipeline/tools/create_dataset_list.py | 147 +++++ deploy/pipeline/tools/lane_to_mask.py | 508 ++++++++++++++++++ .../customization/ppvehicle_violation.md | 235 ++++++++ .../customization/ppvehicle_violation_en.md | 240 +++++++++ 6 files changed, 1134 insertions(+), 4 deletions(-) create mode 100644 deploy/pipeline/tools/create_dataset_list.py create mode 100644 deploy/pipeline/tools/lane_to_mask.py create mode 100644 docs/advanced_tutorials/customization/ppvehicle_violation.md create mode 100644 docs/advanced_tutorials/customization/ppvehicle_violation_en.md diff --git a/deploy/pipeline/README.md b/deploy/pipeline/README.md index 5d20d7358..f05510af9 100644 --- a/deploy/pipeline/README.md +++ b/deploy/pipeline/README.md @@ -155,10 +155,10 @@ - [快速开始](docs/tutorials/ppvehicle_press.md) -- [二次开发教程] +- [二次开发教程](../../docs/advanced_tutorials/customization/ppvehicle_violation.md) #### 车辆逆行 - [快速开始](docs/tutorials/ppvehicle_retrograde.md) -- [二次开发教程] +- [二次开发教程](../../docs/advanced_tutorials/customization/ppvehicle_violation.md) diff --git a/deploy/pipeline/README_en.md b/deploy/pipeline/README_en.md index 569f5dd93..ef5c0077a 100644 --- a/deploy/pipeline/README_en.md +++ b/deploy/pipeline/README_en.md @@ -152,10 +152,10 @@ Click to download the model, then unzip and save it in the `. /output_inference` - [A quick start](docs/tutorials/ppvehicle_press_en.md) -- [Customized development tutorials] +- [Customized development tutorials](../../docs/advanced_tutorials/customization/ppvehicle_violation_en.md) #### Vehicle Retrograde - [A quick start](docs/tutorials/ppvehicle_retrograde_en.md) -- [Customized development tutorials] +- [Customized development tutorials](../../docs/advanced_tutorials/customization/ppvehicle_violation_en.md) diff --git a/deploy/pipeline/tools/create_dataset_list.py b/deploy/pipeline/tools/create_dataset_list.py new file mode 100644 index 000000000..261e15e8f --- /dev/null +++ b/deploy/pipeline/tools/create_dataset_list.py @@ -0,0 +1,147 @@ +# coding: utf8 +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os.path +import argparse +import warnings + + +def parse_args(): + parser = argparse.ArgumentParser( + description='PaddleSeg generate file list on cityscapes or your customized dataset.' + ) + parser.add_argument('dataset_root', help='dataset root directory', type=str) + parser.add_argument( + '--type', + help='dataset type: \n' + '- cityscapes \n' + '- custom(default)', + default="custom", + type=str) + parser.add_argument( + '--separator', + dest='separator', + help='file list separator', + default=" ", + type=str) + parser.add_argument( + '--folder', + help='the folder names of images and labels', + type=str, + nargs=2, + default=['images', 'labels']) + parser.add_argument( + '--second_folder', + help='the second-level folder names of train set, validation set, test set', + type=str, + nargs='*', + default=['train', 'val', 'test']) + parser.add_argument( + '--format', + help='data format of images and labels, e.g. jpg or png.', + type=str, + nargs=2, + default=['jpg', 'png']) + parser.add_argument( + '--postfix', + help='postfix of images or labels', + type=str, + nargs=2, + default=['', '']) + + return parser.parse_args() + + +def get_files(image_or_label, dataset_split, args): + dataset_root = args.dataset_root + postfix = args.postfix + format = args.format + folder = args.folder + + pattern = '*%s.%s' % (postfix[image_or_label], format[image_or_label]) + + search_files = os.path.join(dataset_root, folder[image_or_label], + dataset_split, pattern) + search_files2 = os.path.join(dataset_root, folder[image_or_label], + dataset_split, "*", pattern) # 包含子目录 + search_files3 = os.path.join(dataset_root, folder[image_or_label], + dataset_split, "*", "*", pattern) # 包含三级目录 + search_files4 = os.path.join(dataset_root, folder[image_or_label], + dataset_split, "*", "*", "*", + pattern) # 包含四级目录 + search_files5 = os.path.join(dataset_root, folder[image_or_label], + dataset_split, "*", "*", "*", "*", + pattern) # 包含五级目录 + + filenames = glob.glob(search_files) + filenames2 = glob.glob(search_files2) + filenames3 = glob.glob(search_files3) + filenames4 = glob.glob(search_files4) + filenames5 = glob.glob(search_files5) + + filenames = filenames + filenames2 + filenames3 + filenames4 + filenames5 + + return sorted(filenames) + + +def generate_list(args): + dataset_root = args.dataset_root + separator = args.separator + + for dataset_split in args.second_folder: + print("Creating {}.txt...".format(dataset_split)) + image_files = get_files(0, dataset_split, args) + label_files = get_files(1, dataset_split, args) + if not image_files: + img_dir = os.path.join(dataset_root, args.folder[0], dataset_split) + warnings.warn("No images in {} !!!".format(img_dir)) + num_images = len(image_files) + + if not label_files: + label_dir = os.path.join(dataset_root, args.folder[1], + dataset_split) + warnings.warn("No labels in {} !!!".format(label_dir)) + num_label = len(label_files) + + if num_images != num_label and num_label > 0: + raise Exception( + "Number of images = {} number of labels = {} \n" + "Either number of images is equal to number of labels, " + "or number of labels is equal to 0.\n" + "Please check your dataset!".format(num_images, num_label)) + + file_list = os.path.join(dataset_root, dataset_split + '.txt') + with open(file_list, "w") as f: + for item in range(num_images): + left = image_files[item].replace(dataset_root, '', 1) + if left[0] == os.path.sep: + left = left.lstrip(os.path.sep) + + try: + right = label_files[item].replace(dataset_root, '', 1) + if right[0] == os.path.sep: + right = right.lstrip(os.path.sep) + line = left + separator + right + '\n' + except: + line = left + '\n' + + f.write(line) + print(line) + + +if __name__ == '__main__': + args = parse_args() + generate_list(args) diff --git a/deploy/pipeline/tools/lane_to_mask.py b/deploy/pipeline/tools/lane_to_mask.py new file mode 100644 index 000000000..ece2efb87 --- /dev/null +++ b/deploy/pipeline/tools/lane_to_mask.py @@ -0,0 +1,508 @@ +# coding: utf8 +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert poly2d to mask/bitmask.""" + +import os +from functools import partial +from multiprocessing import Pool +from typing import Callable, Dict, List + +import matplotlib # type: ignore +import matplotlib.pyplot as plt # type: ignore +import numpy as np +from PIL import Image +from scalabel.common.parallel import NPROC +from scalabel.common.typing import NDArrayU8 +from scalabel.label.io import group_and_sort, load +from scalabel.label.transforms import poly_to_patch +from scalabel.label.typing import Config, Frame, ImageSize, Label, Poly2D +from scalabel.label.utils import ( + check_crowd, + check_ignored, + get_leaf_categories, ) +from tqdm import tqdm + +from bdd100k.common.logger import logger +from bdd100k.common.typing import BDD100KConfig +from bdd100k.common.utils import get_bdd100k_instance_id, load_bdd100k_config +from bdd100k.label.label import drivables, labels, lane_categories +from bdd100k.label.to_coco import parse_args +from bdd100k.label.to_scalabel import bdd100k_to_scalabel + +IGNORE_LABEL = 255 +STUFF_NUM = 30 +LANE_DIRECTION_MAP = {"parallel": 0, "vertical": 1} +LANE_STYLE_MAP = {"solid": 0, "dashed": 1} + + +def frame_to_mask( + out_path: str, + shape: ImageSize, + colors: List[NDArrayU8], + poly2ds: List[List[Poly2D]], + with_instances: bool=True, + back_color: int=0, + closed: bool=True, ) -> None: + """Converting a frame of poly2ds to mask/bitmask.""" + assert len(colors) == len(poly2ds) + height, width = shape.height, shape.width + + assert back_color >= 0 + if with_instances: + img: NDArrayU8 = ( + np.ones( + [height, width, 4], dtype=np.uint8) * back_color # type: ignore + ) + else: + img = ( + np.ones( + [height, width, 1], dtype=np.uint8) * back_color # type: ignore + ) + + if len(colors) == 0: + pil_img = Image.fromarray(img.squeeze()) + pil_img.save(out_path) + + matplotlib.use("Agg") + fig = plt.figure(facecolor="0") + fig.set_size_inches((width / fig.get_dpi()), height / fig.get_dpi()) + ax = fig.add_axes([0, 0, 1, 1]) + ax.axis("off") + ax.set_xlim(0, width) + ax.set_ylim(0, height) + ax.set_facecolor((0, 0, 0, 0)) + ax.invert_yaxis() + + for i, poly2d in enumerate(poly2ds): + for poly in poly2d: + ax.add_patch( + poly_to_patch( + poly.vertices, + poly.types, + # (0, 0, 0) for the background + color=( + ((i + 1) >> 8) / 255.0, + ((i + 1) % 255) / 255.0, + 0.0, ), + closed=closed, )) + + fig.canvas.draw() + out: NDArrayU8 = np.frombuffer(fig.canvas.tostring_rgb(), np.uint8) + out = out.reshape((height, width, -1)).astype(np.int32) + out = (out[..., 0] << 8) + out[..., 1] + plt.close() + + for i, color in enumerate(colors): + # 0 is for the background + img[out == i + 1] = color + + img[img == 255] = 0 + + pil_img = Image.fromarray(img.squeeze()) + pil_img.save(out_path) + + +def set_instance_color(label: Label, category_id: int, + ann_id: int) -> NDArrayU8: + """Set the color for an instance given its attributes and ID.""" + attributes = label.attributes + if attributes is None: + truncated, occluded, crowd, ignored = 0, 0, 0, 0 + else: + truncated = int(attributes.get("truncated", False)) + occluded = int(attributes.get("occluded", False)) + crowd = int(check_crowd(label)) + ignored = int(check_ignored(label)) + color: NDArrayU8 = np.array( + [ + category_id & 255, + (truncated << 3) + (occluded << 2) + (crowd << 1) + ignored, + ann_id >> 8, + ann_id & 255, + ], + dtype=np.uint8, ) + return color + + +def set_lane_color(label: Label, category_id: int) -> NDArrayU8: + """Set the color for the lane given its attributes and category.""" + attributes = label.attributes + if attributes is None: + lane_direction, lane_style = 0, 0 + else: + lane_direction = LANE_DIRECTION_MAP[str( + attributes.get("laneDirection", "parallel"))] + lane_style = LANE_STYLE_MAP[str(attributes.get("laneStyle", "solid"))] + + #value = category_id + (lane_direction << 5) + (lane_style << 4) + value = category_id + if lane_style == 0 and (category_id == 3 or category_id == 2): + value = 1 + if lane_style == 0: + value = 2 + else: + value = 3 + + color: NDArrayU8 = np.array([value], dtype=np.uint8) + return color + + +def frames_to_masks( + nproc: int, + out_paths: List[str], + shapes: List[ImageSize], + colors_list: List[List[NDArrayU8]], + poly2ds_list: List[List[List[Poly2D]]], + with_instances: bool=True, + back_color: int=0, + closed: bool=True, ) -> None: + """Execute the mask conversion in parallel.""" + with Pool(nproc) as pool: + pool.starmap( + partial( + frame_to_mask, + with_instances=with_instances, + back_color=back_color, + closed=closed, ), + tqdm( + zip(out_paths, shapes, colors_list, poly2ds_list), + total=len(out_paths), ), ) + + +def seg_to_masks( + frames: List[Frame], + out_base: str, + config: Config, + nproc: int=NPROC, + mode: str="sem_seg", + back_color: int=IGNORE_LABEL, + closed: bool=True, ) -> None: + """Converting segmentation poly2d to 1-channel masks.""" + os.makedirs(out_base, exist_ok=True) + img_shape = config.imageSize + + out_paths: List[str] = [] + shapes: List[ImageSize] = [] + colors_list: List[List[NDArrayU8]] = [] + poly2ds_list: List[List[List[Poly2D]]] = [] + + categories = dict( + sem_seg=labels, drivable=drivables, lane_mark=lane_categories)[mode] + cat_name2id = { + cat.name: cat.trainId + for cat in categories if cat.trainId != IGNORE_LABEL + } + + logger.info("Preparing annotations for Semseg to Bitmasks") + + for image_anns in tqdm(frames): + # Mask in .png format + image_name = image_anns.name.replace(".jpg", ".png") + image_name = os.path.split(image_name)[-1] + out_path = os.path.join(out_base, image_name) + out_paths.append(out_path) + + if img_shape is None: + if image_anns.size is not None: + img_shape = image_anns.size + else: + raise ValueError("Image shape not defined!") + shapes.append(img_shape) + + colors: List[NDArrayU8] = [] + poly2ds: List[List[Poly2D]] = [] + colors_list.append(colors) + poly2ds_list.append(poly2ds) + + if image_anns.labels is None: + continue + + for label in image_anns.labels: + if label.category not in cat_name2id: + continue + if label.poly2d is None: + continue + + category_id = cat_name2id[label.category] + if mode in ["sem_seg", "drivable"]: + color: NDArrayU8 = np.array([category_id], dtype=np.uint8) + else: + color = set_lane_color(label, category_id) + + colors.append(color) + poly2ds.append(label.poly2d) + + logger.info("Start Conversion for Seg to Masks") + frames_to_masks( + nproc, + out_paths, + shapes, + colors_list, + poly2ds_list, + with_instances=False, + back_color=back_color, + closed=closed, ) + + +ToMasksFunc = Callable[[List[Frame], str, Config, int], None] +semseg_to_masks: ToMasksFunc = partial( + seg_to_masks, mode="sem_seg", back_color=IGNORE_LABEL, closed=True) +drivable_to_masks: ToMasksFunc = partial( + seg_to_masks, + mode="drivable", + back_color=len(drivables) - 1, + closed=True, ) +lanemark_to_masks: ToMasksFunc = partial( + seg_to_masks, mode="lane_mark", back_color=IGNORE_LABEL, closed=False) + + +def insseg_to_bitmasks(frames: List[Frame], + out_base: str, + config: Config, + nproc: int=NPROC) -> None: + """Converting instance segmentation poly2d to bitmasks.""" + os.makedirs(out_base, exist_ok=True) + img_shape = config.imageSize + + out_paths: List[str] = [] + shapes: List[ImageSize] = [] + colors_list: List[List[NDArrayU8]] = [] + poly2ds_list: List[List[List[Poly2D]]] = [] + + categories = get_leaf_categories(config.categories) + cat_name2id = {cat.name: i + 1 for i, cat in enumerate(categories)} + + logger.info("Preparing annotations for InsSeg to Bitmasks") + + for image_anns in tqdm(frames): + ann_id = 0 + + # Bitmask in .png format + image_name = image_anns.name.replace(".jpg", ".png") + image_name = os.path.split(image_name)[-1] + out_path = os.path.join(out_base, image_name) + out_paths.append(out_path) + + if img_shape is None: + if image_anns.size is not None: + img_shape = image_anns.size + else: + raise ValueError("Image shape not defined!") + shapes.append(img_shape) + + colors: List[NDArrayU8] = [] + poly2ds: List[List[Poly2D]] = [] + colors_list.append(colors) + poly2ds_list.append(poly2ds) + + labels_ = image_anns.labels + if labels_ is None or len(labels_) == 0: + continue + + # Scores higher, rendering later + if labels_[0].score is not None: + labels_ = sorted(labels_, key=lambda label: float(label.score)) + + for label in labels_: + if label.poly2d is None: + continue + if label.category not in cat_name2id: + continue + + ann_id += 1 + category_id = cat_name2id[label.category] + color = set_instance_color(label, category_id, ann_id) + colors.append(color) + poly2ds.append(label.poly2d) + + logger.info("Start conversion for InsSeg to Bitmasks") + frames_to_masks(nproc, out_paths, shapes, colors_list, poly2ds_list) + + +def panseg_to_bitmasks(frames: List[Frame], + out_base: str, + config: Config, + nproc: int=NPROC) -> None: + """Converting panoptic segmentation poly2d to bitmasks.""" + os.makedirs(out_base, exist_ok=True) + img_shape = config.imageSize + + out_paths: List[str] = [] + shapes: List[ImageSize] = [] + colors_list: List[List[NDArrayU8]] = [] + poly2ds_list: List[List[List[Poly2D]]] = [] + cat_name2id = {cat.name: cat.id for cat in labels} + + logger.info("Preparing annotations for InsSeg to Bitmasks") + + for image_anns in tqdm(frames): + cur_ann_id = STUFF_NUM + + # Bitmask in .png format + image_name = image_anns.name.replace(".jpg", ".png") + image_name = os.path.split(image_name)[-1] + out_path = os.path.join(out_base, image_name) + out_paths.append(out_path) + + if img_shape is None: + if image_anns.size is not None: + img_shape = image_anns.size + else: + raise ValueError("Image shape not defined!") + shapes.append(img_shape) + + colors: List[NDArrayU8] = [] + poly2ds: List[List[Poly2D]] = [] + colors_list.append(colors) + poly2ds_list.append(poly2ds) + + labels_ = image_anns.labels + if labels_ is None or len(labels_) == 0: + continue + + # Scores higher, rendering later + if labels_[0].score is not None: + labels_ = sorted(labels_, key=lambda label: float(label.score)) + + for label in labels_: + if label.poly2d is None: + continue + if label.category not in cat_name2id: + continue + + category_id = cat_name2id[label.category] + if category_id == 0: + continue + if category_id <= STUFF_NUM: + ann_id = category_id + else: + cur_ann_id += 1 + ann_id = cur_ann_id + + color = set_instance_color(label, category_id, ann_id) + colors.append(color) + poly2ds.append(label.poly2d) + + logger.info("Start conversion for PanSeg to Bitmasks") + frames_to_masks(nproc, out_paths, shapes, colors_list, poly2ds_list) + + +def segtrack_to_bitmasks(frames: List[Frame], + out_base: str, + config: Config, + nproc: int=NPROC) -> None: + """Converting segmentation tracking poly2d to bitmasks.""" + frames_list = group_and_sort(frames) + img_shape = config.imageSize + + out_paths: List[str] = [] + shapes: List[ImageSize] = [] + colors_list: List[List[NDArrayU8]] = [] + poly2ds_list: List[List[List[Poly2D]]] = [] + + categories = get_leaf_categories(config.categories) + cat_name2id = {cat.name: i + 1 for i, cat in enumerate(categories)} + + logger.info("Preparing annotations for SegTrack to Bitmasks") + + for video_anns in tqdm(frames_list): + global_instance_id: int = 1 + instance_id_maps: Dict[str, int] = {} + + video_name = video_anns[0].videoName + out_dir = os.path.join(out_base, video_name) + if not os.path.isdir(out_dir): + os.makedirs(out_dir) + + for image_anns in video_anns: + # Bitmask in .png format + image_name = image_anns.name.replace(".jpg", ".png") + image_name = os.path.split(image_name)[-1] + out_path = os.path.join(out_dir, image_name) + out_paths.append(out_path) + + if img_shape is None: + if image_anns.size is not None: + img_shape = image_anns.size + else: + raise ValueError("Image shape not defined!") + shapes.append(img_shape) + + colors: List[NDArrayU8] = [] + poly2ds: List[List[Poly2D]] = [] + colors_list.append(colors) + poly2ds_list.append(poly2ds) + + labels_ = image_anns.labels + if labels_ is None or len(labels_) == 0: + continue + + # Scores higher, rendering later + if labels_[0].score is not None: + labels_ = sorted(labels_, key=lambda label: float(label.score)) + + for label in labels_: + if label.poly2d is None: + continue + if label.category not in cat_name2id: + continue + + instance_id, global_instance_id = get_bdd100k_instance_id( + instance_id_maps, global_instance_id, label.id) + category_id = cat_name2id[label.category] + color = set_instance_color(label, category_id, instance_id) + colors.append(color) + poly2ds.append(label.poly2d) + + logger.info("Start Conversion for SegTrack to Bitmasks") + frames_to_masks(nproc, out_paths, shapes, colors_list, poly2ds_list) + + +def main() -> None: + """Main function.""" + args = parse_args() + args.mode = "lane_mark" + + os.environ["QT_QPA_PLATFORM"] = "offscreen" # matplotlib offscreen render + + convert_funcs: Dict[str, ToMasksFunc] = dict( + sem_seg=semseg_to_masks, + drivable=drivable_to_masks, + lane_mark=lanemark_to_masks, + pan_seg=panseg_to_bitmasks, + ins_seg=insseg_to_bitmasks, + seg_track=segtrack_to_bitmasks, ) + + dataset = load(args.input, args.nproc) + if args.config is not None: + bdd100k_config = load_bdd100k_config(args.config) + elif dataset.config is not None: + bdd100k_config = BDD100KConfig(config=dataset.config) + else: + bdd100k_config = load_bdd100k_config(args.mode) + + if args.mode in ["ins_seg", "seg_track"]: + frames = bdd100k_to_scalabel(dataset.frames, bdd100k_config) + else: + frames = dataset.frames + + convert_funcs[args.mode](frames, args.output, bdd100k_config.scalabel, + args.nproc) + + logger.info("Finished!") + + +if __name__ == "__main__": + main() diff --git a/docs/advanced_tutorials/customization/ppvehicle_violation.md b/docs/advanced_tutorials/customization/ppvehicle_violation.md new file mode 100644 index 000000000..b82fe97d3 --- /dev/null +++ b/docs/advanced_tutorials/customization/ppvehicle_violation.md @@ -0,0 +1,235 @@ +简体中文 | [English](./ppvehicle_violation_en.md) + +# 车辆违章任务二次开发 + +车辆违章任务的二次开发,主要集中于车道线分割模型任务。采用PP-LiteSeg模型在车道线数据集bdd100k,上进行fine-tune得到,过程参考[PP-LiteSeg](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/configs/pp_liteseg/README.md)。 + +## 数据准备 + +ppvehicle违法分析将车道线类别分为4类 +``` +0 背景 +1 双黄线 +2 实线 +3 虚线 + +``` + +1. 对于bdd100k数据集,可以结合我们的提供的处理脚本[lane_to_mask.py](../../../deploy/pipeline/tools/lane_to_mask.py)和bdd100k官方[repo](https://github.com/bdd100k/bdd100k)将数据处理成分割需要的数据格式. + +``` +#首先执行以下命令clone bdd100k库: +git clone https://github.com/bdd100k/bdd100k.git + +#拷贝lane_to_mask.py到bdd100k目录 +cp PaddleDetection/deploy/pipeline/tools/lane_to_mask.py bdd100k/ + +#准备bdd100k环境 +cd bdd100k && pip install -r requirements.txt + +#数据转换 +python lane_to_mask.py -i dataset/labels/lane/polygons/lane_train.json -o /output_path + +# -i bdd100k数据集label的json路径, +# -o 生成的mask图像路径 + +``` + +2. 整理数据,按如下格式存放数据 +``` +dataset_root + | + |--images + | |--train + | |--image1.jpg + | |--image2.jpg + | |--... + | |--val + | |--image3.jpg + | |--image4.jpg + | |--... + | |--test + | |--image5.jpg + | |--image6.jpg + | |--... + | + |--labels + | |--train + | |--label1.jpg + | |--label2.jpg + | |--... + | |--val + | |--label3.jpg + | |--label4.jpg + | |--... + | |--test + | |--label5.jpg + | |--label6.jpg + | |--... + | +``` +运行[create_dataset_list.py](../../../deploy/pipeline/tools/create_dataset_list.py)生成txt文件 +``` +python create_dataset_list.py #数据根目录 + --type custom #数据类型,支持cityscapes、custom + + +``` +其他数据以及数据标注,可参考PaddleSeg[准备自定义数据集](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/data/marker/marker_cn.md) + + +## 模型训练 + +首先执行以下命令clone PaddleSeg库代码到训练机器: +``` +git clone https://github.com/PaddlePaddle/PaddleSeg.git +``` + +安装相关依赖环境: +``` +cd PaddleSeg +pip install -r requirements.txt +``` + +### 准备配置文件 +详细可参考PaddleSeg[准备配置文件](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/config/pre_config_cn.md). +本例用pp_liteseg_stdc2_bdd100k_1024x512.yml示例 + +``` +batch_size: 16 +iters: 50000 + +train_dataset: + type: Dataset + dataset_root: data/bdd100k #数据集路径 + train_path: data/bdd100k/train.txt #数据集训练txt文件 + num_classes: 4 #ppvehicle将道路分为4类 + mode: train + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [512, 1024] + - type: RandomHorizontalFlip + - type: RandomAffine + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + +val_dataset: + type: Dataset + dataset_root: data/bdd100k #数据集路径 + val_path: data/bdd100k/val.txt #数据集验证集txt文件 + num_classes: 4 + mode: val + transforms: + - type: Normalize + +optimizer: + type: sgd + momentum: 0.9 + weight_decay: 4.0e-5 + +lr_scheduler: + type: PolynomialDecay + learning_rate: 0.01 #0.01 + end_lr: 0 + power: 0.9 + +loss: + types: + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.6, 0.4] + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.6, 0.4] + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.6, 0.4] + coef: [1, 1,1] + + +model: + type: PPLiteSeg + backbone: + type: STDC2 + pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet2.tar.gz #预训练模型 +``` + +### 执行训练 + +``` +#单卡训练 +export CUDA_VISIBLE_DEVICES=0 # Linux上设置1张可用的卡 +# set CUDA_VISIBLE_DEVICES=0 # Windows上设置1张可用的卡 + +python train.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --do_eval \ + --use_vdl \ + --save_interval 500 \ + --save_dir output + +``` +### 训练参数解释 +``` +--do_eval 是否在保存模型时启动评估, 启动时将会根据mIoU保存最佳模型至best_model +--use_vdl 是否开启visualdl记录训练数据 +--save_interval 500 模型保存的间隔步数 +--save_dir output 模型输出路径 +``` + +## 2、多卡训练 +如果想要使用多卡训练的话,需要将环境变量CUDA_VISIBLE_DEVICES指定为多卡(不指定时默认使用所有的gpu),并使用paddle.distributed.launch启动训练脚本(windows下由于不支持nccl,无法使用多卡训练): + +``` +export CUDA_VISIBLE_DEVICES=0,1,2,3 # 设置4张可用的卡 +python -m paddle.distributed.launch train.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --do_eval \ + --use_vdl \ + --save_interval 500 \ + --save_dir output +``` + + +训练完成后可以执行以下命令进行性能评估: +``` +#单卡评估 +python val.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --model_path output/iter_1000/model.pdparams +``` + + +### 模型导出 + +使用下述命令将训练好的模型导出为预测部署模型。 + +``` +python export.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --model_path output/iter_1000/model.pdparams \ + --save_dir output/inference_model +``` + + +使用时在PP-Vehicle中的配置文件`./deploy/pipeline/config/infer_cfg_ppvehicle.yml`中修改`LANE_SEG`模块中的`model_dir`项. +``` +LANE_SEG: + lane_seg_config: deploy/pipeline/config/lane_seg_config.yml + model_dir: output/inference_model +``` + +然后可以使用-->至此即完成更新车道线分割模型任务。 diff --git a/docs/advanced_tutorials/customization/ppvehicle_violation_en.md b/docs/advanced_tutorials/customization/ppvehicle_violation_en.md new file mode 100644 index 000000000..9b96e8a60 --- /dev/null +++ b/docs/advanced_tutorials/customization/ppvehicle_violation_en.md @@ -0,0 +1,240 @@ +English | [简体中文](./ppvehicle_violation.md) + +# Customized Vehicle Violation + +The secondary development of vehicle violation task mainly focuses on the task of lane line segmentation model. PP-LiteSeg model is used to get the lane line data set bdd100k through fine-tune. The process is referred to [PP-LiteSeg](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/configs/pp_liteseg/README.md)。 + +## Data preparation + +ppvehicle violation analysis divides the lane line into 4 categories +``` +0 Background + +1 double yellow line + +2 Solid line + +3 Dashed line + +``` + +1. For the bdd100k data set, we can combine the processing script provided by [lane_to_mask.py](../../../deploy/pipeline/tools/lane_to_mask.py) and bdd100k [repo](https://github.com/bdd100k/bdd100k) to process the data into the data format required for segmentation. + + +``` +# clone bdd100k: +git clone https://github.com/bdd100k/bdd100k.git + +# copy lane_to_mask.py to bdd100k/ +cp PaddleDetection/deploy/pipeline/tools/lane_to_mask.py bdd100k/ + +# preparation bdd100k env +cd bdd100k && pip install -r requirements.txt + +#bdd100k to mask +python lane_to_mask.py -i dataset/labels/lane/polygons/lane_train.json -o /output_path + +# -i means input path for bdd100k dataset label json, +# -o for output patn + +``` + +2. Organize data and store data in the following format: +``` +dataset_root + | + |--images + | |--train + | |--image1.jpg + | |--image2.jpg + | |--... + | |--val + | |--image3.jpg + | |--image4.jpg + | |--... + | |--test + | |--image5.jpg + | |--image6.jpg + | |--... + | + |--labels + | |--train + | |--label1.jpg + | |--label2.jpg + | |--... + | |--val + | |--label3.jpg + | |--label4.jpg + | |--... + | |--test + | |--label5.jpg + | |--label6.jpg + | |--... + | +``` + +run [create_dataset_list.py](../../../deploy/pipeline/tools/create_dataset_list.py) create txt file + +``` +python create_dataset_list.py #dataset path + --type custom #dataset type,support cityscapes、custom + +``` + +For other data and data annotation, please refer to PaddleSeg [Prepare Custom Datasets](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/data/marker/marker_cn.md) + + +## model training + +clone PaddleSeg: +``` +git clone https://github.com/PaddlePaddle/PaddleSeg.git +``` + +prepapation env: +``` +cd PaddleSeg +pip install -r requirements.txt +``` + +### Prepare configuration file +For details, please refer to PaddleSeg [prepare configuration file](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/config/pre_config_cn.md). + +exp: pp_liteseg_stdc2_bdd100k_1024x512.yml + +``` +batch_size: 16 +iters: 50000 + +train_dataset: + type: Dataset + dataset_root: data/bdd100k #dataset path + train_path: data/bdd100k/train.txt #dataset train txt + num_classes: 4 #lane classes + mode: train + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [512, 1024] + - type: RandomHorizontalFlip + - type: RandomAffine + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + +val_dataset: + type: Dataset + dataset_root: data/bdd100k #dataset path + val_path: data/bdd100k/val.txt #dataset val txt + num_classes: 4 + mode: val + transforms: + - type: Normalize + +optimizer: + type: sgd + momentum: 0.9 + weight_decay: 4.0e-5 + +lr_scheduler: + type: PolynomialDecay + learning_rate: 0.01 #0.01 + end_lr: 0 + power: 0.9 + +loss: + types: + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.6, 0.4] + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.6, 0.4] + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.6, 0.4] + coef: [1, 1,1] + + +model: + type: PPLiteSeg + backbone: + type: STDC2 + pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet2.tar.gz #Pre-training model +``` + +### training model + +``` +#Single GPU training +export CUDA_VISIBLE_DEVICES=0 # Linux +# set CUDA_VISIBLE_DEVICES=0 # Windows +python train.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --do_eval \ + --use_vdl \ + --save_interval 500 \ + --save_dir output + +``` +### Explanation of training parameters +``` +--do_eval Whether to start the evaluation when saving the model. When starting, the best model will be saved to best according to mIoU model +--use_vdl Whether to enable visualdl to record training data +--save_interval 500 Number of steps between model saving +--save_dir output Model output path +``` + +## 2、Multiple GPUs training +if you want to use multiple gpus training, you need to set the environment variable CUDA_VISIBLE_DEVICES is specified as multiple gpus (if not specified, all gpus will be used by default), and the training script will be started using paddle.distributed.launch (because nccl is not supported under windows, multi-card training cannot be used): + +``` +export CUDA_VISIBLE_DEVICES=0,1,2,3 # 4 gpus +python -m paddle.distributed.launch train.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --do_eval \ + --use_vdl \ + --save_interval 500 \ + --save_dir output +``` + + +After training, you can execute the following commands for performance evaluation: +``` +python val.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --model_path output/iter_1000/model.pdparams +``` + + +### Model export + +Use the following command to export the trained model as a prediction deployment model. + +``` +python export.py \ + --config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \ + --model_path output/iter_1000/model.pdparams \ + --save_dir output/inference_model +``` + + +Profile in PP-Vehicle when used `./deploy/pipeline/config/infer_cfg_ppvehicle.yml` set `model_dir` in `LANE_SEG`. +``` +LANE_SEG: + lane_seg_config: deploy/pipeline/config/lane_seg_config.yml + model_dir: output/inference_model +``` + +Then you can use -->to finish the task of updating the lane line segmentation model. -- GitLab