未验证 提交 f9447e98 编写于 作者: L LokeZhou 提交者: GitHub

Cherry pick pipeline (#7798)

* ppvehicle add customized docs, test=document_fix

* fix pipeline readme.md test=document_fix

* fix pipeline readme.md link test=document_fix

* add licensed to lane_to_mask.py

* ppvehicle_violation.md add args list test=document_fix
上级 c35db066
......@@ -155,10 +155,10 @@
- [快速开始](docs/tutorials/ppvehicle_press.md)
- [二次开发教程]
- [二次开发教程](../../docs/advanced_tutorials/customization/ppvehicle_violation.md)
#### 车辆逆行
- [快速开始](docs/tutorials/ppvehicle_retrograde.md)
- [二次开发教程]
- [二次开发教程](../../docs/advanced_tutorials/customization/ppvehicle_violation.md)
......@@ -152,10 +152,10 @@ Click to download the model, then unzip and save it in the `. /output_inference`
- [A quick start](docs/tutorials/ppvehicle_press_en.md)
- [Customized development tutorials]
- [Customized development tutorials](../../docs/advanced_tutorials/customization/ppvehicle_violation_en.md)
#### Vehicle Retrograde
- [A quick start](docs/tutorials/ppvehicle_retrograde_en.md)
- [Customized development tutorials]
- [Customized development tutorials](../../docs/advanced_tutorials/customization/ppvehicle_violation_en.md)
# coding: utf8
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os.path
import argparse
import warnings
def parse_args():
parser = argparse.ArgumentParser(
description='PaddleSeg generate file list on cityscapes or your customized dataset.'
)
parser.add_argument('dataset_root', help='dataset root directory', type=str)
parser.add_argument(
'--type',
help='dataset type: \n'
'- cityscapes \n'
'- custom(default)',
default="custom",
type=str)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default=" ",
type=str)
parser.add_argument(
'--folder',
help='the folder names of images and labels',
type=str,
nargs=2,
default=['images', 'labels'])
parser.add_argument(
'--second_folder',
help='the second-level folder names of train set, validation set, test set',
type=str,
nargs='*',
default=['train', 'val', 'test'])
parser.add_argument(
'--format',
help='data format of images and labels, e.g. jpg or png.',
type=str,
nargs=2,
default=['jpg', 'png'])
parser.add_argument(
'--postfix',
help='postfix of images or labels',
type=str,
nargs=2,
default=['', ''])
return parser.parse_args()
def get_files(image_or_label, dataset_split, args):
dataset_root = args.dataset_root
postfix = args.postfix
format = args.format
folder = args.folder
pattern = '*%s.%s' % (postfix[image_or_label], format[image_or_label])
search_files = os.path.join(dataset_root, folder[image_or_label],
dataset_split, pattern)
search_files2 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", pattern) # 包含子目录
search_files3 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", "*", pattern) # 包含三级目录
search_files4 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", "*", "*",
pattern) # 包含四级目录
search_files5 = os.path.join(dataset_root, folder[image_or_label],
dataset_split, "*", "*", "*", "*",
pattern) # 包含五级目录
filenames = glob.glob(search_files)
filenames2 = glob.glob(search_files2)
filenames3 = glob.glob(search_files3)
filenames4 = glob.glob(search_files4)
filenames5 = glob.glob(search_files5)
filenames = filenames + filenames2 + filenames3 + filenames4 + filenames5
return sorted(filenames)
def generate_list(args):
dataset_root = args.dataset_root
separator = args.separator
for dataset_split in args.second_folder:
print("Creating {}.txt...".format(dataset_split))
image_files = get_files(0, dataset_split, args)
label_files = get_files(1, dataset_split, args)
if not image_files:
img_dir = os.path.join(dataset_root, args.folder[0], dataset_split)
warnings.warn("No images in {} !!!".format(img_dir))
num_images = len(image_files)
if not label_files:
label_dir = os.path.join(dataset_root, args.folder[1],
dataset_split)
warnings.warn("No labels in {} !!!".format(label_dir))
num_label = len(label_files)
if num_images != num_label and num_label > 0:
raise Exception(
"Number of images = {} number of labels = {} \n"
"Either number of images is equal to number of labels, "
"or number of labels is equal to 0.\n"
"Please check your dataset!".format(num_images, num_label))
file_list = os.path.join(dataset_root, dataset_split + '.txt')
with open(file_list, "w") as f:
for item in range(num_images):
left = image_files[item].replace(dataset_root, '', 1)
if left[0] == os.path.sep:
left = left.lstrip(os.path.sep)
try:
right = label_files[item].replace(dataset_root, '', 1)
if right[0] == os.path.sep:
right = right.lstrip(os.path.sep)
line = left + separator + right + '\n'
except:
line = left + '\n'
f.write(line)
print(line)
if __name__ == '__main__':
args = parse_args()
generate_list(args)
# coding: utf8
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert poly2d to mask/bitmask."""
import os
from functools import partial
from multiprocessing import Pool
from typing import Callable, Dict, List
import matplotlib # type: ignore
import matplotlib.pyplot as plt # type: ignore
import numpy as np
from PIL import Image
from scalabel.common.parallel import NPROC
from scalabel.common.typing import NDArrayU8
from scalabel.label.io import group_and_sort, load
from scalabel.label.transforms import poly_to_patch
from scalabel.label.typing import Config, Frame, ImageSize, Label, Poly2D
from scalabel.label.utils import (
check_crowd,
check_ignored,
get_leaf_categories, )
from tqdm import tqdm
from bdd100k.common.logger import logger
from bdd100k.common.typing import BDD100KConfig
from bdd100k.common.utils import get_bdd100k_instance_id, load_bdd100k_config
from bdd100k.label.label import drivables, labels, lane_categories
from bdd100k.label.to_coco import parse_args
from bdd100k.label.to_scalabel import bdd100k_to_scalabel
IGNORE_LABEL = 255
STUFF_NUM = 30
LANE_DIRECTION_MAP = {"parallel": 0, "vertical": 1}
LANE_STYLE_MAP = {"solid": 0, "dashed": 1}
def frame_to_mask(
out_path: str,
shape: ImageSize,
colors: List[NDArrayU8],
poly2ds: List[List[Poly2D]],
with_instances: bool=True,
back_color: int=0,
closed: bool=True, ) -> None:
"""Converting a frame of poly2ds to mask/bitmask."""
assert len(colors) == len(poly2ds)
height, width = shape.height, shape.width
assert back_color >= 0
if with_instances:
img: NDArrayU8 = (
np.ones(
[height, width, 4], dtype=np.uint8) * back_color # type: ignore
)
else:
img = (
np.ones(
[height, width, 1], dtype=np.uint8) * back_color # type: ignore
)
if len(colors) == 0:
pil_img = Image.fromarray(img.squeeze())
pil_img.save(out_path)
matplotlib.use("Agg")
fig = plt.figure(facecolor="0")
fig.set_size_inches((width / fig.get_dpi()), height / fig.get_dpi())
ax = fig.add_axes([0, 0, 1, 1])
ax.axis("off")
ax.set_xlim(0, width)
ax.set_ylim(0, height)
ax.set_facecolor((0, 0, 0, 0))
ax.invert_yaxis()
for i, poly2d in enumerate(poly2ds):
for poly in poly2d:
ax.add_patch(
poly_to_patch(
poly.vertices,
poly.types,
# (0, 0, 0) for the background
color=(
((i + 1) >> 8) / 255.0,
((i + 1) % 255) / 255.0,
0.0, ),
closed=closed, ))
fig.canvas.draw()
out: NDArrayU8 = np.frombuffer(fig.canvas.tostring_rgb(), np.uint8)
out = out.reshape((height, width, -1)).astype(np.int32)
out = (out[..., 0] << 8) + out[..., 1]
plt.close()
for i, color in enumerate(colors):
# 0 is for the background
img[out == i + 1] = color
img[img == 255] = 0
pil_img = Image.fromarray(img.squeeze())
pil_img.save(out_path)
def set_instance_color(label: Label, category_id: int,
ann_id: int) -> NDArrayU8:
"""Set the color for an instance given its attributes and ID."""
attributes = label.attributes
if attributes is None:
truncated, occluded, crowd, ignored = 0, 0, 0, 0
else:
truncated = int(attributes.get("truncated", False))
occluded = int(attributes.get("occluded", False))
crowd = int(check_crowd(label))
ignored = int(check_ignored(label))
color: NDArrayU8 = np.array(
[
category_id & 255,
(truncated << 3) + (occluded << 2) + (crowd << 1) + ignored,
ann_id >> 8,
ann_id & 255,
],
dtype=np.uint8, )
return color
def set_lane_color(label: Label, category_id: int) -> NDArrayU8:
"""Set the color for the lane given its attributes and category."""
attributes = label.attributes
if attributes is None:
lane_direction, lane_style = 0, 0
else:
lane_direction = LANE_DIRECTION_MAP[str(
attributes.get("laneDirection", "parallel"))]
lane_style = LANE_STYLE_MAP[str(attributes.get("laneStyle", "solid"))]
#value = category_id + (lane_direction << 5) + (lane_style << 4)
value = category_id
if lane_style == 0 and (category_id == 3 or category_id == 2):
value = 1
if lane_style == 0:
value = 2
else:
value = 3
color: NDArrayU8 = np.array([value], dtype=np.uint8)
return color
def frames_to_masks(
nproc: int,
out_paths: List[str],
shapes: List[ImageSize],
colors_list: List[List[NDArrayU8]],
poly2ds_list: List[List[List[Poly2D]]],
with_instances: bool=True,
back_color: int=0,
closed: bool=True, ) -> None:
"""Execute the mask conversion in parallel."""
with Pool(nproc) as pool:
pool.starmap(
partial(
frame_to_mask,
with_instances=with_instances,
back_color=back_color,
closed=closed, ),
tqdm(
zip(out_paths, shapes, colors_list, poly2ds_list),
total=len(out_paths), ), )
def seg_to_masks(
frames: List[Frame],
out_base: str,
config: Config,
nproc: int=NPROC,
mode: str="sem_seg",
back_color: int=IGNORE_LABEL,
closed: bool=True, ) -> None:
"""Converting segmentation poly2d to 1-channel masks."""
os.makedirs(out_base, exist_ok=True)
img_shape = config.imageSize
out_paths: List[str] = []
shapes: List[ImageSize] = []
colors_list: List[List[NDArrayU8]] = []
poly2ds_list: List[List[List[Poly2D]]] = []
categories = dict(
sem_seg=labels, drivable=drivables, lane_mark=lane_categories)[mode]
cat_name2id = {
cat.name: cat.trainId
for cat in categories if cat.trainId != IGNORE_LABEL
}
logger.info("Preparing annotations for Semseg to Bitmasks")
for image_anns in tqdm(frames):
# Mask in .png format
image_name = image_anns.name.replace(".jpg", ".png")
image_name = os.path.split(image_name)[-1]
out_path = os.path.join(out_base, image_name)
out_paths.append(out_path)
if img_shape is None:
if image_anns.size is not None:
img_shape = image_anns.size
else:
raise ValueError("Image shape not defined!")
shapes.append(img_shape)
colors: List[NDArrayU8] = []
poly2ds: List[List[Poly2D]] = []
colors_list.append(colors)
poly2ds_list.append(poly2ds)
if image_anns.labels is None:
continue
for label in image_anns.labels:
if label.category not in cat_name2id:
continue
if label.poly2d is None:
continue
category_id = cat_name2id[label.category]
if mode in ["sem_seg", "drivable"]:
color: NDArrayU8 = np.array([category_id], dtype=np.uint8)
else:
color = set_lane_color(label, category_id)
colors.append(color)
poly2ds.append(label.poly2d)
logger.info("Start Conversion for Seg to Masks")
frames_to_masks(
nproc,
out_paths,
shapes,
colors_list,
poly2ds_list,
with_instances=False,
back_color=back_color,
closed=closed, )
ToMasksFunc = Callable[[List[Frame], str, Config, int], None]
semseg_to_masks: ToMasksFunc = partial(
seg_to_masks, mode="sem_seg", back_color=IGNORE_LABEL, closed=True)
drivable_to_masks: ToMasksFunc = partial(
seg_to_masks,
mode="drivable",
back_color=len(drivables) - 1,
closed=True, )
lanemark_to_masks: ToMasksFunc = partial(
seg_to_masks, mode="lane_mark", back_color=IGNORE_LABEL, closed=False)
def insseg_to_bitmasks(frames: List[Frame],
out_base: str,
config: Config,
nproc: int=NPROC) -> None:
"""Converting instance segmentation poly2d to bitmasks."""
os.makedirs(out_base, exist_ok=True)
img_shape = config.imageSize
out_paths: List[str] = []
shapes: List[ImageSize] = []
colors_list: List[List[NDArrayU8]] = []
poly2ds_list: List[List[List[Poly2D]]] = []
categories = get_leaf_categories(config.categories)
cat_name2id = {cat.name: i + 1 for i, cat in enumerate(categories)}
logger.info("Preparing annotations for InsSeg to Bitmasks")
for image_anns in tqdm(frames):
ann_id = 0
# Bitmask in .png format
image_name = image_anns.name.replace(".jpg", ".png")
image_name = os.path.split(image_name)[-1]
out_path = os.path.join(out_base, image_name)
out_paths.append(out_path)
if img_shape is None:
if image_anns.size is not None:
img_shape = image_anns.size
else:
raise ValueError("Image shape not defined!")
shapes.append(img_shape)
colors: List[NDArrayU8] = []
poly2ds: List[List[Poly2D]] = []
colors_list.append(colors)
poly2ds_list.append(poly2ds)
labels_ = image_anns.labels
if labels_ is None or len(labels_) == 0:
continue
# Scores higher, rendering later
if labels_[0].score is not None:
labels_ = sorted(labels_, key=lambda label: float(label.score))
for label in labels_:
if label.poly2d is None:
continue
if label.category not in cat_name2id:
continue
ann_id += 1
category_id = cat_name2id[label.category]
color = set_instance_color(label, category_id, ann_id)
colors.append(color)
poly2ds.append(label.poly2d)
logger.info("Start conversion for InsSeg to Bitmasks")
frames_to_masks(nproc, out_paths, shapes, colors_list, poly2ds_list)
def panseg_to_bitmasks(frames: List[Frame],
out_base: str,
config: Config,
nproc: int=NPROC) -> None:
"""Converting panoptic segmentation poly2d to bitmasks."""
os.makedirs(out_base, exist_ok=True)
img_shape = config.imageSize
out_paths: List[str] = []
shapes: List[ImageSize] = []
colors_list: List[List[NDArrayU8]] = []
poly2ds_list: List[List[List[Poly2D]]] = []
cat_name2id = {cat.name: cat.id for cat in labels}
logger.info("Preparing annotations for InsSeg to Bitmasks")
for image_anns in tqdm(frames):
cur_ann_id = STUFF_NUM
# Bitmask in .png format
image_name = image_anns.name.replace(".jpg", ".png")
image_name = os.path.split(image_name)[-1]
out_path = os.path.join(out_base, image_name)
out_paths.append(out_path)
if img_shape is None:
if image_anns.size is not None:
img_shape = image_anns.size
else:
raise ValueError("Image shape not defined!")
shapes.append(img_shape)
colors: List[NDArrayU8] = []
poly2ds: List[List[Poly2D]] = []
colors_list.append(colors)
poly2ds_list.append(poly2ds)
labels_ = image_anns.labels
if labels_ is None or len(labels_) == 0:
continue
# Scores higher, rendering later
if labels_[0].score is not None:
labels_ = sorted(labels_, key=lambda label: float(label.score))
for label in labels_:
if label.poly2d is None:
continue
if label.category not in cat_name2id:
continue
category_id = cat_name2id[label.category]
if category_id == 0:
continue
if category_id <= STUFF_NUM:
ann_id = category_id
else:
cur_ann_id += 1
ann_id = cur_ann_id
color = set_instance_color(label, category_id, ann_id)
colors.append(color)
poly2ds.append(label.poly2d)
logger.info("Start conversion for PanSeg to Bitmasks")
frames_to_masks(nproc, out_paths, shapes, colors_list, poly2ds_list)
def segtrack_to_bitmasks(frames: List[Frame],
out_base: str,
config: Config,
nproc: int=NPROC) -> None:
"""Converting segmentation tracking poly2d to bitmasks."""
frames_list = group_and_sort(frames)
img_shape = config.imageSize
out_paths: List[str] = []
shapes: List[ImageSize] = []
colors_list: List[List[NDArrayU8]] = []
poly2ds_list: List[List[List[Poly2D]]] = []
categories = get_leaf_categories(config.categories)
cat_name2id = {cat.name: i + 1 for i, cat in enumerate(categories)}
logger.info("Preparing annotations for SegTrack to Bitmasks")
for video_anns in tqdm(frames_list):
global_instance_id: int = 1
instance_id_maps: Dict[str, int] = {}
video_name = video_anns[0].videoName
out_dir = os.path.join(out_base, video_name)
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
for image_anns in video_anns:
# Bitmask in .png format
image_name = image_anns.name.replace(".jpg", ".png")
image_name = os.path.split(image_name)[-1]
out_path = os.path.join(out_dir, image_name)
out_paths.append(out_path)
if img_shape is None:
if image_anns.size is not None:
img_shape = image_anns.size
else:
raise ValueError("Image shape not defined!")
shapes.append(img_shape)
colors: List[NDArrayU8] = []
poly2ds: List[List[Poly2D]] = []
colors_list.append(colors)
poly2ds_list.append(poly2ds)
labels_ = image_anns.labels
if labels_ is None or len(labels_) == 0:
continue
# Scores higher, rendering later
if labels_[0].score is not None:
labels_ = sorted(labels_, key=lambda label: float(label.score))
for label in labels_:
if label.poly2d is None:
continue
if label.category not in cat_name2id:
continue
instance_id, global_instance_id = get_bdd100k_instance_id(
instance_id_maps, global_instance_id, label.id)
category_id = cat_name2id[label.category]
color = set_instance_color(label, category_id, instance_id)
colors.append(color)
poly2ds.append(label.poly2d)
logger.info("Start Conversion for SegTrack to Bitmasks")
frames_to_masks(nproc, out_paths, shapes, colors_list, poly2ds_list)
def main() -> None:
"""Main function."""
args = parse_args()
args.mode = "lane_mark"
os.environ["QT_QPA_PLATFORM"] = "offscreen" # matplotlib offscreen render
convert_funcs: Dict[str, ToMasksFunc] = dict(
sem_seg=semseg_to_masks,
drivable=drivable_to_masks,
lane_mark=lanemark_to_masks,
pan_seg=panseg_to_bitmasks,
ins_seg=insseg_to_bitmasks,
seg_track=segtrack_to_bitmasks, )
dataset = load(args.input, args.nproc)
if args.config is not None:
bdd100k_config = load_bdd100k_config(args.config)
elif dataset.config is not None:
bdd100k_config = BDD100KConfig(config=dataset.config)
else:
bdd100k_config = load_bdd100k_config(args.mode)
if args.mode in ["ins_seg", "seg_track"]:
frames = bdd100k_to_scalabel(dataset.frames, bdd100k_config)
else:
frames = dataset.frames
convert_funcs[args.mode](frames, args.output, bdd100k_config.scalabel,
args.nproc)
logger.info("Finished!")
if __name__ == "__main__":
main()
简体中文 | [English](./ppvehicle_violation_en.md)
# 车辆违章任务二次开发
车辆违章任务的二次开发,主要集中于车道线分割模型任务。采用PP-LiteSeg模型在车道线数据集bdd100k,上进行fine-tune得到,过程参考[PP-LiteSeg](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/configs/pp_liteseg/README.md)
## 数据准备
ppvehicle违法分析将车道线类别分为4类
```
0 背景
1 双黄线
2 实线
3 虚线
```
1. 对于bdd100k数据集,可以结合我们的提供的处理脚本[lane_to_mask.py](../../../deploy/pipeline/tools/lane_to_mask.py)和bdd100k官方[repo](https://github.com/bdd100k/bdd100k)将数据处理成分割需要的数据格式.
```
#首先执行以下命令clone bdd100k库:
git clone https://github.com/bdd100k/bdd100k.git
#拷贝lane_to_mask.py到bdd100k目录
cp PaddleDetection/deploy/pipeline/tools/lane_to_mask.py bdd100k/
#准备bdd100k环境
cd bdd100k && pip install -r requirements.txt
#数据转换
python lane_to_mask.py -i dataset/labels/lane/polygons/lane_train.json -o /output_path
# -i bdd100k数据集label的json路径,
# -o 生成的mask图像路径
```
2. 整理数据,按如下格式存放数据
```
dataset_root
|
|--images
| |--train
| |--image1.jpg
| |--image2.jpg
| |--...
| |--val
| |--image3.jpg
| |--image4.jpg
| |--...
| |--test
| |--image5.jpg
| |--image6.jpg
| |--...
|
|--labels
| |--train
| |--label1.jpg
| |--label2.jpg
| |--...
| |--val
| |--label3.jpg
| |--label4.jpg
| |--...
| |--test
| |--label5.jpg
| |--label6.jpg
| |--...
|
```
运行[create_dataset_list.py](../../../deploy/pipeline/tools/create_dataset_list.py)生成txt文件
```
python create_dataset_list.py <dataset_root> #数据根目录
--type custom #数据类型,支持cityscapes、custom
```
其他数据以及数据标注,可参考PaddleSeg[准备自定义数据集](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/data/marker/marker_cn.md)
## 模型训练
首先执行以下命令clone PaddleSeg库代码到训练机器:
```
git clone https://github.com/PaddlePaddle/PaddleSeg.git
```
安装相关依赖环境:
```
cd PaddleSeg
pip install -r requirements.txt
```
### 准备配置文件
详细可参考PaddleSeg[准备配置文件](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/config/pre_config_cn.md).
本例用pp_liteseg_stdc2_bdd100k_1024x512.yml示例
```
batch_size: 16
iters: 50000
train_dataset:
type: Dataset
dataset_root: data/bdd100k #数据集路径
train_path: data/bdd100k/train.txt #数据集训练txt文件
num_classes: 4 #ppvehicle将道路分为4类
mode: train
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [512, 1024]
- type: RandomHorizontalFlip
- type: RandomAffine
- type: RandomDistort
brightness_range: 0.5
contrast_range: 0.5
saturation_range: 0.5
- type: Normalize
val_dataset:
type: Dataset
dataset_root: data/bdd100k #数据集路径
val_path: data/bdd100k/val.txt #数据集验证集txt文件
num_classes: 4
mode: val
transforms:
- type: Normalize
optimizer:
type: sgd
momentum: 0.9
weight_decay: 4.0e-5
lr_scheduler:
type: PolynomialDecay
learning_rate: 0.01 #0.01
end_lr: 0
power: 0.9
loss:
types:
- type: MixedLoss
losses:
- type: CrossEntropyLoss
- type: LovaszSoftmaxLoss
coef: [0.6, 0.4]
- type: MixedLoss
losses:
- type: CrossEntropyLoss
- type: LovaszSoftmaxLoss
coef: [0.6, 0.4]
- type: MixedLoss
losses:
- type: CrossEntropyLoss
- type: LovaszSoftmaxLoss
coef: [0.6, 0.4]
coef: [1, 1,1]
model:
type: PPLiteSeg
backbone:
type: STDC2
pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet2.tar.gz #预训练模型
```
### 执行训练
```
#单卡训练
export CUDA_VISIBLE_DEVICES=0 # Linux上设置1张可用的卡
# set CUDA_VISIBLE_DEVICES=0 # Windows上设置1张可用的卡
python train.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--do_eval \
--use_vdl \
--save_interval 500 \
--save_dir output
```
### 训练参数解释
```
--do_eval 是否在保存模型时启动评估, 启动时将会根据mIoU保存最佳模型至best_model
--use_vdl 是否开启visualdl记录训练数据
--save_interval 500 模型保存的间隔步数
--save_dir output 模型输出路径
```
## 2、多卡训练
如果想要使用多卡训练的话,需要将环境变量CUDA_VISIBLE_DEVICES指定为多卡(不指定时默认使用所有的gpu),并使用paddle.distributed.launch启动训练脚本(windows下由于不支持nccl,无法使用多卡训练):
```
export CUDA_VISIBLE_DEVICES=0,1,2,3 # 设置4张可用的卡
python -m paddle.distributed.launch train.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--do_eval \
--use_vdl \
--save_interval 500 \
--save_dir output
```
训练完成后可以执行以下命令进行性能评估:
```
#单卡评估
python val.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--model_path output/iter_1000/model.pdparams
```
### 模型导出
使用下述命令将训练好的模型导出为预测部署模型。
```
python export.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--model_path output/iter_1000/model.pdparams \
--save_dir output/inference_model
```
使用时在PP-Vehicle中的配置文件`./deploy/pipeline/config/infer_cfg_ppvehicle.yml`中修改`LANE_SEG`模块中的`model_dir`项.
```
LANE_SEG:
lane_seg_config: deploy/pipeline/config/lane_seg_config.yml
model_dir: output/inference_model
```
然后可以使用-->至此即完成更新车道线分割模型任务。
English | [简体中文](./ppvehicle_violation.md)
# Customized Vehicle Violation
The secondary development of vehicle violation task mainly focuses on the task of lane line segmentation model. PP-LiteSeg model is used to get the lane line data set bdd100k through fine-tune. The process is referred to [PP-LiteSeg](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/configs/pp_liteseg/README.md)
## Data preparation
ppvehicle violation analysis divides the lane line into 4 categories
```
0 Background
1 double yellow line
2 Solid line
3 Dashed line
```
1. For the bdd100k data set, we can combine the processing script provided by [lane_to_mask.py](../../../deploy/pipeline/tools/lane_to_mask.py) and bdd100k [repo](https://github.com/bdd100k/bdd100k) to process the data into the data format required for segmentation.
```
# clone bdd100k:
git clone https://github.com/bdd100k/bdd100k.git
# copy lane_to_mask.py to bdd100k/
cp PaddleDetection/deploy/pipeline/tools/lane_to_mask.py bdd100k/
# preparation bdd100k env
cd bdd100k && pip install -r requirements.txt
#bdd100k to mask
python lane_to_mask.py -i dataset/labels/lane/polygons/lane_train.json -o /output_path
# -i means input path for bdd100k dataset label json,
# -o for output patn
```
2. Organize data and store data in the following format:
```
dataset_root
|
|--images
| |--train
| |--image1.jpg
| |--image2.jpg
| |--...
| |--val
| |--image3.jpg
| |--image4.jpg
| |--...
| |--test
| |--image5.jpg
| |--image6.jpg
| |--...
|
|--labels
| |--train
| |--label1.jpg
| |--label2.jpg
| |--...
| |--val
| |--label3.jpg
| |--label4.jpg
| |--...
| |--test
| |--label5.jpg
| |--label6.jpg
| |--...
|
```
run [create_dataset_list.py](../../../deploy/pipeline/tools/create_dataset_list.py) create txt file
```
python create_dataset_list.py <dataset_root> #dataset path
--type custom #dataset type,support cityscapes、custom
```
For other data and data annotation, please refer to PaddleSeg [Prepare Custom Datasets](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/data/marker/marker_cn.md)
## model training
clone PaddleSeg:
```
git clone https://github.com/PaddlePaddle/PaddleSeg.git
```
prepapation env:
```
cd PaddleSeg
pip install -r requirements.txt
```
### Prepare configuration file
For details, please refer to PaddleSeg [prepare configuration file](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.7/docs/config/pre_config_cn.md).
exp: pp_liteseg_stdc2_bdd100k_1024x512.yml
```
batch_size: 16
iters: 50000
train_dataset:
type: Dataset
dataset_root: data/bdd100k #dataset path
train_path: data/bdd100k/train.txt #dataset train txt
num_classes: 4 #lane classes
mode: train
transforms:
- type: ResizeStepScaling
min_scale_factor: 0.5
max_scale_factor: 2.0
scale_step_size: 0.25
- type: RandomPaddingCrop
crop_size: [512, 1024]
- type: RandomHorizontalFlip
- type: RandomAffine
- type: RandomDistort
brightness_range: 0.5
contrast_range: 0.5
saturation_range: 0.5
- type: Normalize
val_dataset:
type: Dataset
dataset_root: data/bdd100k #dataset path
val_path: data/bdd100k/val.txt #dataset val txt
num_classes: 4
mode: val
transforms:
- type: Normalize
optimizer:
type: sgd
momentum: 0.9
weight_decay: 4.0e-5
lr_scheduler:
type: PolynomialDecay
learning_rate: 0.01 #0.01
end_lr: 0
power: 0.9
loss:
types:
- type: MixedLoss
losses:
- type: CrossEntropyLoss
- type: LovaszSoftmaxLoss
coef: [0.6, 0.4]
- type: MixedLoss
losses:
- type: CrossEntropyLoss
- type: LovaszSoftmaxLoss
coef: [0.6, 0.4]
- type: MixedLoss
losses:
- type: CrossEntropyLoss
- type: LovaszSoftmaxLoss
coef: [0.6, 0.4]
coef: [1, 1,1]
model:
type: PPLiteSeg
backbone:
type: STDC2
pretrained: https://bj.bcebos.com/paddleseg/dygraph/PP_STDCNet2.tar.gz #Pre-training model
```
### training model
```
#Single GPU training
export CUDA_VISIBLE_DEVICES=0 # Linux
# set CUDA_VISIBLE_DEVICES=0 # Windows
python train.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--do_eval \
--use_vdl \
--save_interval 500 \
--save_dir output
```
### Explanation of training parameters
```
--do_eval Whether to start the evaluation when saving the model. When starting, the best model will be saved to best according to mIoU model
--use_vdl Whether to enable visualdl to record training data
--save_interval 500 Number of steps between model saving
--save_dir output Model output path
```
## 2、Multiple GPUs training
if you want to use multiple gpus training, you need to set the environment variable CUDA_VISIBLE_DEVICES is specified as multiple gpus (if not specified, all gpus will be used by default), and the training script will be started using paddle.distributed.launch (because nccl is not supported under windows, multi-card training cannot be used):
```
export CUDA_VISIBLE_DEVICES=0,1,2,3 # 4 gpus
python -m paddle.distributed.launch train.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--do_eval \
--use_vdl \
--save_interval 500 \
--save_dir output
```
After training, you can execute the following commands for performance evaluation:
```
python val.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--model_path output/iter_1000/model.pdparams
```
### Model export
Use the following command to export the trained model as a prediction deployment model.
```
python export.py \
--config configs/pp_liteseg/pp_liteseg_stdc2_bdd100k_1024x512.yml \
--model_path output/iter_1000/model.pdparams \
--save_dir output/inference_model
```
Profile in PP-Vehicle when used `./deploy/pipeline/config/infer_cfg_ppvehicle.yml` set `model_dir` in `LANE_SEG`.
```
LANE_SEG:
lane_seg_config: deploy/pipeline/config/lane_seg_config.yml
model_dir: output/inference_model
```
Then you can use -->to finish the task of updating the lane line segmentation model.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册