diff --git a/contrib/LaneNet/README.md b/contrib/LaneNet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b86777305c160edae7a55349d719c9df2a2da4f9
--- /dev/null
+++ b/contrib/LaneNet/README.md
@@ -0,0 +1,138 @@
+# LaneNet 模型训练教程
+
+* 本教程旨在介绍如何通过使用PaddleSeg进行车道线检测
+
+* 在阅读本教程前,请确保您已经了解过PaddleSeg的[快速入门](../README.md#快速入门)和[基础功能](../README.md#基础功能)等章节,以便对PaddleSeg有一定的了解
+
+## 环境依赖
+
+* PaddlePaddle >= 1.7.0 或develop版本
+* Python 3.5+
+
+通过以下命令安装python包依赖,请确保在该分支上至少执行过一次以下命令
+```shell
+$ pip install -r requirements.txt
+```
+
+## 一. 准备待训练数据
+
+我们提前准备好了一份处理好的数据集,通过以下代码进行下载,该数据集由图森车道线检测数据集转换而来,你也可以在这个[页面](https://github.com/TuSimple/tusimple-benchmark/issues/3)下载原始数据集。
+
+```shell
+python dataset/download_tusimple.py
+```
+
+数据目录结构
+```
+LaneNet
+|-- dataset
+ |-- tusimple_lane_detection
+ |-- training
+ |-- gt_binary_image
+ |-- gt_image
+ |-- gt_instance_image
+ |-- train_part.txt
+ |-- val_part.txt
+```
+## 二. 下载预训练模型
+
+下载[vgg预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_pretrained.tar),放在```pretrained_models```文件夹下。
+
+
+## 三. 准备配置
+
+接着我们需要确定相关配置,从本教程的角度,配置分为三部分:
+
+* 数据集
+ * 训练集主目录
+ * 训练集文件列表
+ * 测试集文件列表
+ * 评估集文件列表
+* 预训练模型
+ * 预训练模型名称
+ * 预训练模型的backbone网络
+ * 预训练模型路径
+* 其他
+ * 学习率
+ * Batch大小
+ * ...
+
+在三者中,预训练模型的配置尤为重要,如果模型或者BACKBONE配置错误,会导致预训练的参数没有加载,进而影响收敛速度。预训练模型相关的配置如第二步所展示。
+
+数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/tusimple_lane_detection`中
+
+其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为**configs/lanenet.yaml**
+
+```yaml
+# 数据集配置
+DATASET:
+ DATA_DIR: "./dataset/tusimple_lane_detection"
+ IMAGE_TYPE: "rgb" # choice rgb or rgba
+ NUM_CLASSES: 2
+ TEST_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
+ TRAIN_FILE_LIST: "./dataset/tusimple_lane_detection/training/train_part.txt"
+ VAL_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
+ SEPARATOR: " "
+
+# 预训练模型配置
+MODEL:
+ MODEL_NAME: "lanenet"
+
+# 其他配置
+EVAL_CROP_SIZE: (512, 256)
+TRAIN_CROP_SIZE: (512, 256)
+AUG:
+ AUG_METHOD: u"unpadding" # choice unpadding rangescaling and stepscaling
+ FIX_RESIZE_SIZE: (512, 256) # (width, height), for unpadding
+ MIRROR: False
+ RICH_CROP:
+ ENABLE: False
+BATCH_SIZE: 4
+TEST:
+ TEST_MODEL: "./saved_model/lanenet/final/"
+TRAIN:
+ MODEL_SAVE_DIR: "./saved_model/lanenet/"
+ PRETRAINED_MODEL_DIR: "./pretrained_models/VGG16_pretrained"
+ SNAPSHOT_EPOCH: 5
+SOLVER:
+ NUM_EPOCHS: 100
+ LR: 0.0005
+ LR_POLICY: "poly"
+ OPTIMIZER: "sgd"
+ WEIGHT_DECAY: 0.001
+```
+
+
+## 五. 开始训练
+
+使用下述命令启动训练
+
+```shell
+CUDA_VISIBLE_DEVICES=0 python -u train.py --cfg configs/lanenet.yaml --use_gpu --use_mpio --do_eval
+```
+
+## 六. 进行评估
+
+模型训练完成,使用下述命令启动评估
+
+```shell
+CUDA_VISIBLE_DEVICES=0 python -u eval.py --use_gpu --cfg configs/lanenet.yaml
+```
+
+## 七. 可视化
+需要先下载一个车前视角和鸟瞰图视角转换所需文件,点击[链接](https://paddleseg.bj.bcebos.com/resources/tusimple_ipm_remap.tar),下载后放在```./utils```下。同时我们提供了一个训练好的模型,点击[链接](https://paddleseg.bj.bcebos.com/models/lanenet_vgg_tusimple.tar),下载后放在```./pretrained_models/```下,使用如下命令进行可视化
+```shell
+CUDA_VISIBLE_DEVICES=0 python -u ./vis.py --cfg configs/lanenet.yaml --use_gpu --vis_dir vis_result \
+TEST.TEST_MODEL pretrained_models/LaneNet_vgg_tusimple/
+```
+
+可视化结果示例:
+
+ 预测结果:
+ ![](imgs/0005_pred_lane.png)
+ 分割结果:
+ ![](imgs/0005_pred_binary.png)
+ 车道线实例预测结果:
+ ![](imgs/0005_pred_instance.png)
+
+
diff --git a/contrib/LaneNet/configs/lanenet.yaml b/contrib/LaneNet/configs/lanenet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1445e8803e638b2a44a7170c5020c4a0c56dcd67
--- /dev/null
+++ b/contrib/LaneNet/configs/lanenet.yaml
@@ -0,0 +1,52 @@
+EVAL_CROP_SIZE: (512, 256) # (width, height), for unpadding rangescaling and stepscaling
+TRAIN_CROP_SIZE: (512, 256) # (width, height), for unpadding rangescaling and stepscaling
+AUG:
+ AUG_METHOD: u"unpadding" # choice unpadding rangescaling and stepscaling
+ FIX_RESIZE_SIZE: (512, 256) # (width, height), for unpadding
+ INF_RESIZE_VALUE: 500 # for rangescaling
+ MAX_RESIZE_VALUE: 600 # for rangescaling
+ MIN_RESIZE_VALUE: 400 # for rangescaling
+ MAX_SCALE_FACTOR: 2.0 # for stepscaling
+ MIN_SCALE_FACTOR: 0.5 # for stepscaling
+ SCALE_STEP_SIZE: 0.25 # for stepscaling
+ MIRROR: False
+ RICH_CROP:
+ ENABLE: False
+
+BATCH_SIZE: 4
+
+DATALOADER:
+ BUF_SIZE: 256
+ NUM_WORKERS: 4
+DATASET:
+ DATA_DIR: "./dataset/tusimple_lane_detection"
+ IMAGE_TYPE: "rgb" # choice rgb or rgba
+ NUM_CLASSES: 2
+ TEST_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
+ TEST_TOTAL_IMAGES: 362
+ TRAIN_FILE_LIST: "./dataset/tusimple_lane_detection/training/train_part.txt"
+ TRAIN_TOTAL_IMAGES: 3264
+ VAL_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
+ VAL_TOTAL_IMAGES: 362
+ SEPARATOR: " "
+ IGNORE_INDEX: 255
+
+FREEZE:
+ MODEL_FILENAME: "__model__"
+ PARAMS_FILENAME: "__params__"
+MODEL:
+ MODEL_NAME: "lanenet"
+ DEFAULT_NORM_TYPE: "bn"
+TEST:
+ TEST_MODEL: "./saved_model/lanenet/final/"
+TRAIN:
+ MODEL_SAVE_DIR: "./saved_model/lanenet/"
+ PRETRAINED_MODEL_DIR: "./pretrained_models/VGG16_pretrained"
+ SNAPSHOT_EPOCH: 1
+SOLVER:
+ NUM_EPOCHS: 100
+ LR: 0.0005
+ LR_POLICY: "poly"
+ OPTIMIZER: "sgd"
+ WEIGHT_DECAY: 0.001
+
diff --git a/contrib/LaneNet/data_aug.py b/contrib/LaneNet/data_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..bffb956c657e8d93edcffe5c7946e3b1437a1ef1
--- /dev/null
+++ b/contrib/LaneNet/data_aug.py
@@ -0,0 +1,83 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import cv2
+import numpy as np
+from utils.config import cfg
+from models.model_builder import ModelPhase
+from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \
+ rand_scale_aspect, hsv_color_jitter, rand_crop
+
+def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
+ """
+ 改变图像及标签图像尺寸
+ AUG.AUG_METHOD为unpadding,所有模式均直接resize到AUG.FIX_RESIZE_SIZE的尺寸
+ AUG.AUG_METHOD为stepscaling, 按比例resize,训练时比例范围AUG.MIN_SCALE_FACTOR到AUG.MAX_SCALE_FACTOR,间隔为AUG.SCALE_STEP_SIZE,其他模式返回原图
+ AUG.AUG_METHOD为rangescaling,长边对齐,短边按比例变化,训练时长边对齐范围AUG.MIN_RESIZE_VALUE到AUG.MAX_RESIZE_VALUE,其他模式长边对齐AUG.INF_RESIZE_VALUE
+
+ Args:
+ img(numpy.ndarray): 输入图像
+ grt(numpy.ndarray): 标签图像,默认为None
+ mode(string): 模式, 默认训练模式,即ModelPhase.TRAIN
+
+ Returns:
+ resize后的图像和标签图
+
+ """
+
+ if cfg.AUG.AUG_METHOD == 'unpadding':
+ target_size = cfg.AUG.FIX_RESIZE_SIZE
+ img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
+ if grt is not None:
+ grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST)
+ if grt_instance is not None:
+ grt_instance = cv2.resize(grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
+ elif cfg.AUG.AUG_METHOD == 'stepscaling':
+ if mode == ModelPhase.TRAIN:
+ min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR
+ max_scale_factor = cfg.AUG.MAX_SCALE_FACTOR
+ step_size = cfg.AUG.SCALE_STEP_SIZE
+ scale_factor = get_random_scale(min_scale_factor, max_scale_factor,
+ step_size)
+ img, grt = randomly_scale_image_and_label(
+ img, grt, scale=scale_factor)
+ elif cfg.AUG.AUG_METHOD == 'rangescaling':
+ min_resize_value = cfg.AUG.MIN_RESIZE_VALUE
+ max_resize_value = cfg.AUG.MAX_RESIZE_VALUE
+ if mode == ModelPhase.TRAIN:
+ if min_resize_value == max_resize_value:
+ random_size = min_resize_value
+ else:
+ random_size = int(
+ np.random.uniform(min_resize_value, max_resize_value) + 0.5)
+ else:
+ random_size = cfg.AUG.INF_RESIZE_VALUE
+
+ value = max(img.shape[0], img.shape[1])
+ scale = float(random_size) / float(value)
+ img = cv2.resize(
+ img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
+ if grt is not None:
+ grt = cv2.resize(
+ grt, (0, 0),
+ fx=scale,
+ fy=scale,
+ interpolation=cv2.INTER_NEAREST)
+ else:
+ raise Exception("Unexpect data augmention method: {}".format(
+ cfg.AUG.AUG_METHOD))
+
+ return img, grt, grt_instance
diff --git a/contrib/LaneNet/dataset/download_tusimple.py b/contrib/LaneNet/dataset/download_tusimple.py
new file mode 100644
index 0000000000000000000000000000000000000000..1549cafdea4bbc97aca0401d84cd8844165324c8
--- /dev/null
+++ b/contrib/LaneNet/dataset/download_tusimple.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+
+LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
+TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
+sys.path.append(TEST_PATH)
+
+from test_utils import download_file_and_uncompress
+
+
+def download_tusimple_dataset(savepath, extrapath):
+ url = "https://paddleseg.bj.bcebos.com/dataset/tusimple_lane_detection.tar"
+ download_file_and_uncompress(
+ url=url, savepath=savepath, extrapath=extrapath)
+
+
+if __name__ == "__main__":
+ download_tusimple_dataset(LOCAL_PATH, LOCAL_PATH)
+ print("Dataset download finish!")
diff --git a/contrib/LaneNet/eval.py b/contrib/LaneNet/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..9256c4f024e7d15c9c018c4fe5930e5b7865c7e0
--- /dev/null
+++ b/contrib/LaneNet/eval.py
@@ -0,0 +1,182 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+# GPU memory garbage collection optimization flags
+os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
+
+import sys
+
+cur_path = os.path.abspath(os.path.dirname(__file__))
+root_path = os.path.split(os.path.split(cur_path)[0])[0]
+LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
+SEG_PATH = os.path.join(LOCAL_PATH, "../../../")
+sys.path.append(SEG_PATH)
+sys.path.append(root_path)
+
+import time
+import argparse
+import functools
+import pprint
+import cv2
+import numpy as np
+import paddle
+import paddle.fluid as fluid
+
+from utils.config import cfg
+from pdseg.utils.timer import Timer, calculate_eta
+from models.model_builder import build_model
+from models.model_builder import ModelPhase
+from reader import LaneNetDataset
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='PaddleSeg model evalution')
+ parser.add_argument(
+ '--cfg',
+ dest='cfg_file',
+ help='Config file for training (and optionally testing)',
+ default=None,
+ type=str)
+ parser.add_argument(
+ '--use_gpu',
+ dest='use_gpu',
+ help='Use gpu or cpu',
+ action='store_true',
+ default=False)
+ parser.add_argument(
+ '--use_mpio',
+ dest='use_mpio',
+ help='Use multiprocess IO or not',
+ action='store_true',
+ default=False)
+ parser.add_argument(
+ 'opts',
+ help='See utils/config.py for all options',
+ default=None,
+ nargs=argparse.REMAINDER)
+ if len(sys.argv) == 1:
+ parser.print_help()
+ sys.exit(1)
+ return parser.parse_args()
+
+
+def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
+ np.set_printoptions(precision=5, suppress=True)
+
+ startup_prog = fluid.Program()
+ test_prog = fluid.Program()
+
+ dataset = LaneNetDataset(
+ file_list=cfg.DATASET.VAL_FILE_LIST,
+ mode=ModelPhase.TRAIN,
+ shuffle=True,
+ data_dir=cfg.DATASET.DATA_DIR)
+
+ def data_generator():
+ #TODO: check is batch reader compatitable with Windows
+ if use_mpio:
+ data_gen = dataset.multiprocess_generator(
+ num_processes=cfg.DATALOADER.NUM_WORKERS,
+ max_queue_size=cfg.DATALOADER.BUF_SIZE)
+ else:
+ data_gen = dataset.generator()
+
+ for b in data_gen:
+ yield b
+
+ py_reader, pred, grts, masks, accuracy, fp, fn = build_model(
+ test_prog, startup_prog, phase=ModelPhase.EVAL)
+
+ py_reader.decorate_sample_generator(
+ data_generator, drop_last=False, batch_size=cfg.BATCH_SIZE)
+
+ # Get device environment
+ places = fluid.cuda_places() if use_gpu else fluid.cpu_places()
+ place = places[0]
+ dev_count = len(places)
+ print("#Device count: {}".format(dev_count))
+
+ exe = fluid.Executor(place)
+ exe.run(startup_prog)
+
+ test_prog = test_prog.clone(for_test=True)
+
+ ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir
+
+ if ckpt_dir is not None:
+ print('load test model:', ckpt_dir)
+ fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
+
+ # Use streaming confusion matrix to calculate mean_iou
+ np.set_printoptions(
+ precision=4, suppress=True, linewidth=160, floatmode="fixed")
+ fetch_list = [pred.name, grts.name, masks.name, accuracy.name, fp.name, fn.name]
+ num_images = 0
+ step = 0
+ avg_acc = 0.0
+ avg_fp = 0.0
+ avg_fn = 0.0
+ # cur_images = 0
+ all_step = cfg.DATASET.TEST_TOTAL_IMAGES // cfg.BATCH_SIZE + 1
+ timer = Timer()
+ timer.start()
+ py_reader.start()
+ while True:
+ try:
+ step += 1
+ pred, grts, masks, out_acc, out_fp, out_fn = exe.run(
+ test_prog, fetch_list=fetch_list, return_numpy=True)
+
+ avg_acc += np.mean(out_acc) * pred.shape[0]
+ avg_fp += np.mean(out_fp) * pred.shape[0]
+ avg_fn += np.mean(out_fn) * pred.shape[0]
+ num_images += pred.shape[0]
+
+ speed = 1.0 / timer.elapsed_time()
+
+ print(
+ "[EVAL]step={} accuracy={:.4f} fp={:.4f} fn={:.4f} step/sec={:.2f} | ETA {}"
+ .format(step, avg_acc / num_images, avg_fp / num_images, avg_fn / num_images, speed,
+ calculate_eta(all_step - step, speed)))
+
+ timer.restart()
+ sys.stdout.flush()
+ except fluid.core.EOFException:
+ break
+
+ print("[EVAL]#image={} accuracy={:.4f} fp={:.4f} fn={:.4f}".format(
+ num_images, avg_acc / num_images, avg_fp / num_images, avg_fn / num_images))
+
+ return avg_acc / num_images, avg_fp / num_images, avg_fn / num_images
+
+
+def main():
+ args = parse_args()
+ if args.cfg_file is not None:
+ cfg.update_from_file(args.cfg_file)
+ if args.opts:
+ cfg.update_from_list(args.opts)
+ cfg.check_and_infer()
+ print(pprint.pformat(cfg))
+ evaluate(cfg, **args.__dict__)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/contrib/LaneNet/imgs/0005_pred_binary.png b/contrib/LaneNet/imgs/0005_pred_binary.png
new file mode 100644
index 0000000000000000000000000000000000000000..77f66b2510683d3b94e6e7f1c219365546d8ca37
Binary files /dev/null and b/contrib/LaneNet/imgs/0005_pred_binary.png differ
diff --git a/contrib/LaneNet/imgs/0005_pred_instance.png b/contrib/LaneNet/imgs/0005_pred_instance.png
new file mode 100644
index 0000000000000000000000000000000000000000..ec99b30e49db0d0f02e198f75785618ff12b3bb6
Binary files /dev/null and b/contrib/LaneNet/imgs/0005_pred_instance.png differ
diff --git a/contrib/LaneNet/imgs/0005_pred_lane.png b/contrib/LaneNet/imgs/0005_pred_lane.png
new file mode 100644
index 0000000000000000000000000000000000000000..18c656f734c2276eaf03a07daf4f018db505d8ea
Binary files /dev/null and b/contrib/LaneNet/imgs/0005_pred_lane.png differ
diff --git a/contrib/LaneNet/loss.py b/contrib/LaneNet/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e888374582d0c83357bb652d7188dbf429832604
--- /dev/null
+++ b/contrib/LaneNet/loss.py
@@ -0,0 +1,138 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle.fluid as fluid
+import numpy as np
+from utils.config import cfg
+
+
+def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims):
+ zeros = fluid.layers.fill_constant_batch_size_like(unique_labels, shape=[1, feature_dims],
+ dtype='float32', value=0)
+ segment_ids = fluid.layers.unsqueeze(segment_ids, axes=[1])
+ segment_ids.stop_gradient = True
+ segment_sum = fluid.layers.scatter_nd_add(zeros, segment_ids, data)
+ zeros.stop_gradient = True
+
+ return segment_sum
+
+
+def norm(x, axis=-1):
+ distance = fluid.layers.reduce_sum(fluid.layers.abs(x), dim=axis, keep_dim=True)
+ return distance
+
+def discriminative_loss_single(
+ prediction,
+ correct_label,
+ feature_dim,
+ label_shape,
+ delta_v,
+ delta_d,
+ param_var,
+ param_dist,
+ param_reg):
+
+ correct_label = fluid.layers.reshape(
+ correct_label, [
+ label_shape[1] * label_shape[0]])
+ prediction = fluid.layers.transpose(prediction, [1, 2, 0])
+ reshaped_pred = fluid.layers.reshape(
+ prediction, [
+ label_shape[1] * label_shape[0], feature_dim])
+
+ unique_labels, unique_id, counts = fluid.layers.unique_with_counts(correct_label)
+ correct_label.stop_gradient = True
+ counts = fluid.layers.cast(counts, 'float32')
+ num_instances = fluid.layers.shape(unique_labels)
+
+ segmented_sum = unsorted_segment_sum(
+ reshaped_pred, unique_id, unique_labels, feature_dims=feature_dim)
+
+ counts_rsp = fluid.layers.reshape(counts, (-1, 1))
+ mu = fluid.layers.elementwise_div(segmented_sum, counts_rsp)
+ counts_rsp.stop_gradient = True
+ mu_expand = fluid.layers.gather(mu, unique_id)
+ tmp = fluid.layers.elementwise_sub(mu_expand, reshaped_pred)
+
+ distance = norm(tmp)
+ distance = distance - delta_v
+
+ distance_pos = fluid.layers.greater_equal(distance, fluid.layers.zeros_like(distance))
+ distance_pos = fluid.layers.cast(distance_pos, 'float32')
+ distance = distance * distance_pos
+
+ distance = fluid.layers.square(distance)
+
+ l_var = unsorted_segment_sum(distance, unique_id, unique_labels, feature_dims=1)
+ l_var = fluid.layers.elementwise_div(l_var, counts_rsp)
+ l_var = fluid.layers.reduce_sum(l_var)
+ l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1), 'float32')
+
+ mu_interleaved_rep = fluid.layers.expand(mu, [num_instances, 1])
+ mu_band_rep = fluid.layers.expand(mu, [1, num_instances])
+ mu_band_rep = fluid.layers.reshape(mu_band_rep, (num_instances * num_instances, feature_dim))
+
+ mu_diff = fluid.layers.elementwise_sub(mu_band_rep, mu_interleaved_rep)
+
+ intermediate_tensor = fluid.layers.reduce_sum(fluid.layers.abs(mu_diff), dim=1)
+ intermediate_tensor.stop_gradient = True
+ zero_vector = fluid.layers.zeros([1], 'float32')
+ bool_mask = fluid.layers.not_equal(intermediate_tensor, zero_vector)
+ temp = fluid.layers.where(bool_mask)
+ mu_diff_bool = fluid.layers.gather(mu_diff, temp)
+
+ mu_norm = norm(mu_diff_bool)
+ mu_norm = 2. * delta_d - mu_norm
+ mu_norm_pos = fluid.layers.greater_equal(mu_norm, fluid.layers.zeros_like(mu_norm))
+ mu_norm_pos = fluid.layers.cast(mu_norm_pos, 'float32')
+ mu_norm = mu_norm * mu_norm_pos
+ mu_norm_pos.stop_gradient = True
+
+ mu_norm = fluid.layers.square(mu_norm)
+
+ l_dist = fluid.layers.reduce_mean(mu_norm)
+
+ l_reg = fluid.layers.reduce_mean(norm(mu, axis=1))
+
+ l_var = param_var * l_var
+ l_dist = param_dist * l_dist
+ l_reg = param_reg * l_reg
+ loss = l_var + l_dist + l_reg
+ return loss, l_var, l_dist, l_reg
+
+
+def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
+ delta_v, delta_d, param_var, param_dist, param_reg):
+ batch_size = int(cfg.BATCH_SIZE_PER_DEV)
+ output_ta_loss = 0.
+ output_ta_var = 0.
+ output_ta_dist = 0.
+ output_ta_reg = 0.
+ for i in range(batch_size):
+ disc_loss_single, l_var_single, l_dist_single, l_reg_single = discriminative_loss_single(
+ prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist,
+ param_reg)
+ output_ta_loss += disc_loss_single
+ output_ta_var += l_var_single
+ output_ta_dist += l_dist_single
+ output_ta_reg += l_reg_single
+
+ disc_loss = output_ta_loss / batch_size
+ l_var = output_ta_var / batch_size
+ l_dist = output_ta_dist / batch_size
+ l_reg = output_ta_reg / batch_size
+ return disc_loss, l_var, l_dist, l_reg
+
+
diff --git a/contrib/LaneNet/models/__init__.py b/contrib/LaneNet/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f750f6f7b42bb028f81a24edb2bb9e30c190578e
--- /dev/null
+++ b/contrib/LaneNet/models/__init__.py
@@ -0,0 +1,17 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import models.modeling
+#import models.backbone
diff --git a/contrib/LaneNet/models/model_builder.py b/contrib/LaneNet/models/model_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed6c275ecd51a2fc9f7f2fdf125300ce026c0a0a
--- /dev/null
+++ b/contrib/LaneNet/models/model_builder.py
@@ -0,0 +1,261 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+sys.path.append("..")
+import struct
+
+import paddle.fluid as fluid
+from paddle.fluid.proto.framework_pb2 import VarType
+
+from pdseg import solver
+from utils.config import cfg
+from pdseg.loss import multi_softmax_with_loss
+from loss import discriminative_loss
+from models.modeling import lanenet
+
+class ModelPhase(object):
+ """
+ Standard name for model phase in PaddleSeg
+
+ The following standard keys are defined:
+ * `TRAIN`: training mode.
+ * `EVAL`: testing/evaluation mode.
+ * `PREDICT`: prediction/inference mode.
+ * `VISUAL` : visualization mode
+ """
+
+ TRAIN = 'train'
+ EVAL = 'eval'
+ PREDICT = 'predict'
+ VISUAL = 'visual'
+
+ @staticmethod
+ def is_train(phase):
+ return phase == ModelPhase.TRAIN
+
+ @staticmethod
+ def is_predict(phase):
+ return phase == ModelPhase.PREDICT
+
+ @staticmethod
+ def is_eval(phase):
+ return phase == ModelPhase.EVAL
+
+ @staticmethod
+ def is_visual(phase):
+ return phase == ModelPhase.VISUAL
+
+ @staticmethod
+ def is_valid_phase(phase):
+ """ Check valid phase """
+ if ModelPhase.is_train(phase) or ModelPhase.is_predict(phase) \
+ or ModelPhase.is_eval(phase) or ModelPhase.is_visual(phase):
+ return True
+
+ return False
+
+
+def seg_model(image, class_num):
+ model_name = cfg.MODEL.MODEL_NAME
+ if model_name == 'lanenet':
+ logits = lanenet.lanenet(image, class_num)
+ else:
+ raise Exception(
+ "unknow model name, only support unet, deeplabv3p, icnet, pspnet, hrnet"
+ )
+ return logits
+
+
+def softmax(logit):
+ logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
+ logit = fluid.layers.softmax(logit)
+ logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
+ return logit
+
+
+def sigmoid_to_softmax(logit):
+ """
+ one channel to two channel
+ """
+ logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
+ logit = fluid.layers.sigmoid(logit)
+ logit_back = 1 - logit
+ logit = fluid.layers.concat([logit_back, logit], axis=-1)
+ logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
+ return logit
+
+
+def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
+ if not ModelPhase.is_valid_phase(phase):
+ raise ValueError("ModelPhase {} is not valid!".format(phase))
+ if ModelPhase.is_train(phase):
+ width = cfg.TRAIN_CROP_SIZE[0]
+ height = cfg.TRAIN_CROP_SIZE[1]
+ else:
+ width = cfg.EVAL_CROP_SIZE[0]
+ height = cfg.EVAL_CROP_SIZE[1]
+
+ image_shape = [cfg.DATASET.DATA_DIM, height, width]
+ grt_shape = [1, height, width]
+ class_num = cfg.DATASET.NUM_CLASSES
+
+ with fluid.program_guard(main_prog, start_prog):
+ with fluid.unique_name.guard():
+ image = fluid.layers.data(
+ name='image', shape=image_shape, dtype='float32')
+ label = fluid.layers.data(
+ name='label', shape=grt_shape, dtype='int32')
+ if cfg.MODEL.MODEL_NAME == 'lanenet':
+ label_instance = fluid.layers.data(
+ name='label_instance', shape=grt_shape, dtype='int32')
+ mask = fluid.layers.data(
+ name='mask', shape=grt_shape, dtype='int32')
+
+ # use PyReader when doing traning and evaluation
+ if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase):
+ py_reader = fluid.io.PyReader(
+ feed_list=[image, label, label_instance, mask],
+ capacity=cfg.DATALOADER.BUF_SIZE,
+ iterable=False,
+ use_double_buffer=True)
+
+
+ loss_type = cfg.SOLVER.LOSS
+ if not isinstance(loss_type, list):
+ loss_type = list(loss_type)
+
+ logits = seg_model(image, class_num)
+
+ if ModelPhase.is_train(phase):
+ loss_valid = False
+ valid_loss = []
+ if cfg.MODEL.MODEL_NAME == 'lanenet':
+ embeding_logit = logits[1]
+ logits = logits[0]
+ disc_loss, _, _, l_reg = discriminative_loss(embeding_logit, label_instance, 4,
+ image_shape[1:], 0.5, 3.0, 1.0, 1.0, 0.001)
+
+ if "softmax_loss" in loss_type:
+ weight = None
+ if cfg.MODEL.MODEL_NAME == 'lanenet':
+ weight = get_dynamic_weight(label)
+ seg_loss = multi_softmax_with_loss(logits, label, mask, class_num, weight)
+ loss_valid = True
+ valid_loss.append("softmax_loss")
+
+ if not loss_valid:
+ raise Exception("SOLVER.LOSS: {} is set wrong. it should "
+ "include one of (softmax_loss, bce_loss, dice_loss) at least"
+ " example: ['softmax_loss']".format(cfg.SOLVER.LOSS))
+
+ invalid_loss = [x for x in loss_type if x not in valid_loss]
+ if len(invalid_loss) > 0:
+ print("Warning: the loss {} you set is invalid. it will not be included in loss computed.".format(invalid_loss))
+
+ avg_loss = disc_loss + 0.00001 * l_reg + seg_loss
+
+ #get pred result in original size
+ if isinstance(logits, tuple):
+ logit = logits[0]
+ else:
+ logit = logits
+
+ if logit.shape[2:] != label.shape[2:]:
+ logit = fluid.layers.resize_bilinear(logit, label.shape[2:])
+
+ # return image input and logit output for inference graph prune
+ if ModelPhase.is_predict(phase):
+ if class_num == 1:
+ logit = sigmoid_to_softmax(logit)
+ else:
+ logit = softmax(logit)
+ return image, logit
+
+ if class_num == 1:
+ out = sigmoid_to_softmax(logit)
+ out = fluid.layers.transpose(out, [0, 2, 3, 1])
+ else:
+ out = fluid.layers.transpose(logit, [0, 2, 3, 1])
+
+ pred = fluid.layers.argmax(out, axis=3)
+ pred = fluid.layers.unsqueeze(pred, axes=[3])
+ if ModelPhase.is_visual(phase):
+ if cfg.MODEL.MODEL_NAME == 'lanenet':
+ return pred, logits[1]
+ if class_num == 1:
+ logit = sigmoid_to_softmax(logit)
+ else:
+ logit = softmax(logit)
+ return pred, logit
+
+ accuracy, fp, fn = compute_metric(pred, label)
+ if ModelPhase.is_eval(phase):
+ return py_reader, pred, label, mask, accuracy, fp, fn
+
+ if ModelPhase.is_train(phase):
+ optimizer = solver.Solver(main_prog, start_prog)
+ decayed_lr = optimizer.optimise(avg_loss)
+ return py_reader, avg_loss, decayed_lr, pred, label, mask, disc_loss, seg_loss, accuracy, fp, fn
+
+
+def compute_metric(pred, label):
+ label = fluid.layers.transpose(label, [0, 2, 3, 1])
+
+ idx = fluid.layers.where(pred == 1)
+ pix_cls_ret = fluid.layers.gather_nd(label, idx)
+
+ correct_num = fluid.layers.reduce_sum(fluid.layers.cast(pix_cls_ret, 'float32'))
+
+ gt_num = fluid.layers.cast(fluid.layers.shape(fluid.layers.gather_nd(label,
+ fluid.layers.where(label == 1)))[0], 'int64')
+ pred_num = fluid.layers.cast(fluid.layers.shape(fluid.layers.gather_nd(pred, idx))[0], 'int64')
+ accuracy = correct_num / gt_num
+
+ false_pred = pred_num - correct_num
+ fp = fluid.layers.cast(false_pred, 'float32') / fluid.layers.cast(fluid.layers.shape(pix_cls_ret)[0], 'int64')
+
+ label_cls_ret = fluid.layers.gather_nd(label, fluid.layers.where(label == 1))
+ mis_pred = fluid.layers.cast(fluid.layers.shape(label_cls_ret)[0], 'int64') - correct_num
+ fn = fluid.layers.cast(mis_pred, 'float32') / fluid.layers.cast(fluid.layers.shape(label_cls_ret)[0], 'int64')
+ accuracy.stop_gradient = True
+ fp.stop_gradient = True
+ fn.stop_gradient = True
+ return accuracy, fp, fn
+
+
+def get_dynamic_weight(label):
+ label = fluid.layers.reshape(label, [-1])
+ unique_labels, unique_id, counts = fluid.layers.unique_with_counts(label)
+ counts = fluid.layers.cast(counts, 'float32')
+ weight = 1.0 / fluid.layers.log((counts / fluid.layers.reduce_sum(counts) + 1.02))
+ return weight
+
+
+def to_int(string, dest="I"):
+ return struct.unpack(dest, string)[0]
+
+
+def parse_shape_from_file(filename):
+ with open(filename, "rb") as file:
+ version = file.read(4)
+ lod_level = to_int(file.read(8), dest="Q")
+ for i in range(lod_level):
+ _size = to_int(file.read(8), dest="Q")
+ _ = file.read(_size)
+ version = file.read(4)
+ tensor_desc_size = to_int(file.read(4))
+ tensor_desc = VarType.TensorDesc()
+ tensor_desc.ParseFromString(file.read(tensor_desc_size))
+ return tuple(tensor_desc.dims)
diff --git a/contrib/LaneNet/models/modeling/__init__.py b/contrib/LaneNet/models/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/contrib/LaneNet/models/modeling/lanenet.py b/contrib/LaneNet/models/modeling/lanenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..68837983f08ce4220bfb5bb0ea7d96404687b259
--- /dev/null
+++ b/contrib/LaneNet/models/modeling/lanenet.py
@@ -0,0 +1,440 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+from __future__ import print_function
+
+import paddle.fluid as fluid
+
+
+from utils.config import cfg
+from pdseg.models.libs.model_libs import scope, name_scope
+from pdseg.models.libs.model_libs import bn, bn_relu, relu
+from pdseg.models.libs.model_libs import conv, max_pool, deconv
+from pdseg.models.backbone.vgg import VGGNet as vgg_backbone
+#from models.backbone.vgg import VGGNet as vgg_backbone
+
+# Bottleneck type
+REGULAR = 1
+DOWNSAMPLING = 2
+UPSAMPLING = 3
+DILATED = 4
+ASYMMETRIC = 5
+
+
+def prelu(x, decoder=False):
+ # If decoder, then perform relu else perform prelu
+ if decoder:
+ return fluid.layers.relu(x)
+ return fluid.layers.prelu(x, 'channel')
+
+
+def iniatial_block(inputs, name_scope='iniatial_block'):
+ '''
+ The initial block for Enet has 2 branches: The convolution branch and Maxpool branch.
+ The conv branch has 13 filters, while the maxpool branch gives 3 channels corresponding to the RGB channels.
+ Both output layers are then concatenated to give an output of 16 channels.
+
+ :param inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels]
+ :return net_concatenated(Tensor): a 4D Tensor of new shape [batch_size, height, width, channels]
+ '''
+ # Convolutional branch
+ with scope(name_scope):
+ net_conv = conv(inputs, 13, 3, stride=2, padding=1)
+ net_conv = bn(net_conv)
+ net_conv = fluid.layers.prelu(net_conv, 'channel')
+
+ # Max pool branch
+ net_pool = max_pool(inputs, [2, 2], stride=2, padding='SAME')
+
+ # Concatenated output - does it matter max pool comes first or conv comes first? probably not.
+ net_concatenated = fluid.layers.concat([net_conv, net_pool], axis=1)
+ return net_concatenated
+
+
+def bottleneck(inputs,
+ output_depth,
+ filter_size,
+ regularizer_prob,
+ projection_ratio=4,
+ type=REGULAR,
+ seed=0,
+ output_shape=None,
+ dilation_rate=None,
+ decoder=False,
+ name_scope='bottleneck'):
+
+ # Calculate the depth reduction based on the projection ratio used in 1x1 convolution.
+ reduced_depth = int(inputs.shape[1] / projection_ratio)
+
+ # DOWNSAMPLING BOTTLENECK
+ if type == DOWNSAMPLING:
+ #=============MAIN BRANCH=============
+ #Just perform a max pooling
+ with scope('down_sample'):
+ inputs_shape = inputs.shape
+ with scope('main_max_pool'):
+ net_main = fluid.layers.conv2d(inputs, inputs_shape[1], filter_size=3, stride=2, padding='SAME')
+
+ #First get the difference in depth to pad, then pad with zeros only on the last dimension.
+ depth_to_pad = abs(inputs_shape[1] - output_depth)
+ paddings = [0, 0, 0, depth_to_pad, 0, 0, 0, 0]
+ with scope('main_padding'):
+ net_main = fluid.layers.pad(net_main, paddings=paddings)
+
+ with scope('block1'):
+ net = conv(inputs, reduced_depth, [2, 2], stride=2, padding='same')
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ with scope('block2'):
+ net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ with scope('block3'):
+ net = conv(net, output_depth, [1, 1], padding='same')
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Regularizer
+ net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
+
+ # Finally, combine the two branches together via an element-wise addition
+ net = fluid.layers.elementwise_add(net, net_main)
+ net = prelu(net, decoder=decoder)
+
+ return net, inputs_shape
+
+ # DILATION CONVOLUTION BOTTLENECK
+ # Everything is the same as a regular bottleneck except for the dilation rate argument
+ elif type == DILATED:
+ #Check if dilation rate is given
+ if not dilation_rate:
+ raise ValueError('Dilation rate is not given.')
+
+ with scope('dilated'):
+ # Save the main branch for addition later
+ net_main = inputs
+
+ # First projection with 1x1 kernel (dimensionality reduction)
+ with scope('block1'):
+ net = conv(inputs, reduced_depth, [1, 1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Second conv block --- apply dilated convolution here
+ with scope('block2'):
+ net = conv(net, reduced_depth, filter_size, padding='SAME', dilation=dilation_rate)
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Final projection with 1x1 kernel (Expansion)
+ with scope('block3'):
+ net = conv(net, output_depth, [1,1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Regularizer
+ net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
+ net = prelu(net, decoder=decoder)
+
+ # Add the main branch
+ net = fluid.layers.elementwise_add(net_main, net)
+ net = prelu(net, decoder=decoder)
+
+ return net
+
+ # ASYMMETRIC CONVOLUTION BOTTLENECK
+ # Everything is the same as a regular bottleneck except for a [5,5] kernel decomposed into two [5,1] then [1,5]
+ elif type == ASYMMETRIC:
+ # Save the main branch for addition later
+ with scope('asymmetric'):
+ net_main = inputs
+ # First projection with 1x1 kernel (dimensionality reduction)
+ with scope('block1'):
+ net = conv(inputs, reduced_depth, [1, 1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Second conv block --- apply asymmetric conv here
+ with scope('block2'):
+ with scope('asymmetric_conv2a'):
+ net = conv(net, reduced_depth, [filter_size, 1], padding='same')
+ with scope('asymmetric_conv2b'):
+ net = conv(net, reduced_depth, [1, filter_size], padding='same')
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Final projection with 1x1 kernel
+ with scope('block3'):
+ net = conv(net, output_depth, [1, 1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Regularizer
+ net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
+ net = prelu(net, decoder=decoder)
+
+ # Add the main branch
+ net = fluid.layers.elementwise_add(net_main, net)
+ net = prelu(net, decoder=decoder)
+
+ return net
+
+ # UPSAMPLING BOTTLENECK
+ # Everything is the same as a regular one, except convolution becomes transposed.
+ elif type == UPSAMPLING:
+ #Check if pooling indices is given
+
+ #Check output_shape given or not
+ if output_shape is None:
+ raise ValueError('Output depth is not given')
+
+ #=======MAIN BRANCH=======
+ #Main branch to upsample. output shape must match with the shape of the layer that was pooled initially, in order
+ #for the pooling indices to work correctly. However, the initial pooled layer was padded, so need to reduce dimension
+ #before unpooling. In the paper, padding is replaced with convolution for this purpose of reducing the depth!
+ with scope('upsampling'):
+ with scope('unpool'):
+ net_unpool = conv(inputs, output_depth, [1, 1])
+ net_unpool = bn(net_unpool)
+ net_unpool = fluid.layers.resize_bilinear(net_unpool, out_shape=output_shape[2:])
+
+ # First 1x1 projection to reduce depth
+ with scope('block1'):
+ net = conv(inputs, reduced_depth, [1, 1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ with scope('block2'):
+ net = deconv(net, reduced_depth, filter_size=filter_size, stride=2, padding='same')
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Final projection with 1x1 kernel
+ with scope('block3'):
+ net = conv(net, output_depth, [1, 1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Regularizer
+ net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
+ net = prelu(net, decoder=decoder)
+
+ # Finally, add the unpooling layer and the sub branch together
+ net = fluid.layers.elementwise_add(net, net_unpool)
+ net = prelu(net, decoder=decoder)
+
+ return net
+
+ # REGULAR BOTTLENECK
+ else:
+ with scope('regular'):
+ net_main = inputs
+
+ # First projection with 1x1 kernel
+ with scope('block1'):
+ net = conv(inputs, reduced_depth, [1, 1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Second conv block
+ with scope('block2'):
+ net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Final projection with 1x1 kernel
+ with scope('block3'):
+ net = conv(net, output_depth, [1, 1])
+ net = bn(net)
+ net = prelu(net, decoder=decoder)
+
+ # Regularizer
+ net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
+ net = prelu(net, decoder=decoder)
+
+ # Add the main branch
+ net = fluid.layers.elementwise_add(net_main, net)
+ net = prelu(net, decoder=decoder)
+
+ return net
+
+
+def ENet_stage1(inputs, name_scope='stage1_block'):
+ with scope(name_scope):
+ with scope('bottleneck1_0'):
+ net, inputs_shape_1 \
+ = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING,
+ name_scope='bottleneck1_0')
+ with scope('bottleneck1_1'):
+ net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
+ name_scope='bottleneck1_1')
+ with scope('bottleneck1_2'):
+ net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
+ name_scope='bottleneck1_2')
+ with scope('bottleneck1_3'):
+ net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
+ name_scope='bottleneck1_3')
+ with scope('bottleneck1_4'):
+ net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
+ name_scope='bottleneck1_4')
+ return net, inputs_shape_1
+
+
+def ENet_stage2(inputs, name_scope='stage2_block'):
+ with scope(name_scope):
+ net, inputs_shape_2 \
+ = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DOWNSAMPLING,
+ name_scope='bottleneck2_0')
+ for i in range(2):
+ with scope('bottleneck2_{}'.format(str(4 * i + 1))):
+ net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1,
+ name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
+ with scope('bottleneck2_{}'.format(str(4 * i + 2))):
+ net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
+ name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
+ with scope('bottleneck2_{}'.format(str(4 * i + 3))):
+ net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
+ name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
+ with scope('bottleneck2_{}'.format(str(4 * i + 4))):
+ net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
+ name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
+ return net, inputs_shape_2
+
+
+def ENet_stage3(inputs, name_scope='stage3_block'):
+ with scope(name_scope):
+ for i in range(2):
+ with scope('bottleneck3_{}'.format(str(4 * i + 0))):
+ net = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1,
+ name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
+ with scope('bottleneck3_{}'.format(str(4 * i + 1))):
+ net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
+ name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
+ with scope('bottleneck3_{}'.format(str(4 * i + 2))):
+ net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
+ name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
+ with scope('bottleneck3_{}'.format(str(4 * i + 3))):
+ net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
+ name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
+ return net
+
+
+def ENet_stage4(inputs, inputs_shape, connect_tensor,
+ skip_connections=True, name_scope='stage4_block'):
+ with scope(name_scope):
+ with scope('bottleneck4_0'):
+ net = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.1,
+ type=UPSAMPLING, decoder=True, output_shape=inputs_shape,
+ name_scope='bottleneck4_0')
+
+ if skip_connections:
+ net = fluid.layers.elementwise_add(net, connect_tensor)
+ with scope('bottleneck4_1'):
+ net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
+ name_scope='bottleneck4_1')
+ with scope('bottleneck4_2'):
+ net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
+ name_scope='bottleneck4_2')
+
+ return net
+
+
+def ENet_stage5(inputs, inputs_shape, connect_tensor, skip_connections=True,
+ name_scope='stage5_block'):
+ with scope(name_scope):
+ net = bottleneck(inputs, output_depth=16, filter_size=3, regularizer_prob=0.1, type=UPSAMPLING,
+ decoder=True, output_shape=inputs_shape,
+ name_scope='bottleneck5_0')
+
+ if skip_connections:
+ net = fluid.layers.elementwise_add(net, connect_tensor)
+ with scope('bottleneck5_1'):
+ net = bottleneck(net, output_depth=16, filter_size=3, regularizer_prob=0.1, decoder=True,
+ name_scope='bottleneck5_1')
+ return net
+
+
+def decoder(input, num_classes):
+
+ if 'enet' in cfg.MODEL.LANENET.BACKBONE:
+ # Segmentation branch
+ with scope('LaneNetSeg'):
+ initial, stage1, stage2, inputs_shape_1, inputs_shape_2 = input
+ segStage3 = ENet_stage3(stage2)
+ segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1)
+ segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial)
+ segLogits = deconv(segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
+
+ # Embedding branch
+ with scope('LaneNetEm'):
+ emStage3 = ENet_stage3(stage2)
+ emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1)
+ emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial)
+ emLogits = deconv(emStage5, 4, filter_size=2, stride=2, padding='SAME')
+
+ elif 'vgg' in cfg.MODEL.LANENET.BACKBONE:
+ encoder_list = ['pool5', 'pool4', 'pool3']
+ # score stage
+ input_tensor = input[encoder_list[0]]
+ with scope('score_origin'):
+ score = conv(input_tensor, 64, 1)
+ encoder_list = encoder_list[1:]
+ for i in range(len(encoder_list)):
+ with scope('deconv_{:d}'.format(i + 1)):
+ deconv_out = deconv(score, 64, filter_size=4, stride=2, padding='SAME')
+ input_tensor = input[encoder_list[i]]
+ with scope('score_{:d}'.format(i + 1)):
+ score = conv(input_tensor, 64, 1)
+ score = fluid.layers.elementwise_add(deconv_out, score)
+
+ with scope('deconv_final'):
+ emLogits = deconv(score, 64, filter_size=16, stride=8, padding='SAME')
+ with scope('score_final'):
+ segLogits = conv(emLogits, num_classes, 1)
+ emLogits = relu(conv(emLogits, 4, 1))
+ return segLogits, emLogits
+
+
+def encoder(input):
+ if 'vgg' in cfg.MODEL.LANENET.BACKBONE:
+ model = vgg_backbone(layers=16)
+ #output = model.net(input)
+
+ _, encode_feature_dict = model.net(input, end_points=13, decode_points=[7, 10, 13])
+ output = {}
+ output['pool3'] = encode_feature_dict[7]
+ output['pool4'] = encode_feature_dict[10]
+ output['pool5'] = encode_feature_dict[13]
+ elif 'enet' in cfg.MODEL.LANET.BACKBONE:
+ with scope('LaneNetBase'):
+ initial = iniatial_block(input)
+ stage1, inputs_shape_1 = ENet_stage1(initial)
+ stage2, inputs_shape_2 = ENet_stage2(stage1)
+ output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2)
+ else:
+ raise Exception("LaneNet expect enet and vgg backbone, but received {}".
+ format(cfg.MODEL.LANENET.BACKBONE))
+ return output
+
+
+def lanenet(img, num_classes):
+
+ output = encoder(img)
+ segLogits, emLogits = decoder(output, num_classes)
+
+ return segLogits, emLogits
diff --git a/contrib/LaneNet/reader.py b/contrib/LaneNet/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..29af37b8caf15da8cdabc73a847c30ca88d65c4a
--- /dev/null
+++ b/contrib/LaneNet/reader.py
@@ -0,0 +1,321 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+import sys
+import os
+import time
+import codecs
+
+import numpy as np
+import cv2
+
+from utils.config import cfg
+import data_aug as aug
+from pdseg.data_utils import GeneratorEnqueuer
+from models.model_builder import ModelPhase
+import copy
+
+
+def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
+ # resolve cv2.imread open Chinese file path issues on Windows Platform.
+ return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
+
+
+class LaneNetDataset():
+ def __init__(self,
+ file_list,
+ data_dir,
+ shuffle=False,
+ mode=ModelPhase.TRAIN):
+ self.mode = mode
+ self.shuffle = shuffle
+ self.data_dir = data_dir
+
+ self.shuffle_seed = 0
+ # NOTE: Please ensure file list was save in UTF-8 coding format
+ with codecs.open(file_list, 'r', 'utf-8') as flist:
+ self.lines = [line.strip() for line in flist]
+ self.all_lines = copy.deepcopy(self.lines)
+ if shuffle and cfg.NUM_TRAINERS > 1:
+ np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
+ elif shuffle:
+ np.random.shuffle(self.lines)
+
+ def generator(self):
+ if self.shuffle and cfg.NUM_TRAINERS > 1:
+ np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
+ num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
+ self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
+ self.shuffle_seed += 1
+ elif self.shuffle:
+ np.random.shuffle(self.lines)
+
+ for line in self.lines:
+ yield self.process_image(line, self.data_dir, self.mode)
+
+ def sharding_generator(self, pid=0, num_processes=1):
+ """
+ Use line id as shard key for multiprocess io
+ It's a normal generator if pid=0, num_processes=1
+ """
+ for index, line in enumerate(self.lines):
+ # Use index and pid to shard file list
+ if index % num_processes == pid:
+ yield self.process_image(line, self.data_dir, self.mode)
+
+ def batch_reader(self, batch_size):
+ br = self.batch(self.reader, batch_size)
+ for batch in br:
+ yield batch[0], batch[1], batch[2]
+
+ def multiprocess_generator(self, max_queue_size=32, num_processes=8):
+ # Re-shuffle file list
+ if self.shuffle and cfg.NUM_TRAINERS > 1:
+ np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
+ num_lines = len(self.all_lines) // self.num_trainers
+ self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)]
+ self.shuffle_seed += 1
+ elif self.shuffle:
+ np.random.shuffle(self.lines)
+
+ # Create multiple sharding generators according to num_processes for multiple processes
+ generators = []
+ for pid in range(num_processes):
+ generators.append(self.sharding_generator(pid, num_processes))
+
+ try:
+ enqueuer = GeneratorEnqueuer(generators)
+ enqueuer.start(max_queue_size=max_queue_size, workers=num_processes)
+ while True:
+ generator_out = None
+ while enqueuer.is_running():
+ if not enqueuer.queue.empty():
+ generator_out = enqueuer.queue.get(timeout=5)
+ break
+ else:
+ time.sleep(0.01)
+ if generator_out is None:
+ break
+ yield generator_out
+ finally:
+ if enqueuer is not None:
+ enqueuer.stop()
+
+ def batch(self, reader, batch_size, is_test=False, drop_last=False):
+ def batch_reader(is_test=False, drop_last=drop_last):
+ if is_test:
+ imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
+ for img, grt, grt_instance, img_name, valid_shape, org_shape in reader():
+ imgs.append(img)
+ grts.append(grt)
+ grts_instance.append(grt_instance)
+ img_names.append(img_name)
+ valid_shapes.append(valid_shape)
+ org_shapes.append(org_shape)
+ if len(imgs) == batch_size:
+ yield np.array(imgs), np.array(
+ grts), np.array(grts_instance), img_names, np.array(valid_shapes), np.array(
+ org_shapes)
+ imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
+
+ if not drop_last and len(imgs) > 0:
+ yield np.array(imgs), np.array(grts), np.array(grts_instance), img_names, np.array(
+ valid_shapes), np.array(org_shapes)
+ else:
+ imgs, labs, labs_instance, ignore = [], [], [], []
+ bs = 0
+ for img, lab, lab_instance, ig in reader():
+ imgs.append(img)
+ labs.append(lab)
+ labs_instance.append(lab_instance)
+ ignore.append(ig)
+ bs += 1
+ if bs == batch_size:
+ yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
+ bs = 0
+ imgs, labs, labs_instance, ignore = [], [], [], []
+
+ if not drop_last and bs > 0:
+ yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
+
+ return batch_reader(is_test, drop_last)
+
+ def load_image(self, line, src_dir, mode=ModelPhase.TRAIN):
+ # original image cv2.imread flag setting
+ cv2_imread_flag = cv2.IMREAD_COLOR
+ if cfg.DATASET.IMAGE_TYPE == "rgba":
+ # If use RBGA 4 channel ImageType, use IMREAD_UNCHANGED flags to
+ # reserver alpha channel
+ cv2_imread_flag = cv2.IMREAD_UNCHANGED
+
+ parts = line.strip().split(cfg.DATASET.SEPARATOR)
+ if len(parts) != 3:
+ if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
+ raise Exception("File list format incorrect! It should be"
+ " image_name{}label_name\\n".format(
+ cfg.DATASET.SEPARATOR))
+ img_name, grt_name, grt_instance_name = parts[0], None, None
+ else:
+ img_name, grt_name, grt_instance_name = parts[0], parts[1], parts[2]
+
+ img_path = os.path.join(src_dir, img_name)
+ img = cv2_imread(img_path, cv2_imread_flag)
+
+ if grt_name is not None:
+ grt_path = os.path.join(src_dir, grt_name)
+ grt_instance_path = os.path.join(src_dir, grt_instance_name)
+ grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
+ grt[grt == 255] = 1
+ grt[grt != 1] = 0
+ grt_instance = cv2_imread(grt_instance_path, cv2.IMREAD_GRAYSCALE)
+ else:
+ grt = None
+ grt_instance = None
+
+ if img is None:
+ raise Exception(
+ "Empty image, src_dir: {}, img: {} & lab: {}".format(
+ src_dir, img_path, grt_path))
+
+ img_height = img.shape[0]
+ img_width = img.shape[1]
+
+ if grt is not None:
+ grt_height = grt.shape[0]
+ grt_width = grt.shape[1]
+
+ if img_height != grt_height or img_width != grt_width:
+ raise Exception(
+ "source img and label img must has the same size")
+ else:
+ if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
+ raise Exception(
+ "Empty image, src_dir: {}, img: {} & lab: {}".format(
+ src_dir, img_path, grt_path))
+
+ if len(img.shape) < 3:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img_channels = img.shape[2]
+ if img_channels < 3:
+ raise Exception("PaddleSeg only supports gray, rgb or rgba image")
+ if img_channels != cfg.DATASET.DATA_DIM:
+ raise Exception(
+ "Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
+ .format(img_channels, cfg.DATASET.DATADIM, img_name))
+ if img_channels != len(cfg.MEAN):
+ raise Exception(
+ "img name {}, img chns {} mean size {}, size unequal".format(
+ img_name, img_channels, len(cfg.MEAN)))
+ if img_channels != len(cfg.STD):
+ raise Exception(
+ "img name {}, img chns {} std size {}, size unequal".format(
+ img_name, img_channels, len(cfg.STD)))
+
+ return img, grt, grt_instance, img_name, grt_name
+
+ def normalize_image(self, img):
+ """ 像素归一化后减均值除方差 """
+ img = img.transpose((2, 0, 1)).astype('float32') / 255.0
+ img_mean = np.array(cfg.MEAN).reshape((len(cfg.MEAN), 1, 1))
+ img_std = np.array(cfg.STD).reshape((len(cfg.STD), 1, 1))
+ img -= img_mean
+ img /= img_std
+
+ return img
+
+ def process_image(self, line, data_dir, mode):
+ """ process_image """
+ img, grt, grt_instance, img_name, grt_name = self.load_image(
+ line, data_dir, mode=mode)
+ if mode == ModelPhase.TRAIN:
+ img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode)
+ if cfg.AUG.RICH_CROP.ENABLE:
+ if cfg.AUG.RICH_CROP.BLUR:
+ if cfg.AUG.RICH_CROP.BLUR_RATIO <= 0:
+ n = 0
+ elif cfg.AUG.RICH_CROP.BLUR_RATIO >= 1:
+ n = 1
+ else:
+ n = int(1.0 / cfg.AUG.RICH_CROP.BLUR_RATIO)
+ if n > 0:
+ if np.random.randint(0, n) == 0:
+ radius = np.random.randint(3, 10)
+ if radius % 2 != 1:
+ radius = radius + 1
+ if radius > 9:
+ radius = 9
+ img = cv2.GaussianBlur(img, (radius, radius), 0, 0)
+
+ img, grt = aug.random_rotation(
+ img,
+ grt,
+ rich_crop_max_rotation=cfg.AUG.RICH_CROP.MAX_ROTATION,
+ mean_value=cfg.DATASET.PADDING_VALUE)
+
+ img, grt = aug.rand_scale_aspect(
+ img,
+ grt,
+ rich_crop_min_scale=cfg.AUG.RICH_CROP.MIN_AREA_RATIO,
+ rich_crop_aspect_ratio=cfg.AUG.RICH_CROP.ASPECT_RATIO)
+ img = aug.hsv_color_jitter(
+ img,
+ brightness_jitter_ratio=cfg.AUG.RICH_CROP.
+ BRIGHTNESS_JITTER_RATIO,
+ saturation_jitter_ratio=cfg.AUG.RICH_CROP.
+ SATURATION_JITTER_RATIO,
+ contrast_jitter_ratio=cfg.AUG.RICH_CROP.
+ CONTRAST_JITTER_RATIO)
+
+ if cfg.AUG.FLIP:
+ if cfg.AUG.FLIP_RATIO <= 0:
+ n = 0
+ elif cfg.AUG.FLIP_RATIO >= 1:
+ n = 1
+ else:
+ n = int(1.0 / cfg.AUG.FLIP_RATIO)
+ if n > 0:
+ if np.random.randint(0, n) == 0:
+ img = img[::-1, :, :]
+ grt = grt[::-1, :]
+
+ if cfg.AUG.MIRROR:
+ if np.random.randint(0, 2) == 1:
+ img = img[:, ::-1, :]
+ grt = grt[:, ::-1]
+
+ img, grt = aug.rand_crop(img, grt, mode=mode)
+ elif ModelPhase.is_eval(mode):
+ img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
+ elif ModelPhase.is_visual(mode):
+ ori_img = img.copy()
+ img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
+ valid_shape = [img.shape[0], img.shape[1]]
+ else:
+ raise ValueError("Dataset mode={} Error!".format(mode))
+
+ # Normalize image
+ img = self.normalize_image(img)
+
+ if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode):
+ grt = np.expand_dims(np.array(grt).astype('int32'), axis=0)
+ ignore = (grt != cfg.DATASET.IGNORE_INDEX).astype('int32')
+ if ModelPhase.is_train(mode):
+ return (img, grt, grt_instance, ignore)
+ elif ModelPhase.is_eval(mode):
+ return (img, grt, grt_instance, ignore)
+ elif ModelPhase.is_visual(mode):
+ return (img, grt, grt_instance, img_name, valid_shape, ori_img)
diff --git a/contrib/LaneNet/requirements.txt b/contrib/LaneNet/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2b5eb8643803e1177297d2a766227e274dcdc29d
--- /dev/null
+++ b/contrib/LaneNet/requirements.txt
@@ -0,0 +1,13 @@
+pre-commit
+yapf == 0.26.0
+flake8
+pyyaml >= 5.1
+tb-paddle
+tensorboard >= 1.15.0
+Pillow
+numpy
+six
+opencv-python
+tqdm
+requests
+sklearn
diff --git a/contrib/LaneNet/train.py b/contrib/LaneNet/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ee9489c9b18b19b6b84615a400815a3bc33ccb2
--- /dev/null
+++ b/contrib/LaneNet/train.py
@@ -0,0 +1,470 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+# GPU memory garbage collection optimization flags
+os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
+
+import sys
+
+cur_path = os.path.abspath(os.path.dirname(__file__))
+root_path = os.path.split(os.path.split(cur_path)[0])[0]
+SEG_PATH = os.path.join(cur_path, "../../../")
+sys.path.append(SEG_PATH)
+sys.path.append(root_path)
+
+import argparse
+import pprint
+
+import numpy as np
+import paddle.fluid as fluid
+
+from utils.config import cfg
+from pdseg.utils.timer import Timer, calculate_eta
+from reader import LaneNetDataset
+from models.model_builder import build_model
+from models.model_builder import ModelPhase
+from models.model_builder import parse_shape_from_file
+from eval import evaluate
+from vis import visualize
+from utils import dist_utils
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='PaddleSeg training')
+ parser.add_argument(
+ '--cfg',
+ dest='cfg_file',
+ help='Config file for training (and optionally testing)',
+ default=None,
+ type=str)
+ parser.add_argument(
+ '--use_gpu',
+ dest='use_gpu',
+ help='Use gpu or cpu',
+ action='store_true',
+ default=False)
+ parser.add_argument(
+ '--use_mpio',
+ dest='use_mpio',
+ help='Use multiprocess I/O or not',
+ action='store_true',
+ default=False)
+ parser.add_argument(
+ '--log_steps',
+ dest='log_steps',
+ help='Display logging information at every log_steps',
+ default=10,
+ type=int)
+ parser.add_argument(
+ '--debug',
+ dest='debug',
+ help='debug mode, display detail information of training',
+ action='store_true')
+ parser.add_argument(
+ '--use_tb',
+ dest='use_tb',
+ help='whether to record the data during training to Tensorboard',
+ action='store_true')
+ parser.add_argument(
+ '--tb_log_dir',
+ dest='tb_log_dir',
+ help='Tensorboard logging directory',
+ default=None,
+ type=str)
+ parser.add_argument(
+ '--do_eval',
+ dest='do_eval',
+ help='Evaluation models result on every new checkpoint',
+ action='store_true')
+ parser.add_argument(
+ 'opts',
+ help='See utils/config.py for all options',
+ default=None,
+ nargs=argparse.REMAINDER)
+ return parser.parse_args()
+
+
+def save_vars(executor, dirname, program=None, vars=None):
+ """
+ Temporary resolution for Win save variables compatability.
+ Will fix in PaddlePaddle v1.5.2
+ """
+
+ save_program = fluid.Program()
+ save_block = save_program.global_block()
+
+ for each_var in vars:
+ # NOTE: don't save the variable which type is RAW
+ if each_var.type == fluid.core.VarDesc.VarType.RAW:
+ continue
+ new_var = save_block.create_var(
+ name=each_var.name,
+ shape=each_var.shape,
+ dtype=each_var.dtype,
+ type=each_var.type,
+ lod_level=each_var.lod_level,
+ persistable=True)
+ file_path = os.path.join(dirname, new_var.name)
+ file_path = os.path.normpath(file_path)
+ save_block.append_op(
+ type='save',
+ inputs={'X': [new_var]},
+ outputs={},
+ attrs={'file_path': file_path})
+
+ executor.run(save_program)
+
+
+def save_checkpoint(exe, program, ckpt_name):
+ """
+ Save checkpoint for evaluation or resume training
+ """
+ ckpt_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, str(ckpt_name))
+ print("Save model checkpoint to {}".format(ckpt_dir))
+ if not os.path.isdir(ckpt_dir):
+ os.makedirs(ckpt_dir)
+
+ save_vars(
+ exe,
+ ckpt_dir,
+ program,
+ vars=list(filter(fluid.io.is_persistable, program.list_vars())))
+
+ return ckpt_dir
+
+
+def load_checkpoint(exe, program):
+ """
+ Load checkpoiont from pretrained model directory for resume training
+ """
+
+ print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
+ if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
+ raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
+ cfg.TRAIN.RESUME_MODEL_DIR))
+
+ fluid.io.load_persistables(
+ exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
+
+ model_path = cfg.TRAIN.RESUME_MODEL_DIR
+ # Check is path ended by path spearator
+ if model_path[-1] == os.sep:
+ model_path = model_path[0:-1]
+ epoch_name = os.path.basename(model_path)
+ # If resume model is final model
+ if epoch_name == 'final':
+ begin_epoch = cfg.SOLVER.NUM_EPOCHS
+ # If resume model path is end of digit, restore epoch status
+ elif epoch_name.isdigit():
+ epoch = int(epoch_name)
+ begin_epoch = epoch + 1
+ else:
+ raise ValueError("Resume model path is not valid!")
+ print("Model checkpoint loaded successfully!")
+
+ return begin_epoch
+
+
+def print_info(*msg):
+ if cfg.TRAINER_ID == 0:
+ print(*msg)
+
+
+def train(cfg):
+ startup_prog = fluid.Program()
+ train_prog = fluid.Program()
+ drop_last = True
+
+ dataset = LaneNetDataset(
+ file_list=cfg.DATASET.TRAIN_FILE_LIST,
+ mode=ModelPhase.TRAIN,
+ shuffle=True,
+ data_dir=cfg.DATASET.DATA_DIR)
+
+ def data_generator():
+ if args.use_mpio:
+ data_gen = dataset.multiprocess_generator(
+ num_processes=cfg.DATALOADER.NUM_WORKERS,
+ max_queue_size=cfg.DATALOADER.BUF_SIZE)
+ else:
+ data_gen = dataset.generator()
+
+ batch_data = []
+ for b in data_gen:
+ batch_data.append(b)
+ if len(batch_data) == (cfg.BATCH_SIZE // cfg.NUM_TRAINERS):
+ for item in batch_data:
+ yield item
+ batch_data = []
+
+ # Get device environment
+ gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
+ place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
+ places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
+
+ # Get number of GPU
+ dev_count = cfg.NUM_TRAINERS if cfg.NUM_TRAINERS > 1 else len(places)
+ print_info("#Device count: {}".format(dev_count))
+
+ # Make sure BATCH_SIZE can divided by GPU cards
+ assert cfg.BATCH_SIZE % dev_count == 0, (
+ 'BATCH_SIZE:{} not divisble by number of GPUs:{}'.format(
+ cfg.BATCH_SIZE, dev_count))
+ # If use multi-gpu training mode, batch data will allocated to each GPU evenly
+ batch_size_per_dev = cfg.BATCH_SIZE // dev_count
+ cfg.BATCH_SIZE_PER_DEV = batch_size_per_dev
+ print_info("batch_size_per_dev: {}".format(batch_size_per_dev))
+
+ py_reader, avg_loss, lr, pred, grts, masks, emb_loss, seg_loss, accuracy, fp, fn = build_model(
+ train_prog, startup_prog, phase=ModelPhase.TRAIN)
+ py_reader.decorate_sample_generator(
+ data_generator, batch_size=batch_size_per_dev, drop_last=drop_last)
+
+ exe = fluid.Executor(place)
+ exe.run(startup_prog)
+
+ exec_strategy = fluid.ExecutionStrategy()
+ # Clear temporary variables every 100 iteration
+ if args.use_gpu:
+ exec_strategy.num_threads = fluid.core.get_cuda_device_count()
+ exec_strategy.num_iteration_per_drop_scope = 100
+ build_strategy = fluid.BuildStrategy()
+
+ if cfg.NUM_TRAINERS > 1 and args.use_gpu:
+ dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
+ exec_strategy.num_threads = 1
+
+ if cfg.TRAIN.SYNC_BATCH_NORM and args.use_gpu:
+ if dev_count > 1:
+ # Apply sync batch norm strategy
+ print_info("Sync BatchNorm strategy is effective.")
+ build_strategy.sync_batch_norm = True
+ else:
+ print_info(
+ "Sync BatchNorm strategy will not be effective if GPU device"
+ " count <= 1")
+ compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
+ loss_name=avg_loss.name,
+ exec_strategy=exec_strategy,
+ build_strategy=build_strategy)
+
+ # Resume training
+ begin_epoch = cfg.SOLVER.BEGIN_EPOCH
+ if cfg.TRAIN.RESUME_MODEL_DIR:
+ begin_epoch = load_checkpoint(exe, train_prog)
+ # Load pretrained model
+ elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
+ print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR)
+ load_vars = []
+ load_fail_vars = []
+
+ def var_shape_matched(var, shape):
+ """
+ Check whehter persitable variable shape is match with current network
+ """
+ var_exist = os.path.exists(
+ os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
+ if var_exist:
+ var_shape = parse_shape_from_file(
+ os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
+ if var_shape != shape:
+ print(var.name, var_shape, shape)
+ return var_shape == shape
+ return False
+
+ for x in train_prog.list_vars():
+ if isinstance(x, fluid.framework.Parameter):
+ shape = tuple(fluid.global_scope().find_var(
+ x.name).get_tensor().shape())
+ if var_shape_matched(x, shape):
+ load_vars.append(x)
+ else:
+ load_fail_vars.append(x)
+
+ fluid.io.load_vars(
+ exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
+ for var in load_vars:
+ print_info("Parameter[{}] loaded sucessfully!".format(var.name))
+ for var in load_fail_vars:
+ print_info(
+ "Parameter[{}] don't exist or shape does not match current network, skip"
+ " to load it.".format(var.name))
+ print_info("{}/{} pretrained parameters loaded successfully!".format(
+ len(load_vars),
+ len(load_vars) + len(load_fail_vars)))
+ else:
+ print_info(
+ 'Pretrained model dir {} not exists, training from scratch...'.
+ format(cfg.TRAIN.PRETRAINED_MODEL_DIR))
+
+ # fetch_list = [avg_loss.name, lr.name, accuracy.name, precision.name, recall.name]
+ fetch_list = [avg_loss.name, lr.name, seg_loss.name, emb_loss.name, accuracy.name, fp.name, fn.name]
+ if args.debug:
+ # Fetch more variable info and use streaming confusion matrix to
+ # calculate IoU results if in debug mode
+ np.set_printoptions(
+ precision=4, suppress=True, linewidth=160, floatmode="fixed")
+ fetch_list.extend([pred.name, grts.name, masks.name])
+ # cm = ConfusionMatrix(cfg.DATASET.NUM_CLASSES, streaming=True)
+
+ if args.use_tb:
+ if not args.tb_log_dir:
+ print_info("Please specify the log directory by --tb_log_dir.")
+ exit(1)
+
+ from tb_paddle import SummaryWriter
+ log_writer = SummaryWriter(args.tb_log_dir)
+
+ # trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
+ # num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
+ global_step = 0
+ all_step = cfg.DATASET.TRAIN_TOTAL_IMAGES // cfg.BATCH_SIZE
+ if cfg.DATASET.TRAIN_TOTAL_IMAGES % cfg.BATCH_SIZE and drop_last != True:
+ all_step += 1
+ all_step *= (cfg.SOLVER.NUM_EPOCHS - begin_epoch + 1)
+
+ avg_loss = 0.0
+ avg_seg_loss = 0.0
+ avg_emb_loss = 0.0
+ avg_acc = 0.0
+ avg_fp = 0.0
+ avg_fn = 0.0
+ timer = Timer()
+ timer.start()
+ if begin_epoch > cfg.SOLVER.NUM_EPOCHS:
+ raise ValueError(
+ ("begin epoch[{}] is larger than cfg.SOLVER.NUM_EPOCHS[{}]").format(
+ begin_epoch, cfg.SOLVER.NUM_EPOCHS))
+
+ if args.use_mpio:
+ print_info("Use multiprocess reader")
+ else:
+ print_info("Use multi-thread reader")
+
+ for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1):
+ py_reader.start()
+ while True:
+ try:
+ # If not in debug mode, avoid unnessary log and calculate
+ loss, lr, out_seg_loss, out_emb_loss, out_acc, out_fp, out_fn = exe.run(
+ program=compiled_train_prog,
+ fetch_list=fetch_list,
+ return_numpy=True)
+
+ avg_loss += np.mean(np.array(loss))
+ avg_seg_loss += np.mean(np.array(out_seg_loss))
+ avg_emb_loss += np.mean(np.array(out_emb_loss))
+ avg_acc += np.mean(out_acc)
+ avg_fp += np.mean(out_fp)
+ avg_fn += np.mean(out_fn)
+ global_step += 1
+
+ if global_step % args.log_steps == 0 and cfg.TRAINER_ID == 0:
+ avg_loss /= args.log_steps
+ avg_seg_loss /= args.log_steps
+ avg_emb_loss /= args.log_steps
+ avg_acc /= args.log_steps
+ avg_fp /= args.log_steps
+ avg_fn /= args.log_steps
+ speed = args.log_steps / timer.elapsed_time()
+ print((
+ "epoch={} step={} lr={:.5f} loss={:.4f} seg_loss={:.4f} emb_loss={:.4f} accuracy={:.4} fp={:.4} fn={:.4} step/sec={:.3f} | ETA {}"
+ ).format(epoch, global_step, lr[0], avg_loss, avg_seg_loss, avg_emb_loss, avg_acc, avg_fp, avg_fn, speed,
+ calculate_eta(all_step - global_step, speed)))
+ if args.use_tb:
+ log_writer.add_scalar('Train/loss', avg_loss,
+ global_step)
+ log_writer.add_scalar('Train/lr', lr[0],
+ global_step)
+ log_writer.add_scalar('Train/speed', speed,
+ global_step)
+ sys.stdout.flush()
+ avg_loss = 0.0
+ avg_seg_loss = 0.0
+ avg_emb_loss = 0.0
+ avg_acc = 0.0
+ avg_fp = 0.0
+ avg_fn = 0.0
+ timer.restart()
+
+ except fluid.core.EOFException:
+ py_reader.reset()
+ break
+ except Exception as e:
+ print(e)
+
+ if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 and cfg.TRAINER_ID == 0:
+ ckpt_dir = save_checkpoint(exe, train_prog, epoch)
+
+ if args.do_eval:
+ print("Evaluation start")
+ accuracy, fp, fn = evaluate(
+ cfg=cfg,
+ ckpt_dir=ckpt_dir,
+ use_gpu=args.use_gpu,
+ use_mpio=args.use_mpio)
+ if args.use_tb:
+ log_writer.add_scalar('Evaluate/accuracy', accuracy,
+ global_step)
+ log_writer.add_scalar('Evaluate/fp', fp,
+ global_step)
+ log_writer.add_scalar('Evaluate/fn', fn,
+ global_step)
+
+ # Use Tensorboard to visualize results
+ if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None:
+ visualize(
+ cfg=cfg,
+ use_gpu=args.use_gpu,
+ vis_file_list=cfg.DATASET.VIS_FILE_LIST,
+ vis_dir="visual",
+ ckpt_dir=ckpt_dir,
+ log_writer=log_writer)
+
+ # save final model
+ if cfg.TRAINER_ID == 0:
+ save_checkpoint(exe, train_prog, 'final')
+
+
+def main(args):
+ if args.cfg_file is not None:
+ cfg.update_from_file(args.cfg_file)
+ if args.opts:
+ cfg.update_from_list(args.opts)
+
+ cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
+ cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
+
+ cfg.check_and_infer()
+ print_info(pprint.pformat(cfg))
+ train(cfg)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ if fluid.core.is_compiled_with_cuda() != True and args.use_gpu == True:
+ print(
+ "You can not set use_gpu = True in the model because you are using paddlepaddle-cpu."
+ )
+ print(
+ "Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_gpu=False to run models on CPU."
+ )
+ sys.exit(1)
+ main(args)
diff --git a/contrib/LaneNet/utils/__init__.py b/contrib/LaneNet/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/contrib/LaneNet/utils/config.py b/contrib/LaneNet/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1186636c7d2b8004756bdfbaaca74aa47d32b7f
--- /dev/null
+++ b/contrib/LaneNet/utils/config.py
@@ -0,0 +1,233 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import os
+import sys
+
+# LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
+# PDSEG_PATH = os.path.join(LOCAL_PATH, "../../../", "pdseg")
+# print(PDSEG_PATH)
+# sys.path.insert(0, PDSEG_PATH)
+# print(sys.path)
+
+from pdseg.utils.collect import SegConfig
+import numpy as np
+
+cfg = SegConfig()
+
+########################## 基本配置 ###########################################
+# 均值,图像预处理减去的均值
+cfg.MEAN = [0.5, 0.5, 0.5]
+# 标准差,图像预处理除以标准差·
+cfg.STD = [0.5, 0.5, 0.5]
+# 批处理大小
+cfg.BATCH_SIZE = 1
+# 验证时图像裁剪尺寸(宽,高)
+cfg.EVAL_CROP_SIZE = tuple()
+# 训练时图像裁剪尺寸(宽,高)
+cfg.TRAIN_CROP_SIZE = tuple()
+# 多进程训练总进程数
+cfg.NUM_TRAINERS = 1
+# 多进程训练进程ID
+cfg.TRAINER_ID = 0
+# 每张gpu上的批大小,无需设置,程序会自动根据batch调整
+cfg.BATCH_SIZE_PER_DEV = 1
+########################## 数据载入配置 #######################################
+# 数据载入时的并发数, 建议值8
+cfg.DATALOADER.NUM_WORKERS = 8
+# 数据载入时缓存队列大小, 建议值256
+cfg.DATALOADER.BUF_SIZE = 256
+
+########################## 数据集配置 #########################################
+# 数据主目录目录
+cfg.DATASET.DATA_DIR = './dataset/cityscapes/'
+# 训练集列表
+cfg.DATASET.TRAIN_FILE_LIST = './dataset/cityscapes/train.list'
+# 训练集数量
+cfg.DATASET.TRAIN_TOTAL_IMAGES = 2975
+# 验证集列表
+cfg.DATASET.VAL_FILE_LIST = './dataset/cityscapes/val.list'
+# 验证数据数量
+cfg.DATASET.VAL_TOTAL_IMAGES = 500
+# 测试数据列表
+cfg.DATASET.TEST_FILE_LIST = './dataset/cityscapes/test.list'
+# 测试数据数量
+cfg.DATASET.TEST_TOTAL_IMAGES = 500
+# Tensorboard 可视化的数据集
+cfg.DATASET.VIS_FILE_LIST = None
+# 类别数(需包括背景类)
+cfg.DATASET.NUM_CLASSES = 19
+# 输入图像类型, 支持三通道'rgb',四通道'rgba',单通道灰度图'gray'
+cfg.DATASET.IMAGE_TYPE = 'rgb'
+# 输入图片的通道数
+cfg.DATASET.DATA_DIM = 3
+# 数据列表分割符, 默认为空格
+cfg.DATASET.SEPARATOR = ' '
+# 忽略的像素标签值, 默认为255,一般无需改动
+cfg.DATASET.IGNORE_INDEX = 255
+# 数据增强是图像的padding值
+cfg.DATASET.PADDING_VALUE = [127.5,127.5,127.5]
+
+########################### 数据增强配置 ######################################
+# 图像镜像左右翻转
+cfg.AUG.MIRROR = True
+# 图像上下翻转开关,True/False
+cfg.AUG.FLIP = False
+# 图像启动上下翻转的概率,0-1
+cfg.AUG.FLIP_RATIO = 0.5
+# 图像resize的固定尺寸(宽,高),非负
+cfg.AUG.FIX_RESIZE_SIZE = tuple()
+# 图像resize的方式有三种:
+# unpadding(固定尺寸),stepscaling(按比例resize),rangescaling(长边对齐)
+cfg.AUG.AUG_METHOD = 'rangescaling'
+# 图像resize方式为stepscaling,resize最小尺度,非负
+cfg.AUG.MIN_SCALE_FACTOR = 0.5
+# 图像resize方式为stepscaling,resize最大尺度,不小于MIN_SCALE_FACTOR
+cfg.AUG.MAX_SCALE_FACTOR = 2.0
+# 图像resize方式为stepscaling,resize尺度范围间隔,非负
+cfg.AUG.SCALE_STEP_SIZE = 0.25
+# 图像resize方式为rangescaling,训练时长边resize的范围最小值,非负
+cfg.AUG.MIN_RESIZE_VALUE = 400
+# 图像resize方式为rangescaling,训练时长边resize的范围最大值,
+# 不小于MIN_RESIZE_VALUE
+cfg.AUG.MAX_RESIZE_VALUE = 600
+# 图像resize方式为rangescaling, 测试验证可视化模式下长边resize的长度,
+# 在MIN_RESIZE_VALUE到MAX_RESIZE_VALUE范围内
+cfg.AUG.INF_RESIZE_VALUE = 500
+
+# RichCrop数据增广开关,用于提升模型鲁棒性
+cfg.AUG.RICH_CROP.ENABLE = False
+# 图像旋转最大角度,0-90
+cfg.AUG.RICH_CROP.MAX_ROTATION = 15
+# 裁取图像与原始图像面积比,0-1
+cfg.AUG.RICH_CROP.MIN_AREA_RATIO = 0.5
+# 裁取图像宽高比范围,非负
+cfg.AUG.RICH_CROP.ASPECT_RATIO = 0.33
+# 亮度调节范围,0-1
+cfg.AUG.RICH_CROP.BRIGHTNESS_JITTER_RATIO = 0.5
+# 饱和度调节范围,0-1
+cfg.AUG.RICH_CROP.SATURATION_JITTER_RATIO = 0.5
+# 对比度调节范围,0-1
+cfg.AUG.RICH_CROP.CONTRAST_JITTER_RATIO = 0.5
+# 图像模糊开关,True/False
+cfg.AUG.RICH_CROP.BLUR = False
+# 图像启动模糊百分比,0-1
+cfg.AUG.RICH_CROP.BLUR_RATIO = 0.1
+
+########################### 训练配置 ##########################################
+# 模型保存路径
+cfg.TRAIN.MODEL_SAVE_DIR = ''
+# 预训练模型路径
+cfg.TRAIN.PRETRAINED_MODEL_DIR = ''
+# 是否resume,继续训练
+cfg.TRAIN.RESUME_MODEL_DIR = ''
+# 是否使用多卡间同步BatchNorm均值和方差
+cfg.TRAIN.SYNC_BATCH_NORM = False
+# 模型参数保存的epoch间隔数,可用来继续训练中断的模型
+cfg.TRAIN.SNAPSHOT_EPOCH = 10
+
+########################### 模型优化相关配置 ##################################
+# 初始学习率
+cfg.SOLVER.LR = 0.1
+# 学习率下降方法, 支持poly piecewise cosine 三种
+cfg.SOLVER.LR_POLICY = "poly"
+# 优化算法, 支持SGD和Adam两种算法
+cfg.SOLVER.OPTIMIZER = "sgd"
+# 动量参数
+cfg.SOLVER.MOMENTUM = 0.9
+# 二阶矩估计的指数衰减率
+cfg.SOLVER.MOMENTUM2 = 0.999
+# 学习率Poly下降指数
+cfg.SOLVER.POWER = 0.9
+# step下降指数
+cfg.SOLVER.GAMMA = 0.1
+# step下降间隔
+cfg.SOLVER.DECAY_EPOCH = [10, 20]
+# 学习率权重衰减,0-1
+cfg.SOLVER.WEIGHT_DECAY = 0.00004
+# 训练开始epoch数,默认为1
+cfg.SOLVER.BEGIN_EPOCH = 1
+# 训练epoch数,正整数
+cfg.SOLVER.NUM_EPOCHS = 30
+# loss的选择,支持softmax_loss, bce_loss, dice_loss
+cfg.SOLVER.LOSS = ["softmax_loss"]
+# cross entropy weight, 默认为None,如果设置为'dynamic',会根据每个batch中各个类别的数目,
+# 动态调整类别权重。
+# 也可以设置一个静态权重(list的方式),比如有3类,每个类别权重可以设置为[0.1, 2.0, 0.9]
+cfg.SOLVER.CROSS_ENTROPY_WEIGHT = None
+########################## 测试配置 ###########################################
+# 测试模型路径
+cfg.TEST.TEST_MODEL = ''
+
+########################## 模型通用配置 #######################################
+# 模型名称, 支持deeplab, unet, icnet三种
+cfg.MODEL.MODEL_NAME = ''
+# BatchNorm类型: bn、gn(group_norm)
+cfg.MODEL.DEFAULT_NORM_TYPE = 'bn'
+# 多路损失加权值
+cfg.MODEL.MULTI_LOSS_WEIGHT = [1.0]
+# DEFAULT_NORM_TYPE为gn时group数
+cfg.MODEL.DEFAULT_GROUP_NUMBER = 32
+# 极小值, 防止分母除0溢出,一般无需改动
+cfg.MODEL.DEFAULT_EPSILON = 1e-5
+# BatchNorm动量, 一般无需改动
+cfg.MODEL.BN_MOMENTUM = 0.99
+# 是否使用FP16训练
+cfg.MODEL.FP16 = False
+# 混合精度训练需对LOSS进行scale, 默认为动态scale,静态scale可以设置为512.0
+cfg.MODEL.SCALE_LOSS = "DYNAMIC"
+
+########################## DeepLab模型配置 ####################################
+# DeepLab backbone 配置, 可选项xception_65, mobilenetv2
+cfg.MODEL.DEEPLAB.BACKBONE = "xception_65"
+# DeepLab output stride
+cfg.MODEL.DEEPLAB.OUTPUT_STRIDE = 16
+# MobileNet backbone scale 设置
+cfg.MODEL.DEEPLAB.DEPTH_MULTIPLIER = 1.0
+# MobileNet backbone scale 设置
+cfg.MODEL.DEEPLAB.ENCODER_WITH_ASPP = True
+# MobileNet backbone scale 设置
+cfg.MODEL.DEEPLAB.ENABLE_DECODER = True
+# ASPP是否使用可分离卷积
+cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV = True
+# 解码器是否使用可分离卷积
+cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV = True
+
+########################## UNET模型配置 #######################################
+# 上采样方式, 默认为双线性插值
+cfg.MODEL.UNET.UPSAMPLE_MODE = 'bilinear'
+
+########################## ICNET模型配置 ######################################
+# RESNET backbone scale 设置
+cfg.MODEL.ICNET.DEPTH_MULTIPLIER = 0.5
+# RESNET 层数 设置
+cfg.MODEL.ICNET.LAYERS = 50
+
+########################## PSPNET模型配置 ######################################
+# Lannet backbone name
+cfg.MODEL.LANENET.BACKBONE = "vgg"
+
+########################## LaneNet模型配置 ######################################
+
+########################## 预测部署模型配置 ###################################
+# 预测保存的模型名称
+cfg.FREEZE.MODEL_FILENAME = '__model__'
+# 预测保存的参数名称
+cfg.FREEZE.PARAMS_FILENAME = '__params__'
+# 预测模型参数保存的路径
+cfg.FREEZE.SAVE_DIR = 'freeze_model'
diff --git a/contrib/LaneNet/utils/dist_utils.py b/contrib/LaneNet/utils/dist_utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..64c8800fd2010d4e1e5def6cc4ea2e1ad673b4a3
--- /dev/null
+++ b/contrib/LaneNet/utils/dist_utils.py
@@ -0,0 +1,92 @@
+#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+import paddle.fluid as fluid
+
+
+def nccl2_prepare(args, startup_prog, main_prog):
+ config = fluid.DistributeTranspilerConfig()
+ config.mode = "nccl2"
+ t = fluid.DistributeTranspiler(config=config)
+
+ envs = args.dist_env
+
+ t.transpile(
+ envs["trainer_id"],
+ trainers=','.join(envs["trainer_endpoints"]),
+ current_endpoint=envs["current_endpoint"],
+ startup_program=startup_prog,
+ program=main_prog)
+
+
+def pserver_prepare(args, train_prog, startup_prog):
+ config = fluid.DistributeTranspilerConfig()
+ config.slice_var_up = args.split_var
+ t = fluid.DistributeTranspiler(config=config)
+ envs = args.dist_env
+ training_role = envs["training_role"]
+
+ t.transpile(
+ envs["trainer_id"],
+ program=train_prog,
+ pservers=envs["pserver_endpoints"],
+ trainers=envs["num_trainers"],
+ sync_mode=not args.async_mode,
+ startup_program=startup_prog)
+ if training_role == "PSERVER":
+ pserver_program = t.get_pserver_program(envs["current_endpoint"])
+ pserver_startup_program = t.get_startup_program(
+ envs["current_endpoint"],
+ pserver_program,
+ startup_program=startup_prog)
+ return pserver_program, pserver_startup_program
+ elif training_role == "TRAINER":
+ train_program = t.get_trainer_program()
+ return train_program, startup_prog
+ else:
+ raise ValueError(
+ 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
+ )
+
+
+def nccl2_prepare_paddle(trainer_id, startup_prog, main_prog):
+ config = fluid.DistributeTranspilerConfig()
+ config.mode = "nccl2"
+ t = fluid.DistributeTranspiler(config=config)
+ t.transpile(
+ trainer_id,
+ trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
+ current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
+ startup_program=startup_prog,
+ program=main_prog)
+
+
+def prepare_for_multi_process(exe, build_strategy, train_prog):
+ # prepare for multi-process
+ trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
+ num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
+ if num_trainers < 2: return
+
+ build_strategy.num_trainers = num_trainers
+ build_strategy.trainer_id = trainer_id
+ # NOTE(zcd): use multi processes to train the model,
+ # and each process use one GPU card.
+ startup_prog = fluid.Program()
+ nccl2_prepare_paddle(trainer_id, startup_prog, train_prog)
+ # the startup_prog are run two times, but it doesn't matter.
+ exe.run(startup_prog)
diff --git a/contrib/LaneNet/utils/generate_tusimple_dataset.py b/contrib/LaneNet/utils/generate_tusimple_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2a89584b6aa3eab45e0fb8516d935f13afe644
--- /dev/null
+++ b/contrib/LaneNet/utils/generate_tusimple_dataset.py
@@ -0,0 +1,165 @@
+"""
+generate tusimple training dataset
+"""
+import argparse
+import glob
+import json
+import os
+import os.path as ops
+import shutil
+
+import cv2
+import numpy as np
+
+
+def init_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--src_dir', type=str, help='The origin path of unzipped tusimple dataset')
+
+ return parser.parse_args()
+
+
+def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, instance_dst_dir):
+
+ assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
+
+ image_nums = len(os.listdir(os.path.join(src_dir, ori_dst_dir)))
+
+ with open(json_file_path, 'r') as file:
+ for line_index, line in enumerate(file):
+ info_dict = json.loads(line)
+
+ image_dir = ops.split(info_dict['raw_file'])[0]
+ image_dir_split = image_dir.split('/')[1:]
+ image_dir_split.append(ops.split(info_dict['raw_file'])[1])
+ image_name = '_'.join(image_dir_split)
+ image_path = ops.join(src_dir, info_dict['raw_file'])
+ assert ops.exists(image_path), '{:s} not exist'.format(image_path)
+
+ h_samples = info_dict['h_samples']
+ lanes = info_dict['lanes']
+
+ image_name_new = '{:s}.png'.format('{:d}'.format(line_index + image_nums).zfill(4))
+
+ src_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
+ dst_binary_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
+ dst_instance_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
+
+ for lane_index, lane in enumerate(lanes):
+ assert len(h_samples) == len(lane)
+ lane_x = []
+ lane_y = []
+ for index in range(len(lane)):
+ if lane[index] == -2:
+ continue
+ else:
+ ptx = lane[index]
+ pty = h_samples[index]
+ lane_x.append(ptx)
+ lane_y.append(pty)
+ if not lane_x:
+ continue
+ lane_pts = np.vstack((lane_x, lane_y)).transpose()
+ lane_pts = np.array([lane_pts], np.int64)
+
+ cv2.polylines(dst_binary_image, lane_pts, isClosed=False,
+ color=255, thickness=5)
+ cv2.polylines(dst_instance_image, lane_pts, isClosed=False,
+ color=lane_index * 50 + 20, thickness=5)
+
+ dst_binary_image_path = ops.join(src_dir, binary_dst_dir, image_name_new)
+ dst_instance_image_path = ops.join(src_dir, instance_dst_dir, image_name_new)
+ dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new)
+
+ cv2.imwrite(dst_binary_image_path, dst_binary_image)
+ cv2.imwrite(dst_instance_image_path, dst_instance_image)
+ cv2.imwrite(dst_rgb_image_path, src_image)
+
+ print('Process {:s} success'.format(image_name))
+
+
+def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train', split=False):
+
+ label_list = []
+ with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file:
+
+ for image_name in os.listdir(b_gt_image_dir):
+ if not image_name.endswith('.png'):
+ continue
+
+ binary_gt_image_path = ops.join(b_gt_image_dir, image_name)
+ instance_gt_image_path = ops.join(i_gt_image_dir, image_name)
+ image_path = ops.join(image_dir, image_name)
+
+ assert ops.exists(image_path), '{:s} not exist'.format(image_path)
+ assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(instance_gt_image_path)
+
+ b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR)
+ i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR)
+ image = cv2.imread(image_path, cv2.IMREAD_COLOR)
+
+ if b_gt_image is None or image is None or i_gt_image is None:
+ print('image: {:s} corrupt'.format(image_name))
+ continue
+ else:
+ info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path, instance_gt_image_path)
+ file.write(info + '\n')
+ label_list.append(info)
+ if phase == 'train' and split:
+ np.random.RandomState(0).shuffle(label_list)
+ val_list_len = len(label_list) // 10
+ val_label_list = label_list[:val_list_len]
+ train_label_list = label_list[val_list_len:]
+ with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase), 'w') as file:
+ for info in train_label_list:
+ file.write(info + '\n')
+ with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase), 'w') as file:
+ for info in val_label_list:
+ file.write(info + '\n')
+ return
+
+
+def process_tusimple_dataset(src_dir):
+
+ traing_folder_path = ops.join(src_dir, 'training')
+ testing_folder_path = ops.join(src_dir, 'testing')
+
+ os.makedirs(traing_folder_path, exist_ok=True)
+ os.makedirs(testing_folder_path, exist_ok=True)
+
+ for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
+ json_label_name = ops.split(json_label_path)[1]
+
+ shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name))
+
+ for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)):
+ json_label_name = ops.split(json_label_path)[1]
+
+ shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name))
+
+ train_gt_image_dir = ops.join('training', 'gt_image')
+ train_gt_binary_dir = ops.join('training', 'gt_binary_image')
+ train_gt_instance_dir = ops.join('training', 'gt_instance_image')
+
+ test_gt_image_dir = ops.join('testing', 'gt_image')
+ test_gt_binary_dir = ops.join('testing', 'gt_binary_image')
+ test_gt_instance_dir = ops.join('testing', 'gt_instance_image')
+
+ os.makedirs(os.path.join(src_dir, train_gt_image_dir), exist_ok=True)
+ os.makedirs(os.path.join(src_dir, train_gt_binary_dir), exist_ok=True)
+ os.makedirs(os.path.join(src_dir, train_gt_instance_dir), exist_ok=True)
+
+ os.makedirs(os.path.join(src_dir, test_gt_image_dir), exist_ok=True)
+ os.makedirs(os.path.join(src_dir, test_gt_binary_dir), exist_ok=True)
+ os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True)
+
+ for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
+ process_json_file(json_label_path, src_dir, train_gt_image_dir, train_gt_binary_dir, train_gt_instance_dir)
+
+ gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir, train_gt_image_dir, 'train', True)
+
+
+if __name__ == '__main__':
+ args = init_args()
+
+ process_tusimple_dataset(args.src_dir)
diff --git a/contrib/LaneNet/utils/lanenet_postprocess.py b/contrib/LaneNet/utils/lanenet_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..21230279f7c4b4e1042e51f7583a3a13d4ebc5d7
--- /dev/null
+++ b/contrib/LaneNet/utils/lanenet_postprocess.py
@@ -0,0 +1,376 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
+"""
+LaneNet model post process
+"""
+import os.path as ops
+import math
+
+import cv2
+import time
+import numpy as np
+from sklearn.cluster import DBSCAN
+from sklearn.preprocessing import StandardScaler
+
+
+def _morphological_process(image, kernel_size=5):
+ """
+ morphological process to fill the hole in the binary segmentation result
+ :param image:
+ :param kernel_size:
+ :return:
+ """
+ if len(image.shape) == 3:
+ raise ValueError('Binary segmentation result image should be a single channel image')
+
+ if image.dtype is not np.uint8:
+ image = np.array(image, np.uint8)
+
+ kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
+
+ # close operation fille hole
+ closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
+
+ return closing
+
+
+def _connect_components_analysis(image):
+ """
+ connect components analysis to remove the small components
+ :param image:
+ :return:
+ """
+ if len(image.shape) == 3:
+ gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+ else:
+ gray_image = image
+
+ return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S)
+
+
+class _LaneFeat(object):
+ """
+
+ """
+ def __init__(self, feat, coord, class_id=-1):
+ """
+ lane feat object
+ :param feat: lane embeddng feats [feature_1, feature_2, ...]
+ :param coord: lane coordinates [x, y]
+ :param class_id: lane class id
+ """
+ self._feat = feat
+ self._coord = coord
+ self._class_id = class_id
+
+ @property
+ def feat(self):
+ return self._feat
+
+ @feat.setter
+ def feat(self, value):
+ if not isinstance(value, np.ndarray):
+ value = np.array(value, dtype=np.float64)
+
+ if value.dtype != np.float32:
+ value = np.array(value, dtype=np.float64)
+
+ self._feat = value
+
+ @property
+ def coord(self):
+ return self._coord
+
+ @coord.setter
+ def coord(self, value):
+ if not isinstance(value, np.ndarray):
+ value = np.array(value)
+
+ if value.dtype != np.int32:
+ value = np.array(value, dtype=np.int32)
+
+ self._coord = value
+
+ @property
+ def class_id(self):
+ return self._class_id
+
+ @class_id.setter
+ def class_id(self, value):
+ if not isinstance(value, np.int64):
+ raise ValueError('Class id must be integer')
+
+ self._class_id = value
+
+
+class _LaneNetCluster(object):
+ """
+ Instance segmentation result cluster
+ """
+ def __init__(self):
+ """
+
+ """
+ self._color_map = [np.array([255, 0, 0]),
+ np.array([0, 255, 0]),
+ np.array([0, 0, 255]),
+ np.array([125, 125, 0]),
+ np.array([0, 125, 125]),
+ np.array([125, 0, 125]),
+ np.array([50, 100, 50]),
+ np.array([100, 50, 100])]
+
+ @staticmethod
+ def _embedding_feats_dbscan_cluster(embedding_image_feats):
+ """
+ dbscan cluster
+ """
+ db = DBSCAN(eps=0.4, min_samples=500)
+
+ try:
+ features = StandardScaler().fit_transform(embedding_image_feats)
+ db.fit(features)
+ except Exception as err:
+ print(err)
+ ret = {
+ 'origin_features': None,
+ 'cluster_nums': 0,
+ 'db_labels': None,
+ 'unique_labels': None,
+ 'cluster_center': None
+ }
+ return ret
+ db_labels = db.labels_
+ unique_labels = np.unique(db_labels)
+ num_clusters = len(unique_labels)
+ cluster_centers = db.components_
+
+ ret = {
+ 'origin_features': features,
+ 'cluster_nums': num_clusters,
+ 'db_labels': db_labels,
+ 'unique_labels': unique_labels,
+ 'cluster_center': cluster_centers
+ }
+
+ return ret
+
+ @staticmethod
+ def _get_lane_embedding_feats(binary_seg_ret, instance_seg_ret):
+ """
+ get lane embedding features according the binary seg result
+ """
+
+ idx = np.where(binary_seg_ret == 255)
+ lane_embedding_feats = instance_seg_ret[idx]
+
+ lane_coordinate = np.vstack((idx[1], idx[0])).transpose()
+
+ assert lane_embedding_feats.shape[0] == lane_coordinate.shape[0]
+
+ ret = {
+ 'lane_embedding_feats': lane_embedding_feats,
+ 'lane_coordinates': lane_coordinate
+ }
+
+ return ret
+
+ def apply_lane_feats_cluster(self, binary_seg_result, instance_seg_result):
+ """
+
+ :param binary_seg_result:
+ :param instance_seg_result:
+ :return:
+ """
+ # get embedding feats and coords
+ get_lane_embedding_feats_result = self._get_lane_embedding_feats(
+ binary_seg_ret=binary_seg_result,
+ instance_seg_ret=instance_seg_result
+ )
+
+ # dbscan cluster
+ dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
+ embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats']
+ )
+
+ mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8)
+ db_labels = dbscan_cluster_result['db_labels']
+ unique_labels = dbscan_cluster_result['unique_labels']
+ coord = get_lane_embedding_feats_result['lane_coordinates']
+
+ if db_labels is None:
+ return None, None
+
+ lane_coords = []
+
+ for index, label in enumerate(unique_labels.tolist()):
+ if label == -1:
+ continue
+ idx = np.where(db_labels == label)
+ pix_coord_idx = tuple((coord[idx][:, 1], coord[idx][:, 0]))
+ mask[pix_coord_idx] = self._color_map[index]
+ lane_coords.append(coord[idx])
+
+ return mask, lane_coords
+
+
+class LaneNetPostProcessor(object):
+ """
+ lanenet post process for lane generation
+ """
+ def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'):
+ """
+ convert front car view to bird view
+ """
+ assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path)
+
+ self._cluster = _LaneNetCluster()
+ self._ipm_remap_file_path = ipm_remap_file_path
+
+ remap_file_load_ret = self._load_remap_matrix()
+ self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
+ self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
+
+ self._color_map = [np.array([255, 0, 0]),
+ np.array([0, 255, 0]),
+ np.array([0, 0, 255]),
+ np.array([125, 125, 0]),
+ np.array([0, 125, 125]),
+ np.array([125, 0, 125]),
+ np.array([50, 100, 50]),
+ np.array([100, 50, 100])]
+
+ def _load_remap_matrix(self):
+ fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
+
+ remap_to_ipm_x = fs.getNode('remap_ipm_x').mat()
+ remap_to_ipm_y = fs.getNode('remap_ipm_y').mat()
+
+ ret = {
+ 'remap_to_ipm_x': remap_to_ipm_x,
+ 'remap_to_ipm_y': remap_to_ipm_y,
+ }
+
+ fs.release()
+
+ return ret
+
+ def postprocess(self, binary_seg_result, instance_seg_result=None,
+ min_area_threshold=100, source_image=None,
+ data_source='tusimple'):
+
+ # convert binary_seg_result
+ binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
+ # apply image morphology operation to fill in the hold and reduce the small area
+ morphological_ret = _morphological_process(binary_seg_result, kernel_size=5)
+ connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret)
+
+ labels = connect_components_analysis_ret[1]
+ stats = connect_components_analysis_ret[2]
+ for index, stat in enumerate(stats):
+ if stat[4] <= min_area_threshold:
+ idx = np.where(labels == index)
+ morphological_ret[idx] = 0
+
+ # apply embedding features cluster
+ mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
+ binary_seg_result=morphological_ret,
+ instance_seg_result=instance_seg_result
+ )
+
+ if mask_image is None:
+ return {
+ 'mask_image': None,
+ 'fit_params': None,
+ 'source_image': None,
+ }
+
+ # lane line fit
+ fit_params = []
+ src_lane_pts = []
+ for lane_index, coords in enumerate(lane_coords):
+ if data_source == 'tusimple':
+ tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
+ tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255
+ else:
+ raise ValueError('Wrong data source now only support tusimple')
+ tmp_ipm_mask = cv2.remap(
+ tmp_mask,
+ self._remap_to_ipm_x,
+ self._remap_to_ipm_y,
+ interpolation=cv2.INTER_NEAREST
+ )
+ nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
+ nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
+
+ fit_param = np.polyfit(nonzero_y, nonzero_x, 2)
+ fit_params.append(fit_param)
+
+ [ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
+ plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
+ fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2]
+
+ lane_pts = []
+ for index in range(0, plot_y.shape[0], 5):
+ src_x = self._remap_to_ipm_x[
+ int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
+ if src_x <= 0:
+ continue
+ src_y = self._remap_to_ipm_y[
+ int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
+ src_y = src_y if src_y > 0 else 0
+
+ lane_pts.append([src_x, src_y])
+
+ src_lane_pts.append(lane_pts)
+
+ # tusimple test data sample point along y axis every 10 pixels
+ source_image_width = source_image.shape[1]
+ for index, single_lane_pts in enumerate(src_lane_pts):
+ single_lane_pt_x = np.array(single_lane_pts, dtype=np.float32)[:, 0]
+ single_lane_pt_y = np.array(single_lane_pts, dtype=np.float32)[:, 1]
+ if data_source == 'tusimple':
+ start_plot_y = 240
+ end_plot_y = 720
+ else:
+ raise ValueError('Wrong data source now only support tusimple')
+ step = int(math.floor((end_plot_y - start_plot_y) / 10))
+ for plot_y in np.linspace(start_plot_y, end_plot_y, step):
+ diff = single_lane_pt_y - plot_y
+ fake_diff_bigger_than_zero = diff.copy()
+ fake_diff_smaller_than_zero = diff.copy()
+ fake_diff_bigger_than_zero[np.where(diff <= 0)] = float('inf')
+ fake_diff_smaller_than_zero[np.where(diff > 0)] = float('-inf')
+ idx_low = np.argmax(fake_diff_smaller_than_zero)
+ idx_high = np.argmin(fake_diff_bigger_than_zero)
+
+ previous_src_pt_x = single_lane_pt_x[idx_low]
+ previous_src_pt_y = single_lane_pt_y[idx_low]
+ last_src_pt_x = single_lane_pt_x[idx_high]
+ last_src_pt_y = single_lane_pt_y[idx_high]
+
+ if previous_src_pt_y < start_plot_y or last_src_pt_y < start_plot_y or \
+ fake_diff_smaller_than_zero[idx_low] == float('-inf') or \
+ fake_diff_bigger_than_zero[idx_high] == float('inf'):
+ continue
+
+ interpolation_src_pt_x = (abs(previous_src_pt_y - plot_y) * previous_src_pt_x +
+ abs(last_src_pt_y - plot_y) * last_src_pt_x) / \
+ (abs(previous_src_pt_y - plot_y) + abs(last_src_pt_y - plot_y))
+ interpolation_src_pt_y = (abs(previous_src_pt_y - plot_y) * previous_src_pt_y +
+ abs(last_src_pt_y - plot_y) * last_src_pt_y) / \
+ (abs(previous_src_pt_y - plot_y) + abs(last_src_pt_y - plot_y))
+
+ if interpolation_src_pt_x > source_image_width or interpolation_src_pt_x < 10:
+ continue
+
+ lane_color = self._color_map[index].tolist()
+ cv2.circle(source_image, (int(interpolation_src_pt_x),
+ int(interpolation_src_pt_y)), 5, lane_color, -1)
+ ret = {
+ 'mask_image': mask_image,
+ 'fit_params': fit_params,
+ 'source_image': source_image,
+ }
+ return ret
diff --git a/contrib/LaneNet/vis.py b/contrib/LaneNet/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..594258758316cb8945463962c71a52b42314faa6
--- /dev/null
+++ b/contrib/LaneNet/vis.py
@@ -0,0 +1,207 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+# GPU memory garbage collection optimization flags
+os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
+
+import sys
+
+cur_path = os.path.abspath(os.path.dirname(__file__))
+root_path = os.path.split(os.path.split(cur_path)[0])[0]
+SEG_PATH = os.path.join(cur_path, "../../../")
+sys.path.append(SEG_PATH)
+sys.path.append(root_path)
+
+import matplotlib
+matplotlib.use('Agg')
+import time
+import argparse
+import pprint
+import cv2
+import numpy as np
+import paddle.fluid as fluid
+
+from utils.config import cfg
+from reader import LaneNetDataset
+from models.model_builder import build_model
+from models.model_builder import ModelPhase
+from utils import lanenet_postprocess
+import matplotlib.pyplot as plt
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='PaddeSeg visualization tools')
+ parser.add_argument(
+ '--cfg',
+ dest='cfg_file',
+ help='Config file for training (and optionally testing)',
+ default=None,
+ type=str)
+ parser.add_argument(
+ '--use_gpu', dest='use_gpu', help='Use gpu or cpu', action='store_true')
+ parser.add_argument(
+ '--vis_dir',
+ dest='vis_dir',
+ help='visual save dir',
+ type=str,
+ default='visual')
+ parser.add_argument(
+ '--also_save_raw_results',
+ dest='also_save_raw_results',
+ help='whether to save raw result',
+ action='store_true')
+ parser.add_argument(
+ '--local_test',
+ dest='local_test',
+ help='if in local test mode, only visualize 5 images for testing',
+ action='store_true')
+ parser.add_argument(
+ 'opts',
+ help='See config.py for all options',
+ default=None,
+ nargs=argparse.REMAINDER)
+ if len(sys.argv) == 1:
+ parser.print_help()
+ sys.exit(1)
+ return parser.parse_args()
+
+
+def makedirs(directory):
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+
+def to_png_fn(fn, name=""):
+ """
+ Append png as filename postfix
+ """
+ directory, filename = os.path.split(fn)
+ basename, ext = os.path.splitext(filename)
+
+ return basename + name + ".png"
+
+
+def minmax_scale(input_arr):
+ min_val = np.min(input_arr)
+ max_val = np.max(input_arr)
+
+ output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)
+
+ return output_arr
+
+
+
+def visualize(cfg,
+ vis_file_list=None,
+ use_gpu=False,
+ vis_dir="visual",
+ also_save_raw_results=False,
+ ckpt_dir=None,
+ log_writer=None,
+ local_test=False,
+ **kwargs):
+ if vis_file_list is None:
+ vis_file_list = cfg.DATASET.TEST_FILE_LIST
+
+
+ dataset = LaneNetDataset(
+ file_list=vis_file_list,
+ mode=ModelPhase.VISUAL,
+ shuffle=True,
+ data_dir=cfg.DATASET.DATA_DIR)
+
+ startup_prog = fluid.Program()
+ test_prog = fluid.Program()
+ pred, logit = build_model(test_prog, startup_prog, phase=ModelPhase.VISUAL)
+ # Clone forward graph
+ test_prog = test_prog.clone(for_test=True)
+
+ # Get device environment
+ place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
+ exe = fluid.Executor(place)
+ exe.run(startup_prog)
+
+ ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir
+
+ fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
+
+ save_dir = os.path.join(vis_dir, 'visual_results')
+ makedirs(save_dir)
+ if also_save_raw_results:
+ raw_save_dir = os.path.join(vis_dir, 'raw_results')
+ makedirs(raw_save_dir)
+
+ fetch_list = [pred.name, logit.name]
+ test_reader = dataset.batch(dataset.generator, batch_size=1, is_test=True)
+
+ postprocessor = lanenet_postprocess.LaneNetPostProcessor()
+ for imgs, grts, grts_instance, img_names, valid_shapes, org_imgs in test_reader:
+ segLogits, emLogits = exe.run(
+ program=test_prog,
+ feed={'image': imgs},
+ fetch_list=fetch_list,
+ return_numpy=True)
+ num_imgs = segLogits.shape[0]
+
+ for i in range(num_imgs):
+ gt_image = org_imgs[i]
+ binary_seg_image, instance_seg_image = segLogits[i].squeeze(-1), emLogits[i].transpose((1,2,0))
+
+ postprocess_result = postprocessor.postprocess(
+ binary_seg_result=binary_seg_image,
+ instance_seg_result=instance_seg_image,
+ source_image=gt_image
+ )
+ pred_binary_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_binary'))
+ pred_lane_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_lane'))
+ pred_instance_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_instance'))
+ dirname = os.path.dirname(pred_binary_fn)
+
+ makedirs(dirname)
+ mask_image = postprocess_result['mask_image']
+ for i in range(4):
+ instance_seg_image[:, :, i] = minmax_scale(instance_seg_image[:, :, i])
+ embedding_image = np.array(instance_seg_image).astype(np.uint8)
+
+ plt.figure('mask_image')
+ plt.imshow(mask_image[:, :, (2, 1, 0)])
+ plt.figure('src_image')
+ plt.imshow(gt_image[:, :, (2, 1, 0)])
+ plt.figure('instance_image')
+ plt.imshow(embedding_image[:, :, (2, 1, 0)])
+ plt.figure('binary_image')
+ plt.imshow(binary_seg_image * 255, cmap='gray')
+ plt.show()
+
+ cv2.imwrite(pred_binary_fn, np.array(binary_seg_image * 255).astype(np.uint8))
+ cv2.imwrite(pred_lane_fn, postprocess_result['source_image'])
+ cv2.imwrite(pred_instance_fn, mask_image)
+ print(pred_lane_fn, 'saved!')
+
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ if args.cfg_file is not None:
+ cfg.update_from_file(args.cfg_file)
+ if args.opts:
+ cfg.update_from_list(args.opts)
+ cfg.check_and_infer()
+ print(pprint.pformat(cfg))
+ visualize(cfg, **args.__dict__)
diff --git a/pdseg/loss.py b/pdseg/loss.py
index 36ba43b27fca957a31f9ba68160f66792686c619..66f04f4ad412b115fef04b637ea5a544fa0c2da4 100644
--- a/pdseg/loss.py
+++ b/pdseg/loss.py
@@ -20,7 +20,7 @@ import importlib
from utils.config import cfg
-def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2):
+def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2, weight=None):
ignore_mask = fluid.layers.cast(ignore_mask, 'float32')
label = fluid.layers.elementwise_min(
label, fluid.layers.assign(np.array([num_classes - 1], dtype=np.int32)))
@@ -29,12 +29,40 @@ def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2):
label = fluid.layers.reshape(label, [-1, 1])
label = fluid.layers.cast(label, 'int64')
ignore_mask = fluid.layers.reshape(ignore_mask, [-1, 1])
-
- loss, probs = fluid.layers.softmax_with_cross_entropy(
- logit,
- label,
- ignore_index=cfg.DATASET.IGNORE_INDEX,
- return_softmax=True)
+ if weight is None:
+ loss, probs = fluid.layers.softmax_with_cross_entropy(
+ logit,
+ label,
+ ignore_index=cfg.DATASET.IGNORE_INDEX,
+ return_softmax=True)
+ else:
+ label_one_hot = fluid.layers.one_hot(input=label, depth=num_classes)
+ if isinstance(weight, list):
+ assert len(weight) == num_classes, "weight length must equal num of classes"
+ weight = fluid.layers.assign(np.array([weight], dtype='float32'))
+ elif isinstance(weight, str):
+ assert weight.lower() == 'dynamic', 'if weight is string, must be dynamic!'
+ tmp = []
+ total_num = fluid.layers.cast(fluid.layers.shape(label)[0], 'float32')
+ for i in range(num_classes):
+ cls_pixel_num = fluid.layers.reduce_sum(label_one_hot[:, i])
+ ratio = total_num / (cls_pixel_num + 1)
+ tmp.append(ratio)
+ weight = fluid.layers.concat(tmp)
+ weight = weight / fluid.layers.reduce_sum(weight) * num_classes
+ elif isinstance(weight, fluid.layers.Variable):
+ pass
+ else:
+ raise ValueError('Expect weight is a list, string or Variable, but receive {}'.format(type(weight)))
+ weight = fluid.layers.reshape(weight, [1, num_classes])
+ weighted_label_one_hot = fluid.layers.elementwise_mul(label_one_hot, weight)
+ probs = fluid.layers.softmax(logit)
+ loss = fluid.layers.cross_entropy(
+ probs,
+ weighted_label_one_hot,
+ soft_label=True,
+ ignore_index=cfg.DATASET.IGNORE_INDEX)
+ weighted_label_one_hot.stop_gradient = True
loss = loss * ignore_mask
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask)
@@ -80,7 +108,7 @@ def bce_loss(logit, label, ignore_mask=None):
return loss
-def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2):
+def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2, weight=None):
if isinstance(logits, tuple):
avg_loss = 0
for i, logit in enumerate(logits):
@@ -91,7 +119,7 @@ def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2):
num_classes)
avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
else:
- avg_loss = softmax_with_loss(logits, label, ignore_mask, num_classes)
+ avg_loss = softmax_with_loss(logits, label, ignore_mask, num_classes, weight=weight)
return avg_loss
def multi_dice_loss(logits, label, ignore_mask=None):
diff --git a/pdseg/models/__init__.py b/pdseg/models/__init__.py
index f2a9093490fc284154c8e09dc5c58e638c567d26..f1465913991c5aaffefff26c1f5a5d668edd1596 100644
--- a/pdseg/models/__init__.py
+++ b/pdseg/models/__init__.py
@@ -14,5 +14,3 @@
# limitations under the License.
import models.modeling
-import models.libs
-import models.backbone
diff --git a/pdseg/models/backbone/vgg.py b/pdseg/models/backbone/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9df0a66cd85b291aad8846eed30c9bb7b4e947
--- /dev/null
+++ b/pdseg/models/backbone/vgg.py
@@ -0,0 +1,81 @@
+# coding: utf8
+# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid import ParamAttr
+
+__all__ = ["VGGNet"]
+
+
+def check_points(count, points):
+ if points is None:
+ return False
+ else:
+ if isinstance(points, list):
+ return (True if count in points else False)
+ else:
+ return (True if count == points else False)
+
+
+class VGGNet():
+ def __init__(self, layers=16):
+ self.layers = layers
+
+ def net(self, input, class_dim=1000, end_points=None, decode_points=None):
+ short_cuts = dict()
+ layers_count = 0
+ layers = self.layers
+ vgg_spec = {
+ 11: ([1, 1, 2, 2, 2]),
+ 13: ([2, 2, 2, 2, 2]),
+ 16: ([2, 2, 3, 3, 3]),
+ 19: ([2, 2, 4, 4, 4])
+ }
+ assert layers in vgg_spec.keys(), \
+ "supported layers are {} but input layer is {}".format(vgg_spec.keys(), layers)
+
+ nums = vgg_spec[layers]
+ channels = [64, 128, 256, 512, 512]
+ conv = input
+ for i in range(len(nums)):
+ conv = self.conv_block(conv, channels[i], nums[i], name="conv" + str(i + 1) + "_")
+ layers_count += nums[i]
+ if check_points(layers_count, decode_points):
+ short_cuts[layers_count] = conv
+ if check_points(layers_count, end_points):
+ return conv, short_cuts
+
+ return conv
+
+ def conv_block(self, input, num_filter, groups, name=None):
+ conv = input
+ for i in range(groups):
+ conv = fluid.layers.conv2d(
+ input=conv,
+ num_filters=num_filter,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ act='relu',
+ param_attr=fluid.param_attr.ParamAttr(
+ name=name + str(i + 1) + "_weights"),
+ bias_attr=False)
+ return fluid.layers.pool2d(
+ input=conv, pool_size=2, pool_type='max', pool_stride=2)
diff --git a/pdseg/models/model_builder.py b/pdseg/models/model_builder.py
index 495652464f8cd14fef650bf5bdc77c14ebdbb4e7..65483b336b59440589f5c2fa27fd8ae456df176a 100644
--- a/pdseg/models/model_builder.py
+++ b/pdseg/models/model_builder.py
@@ -223,8 +223,9 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
avg_loss_list = []
valid_loss = []
if "softmax_loss" in loss_type:
+ weight = cfg.SOLVER.CROSS_ENTROPY_WEIGHT
avg_loss_list.append(
- multi_softmax_with_loss(logits, label, mask, class_num))
+ multi_softmax_with_loss(logits, label, mask, class_num, weight))
loss_valid = True
valid_loss.append("softmax_loss")
if "dice_loss" in loss_type:
diff --git a/pdseg/utils/config.py b/pdseg/utils/config.py
index 5d66c2f076ca964fcdf23d1cfd427e61acf68876..d321aa4f8475fcdef7645fbd051aa26deeed3221 100644
--- a/pdseg/utils/config.py
+++ b/pdseg/utils/config.py
@@ -158,7 +158,10 @@ cfg.SOLVER.LOSS = ["softmax_loss"]
cfg.SOLVER.LR_WARMUP = False
# warmup的迭代次数
cfg.SOLVER.LR_WARMUP_STEPS = 2000
-
+# cross entropy weight, 默认为None,如果设置为'dynamic',会根据每个batch中各个类别的数目,
+# 动态调整类别权重。
+# 也可以设置一个静态权重(list的方式),比如有3类,每个类别权重可以设置为[0.1, 2.0, 0.9]
+cfg.SOLVER.CROSS_ENTROPY_WEIGHT = None
########################## 测试配置 ###########################################
# 测试模型路径
cfg.TEST.TEST_MODEL = ''